colossalai.initialize
- colossalai.initialize.get_default_parser()
Reads user command line and uses an argument parser to parse the input arguments. Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
- Returns
Returns the parser with the default arguments, the user may add customized arguments into this parser
- Return type
Namespace
- colossalai.initialize.launch(config, rank, world_size, host, port, backend='nccl', local_rank=None, seed=1024, verbose=True)
This function first parses the configuration arguments, using
parse_args()in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context’s functions.- Parameters
config (Union[str, dict, Config]) – Config file or config file path are both acceptable
rank (int) – Rank for the default process group
world_size (int) – World size of the default process group
host (str) – The master address for distributed training
port (str) – The master port for distributed training
backend (str, optional) – Backend for torch.distributed
local_rank (int, optional) – Rank for the process on the node and is used to set the default CUDA device, defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
seed (int, optional) – Specified random seed for every processes
verbose (bool, optional) – Whether to print logs
- Raises
Exception – Raise exception when config type is wrong
- colossalai.initialize.launch_from_slurm(config, host, port, backend='nccl', seed=1024, verbose=True)
A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables set by SLURM
- Parameters
config (Union[str, dict, Config]) – Config file or config file path are both acceptable
host (str) – The master address for distributed training
port (str) – The master port for distributed training
backend (str, optional) – Backend for torch.distributed
seed (int, optional) – Specified random seed for every processes
verbose (bool, optional) – Whether to print logs
- colossalai.initialize.launch_from_openmpi(config, host, port, backend='nccl', seed=1024, verbose=True)
A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables set by OpenMPI
- Parameters
config (Union[str, dict, Config]) – Config file or config file path are both acceptable
host (str) – The master address for distributed training
port (str) – The master port for distributed training
backend (str, optional) – Backend for torch.distributed
seed (int, optional) – Specified random seed for every processes
verbose (bool, optional) – Whether to print logs
- colossalai.initialize.launch_from_torch(config, backend='nccl', seed=1024, verbose=True)
A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch
- Parameters
config (Union[str, dict, Config]) – Config file or config file path are both acceptable
backend (str, optional) – Backend for torch.distributed
seed (int, optional) – Specified random seed for every processes
verbose (bool, optional) – Whether to print logs
- colossalai.initialize.initialize(model, optimizer, criterion=None, train_dataloader=None, test_dataloader=None, lr_scheduler=None, ophooks=None, verbose=True)
Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
- Parameters
model (
torch.nn.Moduleor Callbale) – Your model instance or a function to build the modeloptimizer (
torch.optim.optimizer.OptimizerorType[torch.optim.optimizer]) – Your optimizer instancecriterion (
torch.nn.modules.loss._Loss, optional) – Your criterion instancetrain_dataloader (
torch.utils.data.DataLoader, optional) – Dataloader for trainingtest_dataloader (
torch.utils.data.DataLoader, optional) – Dataloader for testinglr_scheduler (
torch.nn.lr_scheduler._LRScheduler, optional) – Your lr scheduler instance, optionalverbose (bool, optional) – Whether to print logs
- Returns
(engine, train_dataloader, test_dataloader, lr_scheduler)
- Return type
Tuple