A Complete Look at the FX Tools Used by Meta: Optimizing PyTorch Models With Graph Transformation

The graph mode in PyTorch is more performant. This article introduces Torch.FX, a powerful tool that can capture and optimize the graph of PyTorch programs.
1. Introduction
PyTorch supports two execution modes: eager mode and graph mode.
In eager mode, operators in the model are executed immediately when they are read. It is easy to use and more friendly to machine learning practitioners, so it is set as the default execution mode.
In graph mode, operators are first synthesized into a graph and then compiled and executed as a whole. It has higher performance and is therefore widely used in actual production.
Specifically, graph mode supports operator fusion. By merging two operators, the total overhead of memory reads and kernel launches can be reduced or localized.
Fusion can be horizontal:Takes a single operation (such as BatchNorm) applied to multiple operands and merges them into a single array.
Fusion can also be vertical:Combine a kernel with another kernel that requires the output of the first kernel (such as ReLU followed by convolution).
Torch.FX (abbreviated as FX) is a publicly available toolkit that supports graph mode execution as part of the PyTorch package. It can:
1. Get the graph from the PyTorch program
2. Allow developers to write transformations on the obtained graph
Meta has previously used FX to optimize the training throughput of production models. This article will introduce the FX-based optimization developed by Meta to show how to use graph transformation to optimize the performance of PyTorch deployed models.
II. Background
Embedding tables are widely used in recommendation systems.This section will introduce the background knowledge of FX and embedding table.
2.1. FX
Figure 1 is a simple example showing how to convert a PyTorch program with FX.It consists of three steps:
- Get the graph from the program
- Modify the graph (in this case, we use GELU instead of RELU)
- Generate a new program from the modified graph

The FX API provides many other functions for inspecting and transforming PyTorch program graphs.
2.2. Embedding table

In the recommendation system,Sparse features (e.g., User ID, Story ID) are represented by an embedding table.
The embedding table E is a HxD matrix, where H is the hash size and D is the embedding vector dimension. Each row of E is a floating point vector.
The function of feature hashing is to map a sparse feature to an index list of E, such as [S1, S2, …, Sk], where 0 ≤ Si
To fully utilize the GPU, sparse features are usually processed in batches.Each entity in a batch has its own index list. If a batch has B entities, it can be simply understood as a representation having B index lists.
A more rigorous representation would be to merge the B index lists into one index list and add a list of index lengths (one for each entity in the batch).
For example, if a batch contains 3 entities, its index list is as follows:
- Entity 1: indices = [10, 20]
- Entity 2: indices = [5, 9, 77, 81]
- Entity 3: indices = [15, 20, 45]
Then the indice and length of the full batch size will be:
- Indices = [10, 20, 5, 9, 77, 81, 15, 20, 45]
- Lengths = [2, 4, 3]
The output of the embedding table query for the entire batch is a BxD matrix.
3. 3 FX Transformations
PyTorch has updated three FX transformations to speed up access to the embedding table, which will be introduced one by one in this section.
Below are 3.1 about the transformation of combining multiple small input tensors into one large tensor; 3.2 about the transformation of fusing multiple parallel computing chains into one computing chain; and 3.3 about the transformation of overlapping communication and computing.
3.1 Combining Input Sparse Features
Each input sparse feature in a batch can be represented as two lists: an index list and a B length list, where B represents the batch size.
In PyTorch, both lists can exist as tensors.When a PyTorch model runs on a GPU, the embedding table is usually stored in GPU memory (which is closer to the GPU and has higher read and write bandwidth than CPU memory).
When the input sparse features need to be used, both tensors must be copied from the CPU to the GPU first. However, each host-to-device memory copy requires launching a kernel, which is more time-consuming than the actual data transfer.
If a model uses many input sparse features, this copying may become a performance bottleneck (e.g., 1000 input sparse features will require copying 2000 tensors from host to device).
One optimization to reduce the number of host-to-device memcpys is to combine multiple input sparse features before sending them to the device.
For example, given the following three input features:
- Feature_A: indices = [106, 211, 7], lengths = [2, 1]
- Feature_B: indices = [52, 498, 616, 870, 1013], lengths = [3, 2]
- Feature_C: indices = [2011, 19, 351, 790], lengths = [1, 3]
The combined form is:
Features_A_B_C: indices = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790], lengths = [2, 1, 3, 2, 1, 3]
So instead of copying 3×2=6 tensors from host to device, only 2 tensors need to be copied.
Figure 3(b) depicts the implementation of this optimization, which consists of two components:
- CPU side:The input pipeline is modified to combine the indices of all sparse features into one tensor and all lengths into another tensor. These two tensors are then copied to the GPU.
- GPU side:Using FX, insert a Permute_and_Split operator in the model graph to recover the individual feature indices and length tensors from the merged tensors and send them to the corresponding nodes downstream.


3.2 Horizontal fusion of computation chains starting from accessing the embedding table
In a production model, it is common to have 10 embedding tables per GPU. For performance reasons,Queries to these tables are grouped together so that their outputs are concatenated into one large tensor.(See the red part in Figure 4(a)).
To compute a single feature output,Use the Split operator to split a large tensor into N small tensors(where N is the number of features) and then applies the desired computation to each tensor.
As shown in Figure 4(a), the calculation applied to each feature output O is Tanh(LayerNorm(O)). All calculation results are concatenated into a large tensor and then passed to the downstream operator (Op1 in Figure 4(a)).
The main runtime cost here is the GPU kernel launch overhead.For example, the number of GPU kernel launches in Figure 4(a) is 2*N+3 (each ellipse in the figure represents a GPU kernel). This affects the performance because the execution time of LayerNorm and Tanh on the GPU is very short compared to their kernel launch time.
Additionally, the Split operator may create an extra copy of the embedding vector output tensor, consuming additional GPU memory.
Using FX to implement an optimization called horizontal fusion can greatly reduce the number of GPU kernel launches(In this example, the number of GPU kernel launches after optimization is 5, see Figure 4(b)).
The 2D embedding tensor of shape (B, NxD) is reshaped into a 3D tensor of shape (B, N, D) using the Add_middle_dim operator instead of an explicit Split. Next, a single LayerNorm is applied to its last dimension. A Tanh is applied to the result of the LayerNorm. Finally, the result of the Tanh is restored to a 2D tensor using the Remove_middle_dim operator.
Since Add_middle_dim and Remove_middle_dim just reshape the tensor,No additional copies are created, so GPU memory consumption can also be reduced.


3.3 Overlap between computation and communication
The training of recommendation models for production is usually done on distributed GPU systems.Since the device memory capacity of each GPU is not enough to hold all the embedding tables in the model, they need to be distributed across multiple GPUs.
During the training step, a GPU needs to read/write feature values from the embedding table on other GPUs.This is called all-to-all communication and can be a significant performance penalty.
By implementing a transformation through FX, it is possible to overlap computation with all-to-all communication.Figure 5(a) shows an example of a model graph with embedding vector table access (EmbeddingAllToAll) and other operators. As shown in Figure 5(b), without any optimization, they are executed sequentially on a GPU stream.
Use FX to split EmbeddingAllToAll into EmbeddingAllToAll_Request and EmbeddingAllToAll_Wait, and arrange independent operators between them.

3.4 Summary

To discover which models would benefit from these transformations, developers analyzed performance data collected by MAIProf for models running in the Meta Data Center.We show that these transformations achieve 2-3x speedup on a set of production models compared to eager mode.
IV. Conclusion
From a performance perspective, graph mode in PyTorch is preferred over eager mode used in production environments. FX is a powerful tool for capturing and optimizing PyTorch program graphs. This post shows three FX transformations that are used to optimize a production recommendation model inside Meta.
Finally, I hope more PyTorch developers can use graph transformation to improve the performance of the model.
—— End ——