colossalai.trainer

class colossalai.trainer.Trainer(engine, schedule=None, timer=None, logger=None)[source]

This is a class tending for easy deployments of users’ training and evaluation instead of writing their own scripts. It is similar with ignite.engine and keras.engine, but is called Trainer.

Parameters
  • engine (Engine) – Engine responsible for the process function.

  • schedule (BaseSchedule, optional) – Schedule responsible for forward and backward steps.

  • timer (MultiTimer, optional) – Timer used to monitor the whole training.

  • logger (colossalai.logging.DistributedLogger, optional) – Logger used to record the whole training log.

Note

when schedule is None, the NonPipelineSchedule would be used. If you would like to use pipeline, you should choose PipelineSchedule or InterleavedPipelineSchedule for the schedule

Examples

>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
>>> model = ...
>>> criterion = ...
>>> optimizer = ...
>>> train_dataloader = ...
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> # Beginning training progress
>>> timier = ...
>>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier)
>>> # add hooks you would like to use here.
>>> hook_list = []
>>> trainer.fit(
>>>    train_dataloader=train_dataloader,
>>>    epochs=gpc.config.NUM_EPOCHS,
>>>    test_interval=1,
>>>    hooks=hook_list,
>>>    display_progress=True,
>>>    return_output_label=False
>>>    )

More examples and details could be found in Training with engine and trainer and ColossalAI-Examples.

property cur_epoch

Returns the index of the current epoch.

property cur_step

Returns how many iteration steps have been processed.

fit(train_dataloader, epochs, max_steps=None, test_dataloader=None, test_interval=1, hooks=None, display_progress=False, return_output_label=True)[source]

Trains the model to fit training data.

Parameters
  • train_dataloader (torch.utils.data.DataLoader) – DataLoader for training.

  • epochs (int) – Maximum number of epochs.

  • max_steps (int, optional) – Maximum number of running iterations.

  • test_dataloader (torch.utils.data.DataLoader, optional) – DataLoader for validation.

  • test_interval (int, optional) – Interval of validation

  • hooks (list[BaseHook], optional) – A list of hooks used in training.

  • display_progress (bool, optional) – If True, a progress bar will be displayed.

evaluate(test_dataloader, hooks=None, display_progress=False, return_output_label=True)[source]

Evaluates the model with testing data.

Parameters
  • test_dataloader (torch.utils.data.DataLoader, optional) – Dataloader for testing.

  • hooks (list, optional) – A list of hooks used in evaluation. Defaults to None.

  • display_progress (bool, optional) – If True, the evaluation progress will be printed. Defaults to False.

  • return_output_label (bool, optional) – If True, the output of model and the label will be returned. Defaults to True.

predict(data)[source]

Uses trained model to make a prediction for a tensor or a tensor list.

Parameters

data (Union[torch.tensor, List[torch.tensor]]) – Data as the input.

Returns

The output of model as the prediction

Return type

torch.tensor