torch.compile 底层在做什么?从 Triton 代码生成到手写 CUDA kernel

本文是 torch.compile 到底优化了什么?从 Benchmark 看算子融合的本质 的续篇,建议先阅读上篇了解算子融合的基本概念。


TLDR

  • torch.compile 在 GPU 上最终生成的是 Triton 代码,再编译为 PTX 汇编,最终变成 GPU 机器码。
  • 它的自动优化局限于逐元素级别的融合,无法做算法级重构。
  • 手写 CUDA kernel(如 FlashAttention)可以远超自动优化,但两者并不冲突,实际项目中通常结合使用。

编译链路全景

torch.compile 有一条完整的编译链路,最终生成接近硬件的底层代码:

Python 模型代码
    ↓  TorchDynamo(字节码分析,捕获计算图)
FX Graph(中间表示)
    ↓  AOTAutograd(处理自动求导)
Normalized Graph
    ↓  Inductor(后端代码生成)
    ├── GPU → Triton 代码(Python-like DSL)
    │           ↓  Triton 编译器
    │         PTX(NVIDIA 中间汇编)
    │           ↓  NVIDIA 驱动
    │         SASS(GPU 机器码)

    └── CPU → C++ / OpenMP 代码
                ↓  GCC / Clang
              x86 机器码

每一层做的事情:

阶段工具做了什么
图捕获TorchDynamo分析 Python 字节码,把模型的 forward 捕获为计算图
自动求导AOTAutograd提前展开 autograd,生成前向 + 反向的完整图
代码生成Inductor对图做算子融合,生成优化后的 Triton 或 C++ 代码
底层编译Triton / GCC把生成的代码编译为 GPU/CPU 的机器码

生成的代码长什么样?

TORCH_LOGS="output_code" 可以看到 Inductor 生成的 Triton 代码:

TORCH_LOGS="output_code" python benchmark.py

输出类似:

@triton.jit
def fused_kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
    tmp0 = tl.load(in_ptr0 + xindex)       # 一次显存读取
    tmp1 = tl.sin(tmp0)
    tmp2 = tl.cos(tmp0)
    tmp3 = tmp2 * tl.sigmoid(tmp0)
    tmp4 = tmp1 + tmp3                      # 全在寄存器中计算
    tl.store(out_ptr0 + xindex, tmp4)       # 一次显存写入

关键观察:所有的 sin、cos、sigmoid、乘法、加法都在同一个函数体中,数据通过局部变量(GPU 寄存器)传递,没有中间的显存读写。这就是上篇文章中 profiler 看到 35 个 kernel 变成 1 个的底层原因。

常用的 TORCH_LOGS 选项
  • 看 graph break:TORCH_LOGS="graph_breaks" python script.py
  • 看生成代码(融合结果):TORCH_LOGS="output_code" python script.py
  • 看完整编译日志:TORCH_LOGS="dynamo,inductor" python script.py

自动优化的局限

torch.compile 的 Inductor 后端做的是通用的编译优化——它不理解你在算什么,只看到一堆算子,然后尝试把相邻的 pointwise 算子合并。

这意味着它做不到算法级别的优化

例子:Attention 的优化

标准 Attention 的计算流程:

Q·K^T → Scale → Mask → Softmax → ·V

torch.compile 能做的(逐元素融合):

Q·K^T → 写回显存 → [Scale + Mask 融合] → 写回显存 → Softmax → 写回显存 → ·V

Scale 和 Mask 可以融合,但 Softmax 需要全局归约(求 max 和 sum),无法与前后的操作融合。完整的 N×N attention matrix 必须存在显存中。

FlashAttention 的做法(算法重构):

利用 tiling 分块 + online softmax,整个 Attention 在一个 kernel 内完成
不需要把完整的 N×N attention matrix 写入显存
显存占用从 O(N²) 降到 O(N)

这种优化需要人理解”softmax 可以用 online 算法分块计算”——这是数学洞察,不是编译器能自动推导的代码变换。


自动 vs 手写优化

两者并不冲突

实际项目中通常两者结合:关键算子(如 Attention)用手写的 FlashAttention,其余部分(LayerNorm、残差、激活函数等)交给 torch.compile 自动融合。这也是目前主流 LLM 推理框架的做法。

自动优化(torch.compile)手写优化(CUDA/Triton kernel)
优化层级逐元素操作融合算法级重构
难度一行代码 torch.compile(model)需要深入理解 GPU 架构 + 算法
典型收益1.5x-4x5x-20x(如 FlashAttention)
适用性通用,任何模型都能用针对特定计算模式手工打造
代表Inductor 生成的 Triton kernelFlashAttention、cuDNN Conv、CUTLASS matmul