Source code for pyblaze.optim.lamb

import math
import torch
import torch.optim as optim

[docs]class LAMB(optim.Optimizer): """ Optimizer presented in "Large Batch Optimization for Deep Learning: Training Bert in 76 Minutes" (You et al., 2019). The LAMB optimizer ("Layer-wise Adaptive Moments optimizer for Batch training") enables training on very large batches and provides an alternative for Adam whose performance deteriorates for large batches. """
[docs] def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, lr_decay=0): """ Initializes a new LAMB optimizer. Parameters ---------- params: iterable of torch.Tensor or dict of str -> torch.Tensor The parameters to optimize, optionally overriding default values for parameter groups. lr: float, default: 1e-3 The learning rate. betas: tuple of (float, float), default: (0.9, 0.999) The betas used to compute the running average of gradients. eps: float, default: 1e-8 Epsilon parameter for numerical stability. weight_decay: float, default: 0 L2 penalty to apply. lr_decay: float, default: 0 Learning rate decay over each update. """ assert lr > 0, "Learning rate must be greater 0." assert betas[0] > 0 and betas[1] > 0, "Beta values must be greater 0." assert betas[0] < 1 and betas[1] < 1, "Beta values must be smaller 1." assert eps > 0, "Epsilon must be gerater 0." defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, lr_decay=lr_decay ) super().__init__(params, defaults)
[docs] def step(self, closure=None): """ Performs a single optimization step. Parameters ---------- closure: callable, default: None A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError( "LAMB does not support sparse gradients. " + "Consider using SparseAdam instead." ) state = self.state[p] # Initialize state if needed if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p.data) state['exp_avg_sq'] = torch.zeros_like(p.data) # Get relevant parameters exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] # Next step state['step'] += 1 # Update learning rate if needed if group['lr_decay'] > 0: decay = group['lr_decay'] lr = group['lr'] * (1 / 1 + decay * state['step']) else: lr = group['lr'] # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] step_size = lr * math.sqrt(bias_correction2) / bias_correction1 adam_step = exp_avg / denom # L2 penalty if group['weight_decay'] != 0: adam_step.add_(group['weight_decay'], p.data) # Compute trust ratio and cap at 10 r1 = p.data.norm() r2 = adam_step.norm() trust_ratio = min(r1 / r2, 10) if r1 != 0 and r2 != 0 else 1 # Update weights p.data.add_(-step_size * trust_ratio, adam_step) return loss