colossalai.zero.init_ctx
- class colossalai.zero.init_ctx.ZeroInitContext(convert_fp16, target_device, shard_strategy, shard_param=False, shard_grad=False, rm_torch_payload_on_the_fly=False, model_numel_tensor=tensor([0], dtype=torch.int32), dp_process_group=None)
A context to initialize model.
Convert the model to fp16.
The paramaters of the module are adapted to type ShardedParameter.
Shard the param and grad according to flags.
- Parameters
convert_fp16 (bool) – Whether to convert params to fp16.
target_device (torch.device) – The device where param data after exiting the context.
shard_strategy (BaseShardStrategy) – Shard strategy instance.
shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.
shard_grad (bool, optional) – Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional) – If set to True, remove tensor payload on param.data after module init finished. This will reduce memory usage when initializing model. But it’s not suitable for all models, especially when there are weight init operations in __init__. If set to False, remove tensor payload on param.data afther the context exist. This is used when you add some logic to operate tensors in __init__ of module. See torchvision resnet18. Defaults to False.
model_numel_tensor (torch.Tensor, optional) – A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
dp_process_group (Optional[ProcessGroup], optional) – Data parallel process group. Defaults to None.