Shortcuts

connectomics.engine

class connectomics.engine.Trainer(cfg, device, mode='train', rank=None, checkpoint=None)[source]

Trainer class for supervised learning.

Parameters
  • cfg (yacs.config.CfgNode) – YACS configuration options.

  • device (torch.device) – model running device. GPUs are recommended for model training and inference.

  • mode (str) – running mode of the trainer ('train' or 'test'). Default: 'train'

  • rank (int, optional) – node rank for distributed training. Default: None

  • checkpoint (str, optional) – the checkpoint file to be loaded. Default: None

run_chunk(mode)[source]

Run chunk-based training and inference for large-scale datasets.

Parameters

mode (str) –

save_checkpoint(iteration, is_best=False)[source]

Save the model checkpoint.

Parameters
  • iteration (int) –

  • is_best (bool) –

test()[source]

Inference function of the trainer class.

train()[source]

Training function of the trainer class.

update_checkpoint(checkpoint=None)[source]

Update the model with the specified checkpoint file path.

Parameters

checkpoint (Optional[str]) –

validate(iter_total)[source]

Validation function of the trainer class.