Warm-up is a way to reduce the primacy effect for adaptive schedulers like Adam or AdamW of the early training examples. It allows them to compute the correct gradients from the beginning on. Without it, you may need to run a few extra epochs to get the convergence desired.
Using a too large learning rate may result in numerical instability especially at the very beginning of the training, where parameters are randomly initialized. The warmup strategy increases the learning rate from 0 to the initial learning rate linearly during the initial N epochs or m batches.
In some cases initializing the parameters is not sufficient to guarantee a good solution. This particularly is a problem for some advanced network designs that may lead to unstable optimization problems. We could address this by choosing a sufficiently small learning rate to prevent divergence in the beginning. Unfortunately, this means that progress is slow. Conversely, a large learning rate initially leads to divergence.
A rather simple fix for this dilemma is to use a warmup period during which the learning rate increases to its initial maximum and to cool down the rate until the end of the optimization process. Warmup steps are just a few updates with a low learning rate before/at the beginning of training. After this warmup, you use the regular learning rate (schedule) to train your model to convergence.
import torchfrom torch.optim.lr_scheduler import StepLR, ExponentialLRfrom torch.optim.sgd import SGDfrom warmup_scheduler import GradualWarmupSchedulerif __name__ == '__main__':model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]optim = SGD(model, 0.1)# scheduler_warmup is chained with schduler_steplrscheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)# this zero gradient update is needed to avoid a warning message, issue #8.optim.zero_grad()optim.step()for epoch in range(1, 20):scheduler_warmup.step(epoch)print(epoch, optim.param_groups['lr'])optim.step() # backward pass (update network)