作者|strint

1

概要


torch.fx 是 PyTorch 官方发布的 Python 到 Python 的代码变换工具。如果你想做 Torch 代码变换,torch.fx 是首选工具。


torch.fx 会将 Torch 代码 trace 成 6 种基础的 node 组成的 graph,基于这个 graph 可以方便的做各种变换,变换后的 graph 可以再生成 torch 代码(一个 nn.Module),然后像普通的 nn.Module 一样去执行。


torch 2.0 新发布的 torch.compile(也即 TorchDynamo) 默认将代码转换成了 torch.fx 的 GraphModule,进一步加强了 torch.fx 的重要性。(相关文章:TorchDynamo初探:Python ByteCode的动态修改


关键词:PyTorch,图变换,编译


2

最小用例


torch.fx 有三块基础功能。基础功能一是将 torch nn.Module 转换成 fx.GraphModule,该转换被称为 symbolic trace;基础功能二是中间表达和图改写;基础功能三是 Python 代码生成。


首先定义一个有代表性的 nn.Module,包括了 fx 要处理的6种基础操作:

import torch# Simple module for demonstrationclass MyModule(torch.nn.Module):    def __init__(self):        super().__init__()        self.param = torch.nn.Parameter(torch.rand(3, 4))        self.linear = torch.nn.Linear(4, 5)
def forward(self, x): return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()

然后使用 fx 的第一个基础功能 symbolic trace,它可以把 torch python 代码转换成符号化的(symbolic)表达,该表达的类型是 fx.GraphModule:

from torch.fx import symbolic_trace# Symbolic tracing frontend - captures the semantics of the modulesymbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

fx.GraphModule 的特点是它执行计算时的行为和 nn.Module 相同,但是同时又具备一个内含的计算图,而该计算图是可以用图遍历的方式去操作的,中间表达和图改写都是基于该计算图去做的。打印 fx.GraphModule 可以看到上面 module 的图 IR 表达:

# High-level intermediate representation (IR) - Graph representationprint(symbolic_traced.graph)"""graph():    %x : [#users=1] = placeholder[target=x]    %param : [#users=1] = get_attr[target=param]    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})    return clamp"""

fx.GraphModule 内含的计算图可以再被转成 torch python 代码(也可以把计算图转成自定义的代码),即代码生成功能,比如下面就是上面 module 对应的 python 代码:

# Code generation - valid Python codeprint(symbolic_traced.code)"""def forward(self, x):    param = self.param    add = x + param;  x = param = None    linear = self.linear(add);  add = None    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None    return clamp"""

后文再分别讨论这三块主要功能。


3

图生成(Symbolic Trace)


fx 构图的方法是 symbolic trace. 可以理解为把假的输入传入 nn.Module 或者函数,执行假的输入时,不是真的执行,而是记录执行的操作路径(Symbolic Trace),最后形成个完整执行记录就是一个图。


symbolic_trace函数的输入是 root 和 concrete_args. root 是要 trace 的代码。concrete_args 是可选地,可以传入一些假的输入以特化 trace。


trace 功能默认是用 Tracer 的 trace 方法实现的。它实现了 trace 功能,返回一个 fx.Graph,然后用 fx.Graph 和原 root 构造一个 fx.GraphModule 并返回。

def symbolic_trace(    root: Union[torch.nn.Module, Callable[..., Any]],    concrete_args: Optional[Dict[str, Any]] = None,) -> GraphModule:    tracer = Tracer()    graph = tracer.trace(root, concrete_args)    name = (        root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__    )    return GraphModule(tracer.root, graph, name)

所以其实可以不用 symbolic_trace ,而是自己直接去调用 Tracer。如果需要自定义 trace 逻辑,其实可以用继承和改写 Tracer 的方式来改写 Tracer 的行为。


Tracer 的功能


Tracer 的主要方法是 trace,用来把输入的 nn.Module 或者函数转换成符号化的计算图(IR),trace 的实质就是随着值的传递把对应的操作记录下来。


trace 的机制依赖于把输入转换成抽象的值 Proxy,Proxy 起到代理 tensor 执行的作用。trace 的过程,即把 tensor 都转成 Proxy 在代码中传递,且 Proxy 可以输入常规的 torch 操作。


Proxy 输入常规的 torch 操作之所以可以工作,是依赖了 torch 下发操作的__torch_function__协议](https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md)。可以认为一个类型支持了 __torch_function__,就可以传给 torch 的常规函数去执行,而执行时调用的逻辑就在__torch_function__中定义。如此给 Proxy 的 __torch_function__ 定义好记录操作到图的逻辑,就可以完成 trace 功能(https://github.com/pytorch/pytorch/blob/de586001269fa04fa76ccc64964f676a25e120b2/torch/fx/proxy.py#L449)。


可以利用这个机制,实现一个极简的跟踪和打印 torch 操作的 ProxyTensor。对于加法,会把加法符号化:


import torch
class ProxyTensor(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=None, kwargs=None): if func.__name__ == 'add': print("\n=> torch function call:") print(f"==> function name: {func.__name__}") print(f"==> function args: ({', '.join((str(type(arg)) for arg in args))})") # 自定义张量相加的行为 result = args[0].symbolic() + " + " + args[1].symbolic() return result else: # 对于其他运算,使用默认行为 return super().__torch_function__(func, types, args=args, kwargs=kwargs)
def symbolic(self): return "tensor(" + str(self.shape) + ", " + str(self.dtype) +")"
# 创建自定义张量x = ProxyTensor([4, 5, 6])y = ProxyTensor([1, 2, 3])
result = x - yprint(f"minus result: {result}")
result = x + yprint(f"add result: {result}")

执行一个减法,是常规的 torch tensor 运算:

minus result: ProxyTensor([3., 3., 3.])

执行加法时,则会对操作、输入、输出做自定义操作,而非执行 tensor 运算:

=> torch function call:==> function name: add==> function args: (<class '__main__.ProxyTensor'>, <class '__main__.ProxyTensor'>)add result: tensor(torch.Size([3]), torch.float32) + tensor(torch.Size([3]), torch.float32)

Tracer 的 trace 实际也是在做类似的事情。首先把 nn.Module 或者函数的输入转换成 graph 中的 Node,然后把 Node 包装到 Proxy 里面,作为新的输入。之后 torch 的操作执行 Proxy 时,就会触发自定义的__torch_function__ 函数。Proxy 的自定义行为是把执行的操作记录为图中的 Node,然后操作把 Node 包装为 Proxy 作为操作结果去继续传递。如此便构造出了计算图。


另外值得考虑的是非内置的 nn.Module 和函数的嵌套调用,fx 中忽略了嵌套,所以 trace 到的都是 torch 内置的操作。如果你希望一个自定义操作为当做内置操作被 trace,可以使用 torch.fx.wrap 注册一下。


另外对于控制流、非 torch 内置操作,可以发现 trace 机制的局限性,他们会被 python 执行,但是 trace 不知道他们存在。所以 if 循环可能只记录了一个分支的执行,for 循环被展开,一个 python 的计算结果被当做常量传入torch内置操作。这是 trace 机制的局限性。


fx.GraphModule


trace 的返回结果是 fx.Graph,然后用其构造一个 fx.GraphModule 并返回。fx.GraphModule 继承自 nn.Module 所以其主要行为和 nn.Module 一致,特别的地方是它的 forward 是从 fx.Graph 生成的。另外它带有一个 graph 属性,用于获取其内部包含的计算图。还有个 code 属性,code 是 str 类型,是从 graph 生成的 python 文本代码,且 forward 方法是该文本代码经过编译得到的 。


symbolic_trace生成的 fx.GraphModule 通常当做普通的 nn.Module 使用即可,在使用时这么理解就够了。这个设计体现了 fx 良好的易用性。


自定义 Tracer


torch fx 还给 trace 的过程提供了自定义的空间,方法是继承和覆盖 Tracer。下面做下简介,通常使用是涉及不到的,所以可以忽略这部分。


有如下几种方法可以自定义:


  • create_node:Tracer 往 graph 中插入一个节点时都会调用它,它会返回一个 node,有如下 6 种类型的 node,这也是 trace 过程记录的基本单位;

    • placeholder,一般是整个被 trace 的 Module 或者函数的输入;

    • call_function,函数调用;

    • call_method,对象上的方法调用;

    • call_module,nn.Module 的调用;

    • get_attr,nn.Module 上的属性的获取;

    • output,一般是整个被 trace 的 Module 或者函数的输出;

  • create_proxy:如前文所示,所有操作( Operation)调用的输入、输出都是 Proxy,所以输入输出都会被转为 Proxy,转 Proxy 的过程都会调用 create_proxy 来实现。Proxy 对应上面 Node 多对应操作的抽象的返回结果,所以 Proxy 构造时会输入对应的 Node。

  • create_args_for_root:创建被 trace 的 Module 或者函数的输入;

  • create_arg:创建内部的函数的输入;

  • call_module:遇到一个 nn.Module 时调用 call_module 来触发对应的 node 创建等行为;

  • getattr:当从 nn.Module 上获取属性是会调用 getattr 来触发对应的 node 创建等行为;

以上方法的自定义和 Tracer 的行为耦合比较紧密,所以需要小心处理,结合 Tracer 的代码实现来做自定义。


4

图中间表达和图改写


上文中 trace torch 代码执行的过程,按 python 执行序去记录了操作序列,对于每个操作会生成一个 node,node 的类型是 fx.Node. 这些 node 总体形成了一个图,图的类型是 fx.Graph。


fx.Node 和 fx.Graph


fx.Node 和 fx.Graph 是 fx 中间表达的核心数据结构。上文的的例子中打印了一下 graph,可以看到一个完整的 graph 的文本表达,即中间表达。每一行对应一个 node(return 对应 output 类型的 node):

```python# High-level intermediate representation (IR) - Graph representationprint(symbolic_traced.graph)"""graph():    %x : [#users=1] = placeholder[target=x]    %param : [#users=1] = get_attr[target=param]    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})    return clamp"""```

fx.Graph 主要支持了一些图上的增删查改操作。支持用 nodes 属性获取图中所有 Node 的列表,支持用 create_node 添加一个新的 Node(也支持支持用 call_module、call_method 等语法糖直接添加特定类型的 Node),用 erase_node 删除一个 Node,用 inserting_after 或者 inserting_before 设置新 Node 的插入点,用 eliminate_dead_code 删除没有被使用的 Node,用 lint 做图结构的检查,用 on_generate_code 在代码生成时插入一些自定义操作。


fx.Node 代表图中的节点。fx.Node 的 op 属性可以获取 Node 的类型,上文创建 node 的部分也提到,有如下 6 种类型的 Node:


  • placeholder,一般是整个被 trace 的 Module 或者函数的输入;

  • call_function,函数调用;

  • call_method,对象上的方法调用;

  • call_module,nn.Module 的调用;

  • get_attr,nn.Module 上的属性的获取;

  • output,一般是整个被 trace 的 Module 或者函数的输出;


fx.Node 支持用 append 方法在该节点后面插入一个新 Node,支持用 prepend 在该节点前面插入一个新 Node,支持用 replace_all_uses_with 来把图中所有对本 Node 的依赖替换为一个新 Node,还支持一些其它的替换操作,支持用 format_node 来格式化打印一个 Node。


另外比较有价值的是 fx.Node target 属性记录了 node 对应的操作。对于 placeholder、output、call_method,target 是个普通的字符串名字;对于 call_function,target 是函数本身;而对于 call_module 和 get_attr,target 也是字符串,但是该字符串是查找对应 module 或者属性对象的 key,这里的设计不太好需要适应下,假设gm是 GraphModule 的实例,如下方法才能通过 key 找到 call_module 和 get_attr Node 对应的实例:

# node 为 call_module 时,其 Module 实例查找方法modules = gm.named_modules()module = modules[node.target]
# node 为 get_attr 时,其 attr 实例查找方法getattr(gm, node.target)

fx.Node 的 meta 属性里面包含了 node 相关的对象信息、代码调用栈信息。对象信息可以帮助拿到对象实例的值而代码调用栈可以帮助确认当前 node 对应的代码位置。这两个信息对于 Debug 非常有帮助。


图遍历模式


图遍历模式是最典型的图改写模式。可以用 fx.Graph.nodes 来获取图中节点并加以改写。如下是一个把 add 操作替换成 bitwise_and 操作的图改写例子。
import torchfrom torch.fx import symbolic_traceimport operator
# 定义一个普通的 moduleclass M(torch.nn.Module): def forward(self, x, y): return x + y, torch.add(x, y), x.add(y)
# trace 一下traced = symbolic_trace(M())
# 要匹配的 target 列表patterns = set([operator.add, torch.add, "add"])
# 遍历 fx.Graph 的 Node 列表并修改for n in traced.graph.nodes: # 如果当前 Node 的 target 符合 add if any(n.target == pattern for pattern in patterns): # 在当前 Node 的后面插入 bitwise_and Node with traced.graph.inserting_after(n): new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs) n.replace_all_uses_with(new_node) # 清理掉过时的 Node traced.graph.erase_node(n)# 重新编译下 GraphMoudle# 根据新的图做代码生成,这样就得到了新的 GraphModule 了traced.recompile()

上面注释了一个典型的 graph 图遍历修改图的模式。更多例子可以参考这个链接


另外如果通常需要做一些复杂输入的通用处理,这时 map_aggregate 函数提供了对参数的通用变换工具函数。对于一个由 node 组成的 tuple/list/dict 等类型的输入,你可以提供一个 node 处理函数 fn 给 map_aggregate,然后 map_aggregate 返回一个和原输入 tuple/list/dict 同结构的输入,这个新的输入中的每个 node 都是被 fn 变换过的。该功能类似 oneflow 中的 ArgsTree.


Interpreter 模式


Interpreter 模式提供了一种边执行边修改图的模式。其实质是我们可以遍历图中的节点,且同时挨个执行图中的节点。上文也提到了,可以通过 Node.target 属性获得节点的实例,比如获取 nn.Module,然后执行该实例即可。这里提供了一个通过执行 Node 来记录 Node 的实际输出 tensor 的 shape 和 dtype 的例子ShapeProp。可以看到其核心是遍历 Node 和 执行 Node:

for node in self.graph.nodes:    if node.op == 'placeholder':        result = next(args_iter)    elif node.op == 'get_attr':        result = fetch_attr(node.target)    elif node.op == 'call_function':        # load_arg 可以获取实际的 tensor,然后输入 target 做 operator 的执行        result = node.target(*load_arg(node.args), **load_arg(node.kwargs))    elif node.op == 'call_method':        self_obj, *args = load_arg(node.args)        kwargs = load_arg(node.kwargs)        result = getattr(self_obj, node.target)(*args, **kwargs)    elif node.op == 'call_module':        result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))        if isinstance(result, torch.Tensor):        # 记录执行结果的 shape 和 dtype        node.shape = result.shape        node.dtype = result.dtype

Interpreter 模式提供了一个语法糖fx.Interpreter,它实现了上面的图遍历过程,然后支持重载不同 Node 类型的行为,如此可以自定义执行一个 Node 的逻辑。


fx.Interpreter 接受输入一个 fx.GraphModuel,然后用 run 方法来执行 GraphModule:

def fn(x):    return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)input = torch.randn(3, 4)
class MyInterpreter(fx.Interpreter): pass
result = MyInterpreter(gm).run(input)

run 实际在遍历图中的 Node,然后对该 Node 调用 run_node 方法,而 run_node 方法则调用了各种类型的 Node 的执行方法:

run()    +-- run_node()        +-- placeholder()        +-- get_attr()        +-- call_function()        +-- call_method()        +-- call_module()        +-- output()

run_node 和各种类型的 node 的执行方法都可以重载。使用 Interpreter 来实现 ShapeProp,可以看出不用自己写图遍历了:

class ShapePropInterpreter(fx.Interpreter):    def run_node(self, n : Node) -> Any:        result = super().run_node(n)        if isinstance(result, torch.Tensor):            # 记录执行结果的 shape 和 dtype            n.shape = result.shape            n.dtype = result.dtype        return result
result = ShapePropInterpreter(gm).run(input)

另外也可以用 Interpreter 实现图改写的效果,如下就是把原来的 sigmoid 改成了 neg 操作:

class NegSigmSwapInterpreter(fx.Interpreter):    def call_function(self, target : Target,                      args : Tuple, kwargs : Dict) -> Any:        if target == torch.sigmoid:            # 这里传入的参数是实际值            return torch.neg(*args, **kwargs)        return super().call_function(n)
# 执行 Interpreterresult = NegSigmSwapInterpreter(gm).run(input)

Interpreter 可以边执行,边操作图。但是它的缺点是可以修改图的实际执行,但是不能改图结构。


Transformer 模式


Interpreter 模式另外一个缺点是它是即时执行的,也没有改变图的结构。如果想修改图的结构,就可以使用 Transformer 模式。


fx.Transformer 继承自 Interpreter,所以支持的重载接口类似。不同的是,它实际在做符号化的执行,且再创建一个新的图。


传入一个原始的 GraphModule 到 Transformer,然后调用 transform 方法。会创建一个新的图出来了,然后按原来图的节点顺序执行,返回结果会用于在新的图中创建一个新节点。最后就得到了一个新的 GraphModule。

class NegSigmSwapXformer(fx.Transformer):    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:        if target == torch.sigmoid:            # 这里传入的参数是 Proxy            return torch.neg(*args, **kwargs)        return super().call_function(n)
# 得到了一个 sigmoid 被替换为 neg 操作的 GraphModuletransformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()

fx.Transformer 提供了一种便捷的 Node 到 Node 图改写方式。


5

Python 代码生成


Python 代码生成在 GraphModule 的 recompile 方法调用时触发。它是 fx 的内部行为,使用时通常不用关注,这里介绍下主要实现技巧。


python 代码生成在做的事情就是把 graph 转成 code:

# High-level intermediate representation (IR) - Graph representationprint(symbolic_traced.graph)"""graph():    %x : [#users=1] = placeholder[target=x]    %param : [#users=1] = get_attr[target=param]    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})    return clamp"""print(symbolic_traced.code)"""def forward(self, x):    param = self.param    add = x + param;  x = param = None    linear = self.linear(add);  add = None    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None    return clamp"""

这转换的过程是一对一的,一个 Node 会被转换成对应的 Python 代码。其核心函数是 fx.Graph 中的 emit_node 函数,以 call_method Node 为例:

elif node.op == 'call_method':    assert isinstance(node.target, str)    body.append(        f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'        f'({_format_args(node.args[1:], node.kwargs)})')

上面的方法根据 node 中信息,在 body 中添加了 python 文本代码,把 Node 信息:

%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})

转换成了 Python 代码:

clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None

这部分 emit_node 逻辑是在 fx.Graph 的 python_code 函数中调用的,python_code 函数返回一个 python_code 对象,包含了 python 源代码和全局对象数据。然后执行如下操作,就把图生成的代码赋值给了 GraphModule。

# 生成 python 代码对象,graph 对应一个 fx.Graphpython_code = graph.python_code(root_module='self')# python 代码文本code = python_code.src# python 代码全局对象globals = python_code.globals# 使用 python 字节码编译器编译和加载 python 代码exec(compile(code, key, 'exec'), globals)# 从中获取编译好的总函数 forwardforward_fn = globals_copy['forward']

最后用 forward_fn 替换 GraphModule 的 forward 的方法,就得到了一个和 graph 中执行逻辑相同的 GraphModule 了。


6

torch.fx 和 torch.compile


在 torch 2.0 的 torch.compile (TorchDynamo) 功能下,一个函数或者 nn.Module输入 torch.compile 编译时,可以自定义一个编译器后端。


如下的custom_backend即自定义的编译逻辑。torch.compile 会把对应的 torch 代码 trace 成 fx.GraphModule 对象,然后传入 custom_backend 函数,这样你就可以根据 fx.GraphModule 自定义编译逻辑,生成一个自定义的函数,返回给 torch.compile。下面例子中的 opt_model 第一次执行时,会触发custom_backend 执行,获取一个自定义的函数(经过编译优化的函数)并缓存下来,后面执行时,就可以直接使用编译优化的函数做执行,达到优化执行的效果。

from typing import Listdef custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):    print("custom backend called with FX graph:")    print(gm.graph)    return gm.forward
opt_model = torch.compile(init_model(), backend=custom_backend)


7

总结


torch.fx 是 PyTorch 官方发布的 Python 到 Python 的代码变换工具。它提供了 trace 代码生成图、改写图、再生成新的 Python 代码的工具。灵活性和易用性都很高。本文介绍了其核心功能和一些实践技巧。


OneFlow 利用 torch.fx 和 torch.compile 做 Torch 代码到 OneFlow 代码的转换工作,以更简单的编译和加速 Torch 代码。


参考

[1]. torch fx 官方文档. https://pytorch.org/docs/stable/fx.html

[2]. torch.fx: Practical Program Capture and Transformation for Deep Learning in Python. https://arxiv.org/pdf/2112.08429.pdf

[3]. torch fx 应用于将 torch 转成 oneflow. https://github.com/Oneflow-Inc/diffusers/pull/237

[4]. 适配PyTorch FX,OneFlow让量化感知训练更简单


其他人都在看

试用OneFlow: github.com/Oneflow-Inc/oneflow/