colossalai.trainer
- class colossalai.trainer.Trainer(engine, schedule=None, timer=None, logger=None)[source]
This is a class tending for easy deployments of users’ training and evaluation instead of writing their own scripts. It is similar with
ignite.engineandkeras.engine, but is called Trainer.- Parameters
engine (
Engine) – Engine responsible for the process function.schedule (
BaseSchedule, optional) – Schedule responsible for forward and backward steps.timer (
MultiTimer, optional) – Timer used to monitor the whole training.logger (
colossalai.logging.DistributedLogger, optional) – Logger used to record the whole training log.
Note
when schedule is None, the
NonPipelineSchedulewould be used. If you would like to use pipeline, you should choosePipelineScheduleorInterleavedPipelineSchedulefor the scheduleExamples
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training >>> model = ... >>> criterion = ... >>> optimizer = ... >>> train_dataloader = ... >>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler >>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion) >>> # Beginning training progress >>> timier = ... >>> logger = ... >>> trainer = Trainer(engine=engine, logger=logger, schedule=schedule, timer=timier) >>> # add hooks you would like to use here. >>> hook_list = [] >>> trainer.fit( >>> train_dataloader=train_dataloader, >>> epochs=gpc.config.NUM_EPOCHS, >>> test_interval=1, >>> hooks=hook_list, >>> display_progress=True, >>> return_output_label=False >>> )
More examples and details could be found in Training with engine and trainer and ColossalAI-Examples.
- property cur_epoch
Returns the index of the current epoch.
- property cur_step
Returns how many iteration steps have been processed.
- fit(train_dataloader, epochs, max_steps=None, test_dataloader=None, test_interval=1, hooks=None, display_progress=False, return_output_label=True)[source]
Trains the model to fit training data.
- Parameters
train_dataloader (
torch.utils.data.DataLoader) – DataLoader for training.epochs (int) – Maximum number of epochs.
max_steps (int, optional) – Maximum number of running iterations.
test_dataloader (
torch.utils.data.DataLoader, optional) – DataLoader for validation.test_interval (int, optional) – Interval of validation
hooks (list[BaseHook], optional) – A list of hooks used in training.
display_progress (bool, optional) – If True, a progress bar will be displayed.
- evaluate(test_dataloader, hooks=None, display_progress=False, return_output_label=True)[source]
Evaluates the model with testing data.
- Parameters
test_dataloader (
torch.utils.data.DataLoader, optional) – Dataloader for testing.hooks (list, optional) – A list of hooks used in evaluation. Defaults to None.
display_progress (bool, optional) – If True, the evaluation progress will be printed. Defaults to False.
return_output_label (bool, optional) – If True, the output of model and the label will be returned. Defaults to True.