colossalai.utils.checkpointing

colossalai.utils.checkpointing.get_checkpoint_path(checkpoint_dir, epoch, suffix='')[source]

This is a function to generate the checkpoint path from the tuple (checkpoint_dir, epoch, suffix, gpu_parallel_rank). This is useful during generation and recuperation of the checkpoint.

Parameters
  • checkpoint_dir (str) – Set up a directory for saving checkpoints.

  • epoch (int) – Epoch number (indicate how many epochs have you trained this model).

  • suffix (str, optional) – Additional notation to specify the model or checkpoint, defaults to ‘’

Returns

The checkpoint path to be generated.

Return type

str

colossalai.utils.checkpointing.get_latest_checkpoint_path(checkpoint_dir, suffix='')[source]

This is a function to retrieve the latest checkpoint path from the tuple (checkpoint_dir, suffix, gpu_parallel_rank). This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.

Parameters
  • checkpoint_dir (str) – Directory for saving checkpoints

  • suffix (str, optional) – Additional notation to specify the model or checkpoint, defaults to ‘’

Returns

The latest retrieved checkpoint path.

Return type

str

Raises

FileNotFoundError – Raise error when we cannot find the latest checkpoint file with inputs given.

colossalai.utils.checkpointing.get_latest_checkpoint_pattern(suffix='')[source]

Generate Regular expression of the latest checkpoint’s pattern.

Parameters

suffix (str, optional) – Additional notation to specify the model or checkpoint, defaults to ‘’.

Returns

The regular expression of checkpoint pattern.

Return type

str

colossalai.utils.checkpointing.save_checkpoint(checkpoint_path, epoch, model, optimizer, lr_scheduler=None, **kwargs)[source]

Given a directory to store the checkpoints, saves all the training components’ parameters or buffers, such as model, optimizer, lr_scheduler etc. into a checkpoint dictionary.

This method can be used for both colossalai.nn.BaseModel and normal torch.nn.Module.

Parameters
  • checkpoint_path (str) – Set up a directory for saving checkpoints.

  • epoch (int) – Epoch number (indicate how many epochs have you trained this model).

  • model (torch.nn.Module) – Model to be registered.

  • optimizer (Union[torch.optim.Optimizer, colossalai.nn.optimizer]) – Optimizer to be registered.

:param lr_scheduler (Union[torch.optim.lr_scheduler: colossalai.nn.lr_scheduler], optional): lr_scheduler to be registered, defaults to None. :param : colossalai.nn.lr_scheduler], optional): lr_scheduler to be registered, defaults to None. :param kwargs: additional parameters to be saved. :type kwargs: dict

colossalai.utils.checkpointing.load_checkpoint(checkpoint_path, model, optimizer, lr_scheduler=None, finetune=False, strict=True)[source]

Loads the checkpoint file.

If finetune is False, then we intend to continue/resume the training process from the checkpoint given. So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.

If finetune is True, then only the weights and buffers of model should be reloaded. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Args:

checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict. model (torch.nn.Module): Model to reload parameters and buffers. optimizer (Union[torch.optim.Optimizer, colossalai.nn.optimizer]): Optimizer to recuperate. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional):

lr_scheduler to recuperate, defaults to None.

finetune (bool, optional): Whether to finetune the model with new dataset or

continue the pre-training, defaults to False.

strict (bool, optional): Whether to strictly enforce that the keys in state_dict

of the checkpoint match the names of parameters and buffers in model, defaults to True.

Returns

The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved).

Return type

Tuple(int, checkpoint)

Raises

ValueError – Raise error if the model/optimizer cannot successfully be recuperated