Training large models with your GPU

Yuanhao
4 min readJan 6, 2023

--

In the last post, I shared my story of the Kaggle Jigsaw Multilingual Toxic Comment Classification competition. At that time, I only had a 1080Ti with 11G VRAM, and this made it impossible for me to train the SOTA Roberta-XLM large model which requires a larger VRAM than what I had. In this post, I want to share some tips about how to reduce VRAM usage so that you can train larger deep neural networks with your GPU.

The first tip is to use mixed-precision training. When training a model, generally, all parameters will be stored in the VRAM. It is straightforward that the total VRAM usage equals the stored parameter number times the VRAM usage of a single parameter. A bigger model not only means better performance but also more VRAM usage. Since performance is quite important, like in Kaggle competitions, we do not want to reduce the model capacity. The only way to reduce memory usage is to reduce the memory usage of every single variable. By default, a 32-bit floating point is used, so that one variable consumes 4 bytes. Fortunately, people found that we can use 16-bit floating points without much accuracy lost. This means we can reduce memory consumption by half! In addition, using lower precision also improves training speed, especially on GPUs with Tensor Core support.

After version 1.5, PyTorch started to support automatic mixed precision training. The framework can identify modules that require full precision and use 32-bit floating for them and use 16-bit floating for others. Below is a sample code from the official Pytorch documentation.

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()

The second tip is to use gradient accumulation. The idea behind gradient accumulation is simple: perform several backward passes with the same model parameters before the optimizer updates the parameters. The gradients calculated in each backward pass are accumulated(summed). If your actual batch size is N and you accumulate gradients of M steps, your equivalent batch size is N*M. However, the training result will not be strictly equal, for some parameters, such as batch normalization statistics, can not be perfectly accumulated.

There are some things to note about gradient accumulation: When you use gradient accumulation with mixed-precision training, The scale should be calibrated for the effective batch and scale updates should occur at effective batch granularity. When you use gradient accumulation with distributed data-parallel(DDP) training, use no_sync() context manager to disable gradient all-reduce for the first M-1 steps.

The last but not least tip is to use gradient checkpointing. The basic idea of gradient checkpointing is only to save intermediate results of some nodes as checkpoints and re-compute other parts in between those nodes during backpropagation. According to the authors of gradient checkpointing, with the help of gradient checkpointing, they can fit 10X large models onto their GPUs at only a 20% increase in computation time. This is officially supported by Pytorch since version 0.4.0, and some very commonly used libraries such as Huggingface Transformers also support this feature.

Below is an example of how to turn on gradient checkpointing with a transformer model:

bert = AutoModel.from_pretrained(pretrained_model_name)
bert.config.gradient_checkpointing=True

At the end of this post, I want to share a simple benchmark test I made on the HP Z4 workstation. As I mentioned before, the workstation is equipped with 2 24G VRAM RTX6000 GPUs, while in the experiments I only used one GPU. I trained XLM-Roberta Base/Large with different configurations, and the training set of the Kaggle Jigsaw Multilingual Toxic Comment Classification competition was used.

As we can see, mixed-precision training not only reduces memory consumption, but also provides a significant speedup. The gradient checkpointing is also very powerful. It reduces the VRAM usage from 23.5G to 11.8G! I wish I could know this technology earlier so that I can even train the XLM-Roberta Large model with my previous 11G 2080Ti GPU :)

--

--

No responses yet