torch.compile 底层在做什么?从 Triton 代码生成到手写 CUDA kernel
本文是 torch.compile 到底优化了什么?从 Benchmark 看算子融合的本质 的续篇,建议先阅读上篇了解算子融合的基本概念。
TLDR
torch.compile在 GPU 上最终生成的是 Triton 代码,再编译为 PTX 汇编,最终变成 GPU 机器码。- 它的自动优化局限于逐元素级别的融合,无法做算法级重构。
- 手写 CUDA kernel(如 FlashAttention)可以远超自动优化,但两者并不冲突,实际项目中通常结合使用。
编译链路全景
torch.compile 有一条完整的编译链路,最终生成接近硬件的底层代码:
Python 模型代码
↓ TorchDynamo(字节码分析,捕获计算图)
FX Graph(中间表示)
↓ AOTAutograd(处理自动求导)
Normalized Graph
↓ Inductor(后端代码生成)
├── GPU → Triton 代码(Python-like DSL)
│ ↓ Triton 编译器
│ PTX(NVIDIA 中间汇编)
│ ↓ NVIDIA 驱动
│ SASS(GPU 机器码)
│
└── CPU → C++ / OpenMP 代码
↓ GCC / Clang
x86 机器码
每一层做的事情:
| 阶段 | 工具 | 做了什么 |
|---|---|---|
| 图捕获 | TorchDynamo | 分析 Python 字节码,把模型的 forward 捕获为计算图 |
| 自动求导 | AOTAutograd | 提前展开 autograd,生成前向 + 反向的完整图 |
| 代码生成 | Inductor | 对图做算子融合,生成优化后的 Triton 或 C++ 代码 |
| 底层编译 | Triton / GCC | 把生成的代码编译为 GPU/CPU 的机器码 |
生成的代码长什么样?
用 TORCH_LOGS="output_code" 可以看到 Inductor 生成的 Triton 代码:
TORCH_LOGS="output_code" python benchmark.py
输出类似:
@triton.jit
def fused_kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
tmp0 = tl.load(in_ptr0 + xindex) # 一次显存读取
tmp1 = tl.sin(tmp0)
tmp2 = tl.cos(tmp0)
tmp3 = tmp2 * tl.sigmoid(tmp0)
tmp4 = tmp1 + tmp3 # 全在寄存器中计算
tl.store(out_ptr0 + xindex, tmp4) # 一次显存写入
关键观察:所有的 sin、cos、sigmoid、乘法、加法都在同一个函数体中,数据通过局部变量(GPU 寄存器)传递,没有中间的显存读写。这就是上篇文章中 profiler 看到 35 个 kernel 变成 1 个的底层原因。
常用的 TORCH_LOGS 选项
- 看 graph break:
TORCH_LOGS="graph_breaks" python script.py - 看生成代码(融合结果):
TORCH_LOGS="output_code" python script.py - 看完整编译日志:
TORCH_LOGS="dynamo,inductor" python script.py
自动优化的局限
torch.compile 的 Inductor 后端做的是通用的编译优化——它不理解你在算什么,只看到一堆算子,然后尝试把相邻的 pointwise 算子合并。
这意味着它做不到算法级别的优化。
例子:Attention 的优化
标准 Attention 的计算流程:
Q·K^T → Scale → Mask → Softmax → ·V
torch.compile 能做的(逐元素融合):
Q·K^T → 写回显存 → [Scale + Mask 融合] → 写回显存 → Softmax → 写回显存 → ·V
Scale 和 Mask 可以融合,但 Softmax 需要全局归约(求 max 和 sum),无法与前后的操作融合。完整的 N×N attention matrix 必须存在显存中。
FlashAttention 的做法(算法重构):
利用 tiling 分块 + online softmax,整个 Attention 在一个 kernel 内完成
不需要把完整的 N×N attention matrix 写入显存
显存占用从 O(N²) 降到 O(N)
这种优化需要人理解”softmax 可以用 online 算法分块计算”——这是数学洞察,不是编译器能自动推导的代码变换。
自动 vs 手写优化
两者并不冲突
实际项目中通常两者结合:关键算子(如 Attention)用手写的 FlashAttention,其余部分(LayerNorm、残差、激活函数等)交给 torch.compile 自动融合。这也是目前主流 LLM 推理框架的做法。
| 自动优化(torch.compile) | 手写优化(CUDA/Triton kernel) | |
|---|---|---|
| 优化层级 | 逐元素操作融合 | 算法级重构 |
| 难度 | 一行代码 torch.compile(model) | 需要深入理解 GPU 架构 + 算法 |
| 典型收益 | 1.5x-4x | 5x-20x(如 FlashAttention) |
| 适用性 | 通用,任何模型都能用 | 针对特定计算模式手工打造 |
| 代表 | Inductor 生成的 Triton kernel | FlashAttention、cuDNN Conv、CUTLASS matmul |