colossalai.amp.naive_amp.naive_amp
- class colossalai.amp.naive_amp.naive_amp.NaiveAMPOptimizer(optim, *args, **kwargs)
A wrapper class for optimizer to cast all parameters to fp16
- Parameters
optim (torch.optim.Optimizer) – A normal optimizer like Adam or SGD
args – Args used to initialize FP16 optimizer
kwargs – Kwargs used to initialize FP16 optimizer
- class colossalai.amp.naive_amp.naive_amp.NaiveAMPModel(model, output_to_fp32=True, parallel_mode=ParallelMode.DATA, sync_buffer=True)
A wrapper class for model to cast the model into fp16 and automatically cast the input and output