colossalai.amp.naive_amp
- colossalai.amp.naive_amp.convert_to_naive_amp(model, optimizer, amp_config)[source]
A helper function to wrap training components with naive AMP modules. In this mode, we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss, which is equivalent to Apex O3.
- Parameters
model (
torch.nn.Module) – your model objectoptimizer (
torch.optim.Optimizer) – your optimizer objectamp_config (
colossalai.context.Configor dict) – configuration for naive mode amp.
The
amp_configshould contain parameters below:verbose (bool, optional): if set to `True`, will print debug info (Default: False). clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0). Note that clipping is ignored if clip_grad == 0. dynamic_grad_scale (bool): whether to use dynamic grad scaler.- Returns
A tuple (model, optimizer)
- Return type
Tuples
- class colossalai.amp.naive_amp.NaiveAMPOptimizer(optim, *args, **kwargs)[source]
A wrapper class for optimizer to cast all parameters to fp16
- Parameters
optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD.
grad_scaler (BaseGradScaler) – grad scaler for gradient chose in
constant_grad_scalerordynamic_grad_scaler.clip_grad_norm (float, optional) – clip gradients with this global L2 norm. Default 0.
verbose (bool, optional) – if set to True, will print debug info. Default False.
Note
clipping is ignored if
clip_grad_normequals 0.
- class colossalai.amp.naive_amp.FP16Optimizer(optimizer, grad_scaler, verbose=False, clip_grad_norm=0, dp_process_group=None, mp_process_group=None)[source]
Float16 optimizer for fp16 and bf16 data types.
- Parameters
optimizer (torch.optim.Optimizer) – base optimizer such as Adam or SGD
grad_scaler (BaseGradScaler) – grad scaler for gradient chose in
constant_grad_scalerordynamic_grad_scaler.clip_grad_norm (float, optional) – clip gradients with this global L2 norm. Default 0. Note that clipping is ignored if clip_grad == 0
verbose (bool, optional) – if set to True, will print debug info. Default False.