torch.compile 到底优化了什么?从 Benchmark 看算子融合的本质

本文基于 PyTorch 2.x + CUDA (T4 GPU) 撰写,示例代码可直接复制运行。


TLDR

  • torch.compile 的核心优化是算子融合(Operator Fusion)——把多个小 CUDA kernel 合并成一个,减少显存的反复读写。
  • 它对逐元素操作密集(memory-bandwidth bound)的模型效果显著(2x-10x),对矩阵乘法密集(compute-bound)的模型几乎无效(~1.0x)。
  • 如果你的 benchmark 看不到加速,大概率是因为模型太小或计算被 matmul 主导。

背景:Eager 模式下发生了什么?

在 PyTorch 默认的 eager 模式下,每一个操作(如 torch.sintorch.cos+*)都会独立启动一个 CUDA kernel:

kernel 1: 从显存读 x → 计算 sin(x) → 写回显存
kernel 2: 从显存读 x → 计算 cos(x) → 写回显存
kernel 3: 从显存读 x → 计算 tanh(x) → 写回显存
kernel 4: 从显存读 cos(x), tanh(x) → 计算乘法 → 写回显存
kernel 5: 从显存读 sin(x), 乘法结果 → 计算加法 → 写回显存
...

每个 kernel 都要读写一次显存,而 GPU 显存带宽是有限的。当你有 20+ 个这样的小操作时,大部分时间浪费在了搬运数据上,而不是计算。


torch.compile 的优化:算子融合

算子融合的本质

算子融合优化的不是计算速度,而是显存带宽。只有当模型的瓶颈在于显存读写(memory-bound)时,融合才有效。如果瓶颈在于计算本身(compute-bound,如大矩阵乘法),融合帮不上忙。

torch.compile 会把这些小 kernel 融合成一个大 kernel

fused kernel: 从显存读 x → sin + cos + tanh + 乘 + 加 + sigmoid + ... → 写回显存

数据在 GPU 寄存器中流转,只需要一次读、一次写,这就是融合的本质。


Benchmark 对比

我们设计两个模型来验证这一点,第一个模型的瓶颈在于显存带宽,第二个模型的瓶颈在于计算。

Model 1: Pointwise-heavy(显存带宽瓶颈)

大量逐元素操作,几乎没有矩阵乘法。每一行在 eager 模式下都是一个独立的 CUDA kernel。

class PointwiseHeavyModel(nn.Module):
    """大量逐元素操作,显存带宽瓶颈"""
    def __init__(self, dim=4096):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        # 每一行 = eager 模式下的一个独立 CUDA kernel
        # compile 后全部融合为 ~1 个 kernel
        x = x * self.scale + self.bias
        x = torch.sin(x) + torch.cos(x) * torch.tanh(x)
        x = x * torch.sigmoid(x) + torch.relu(x)
        x = x / (1.0 + torch.abs(x))
        x = torch.sin(x) + torch.cos(x) * torch.tanh(x)
        x = x * torch.sigmoid(x) + torch.relu(x)
        x = x / (1.0 + torch.abs(x))
        x = torch.sin(x) + torch.cos(x) * torch.tanh(x)
        x = x * torch.sigmoid(x) + torch.relu(x)
        return x

Model 2: Matmul-heavy(计算瓶颈)

以大矩阵乘法为主,逐元素操作很少。cuBLAS 对 matmul 本身已经是最优实现。

class MatmulHeavyModel(nn.Module):
    """大矩阵乘法为主,计算瓶颈"""
    def __init__(self, dim=4096, num_layers=4):
        super().__init__()
        self.linears = nn.ModuleList(
            [nn.Linear(dim, dim) for _ in range(num_layers)]
        )

    def forward(self, x):
        for linear in self.linears:
            x = linear(x)
            x = torch.relu(x)  # 仅 1 个逐元素操作
        return x

完整 Benchmark 代码

import torch
import torch.nn as nn
import time

def benchmark(fn, x, warmup=50, steps=200, label=""):
    for _ in range(warmup):
        fn(x)
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    start = time.time()
    for _ in range(steps):
        fn(x)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    end = time.time()

    avg_time = (end - start) / steps * 1000
    print(f"  {label}: {avg_time:.4f} ms")
    return avg_time

def run_comparison(model, x, name):
    print(f"\n{'='*50}")
    print(f" {name}")
    print(f"{'='*50}")
    model.eval()

    def infer_eager(x):
        with torch.no_grad():
            return model(x)

    eager_time = benchmark(infer_eager, x, label="Eager")

    opt_model = torch.compile(model)
    def infer_compiled(x):
        with torch.no_grad():
            return opt_model(x)

    compiled_time = benchmark(infer_compiled, x, label="Compiled")
    print(f"  Speedup: {eager_time / compiled_time:.2f}x")

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on: {device}")
    x = torch.randn(1024, 4096, device=device)

    model1 = PointwiseHeavyModel(dim=4096).to(device)
    run_comparison(model1, x,
        "Pointwise-heavy (memory-bound) → fusion helps a lot")

    model2 = MatmulHeavyModel(dim=4096, num_layers=4).to(device)
    run_comparison(model2, x,
        "Matmul-heavy (compute-bound) → fusion helps little")

预期结果(T4 GPU)

模型EagerCompiledSpeedup
Pointwise-heavy~5.89 ms~0.54 ms10x
Matmul-heavy~45.57 ms~45.83 ms~1.0x

观察融合效果的其他方法

除了 benchmark 计时之外,还可以更直接地”看到”融合。

用 torch.profiler 数 kernel 数量

最直观的证据——对比 eager 和 compiled 模式下实际启动了多少个 CUDA kernel:

import torch
from torch.profiler import profile, ProfilerActivity

model = PointwiseHeavyModel(dim=4096).cuda().eval()
x = torch.randn(1024, 4096, device="cuda")

# --- Eager: 观察 kernel 列表 ---
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    with torch.no_grad():
        model(x)
    torch.cuda.synchronize()

print("=== Eager kernels ===")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))

# --- Compiled: 观察 kernel 列表 ---
opt_model = torch.compile(model)
for _ in range(3):  # warmup,确保编译完成
    with torch.no_grad():
        opt_model(x)
torch.cuda.synchronize()

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    with torch.no_grad():
        opt_model(x)
    torch.cuda.synchronize()

print("=== Compiled kernels ===")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))

实际输出(T4 GPU)

Eager kernels:

---------------------------  ----------  ----------  ----------  ----------  ----------  ----------
                       Name  Self CPU %    Self CPU  CPU total%   CPU total  CPU avg     # of Calls
---------------------------  ----------  ----------  ----------  ----------  ----------  ----------
                  aten::mul       1.83%   133.976us      25.20%    1.849ms   264.123us           7
           cudaLaunchKernel      15.65%     1.148ms      15.65%    1.148ms    32.808us          35
                  aten::add       2.06%   151.205us       3.33%  244.659us    27.184us           9
                  aten::sin       0.52%    37.887us       0.90%   65.946us    21.982us           3
                  aten::cos       0.48%    35.135us       0.79%   57.911us    19.304us           3
                 aten::tanh       0.46%    33.679us       0.99%   72.282us    24.094us           3
              aten::sigmoid       0.44%    32.593us       0.78%   57.208us    19.069us           3
                 aten::relu       1.70%   124.498us       3.05%  223.718us    74.573us           3
                  aten::abs       2.18%   160.038us       3.85%  282.461us    70.615us           4
                  aten::div       0.45%    33.018us       0.78%   57.118us    28.559us           2
---------------------------  ----------  ----------  ----------  ----------  ----------  ----------
Self CPU time total: 7.336ms

Compiled kernels:

------------------------------------------------------  ----------  ----------  ----------  ----------  ----------  ----------
                                                  Name  Self CPU %    Self CPU  CPU total%   CPU total  CPU avg     # of Calls
------------------------------------------------------  ----------  ----------  ----------  ----------  ----------  ----------
triton_poi_fused_abs_add_cos_div_mul_relu_sigmoid_si…       0.94%    32.374us      14.31%  490.925us   490.925us           1
                                        cuLaunchKernel      13.36%   458.551us      13.36%  458.551us   458.551us           1
------------------------------------------------------  ----------  ----------  ----------  ----------  ----------  ----------
Self CPU time total: 3.431ms

如何解读

重点关注以下几列:

  • # of Calls:Eager 下 cudaLaunchKernel 被调用了 35 次(每个 sin/cos/tanh/mul/add… 各自启动一个 kernel),Compiled 下 cuLaunchKernel 只有 1 次
  • Kernel 名称:Compiled 表格中 triton_poi_fused_abs_add_cos_div_mul_relu_sigmoid_si… 直接列出了所有被融合的算子名,一目了然。
  • Self CPU time total:从 7.336ms 降到 3.431ms,CPU 端调度 kernel 的开销减半——因为只需要调度 1 次而不是 35 次。

torch.compile 的编译模式

算子融合其实只是 torch.compile 的一部分。通过 mode 参数可以选择不同的优化策略:

torch.compile(model, mode="default/reduce-overhead/...")
mode做了什么适用场景
"default"算子融合,平衡编译时间与性能通用场景
"reduce-overhead"融合 + CUDA Graphs,减少 CPU 调度开销输入 shape 固定的推理
"max-autotune"融合 + 花更多时间 autotuning 找最优 kernel 配置离线优化,追求极致性能
"max-autotune-no-cudagraphs"max-autotune 但不用 CUDA Graphs输入 shape 动态变化

CUDA Graphs 是什么?

算子融合减少的是GPU 端的开销(减少显存读写)。而 CUDA Graphs 减少的是 CPU 端的开销:

没有 CUDA Graphs:CPU 逐个调度 → cudaLaunchKernel × N 次
有 CUDA Graphs: CPU 一次回放 → cudaGraphLaunch × 1 次

原理是第一次推理时”录制”整个 kernel 执行序列,之后每次推理只需一次 cudaGraphLaunch 回放,CPU 调度开销几乎归零。

LLM 能用 CUDA Graphs 吗?

CUDA Graphs 要求输入 shape 固定——记录的 tensor 形状和回放时必须一致,而 LLM 的序列长度是动态变化的,看起来不能直接用, 但是目前工业界往往采用 padding 绕过这个限制

  • Padding 到固定长度:把所有输入 pad 到同一个 seq_len,浪费一些计算但能用 CUDA Graphs
  • Bucketing:预设几个固定长度档位(如 128/256/512/1024),每个档位各录一份 graph,按输入长度选最近的档位

vLLM、TensorRT-LLM 等推理框架内部都使用了这种策略。


总结

什么时候该用 torch.compile?
  • 模型中有大量连续的逐元素操作(激活函数、归一化、残差连接等)
  • GPU 上运行(CPU 上收益有限)
  • 模型足够大,使得编译调度开销可以忽略不计

不适合的场景:模型以大矩阵乘法为主、模型非常小(< 1ms)、仅在 CPU 上运行。

Pointwise-heavyMatmul-heavy
瓶颈显存带宽(memory-bound)计算量(compute-bound)
Eager 的问题多个 kernel 反复读写显存cuBLAS matmul 已是最优
Compile 优化融合为 ~1 个 kernel,数据留在寄存器几乎无法优化
加速效果显著提升几乎没有(~1.0x)

延伸阅读

想了解 torch.compile 底层的编译链路(TorchDynamo → Inductor → Triton → PTX),以及为什么 FlashAttention 等手写优化能远超自动编译?请看下一篇:

torch.compile 底层在做什么?从 Triton 代码生成到手写 CUDA kernel →