colossalai.trainer
- class colossalai.trainer.Trainer(engine, schedule=None, timer=None, logger=None)
This a class tending for easy deployments of users’ training and evaluation instead of writing their own scripts. It is similar with
ignite.engineandkeras.engine, but is called Trainer.- Parameters
engine (
Engine) – Engine responsible for the process functionschedule (
BaseSchedule, optional) – Schedule responsible for forward and backward stepstimer (
MultiTimer, optional) – Timer used to monitor the whole traininglogger (
colossalai.logging.DistributedLogger, optional) – Logger used to record the whole training
- property cur_epoch
Returns the index of the current epoch.
- property cur_step
Returns how many iteration steps have been processed.
- fit(train_dataloader, epochs, max_steps=None, test_dataloader=None, test_interval=1, hooks=None, display_progress=False, return_output_label=True)
Trains the model to fit training data.
- Parameters
train_dataloader (DataLoader) – DataLoader in training
epochs (int) – Maximum number of epoches
max_steps (int, optional) – Maximum number of running iterations
test_dataloader (DataLoader, optional) – DataLoader in testing
test_interval (int, optional) – Interval of testing
hooks (list, optional) – A list of hooks used in training
display_progress (bool, optional) – If True, the training progress will be printed
return_output_label (bool, optional) – If True, the output of model and the label will be returned
- evaluate(test_dataloader, hooks=None, display_progress=False, return_output_label=True)
Evaluates the model with testing data.
- Parameters
test_dataloader (DataLoader) – DataLoader in testing
hooks (list, optional) – A list of hooks used in evaluation
display_progress (bool, optional) – If True, the evaluation progress will be printed
return_output_label (bool) – If True, the output of model and the label will be returned
- predict(data)
Uses trained model to make a prediction for a tensor or a tensor list.
- Parameters
data (Union[Tensor, List[Tensor]) – Data as the input
- Returns
The output of model as the prediction
- Return type
Tensor