torch.where 与 torch.cond 的区别与异同

本文基于 PyTorch2.x 撰写,示例代码可直接复制运行。


短路求值

短路求值(short-circuit evaluation)指在布尔运算中,一旦结果已提前确定,就略过剩下步骤不再计算的求值策略。

TLDR

  • torch.where(condition, x, y)
    立即返回一个与输入同形的张量,元素由 conditionTrue/False 时对应取 x/y 的值。
    它只是一个张量级操作引入控制流,支持短路求值。

  • torch.cond(pred, true_fn, false_fn, operands)
    惰性地根据 pred 的布尔值,只执行 true_fnfalse_fn 中的一个分支,并返回其输出。
    它属于控制流原语,支持短路求值,常用于 torch.compiletorch.fx 图捕获场景。


相同点

维度torch.wheretorch.cond
语义按条件二选一按条件二选一
可微支持自动求导支持自动求导
设备CPU / CUDA / Meta …CPU / CUDA / Meta …

不同点

维度torch.wheretorch.cond
执行时机立即执行两个分支,再按条件 mask 合并仅执行被选中的分支
短路求值不支持支持
分支函数无,直接给张量需要封装成 Callable
图捕获直接支持torch._dynamo / torch.compile
性能两个分支都计算,可能浪费只计算一个分支,节省算力
典型用途逐元素选择、mask 填充动态形状、早停、复杂控制流

代码示例

import torch

def demo_where():
    print("=== torch.where: elementwise selection (no short-circuit) ===")
    x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
    branch_counts = {"true": 0, "false": 0}

    def true_branch(val):
        branch_counts["true"] += 1
        return val * val

    def false_branch(val):
        branch_counts["false"] += 1
        return -val

    mask = x > 0
    out = torch.where(mask, true_branch(x), false_branch(x))
    print("input :", x)
    print("mask  :", mask)
    print("output:", out)
    print("branch_counts:", branch_counts)


def demo_cond():
    print("=== torch.cond: control flow (short-circuit) ===")
    y = torch.tensor([-1.0, 2.0])

    def true_branch(val):
        return val * 10, torch.tensor(1, device=val.device)

    def false_branch(val):
        return val - 10, torch.tensor(0, device=val.device)

    pred = y.sum() > 0
    out, flag = torch.cond(pred, true_branch, false_branch, [y])
    print("input :", y)
    print("pred  :", pred)
    print("output:", out)
    print("branch:", "TRUE" if flag.item() == 1 else "FALSE")


if __name__ == "__main__":
    demo_where()
    demo_cond()

预期输出:

=== torch.where: elementwise selection (no short-circuit) ===
input : tensor([-2., -1.,  0.,  1.,  2.])
mask  : tensor([False, False, False,  True,  True])
output: tensor([ 4.,  1.,  0.,  1.,  4.])
branch_counts: {'true': 1, 'false': 1}

=== torch.cond: control flow (short-circuit) ===
input : tensor([-1.,  2.])
pred  : tensor(True)
output: tensor([-10.,  20.])
branch: TRUE