torch.compile 到底优化了什么?从 Benchmark 看算子融合的本质
本文基于 PyTorch 2.x + CUDA (T4 GPU) 撰写,示例代码可直接复制运行。
TLDR
torch.compile的核心优化是算子融合(Operator Fusion)——把多个小 CUDA kernel 合并成一个,减少显存的反复读写。- 它对逐元素操作密集(memory-bandwidth bound)的模型效果显著(2x-10x),对矩阵乘法密集(compute-bound)的模型几乎无效(~1.0x)。
- 如果你的 benchmark 看不到加速,大概率是因为模型太小或计算被 matmul 主导。
背景:Eager 模式下发生了什么?
在 PyTorch 默认的 eager 模式下,每一个操作(如 torch.sin、torch.cos、+、*)都会独立启动一个 CUDA kernel:
kernel 1: 从显存读 x → 计算 sin(x) → 写回显存
kernel 2: 从显存读 x → 计算 cos(x) → 写回显存
kernel 3: 从显存读 x → 计算 tanh(x) → 写回显存
kernel 4: 从显存读 cos(x), tanh(x) → 计算乘法 → 写回显存
kernel 5: 从显存读 sin(x), 乘法结果 → 计算加法 → 写回显存
...
每个 kernel 都要读写一次显存,而 GPU 显存带宽是有限的。当你有 20+ 个这样的小操作时,大部分时间浪费在了搬运数据上,而不是计算。
torch.compile 的优化:算子融合
算子融合优化的不是计算速度,而是显存带宽。只有当模型的瓶颈在于显存读写(memory-bound)时,融合才有效。如果瓶颈在于计算本身(compute-bound,如大矩阵乘法),融合帮不上忙。
torch.compile 会把这些小 kernel 融合成一个大 kernel:
fused kernel: 从显存读 x → sin + cos + tanh + 乘 + 加 + sigmoid + ... → 写回显存
数据在 GPU 寄存器中流转,只需要一次读、一次写,这就是融合的本质。
Benchmark 对比
我们设计两个模型来验证这一点,第一个模型的瓶颈在于显存带宽,第二个模型的瓶颈在于计算。
Model 1: Pointwise-heavy(显存带宽瓶颈)
大量逐元素操作,几乎没有矩阵乘法。每一行在 eager 模式下都是一个独立的 CUDA kernel。
class PointwiseHeavyModel(nn.Module):
"""大量逐元素操作,显存带宽瓶颈"""
def __init__(self, dim=4096):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, x):
# 每一行 = eager 模式下的一个独立 CUDA kernel
# compile 后全部融合为 ~1 个 kernel
x = x * self.scale + self.bias
x = torch.sin(x) + torch.cos(x) * torch.tanh(x)
x = x * torch.sigmoid(x) + torch.relu(x)
x = x / (1.0 + torch.abs(x))
x = torch.sin(x) + torch.cos(x) * torch.tanh(x)
x = x * torch.sigmoid(x) + torch.relu(x)
x = x / (1.0 + torch.abs(x))
x = torch.sin(x) + torch.cos(x) * torch.tanh(x)
x = x * torch.sigmoid(x) + torch.relu(x)
return x
Model 2: Matmul-heavy(计算瓶颈)
以大矩阵乘法为主,逐元素操作很少。cuBLAS 对 matmul 本身已经是最优实现。
class MatmulHeavyModel(nn.Module):
"""大矩阵乘法为主,计算瓶颈"""
def __init__(self, dim=4096, num_layers=4):
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(dim, dim) for _ in range(num_layers)]
)
def forward(self, x):
for linear in self.linears:
x = linear(x)
x = torch.relu(x) # 仅 1 个逐元素操作
return x
完整 Benchmark 代码
import torch
import torch.nn as nn
import time
def benchmark(fn, x, warmup=50, steps=200, label=""):
for _ in range(warmup):
fn(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.time()
for _ in range(steps):
fn(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.time()
avg_time = (end - start) / steps * 1000
print(f" {label}: {avg_time:.4f} ms")
return avg_time
def run_comparison(model, x, name):
print(f"\n{'='*50}")
print(f" {name}")
print(f"{'='*50}")
model.eval()
def infer_eager(x):
with torch.no_grad():
return model(x)
eager_time = benchmark(infer_eager, x, label="Eager")
opt_model = torch.compile(model)
def infer_compiled(x):
with torch.no_grad():
return opt_model(x)
compiled_time = benchmark(infer_compiled, x, label="Compiled")
print(f" Speedup: {eager_time / compiled_time:.2f}x")
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")
x = torch.randn(1024, 4096, device=device)
model1 = PointwiseHeavyModel(dim=4096).to(device)
run_comparison(model1, x,
"Pointwise-heavy (memory-bound) → fusion helps a lot")
model2 = MatmulHeavyModel(dim=4096, num_layers=4).to(device)
run_comparison(model2, x,
"Matmul-heavy (compute-bound) → fusion helps little")
预期结果(T4 GPU)
| 模型 | Eager | Compiled | Speedup |
|---|---|---|---|
| Pointwise-heavy | ~5.89 ms | ~0.54 ms | 10x |
| Matmul-heavy | ~45.57 ms | ~45.83 ms | ~1.0x |
观察融合效果的其他方法
除了 benchmark 计时之外,还可以更直接地”看到”融合。
用 torch.profiler 数 kernel 数量
最直观的证据——对比 eager 和 compiled 模式下实际启动了多少个 CUDA kernel:
import torch
from torch.profiler import profile, ProfilerActivity
model = PointwiseHeavyModel(dim=4096).cuda().eval()
x = torch.randn(1024, 4096, device="cuda")
# --- Eager: 观察 kernel 列表 ---
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
with torch.no_grad():
model(x)
torch.cuda.synchronize()
print("=== Eager kernels ===")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
# --- Compiled: 观察 kernel 列表 ---
opt_model = torch.compile(model)
for _ in range(3): # warmup,确保编译完成
with torch.no_grad():
opt_model(x)
torch.cuda.synchronize()
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
with torch.no_grad():
opt_model(x)
torch.cuda.synchronize()
print("=== Compiled kernels ===")
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
实际输出(T4 GPU)
Eager kernels:
--------------------------- ---------- ---------- ---------- ---------- ---------- ----------
Name Self CPU % Self CPU CPU total% CPU total CPU avg # of Calls
--------------------------- ---------- ---------- ---------- ---------- ---------- ----------
aten::mul 1.83% 133.976us 25.20% 1.849ms 264.123us 7
cudaLaunchKernel 15.65% 1.148ms 15.65% 1.148ms 32.808us 35
aten::add 2.06% 151.205us 3.33% 244.659us 27.184us 9
aten::sin 0.52% 37.887us 0.90% 65.946us 21.982us 3
aten::cos 0.48% 35.135us 0.79% 57.911us 19.304us 3
aten::tanh 0.46% 33.679us 0.99% 72.282us 24.094us 3
aten::sigmoid 0.44% 32.593us 0.78% 57.208us 19.069us 3
aten::relu 1.70% 124.498us 3.05% 223.718us 74.573us 3
aten::abs 2.18% 160.038us 3.85% 282.461us 70.615us 4
aten::div 0.45% 33.018us 0.78% 57.118us 28.559us 2
--------------------------- ---------- ---------- ---------- ---------- ---------- ----------
Self CPU time total: 7.336ms
Compiled kernels:
------------------------------------------------------ ---------- ---------- ---------- ---------- ---------- ----------
Name Self CPU % Self CPU CPU total% CPU total CPU avg # of Calls
------------------------------------------------------ ---------- ---------- ---------- ---------- ---------- ----------
triton_poi_fused_abs_add_cos_div_mul_relu_sigmoid_si… 0.94% 32.374us 14.31% 490.925us 490.925us 1
cuLaunchKernel 13.36% 458.551us 13.36% 458.551us 458.551us 1
------------------------------------------------------ ---------- ---------- ---------- ---------- ---------- ----------
Self CPU time total: 3.431ms
如何解读
重点关注以下几列:
# of Calls:Eager 下cudaLaunchKernel被调用了 35 次(每个 sin/cos/tanh/mul/add… 各自启动一个 kernel),Compiled 下cuLaunchKernel只有 1 次。- Kernel 名称:Compiled 表格中
triton_poi_fused_abs_add_cos_div_mul_relu_sigmoid_si…直接列出了所有被融合的算子名,一目了然。 Self CPU time total:从 7.336ms 降到 3.431ms,CPU 端调度 kernel 的开销减半——因为只需要调度 1 次而不是 35 次。
torch.compile 的编译模式
算子融合其实只是 torch.compile 的一部分。通过 mode 参数可以选择不同的优化策略:
torch.compile(model, mode="default/reduce-overhead/...")
| mode | 做了什么 | 适用场景 |
|---|---|---|
"default" | 算子融合,平衡编译时间与性能 | 通用场景 |
"reduce-overhead" | 融合 + CUDA Graphs,减少 CPU 调度开销 | 输入 shape 固定的推理 |
"max-autotune" | 融合 + 花更多时间 autotuning 找最优 kernel 配置 | 离线优化,追求极致性能 |
"max-autotune-no-cudagraphs" | max-autotune 但不用 CUDA Graphs | 输入 shape 动态变化 |
CUDA Graphs 是什么?
算子融合减少的是GPU 端的开销(减少显存读写)。而 CUDA Graphs 减少的是 CPU 端的开销:
没有 CUDA Graphs:CPU 逐个调度 → cudaLaunchKernel × N 次
有 CUDA Graphs: CPU 一次回放 → cudaGraphLaunch × 1 次
原理是第一次推理时”录制”整个 kernel 执行序列,之后每次推理只需一次 cudaGraphLaunch 回放,CPU 调度开销几乎归零。
LLM 能用 CUDA Graphs 吗?
CUDA Graphs 要求输入 shape 固定——记录的 tensor 形状和回放时必须一致,而 LLM 的序列长度是动态变化的,看起来不能直接用, 但是目前工业界往往采用 padding 绕过这个限制:
- Padding 到固定长度:把所有输入 pad 到同一个 seq_len,浪费一些计算但能用 CUDA Graphs
- Bucketing:预设几个固定长度档位(如 128/256/512/1024),每个档位各录一份 graph,按输入长度选最近的档位
vLLM、TensorRT-LLM 等推理框架内部都使用了这种策略。
总结
- 模型中有大量连续的逐元素操作(激活函数、归一化、残差连接等)
- 在 GPU 上运行(CPU 上收益有限)
- 模型足够大,使得编译调度开销可以忽略不计
不适合的场景:模型以大矩阵乘法为主、模型非常小(< 1ms)、仅在 CPU 上运行。
| Pointwise-heavy | Matmul-heavy | |
|---|---|---|
| 瓶颈 | 显存带宽(memory-bound) | 计算量(compute-bound) |
| Eager 的问题 | 多个 kernel 反复读写显存 | cuBLAS matmul 已是最优 |
| Compile 优化 | 融合为 ~1 个 kernel,数据留在寄存器 | 几乎无法优化 |
| 加速效果 | 显著提升 | 几乎没有(~1.0x) |
延伸阅读
想了解 torch.compile 底层的编译链路(TorchDynamo → Inductor → Triton → PTX),以及为什么 FlashAttention 等手写优化能远超自动编译?请看下一篇: