Deep Learning: A Primer on Distributed Training — Part 1

Shivam Bharuka
27 min readJan 11, 2022

Introduction

Deep learning has gained tremendous popularity in the last decade due to breakthroughs in a wide range of tasks such as language translation, image classification, and speech recognition. In deep learning, a learning model, designed roughly like our brain, employs an iterative process with an objective to learn how to perform these complex tasks well. Artificial neural network is a popular technique used to build these deep learning models.

Artificial neural network consists of a collection of nodes which aim to capture the relationship in input data such that it can combine them to provide a desired output. Since it is difficult to model the relationship, it uses data to learn it. It goes through a training phase where it uses a set of known input and output pairs in order to learn the relationship between them. Once a neural network knows how to infer the relationship, it can be used to process unseen inputs and make accurate output predictions. For instance, an artificial neural network can be trained on a large set of cat images (Figure 1) such that it can extrapolate the relationship between the features of an image and determine whether a cat is present in it or not.

Figure 1: A Simple Neural Network to determine whether a cat is present in an image or not.

A powerful attribute of artificial neural networks is that it can provide better predictions with increasing training data and network size. However these gains do not come for free. They demand for more hardware resources during training in order to avoid large execution latency. This calls for development of advanced parallelization techniques to efficiently utilize the allocated resources and train quickly.

Over the past few years, we have seen an explosion of work from both industry and academia on how to parallelize training for these large models. In this note, I aim to detail the different scaling techniques and help you develop an intuition of the problem space. Before we jump into it, let us quickly look at what a typical machine learning workload looks like and what are the challenges we face in training it efficiently.

Artificial Neural Network

Most commonly, artificial neural networks are composed of layers of nodes that are connected together such that the output of one node is sent as input to nodes in the next layer. A neural network contains an input layer, multiple “hidden” layers, and an output layer. As shown in Figure 2, the input layer extracts features about the input data which are passed through the hidden layers and finally the output layer provides the target prediction. Hidden layers are responsible for deriving the relationship between the input features. As a matter of fact, deep learning refers to neural networks which contain multiple (i.e. deep) hidden layers and is able to model complex relationship in the input data.

Figure 2: Illustration of a node in an Artificial Neural Network

A single node in a neural network contains an activation function, a bias and a weight corresponding to each of the incoming connections from other nodes. During execution, it takes a set of inputs, combines it with the corresponding weights, adds the bias, and runs it through the activation function to calculate the output. A neural network needs to tailor its parameters (weights and bias) such that it can provide accurate predictions for a given input. In order to achieve this, it goes through the training phase where it adjusts the model parameters based on a dataset which contain samples with known input and output pairs. During training, each sample is fed through the neural network layers and the difference between the predicted and desired output is calculated. Based on the difference, neural network adjusts the parameters using a learning rule such that the next time the model sees the same input, it can provide a more accurate output. We will examine the learning rule used during the training phase to adjust the model parameters in the next section.

Model Training

Neural networks are trained in multiple iterations called epochs. In each epoch, the model is trained over all the samples in the input dataset. An epoch contains multiple training steps which is further broken down into three phases: (i) Forward Pass, (ii) Backward Pass, and (iii) Optimizer Pass.

Figure 3: Representation of a training step

Training Step. Forward pass takes a training sample, feeds it through the model, and generates predictions. Since the training dataset also contains the expected output (called label), the forward pass calculates the loss between the predicted and expected output using an cost function. The cost function measures the efficacy of the model and measures how much does the predicted value differ from the expected value. Backward pass traverses the layers in the reverse order and uses the loss to compute a gradient for each model parameter. The gradient is intended to minimize the loss between the actual and the predicted output by adjusting the model parameter. The optimizer pass uses the mean gradient as delta to update the model parameters. Training iteratively with the above three phases refines the parameters of the model such that the model is able to provide accurate predictions. Figure 3 shows this combination of phases in a training step.

Figure 4: Gradient Descent Algorithm

Gradient Descent. Let us demystify what it means to calculate the gradient and how does it help to minimize the loss between the predicted and the actual output by a neural network. Gradient Descent (GD) is an optimization algorithm which can help find the minimum for a function. As shown in Figure 4, we calculate the slope (also called gradient) of the function and then move in the opposite direction to find the minima. In case of neural networks, we use gradient descent to minimize the cost function by incrementally updating the model parameters with gradients. Once the loss is within an error range around the minima, we can say that the model has converged and assume that additional training won’t improve its accuracy. There are multiple ways in which the gradient can be applied to the model parameters, including:

  • Batch Gradient Descent: Each epoch contains a single training step where the loss with respect to every sample in the input dataset is calculated. The gradient for each of the samples is averaged together and used to update the model parameters. Though this approach requires few updates to the model parameters, it requires us to iterate through all the samples in the dataset before taking a single training step. This can be very expensive for large datasets and require storing the gradients across all the samples. Moreover, slow updates can also lead to sub-optimal convergence.
  • Stochastic Gradient Descent: Unlike Batch GD where all the samples are processed in a single training step, Stochastic GD only considers a single sample in one training step. This means that each epoch contains as many steps as the number of samples in the training dataset and the model parameters are updated for each sample. This approach is computationally very expensive due to the frequent updates and add a lot of noise in the learning curve.
  • Mini-Batch Gradient Descent: As we know now, both Batch GD and Stochastic GD have their disadvantages in terms of efficiency and model quality. We find a sweet spot by splitting the training dataset into batch of samples (each batch is called a “mini-batch”) which are processed together in a single training step. The gradients for each sample in the mini-batch is averaged and used as delta to update the model parameters. It provides better convergence than Batch GD due to frequent parameter updates and is more efficient than both the other variants in terms of memory and computational requirement. Batch size is an important parameter which needs to be tuned for training since a large value would lead to slow convergence whereas a small value could lead to inefficient resource utilization. We will be using Mini-Batch GD based training in this note due to its dominant use in deep learning models.

Resource Requirement.

Before we jump into scalability and its associated challenges, it is important for us to understand characteristics of training workloads in terms of hardware resources required for execution:

Compute: The total computation required to run model training is equivalent to the operations required to run training step for each mini-batch of data across all epochs. This includes the cost to sequentially run forward, backward and optimizer pass across all mini-batches. The execution of the forward pass in a training step for input samples inside a mini-batch can be parallelized since their gradients are averaged before updating the model parameters. To accelerate computation, input data, weights, and biases for each layer are arranged in a matrix to increase the training throughput. Instead of performing the computation for each node in a neural network layer individually, matrix multiplication allows us to perform neural network computation simultaneously. Moreover, specialized accelerators and compute units such as GPUs are also designed to perform fast matrix computations.

Figure 5 shows an example matrix representation of a 4-layer neural network composed of an input layer, two hidden layers, and an output layer. Each layer is represented using a matrix where the number of rows are equivalent to the number of input samples which are being trained in parallel and the number of columns are equivalent to the number of nodes in that layer. The dimensions of the weight matrix between two layers are based on the number of nodes in the connecting layers. Finally, each row in the output layer provides the prediction for each input data sample from the training dataset.

Figure 5: Matrix Representation of a 4-Layer Neural Network

Memory: The memory required for training includes storage cost of both (i) permanent states which includes model parameters and optimizer states and (ii) intermediate states which includes gradients and activations.

  • Model Parameters. As mentioned before, model parameter refers to the learnable weight assigned to each incoming edge and the bias associated with each neural network node. It is updated at the end of every training step in the optimizer pass. After training finishes, model parameters are used to predict output on unseen inputs.
  • Optimizer States. Advanced deep learning training techniques maintain several states (e.g. learning rate) to optimize the model convergence. During optimizer pass, these states are updated using the gradients from the backward pass and then used to calculate the delta for the model parameters. Similar to model parameters, optimizer states are required to be stored throughout the training lifecycle.
  • Gradients. Gradients are calculated for each model parameter during the backward pass in a training step and can be discarded after updating the optimizer states.
  • Activations. Activations refer to the series of intermediate values per node which serves as input to the next node in the neural network. It is calculated one layer at a time during the forward pass. The activations are also needed during the backward pass to calculate the gradients for each model parameter.

Scaling and Challenges

Recent work has shown that an increase in the size of the training dataset and the model complexity helps capture complex relationship in data. This means that training on more data or adding more complexity to the model increases its accuracy. But stretching these values also increases the execution time of the training phase. Thus, it is desirable to improve training efficiency in order to support the growth and achieve higher accuracy.

There are multiple approaches at ML modeling, software, infrastructure and hardware level which aim to make training faster. Modeling techniques such as compression and approximate mathematics are used to minimize the training workload complexity. Software approaches take inspiration from traditional distributed systems and use multiple machines to parallelize the workload. Lastly, hardware approaches focus on deploying more-powerful hardwares to decrease the execution latency. In this note, we will focus on approaches at the software level which utilize advanced parallelization and resource optimization techniques to scale model training. Let’s start by enlisting the challenges which comes with growth in input data and model size.

Growth in Training Dataset.

Input data is split into batches which are fed through the model iteratively during training. A growth in input data means that the model needs to perform more training steps during both forward and backward pass, proportionally increasing the training time. A common technique to accelerate training is to replicate the model across multiple machines and train multiple batches of data in parallel. This technique is called data parallelism.

Growth in Model Complexity.

A growth in model complexity corresponds to an increase in number of model parameters and computational complexity. Number of hidden layers and the corresponding number of nodes in each layer impacts the number of model parameters used in a model. Similarly, the number of floating point operations required to run a training step for a single sample defines its computational complexity. Research has shown tremendous gains in accuracy with increasing model size and computationally complexity.

An increase in model complexity subsequently increases the amount of computational resources needed to train a single batch of data. Large models in image recognition and language translation models have stretched the model complexity beyond the memory of a single machine and require multiple machines for model training. The technique to split the model across multiple machines is called model parallelism.

Just show me the numbers.

Before we jump into the parallelization techniques in detail, let’s look at some numbers so that we can get an idea of the workload scale. Here are the characteristics of the largest GPT-3 language model which became really popular recently for generating human-like text. We compare it with the resources available in a single NVIDIA V100 GPU.

  • Training Dataset. It is trained on more than 400 billion tokens amounting to 45TB of training data.
  • Number of Parameters. GPT-3 contains 175 billion parameters which would take more than 1 TB of memory to store model parameters, optimizer, and intermediate states required during training. In comparison, a single V100 contains only 32 GB of high-bandwidth memory.
  • Number of FLOPs. The total compute required to train a GPT-3 model is 3640 PetaFLOPs per day. Even with V100’s theoretical capacity of 28 TFLOPs, it would take 355 years to train the model on a single GPU.

Given the scale of GPT-3, you can rightly guess that training it requires multiple GPUs and smart parallelization techniques to finish its execution in a reasonable time. Moreover, this is not the end of the journey and we are already exploring models which are manifold bigger than GPT-3 and contain more than trillion parameters. Let’s take a deep dive into the techniques used to scale training. We will use GPUs as our compute unit due to their widespread adoption for neural network training but a lot of these methods are applicable to other hardware types as well.

Data Parallelism

Data parallelism is the most common technique to parallelize training due to its simplicity. Multiple workers (a worker is a compute unit ranging from a full-fledged GPU to a single CPU core) are utilized to train batches of data in parallel and each worker maintains an exact copy of the model. As shown in Figure 6, a mini-batch of training samples is split into N micro-batches where N is the number of available workers (N is also termed as the degree of data parallelism). During a training step, each worker executes the forward and backward pass for a micro-batch of samples and since multiple copies of the model are trained in parallel, a synchronization technique is implemented to share their learnings before training on the next batch of data.

Figure 6: Data Parallelism using three workers where each worker trains on a batch of data in parallel

The learning of a model is captured during the optimization pass when the gradient is used to update the model parameters. So, we can synchronize either the gradients or the updated model parameters itself. Most commonly, these values are shared using two paradigms: (i) Parameter Server and (ii) All-Reduce. Let us discuss each of them in detail and then talk about their advantages and disadvantages.

Parameter Server (PS)

The PS architecture splits the responsibility of training and synchronization into two separate worker roles: trainer and parameter server. Parameter server uses separate resource (e.g. CPU machine) for storing the model parameters which are pulled by trainers to execute a training step on a micro-batch of sample. In gradient based averaging, each trainer performs the forward and backward pass and pushes the gradients to the parameter server. The PS runs the optimizer pass and updates the stored model parameters. The updated parameters are fetched by the trainers when training on the next batch of data. In the case of model parameter based averaging, the optimizer pass is also performed on the trainers and only the updated parameters are pushed to PS for synchronization. In both cases, PS can update the model parameters synchronously or asynchronously, as described below:

Synchronous PS Training

During synchronous training, the parameter server waits for a response from all trainers before applying the updates to the model parameters. This technique provides strong consistency guarantees in terms of model convergence since it is approximately equivalent to a single trainer performing training on a larger batch of data. Figure 7 shows how synchronous PS training is used to implement gradient based averaging. Each trainer runs the forward and backward pass on a batch of data. During the backward pass, gradient is calculated for each model parameter and communicated to the parameter server. The parameter server aggregates the responses from all trainers and updates the underlying model parameter. The updated parameters are communicated back to the trainers to train the next batch of data.

Figure 7: Synchronous Parameter Server based Data Parallelism with Gradient Averaging

Asynchronous PS Training

Synchronous PS training requires a response from each trainer before parameter server can update the model parameters. This can become a bottleneck and lead to significant hardware under-utilization if the response time of trainers are unequal. Research has shown that for some models, we can relax the synchronization assumption and let trainers proceed independently of each other without impacting convergence. As shown in Figure 8, trainers send gradients after running the forward and backward pass but instead of blocking on responses from all trainers and aggregating gradients, parameter server runs the optimizer pass immediately and updates the model parameters. The updated parameters are immediately pulled by the trainer to perform training on the next batch of data.

Asynchronous PS training introduces staleness since a trainer can be performing training using parameters which have since been updated by other trainers and similarly parameter server can overwrite updates sent by trainers. Moreover, with increasing data parallelism degree, the staleness increases and convergence becomes very slow. There are multiple variants of asynchronous PS such as HogWild!, EASGD, and DistBelief which modify the optimizer algorithms to reduce the staleness and increase the convergence speed.

Figure 8: Asynchronous Parameter Server based Data Parallelism with Gradient Averaging

Stale-Synchronous PS Training

Hybrid techniques between synchronous and asynchronous PS training have also gained some popularity to synchronize the model parameters in parameter server based training. Instead of enforcing strong consistency or removing the sync completely between the trainers, they use recipes with relaxed consistency such as introduction of global synchronization points in asynchronous PS training or limiting the time to wait for gradients before updating the parameters in synchronous PS training.

PS Scalability

In the previous section, we talked about storing all the model parameters in a single parameter server (PS) with the responsibility of synchronizing updates across the trainers. In that case, increasing the number of trainers can create a network congestion due to bandwidth sharing between multiple PS and trainer connections. Since the communication between trainers and parameter server is on the critical path for training, the limited bandwidth of a single PS can significantly slow down both synchronous and asynchronous PS training. In order to avoid this, the model parameters can be sharded and stored across multiple parameter servers. This allows for splitting the network traffic and increasing the aggregate bandwidth available for synchronizing model parameters. Moreover, adding multiple parameter servers also helps scale the storage capacity for storing model parameters. Figure 9 shows an example of a sharded parameter server architecture where the model parameters are sharded across four parameter servers layer-by-layer.

Figure 9: Sharded Parameter Server

Since gradients for each layer are computed sequentially during the backward pass, the layer-by-layer sharding strategy can lead to under-utilization of resources. For example, in Figure 9, only parameter server 4 will be receiving data from all workers when the backward pass computes the gradients for output layer. There are multiple factors such as model size and network topology which can help determine the sharding and placement strategy for parameter servers to avoid such inefficiencies. The goal is to distribute the load evenly amongst the parameter servers and use the network cost of communication links in a cluster to design the placement strategy.

All-Reduce

Instead of utilizing separate resources for synchronizing model parameters, another popular technique is to utilize the all-reduce collective operation to aggregate parameter updates across multiple trainers. The all-reduce operation contains two steps — (i) Reduce-Scatter aggregates the data spread across multiple workers using a reduction function (e.g. sum or min) such that all workers contain a portion of the final data in the end and (ii) All-Gather shares the reduced state spread across the workers such that all the workers have the entire state in the end. There are multiple algorithms to implement the all-reduce operation such as ring, hierarchical, and butterfly. The aim is to minimize the required communication bandwidth and improve the overall latency of the operation. Ring all-reduce is one of the most widely used algorithm today in neural network training.

Figure 10 shows an example of the all-reduce operation for four workers implemented using the ring algorithm. Initially, each worker contains a vector of data which needs to be aggregated together and distributed to all the workers. In the reduce-scatter step, the vector is split into N chunks where N is the total number of workers. The workers are organized in a logical ring structure such that each worker sends a chunk of vector to the neighbor on its right and receives a chunk from the neighbor on its left for N-1 steps. At each step, the received chunk is aggregated with the corresponding local chunk using the reduction function and sent to the next worker. In the end, each worker holds the final aggregated value for a chunk of vector. Then, in the all-gather step, each worker distributes the aggregated chunk to all the other workers using the ring communication for N-1 steps again such that all workers have the entire vector which is all-reduced.

Figure 10: All-Reduce using Ring Algorithm

Figure 11 shows an example of how both gradient and model parameter based averaging utilize all-reduce for synchronization during execution of a training step. In gradient based averaging, all-reduce is used before the optimizer pass to aggregate and distribute gradients across all workers. Similarly, in the case of model parameter based averaging, the updated model parameters are all-reduced across the workers at the end of the training step. All-reduce based data parallelism has strong consistency guarantees since each worker sees exactly the same aggregated value after synchronization.

Figure 11: (Left) All-Reduce based data parallelism with Gradient Averaging; (Right) All-Reduce based data parallelism with Parameter Averaging

What is the preferred technique?

All-reduce with gradient averaging is popularly used to implement data parallelism for large models and parameter server approach is used for models with limited workers where asynchronous and stale-synchronous training can provide acceptable convergence. Here are the properties of all-reduce which contribute to this:

  • Network Congestion. The communication cost of parameter server strategy is dependent on the number of trainers since each trainer establishes a connection with all parameter servers and the bandwidth per communication channel is limited. Therefore, the connection between PS and trainers can become a bottleneck with increasing trainers. In contrast, the communication cost of all-reduce with ring algorithm is independent of the number of trainers and remains constant as more trainers are added to the system.
  • Compute Efficiency. Parameter servers require separate resources for synchronization. These resources are severely under-utilized when the model is not running the optimizer pass. Even during the parallelization of backward and optimizer pass, the number of available parameters which are communicated is limited. On the other hand, all-reduce does not require any additional resources and utilizes the existing resources of the trainers.
  • Cluster Design. Modern architectures have introduced high-bandwidth connection between workers such as NVLink for GPU to GPU communication which makes all-reduce much more efficient than CPU to GPU connection over PCIe.

Similarly, gradient averaging is the preferred technique over parameter averaging to synchronize model replicas due to the following reasons:

  • Training Result. Synchronous gradient averaging is inherently equivalent to single-node training with a large batch size since the same gradient is applied to each model replica during the optimizer pass. This is not true for model parameter averaging since each replica updates its parameters independently. This can result in in-accurate results in terms of finding the optimal parameter value during gradient descent.
  • Overlap. Synchronization in model parameter averaging is a separate step which is performed after the optimizer pass. Unlike gradient averaging, this prevents overlapping of backward and optimizer pass during training on a batch of data. This leads to worker under-utilization since no compute is performed while model parameters are being communicated.

Despite the popularity of all-reduce algorithms, there are some inherent benefits of parameter server strategy over all-reduce:

  • Bandwidth Usage. Parameter server has reduced bandwidth usage compared to all-reduce. For a model with M parameters, PS strategy only sends and receives M bytes whereas all-reduce ring algorithm communicates almost 2M bytes.
  • Consistency. All-reduce only supports synchronous training and cannot start until each worker joins whereas parameter server supports relaxed consistency which can hide the communication cost incurred due to a straggler worker. However, this techniques is adopted at a small scale due to its impact on model convergence.
  • Fault-tolerance. Independent recovery of workers during failure is critical in large scale model training since an increase in number of trainers also increases the surface area of failure (for instance, failure due to bad hardware or routines maintenance). Parameter server provides better support for fault tolerance by splitting the role of training into trainers and parameter servers. Trainers can be made stateless and recovered on the fly and similarly parameter servers can be replicated to support hot failover. All-reduce doesn’t provide such guarantees and since all-reduce synchronization cannot proceed during the failure of a worker, the entire workload needs to be rescheduled.

Analysis and Optimizations

Due to its intuitive design and simple implementation, data parallelism is one of the most popular techniques to scale training. We present a breakdown of its runtime cost and share several optimizations which are used to make it more efficient:

Compute

Data parallelism is equivalent to single-node training with larger batch size and more aggregate resources. Since each training worker can contain multiple cores (GPUs contain hundreds of small cores), in order to utilize the additional resource efficiently, each worker needs to train enough samples in parallel so that their compute capability can be exhausted. Since a mini-batch is evenly split into micro-batches and sent to the workers, we need to increase its size appropriately for efficient resource utilization.

Even so, increasing the batch size does not come for free due to its impact on the convergence speed. Studies [Shallue et al., 2019] show that an increase in batch size provides a proportional gain in training speed until it starts showing diminishing returns and eventually does not provide any gains. This places a hard limit on the degree of data parallelism since increasing the batch size after a threshold lead to slow convergence. The threshold for batch size needs to be carefully determined based on the properties of the model, optimizer algorithm, and training dataset.

Memory

In data parallelism, each worker stores an entire replica of the model. This means that the aggregate memory available for training remains constant even with increasing resources. This puts a hard limit on the supported model size and leads to memory inefficiency. Recent works have explored several memory optimizations which vastly improve the supported model size during data parallelism:

Deduplicate Model States. Zero Redundancy Optimizer (ZeRO) [Rajbhandari et al., 2020] eliminates the memory redundancy in all-reduce based data parallel training and makes the entire aggregate memory of resources available for storing the model. Instead of replicating the optimizer, gradient, and parameter states, ZeRO partitions them across the data parallel workers.

Figure 12 demonstrates the training loop used in data parallelism. As mentioned above, all-reduce algorithm is used to perform the synchronization step using two operations: reduce-scatter and all-gather. Reduce-scatter reduces different partitions of gradients on different workers and all-gather shares the reduced data amongst all the workers.

Figure 12: Training loop in data parallelism using all-reduce synchronization

Optimizer and Gradient Partitioning: ZeRO splits the optimizer states into N parts where N is the number of data parallel workers. Each worker stores a single partition of the optimizer state and only updates the parameters corresponding to it. Furthermore, since each worker only contains a subset of the optimizer state, it does not need to receive the gradients for all the parameters. The gradients for parameters corresponding to which it stores the optimizer state is sufficient for running the optimizer pass. Therefore, during synchronization of gradients, workers only perform the reduce-scatter operation and store the required partition of reduced gradients.

After reduce-scatter, it uses the partition of reduced gradients to update the corresponding optimizer states and model parameters. At the end, since each worker contains a partition of the final parameters, it runs an all-gather operation to share the full parameter state amongst all the workers. Figure 13 shows the new training loop with optimizer and gradient partitioning. In Step 4, using reduce-scatter, each worker only aggregates gradients for a portion of model parameters which are then used to update the corresponding parameters in step 5. In Step 6, workers perform an all-gather operation to share the updated parameters amongst all the workers.

This technique removes the replication of optimizer state completely without increasing the communication volume. Moreover, the memory used to store gradients can also reduced since each worker frees up the gradient memory after reduce-scatter if the corresponding optimizer partition does not belong to it.

Figure 13: Training Loop w/ ZeRO Optimizer and Gradient Partitioning

Parameter Partitioning: In training with optimizer and gradient partitioning, each worker only performs updates on a partition of model parameters. Therefore we can further optimize the memory by partitioning the model parameters such that each worker only stores the parameters corresponding to the optimizer states it stores. Since model parameters are required during forward and backward pass, parameters for a layer are broadcasted from the worker which owns it and once the computation is done for that layer, it is discarded by all workers which do not own the corresponding parameter. This technique completely removes the replication of model parameters but tradeoffs communication for memory. Since the broadcast operation is required during both forward and backward pass, it increases the overall communication volume. However, the all-gather operation at the end of a training step is not required anymore so the overall increase in communication volume is only 1.5x compared to before.

As shown in Figure 14, broadcast operation is run before the forward and backward pass for each layer begins. The model parameters are broadcasted from the worker containing the specific partition of the parameters to all the remaining workers. Upon completion of the backward pass, the computed gradients for each layer are shared using reduce-scatter such that the gradients are aggregated at each worker corresponding to the partition of the parameters they hold. The aggregated gradients are then used to update the model parameters assigned to it.

Figure 14: Training Loop w/ ZeRO Optimizer, Gradient, and Parameter Partitioning

Activation Checkpointing. Activations are intermediate states calculated during forward pass for each neural network node and occupy significant memory during training. We can trade off its computation for memory. The key insight is that activations are not required during forward pass after it is relayed to the next node and is only necessary during backward pass for gradient calculation. We can utilize this to free activation memory and re-compute it during backward pass. This approach reduces the overall training throughput due to extra computations and should only be used when there is a memory bottleneck in the system.

Since calculation of activation requires the activation from previous node and backward pass is run in the reverse order, it will be expensive to drop all activations since it would require us to run the entire forward pass again. In order to avoid this, the neural network is split into segments. The activations at segment boundaries (called checkpoints) are retained and all the intermediate ones within each segment are dropped. As shown in Figure 15, during backward pass, the dropped activations can be calculated at the segment level using the checkpoints. The segment size is carefully designed based on model characteristics. Further optimizations such as checkpointing based on cost of computation of each operation are used to find the optimal value.

Figure 15: Backward Pass with Activation Checkpointing

Memory Offloading. Apart from limited GPU memory, modern clusters have multiple layers of available memory with larger capacity such as CPU DRAMs and SSDs. This hierarchy of memory can be leveraged to offload the storage of model states and increase the memory capacity during data parallel training. We will look into this technique in detail later since it can be independently used to scale training.

Network

Data parallelism relies on a synchronization phase to share updates across model replicas. The communication between workers is sensitive to latency since training cannot proceed until synchronization is completed. Similarly, due to the increasing model size, there is also a demand for high bandwidth. Here are some optimizations to reduce the communication cost of data parallelism:

Overlap Communication and Compute. The most popular technique to amortize the cost of communication is to overlap it with computation. Backward pass can be overlapped with optimizer pass since gradients can be independently synchronized for the layers for which it has already been calculated. However, there are two performance concerns which should be kept in mind while using this technique:

  1. The overlap is dependent on the amount of compute and communication. Since the communication volume in all-reduce algorithm increases with model size, we need to increase the amount of compute proportionally. But this is tricky since the amount of computation is limited by batch size and as mentioned above, increasing the batch size has diminishing return after a certain threshold and can lead to slower convergence.
  2. Large models contain a lot of small parameters. Overlapping communication with compute by performing synchronization after the gradient is available for each layer will lead to communication of small messages during all-reduce which can make it latency bound.

Gradient bucketing is used to address these concerns. Instead of synchronizing gradients using all-reduce after each layer is processed, the gradients are bucketed and synchronized together. The size of the bucket is carefully designed based on the batch size and characteristics of the model parameters.

Synchronization Frequency. Another technique which is commonly used to reduce the communication overhead in data parallelism is to decrease the synchronization frequency. This means that instead of synchronizing gradients during each training step, they are accumulated over multiple steps and then synchronized together. This is essentially equivalent to training with a larger batch size. Since the increase in batch size has an upper bound in terms of convergence speed, the synchronization frequency is also limited to the time in which process aggregates gradients for samples equivalent to the batch size threshold.

Hybrid PS — All-reduce. Hybrid techniques using both parameter server and all-reduce based synchronization have been explored in the context of multi-node multi-GPU data parallel training. The goal is to utilize all-reduce for synchronization amongst workers connected using high-bandwidth interconnects while reducing the communication overhead of doing all-reduce across low-bandwidth connections by leveraging parameter servers which require less data transfer.

Herring [Thangakrishnan et al., 2020] by Amazon proposed using parameter server based data parallelism for aggregating gradients globally across nodes while utilizing reduce-scatter and all-reduce operation for gradient averaging locally amongst GPUs in a single node. Figure 16 shows how Herring designs the training strategy for a cluster with two 8-way GPU nodes where each GPU contains a replica of the model and performs training on a micro-batch of input samples. Note that parameter servers are only used for aggregating gradients across all nodes and do not store model parameters. Model parameters are replicated and stored on each GPU worker and once the globally averaged gradients are sent to the GPUs, they perform the optimizer pass and update the stored parameters.

Figure 16: Herring based Data Parallelism Training

By now you must have realized that even though the naive version of data parallelism is easy to implement and requires little change in the training loop, it can be carefully tuned to obtain high performance and resource efficiency. We illustrate the overall flow used to run a resource optimized all-reduce data parallelism using a few of the optimizations described above:

  1. Training model is loaded in N workers but only a partition of model parameters and optimizer state is initialized by each of them.
  2. For each epoch, training dataset is split into mini-batches of size B. A mini-batch of data is further split into N equal micro-batches and sent to the workers.
  3. Each worker runs the forward pass for every sample in the micro-batch in parallel. It extracts the input features from the sample, feeds them through the layers of nodes by calculating the intermediate values (a.k.a activations), and outputs the final result.
  4. Before running the forward pass for every M layer, each worker holding the parameters corresponding to the next M layers runs an all-gather to broadcast their parameters to all the other workers. Subsequently, each worker runs the forward pass by feeding their micro-batch through the M layers. In the end, each worker discards the activations for all layers except for the Mth layer.
  5. After feeding the inputs through all the layers, each worker calculates the loss between the actual and expected output based on the cost function.
  6. Each worker starts running the backward pass to calculate the gradient for each model parameter. Before running the backward pass for every M layer, the activations are re-calculated by running the forward pass (requires another pass of all-gather to share model parameters) from the Mth layer.
  7. After gradient calculation for every L layer, it performs a reduce-scatter operation to share the accumulated gradients to the workers which contain the optimizer states for the corresponding L layers and the remaining workers free up the gradient memory.
  8. Each worker which receives the aggregated gradients through the reduce-scatter operation runs the optimizer pass and updates the model parameters it holds. Backward pass finishes once the parameters for all the layers are updated.
  9. Training continues with the next batch of data. The total number of training steps is dependent on the number of epochs E and mini-batch size B.

There are multiple factors such as model size, network topology, and hardware capacity which play an important role in finding the ideal configuration for different training parameters used in the above algorithm: number of epochs E, batch size B, activation checkpoint segment size M, and gradient bucket size L.

Conclusion

I hope this provided you with an intuition into the complex problem of scaling neural network training and on how to approach it using the technique of data parallelism. In the next part, I will talk about the other popular technique to scale training called model parallelism. We will also take a look at recent work in designing hybrid parallelism techniques and utilizing hierarchical memory to scale training.

Thanks for reading and please reach out to me if you have any questions or would like to discuss any of this.

--

--

Shivam Bharuka

Computer Architect. Currently @facebookai. Computer Engineering @Illinois_Alma