colossalai.nn.optimizer.fused_sgd
- class colossalai.nn.optimizer.fused_sgd.FusedSGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, wd_after_momentum=False, materialize_master_grads=True, set_grad_none=False)
Implements stochastic gradient descent (optionally with momentum).
Currently GPU-only. Requires ColossalAI to be installed via
pip install -v --no-cache-dir --global-option="--cuda_ext" ./.This version of fused SGD implements 2 fusions.
Fusion of the SGD update’s elementwise operations
A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.
colossalai.nn.optimizer.FusedSGDmay be used as a drop-in replacement fortorch.optim.SGDcolossalai.nn.optimizer.FusedSGDmay be used with or without Amp.Nesterov momentum is based on the formula from On the importance of initialization and momentum in deep learning.
- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr (float) – learning rate
momentum (float, optional) – momentum factor (default: 0)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
dampening (float, optional) – dampening for momentum (default: 0)
nesterov (bool, optional) – enables Nesterov momentum (default: False)
Note
The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks. Considering the specific case of Momentum, the update can be written as
\[\begin{split}v = \rho * v + g \\ p = p - lr * v\end{split}\]where p, g, v and \(\rho\) denote the parameters, gradient, velocity, and momentum respectively. This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form
\[\begin{split}v = \rho * v + lr * g \\ p = p - v\end{split}\]The Nesterov version is analogously modified.
- step(closure=None)
Performs a single optimization step.
- Parameters
closure (callable, optional) – A closure that reevaluates the model and returns the loss.