torch.where 与 torch.cond 的区别与异同
本文基于 PyTorch2.x 撰写,示例代码可直接复制运行。
短路求值
短路求值(short-circuit evaluation)指在布尔运算中,一旦结果已提前确定,就略过剩下步骤不再计算的求值策略。
TLDR
-
torch.where(condition, x, y):
立即返回一个与输入同形的张量,元素由condition为True/False时对应取x/y的值。
它只是一个张量级操作,不引入控制流,不支持短路求值。 -
torch.cond(pred, true_fn, false_fn, operands):
惰性地根据pred的布尔值,只执行true_fn或false_fn中的一个分支,并返回其输出。
它属于控制流原语,支持短路求值,常用于torch.compile或torch.fx图捕获场景。
相同点
| 维度 | torch.where | torch.cond |
|---|---|---|
| 语义 | 按条件二选一 | 按条件二选一 |
| 可微 | 支持自动求导 | 支持自动求导 |
| 设备 | CPU / CUDA / Meta … | CPU / CUDA / Meta … |
不同点
| 维度 | torch.where | torch.cond |
|---|---|---|
| 执行时机 | 立即执行两个分支,再按条件 mask 合并 | 仅执行被选中的分支 |
| 短路求值 | 不支持 | 支持 |
| 分支函数 | 无,直接给张量 | 需要封装成 Callable |
| 图捕获 | 直接支持 | 需 torch._dynamo / torch.compile |
| 性能 | 两个分支都计算,可能浪费 | 只计算一个分支,节省算力 |
| 典型用途 | 逐元素选择、mask 填充 | 动态形状、早停、复杂控制流 |
代码示例
import torch
def demo_where():
print("=== torch.where: elementwise selection (no short-circuit) ===")
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
branch_counts = {"true": 0, "false": 0}
def true_branch(val):
branch_counts["true"] += 1
return val * val
def false_branch(val):
branch_counts["false"] += 1
return -val
mask = x > 0
out = torch.where(mask, true_branch(x), false_branch(x))
print("input :", x)
print("mask :", mask)
print("output:", out)
print("branch_counts:", branch_counts)
def demo_cond():
print("=== torch.cond: control flow (short-circuit) ===")
y = torch.tensor([-1.0, 2.0])
def true_branch(val):
return val * 10, torch.tensor(1, device=val.device)
def false_branch(val):
return val - 10, torch.tensor(0, device=val.device)
pred = y.sum() > 0
out, flag = torch.cond(pred, true_branch, false_branch, [y])
print("input :", y)
print("pred :", pred)
print("output:", out)
print("branch:", "TRUE" if flag.item() == 1 else "FALSE")
if __name__ == "__main__":
demo_where()
demo_cond()
预期输出:
=== torch.where: elementwise selection (no short-circuit) ===
input : tensor([-2., -1., 0., 1., 2.])
mask : tensor([False, False, False, True, True])
output: tensor([ 4., 1., 0., 1., 4.])
branch_counts: {'true': 1, 'false': 1}
=== torch.cond: control flow (short-circuit) ===
input : tensor([-1., 2.])
pred : tensor(True)
output: tensor([-10., 20.])
branch: TRUE