from __future__ import print_function, division
from typing import Optional
import warnings
import os
import time
import math
import GPUtil
import numpy as np
from yacs.config import CfgNode
import torch
from torch.cuda.amp import autocast, GradScaler
from .base import TrainerBase
from .solver import *
from ..model import *
from ..utils.monitor import build_monitor
from ..data.augmentation import build_train_augmentor, TestAugmentor
from ..data.dataset import build_dataloader, get_dataset
from ..data.dataset.build import _get_file_list
from ..data.utils import build_blending_matrix, writeh5
from ..data.utils import get_padsize, array_unpad
[docs]class Trainer(TrainerBase):
r"""Trainer class for supervised learning.
Args:
cfg (yacs.config.CfgNode): YACS configuration options.
device (torch.device): model running device. GPUs are recommended for model training and inference.
mode (str): running mode of the trainer (``'train'`` or ``'test'``). Default: ``'train'``
rank (int, optional): node rank for distributed training. Default: `None`
checkpoint (str, optional): the checkpoint file to be loaded. Default: `None`
"""
def __init__(self,
cfg: CfgNode,
device: torch.device,
mode: str = 'train',
rank: Optional[int] = None,
checkpoint: Optional[str] = None):
self.init_basics(cfg, device, mode, rank)
self.model = build_model(self.cfg, self.device, rank)
if self.mode == 'train':
self.optimizer = build_optimizer(self.cfg, self.model)
self.lr_scheduler = build_lr_scheduler(self.cfg, self.optimizer)
self.scaler = GradScaler() if cfg.MODEL.MIXED_PRECESION else None
self.start_iter = self.cfg.MODEL.PRE_MODEL_ITER
self.update_checkpoint(checkpoint)
# stochastic weight averaging
if self.cfg.SOLVER.SWA.ENABLED:
self.swa_model, self.swa_scheduler = build_swa_model(
self.cfg, self.model, self.optimizer)
self.augmentor = build_train_augmentor(self.cfg)
self.criterion = Criterion.build_from_cfg(self.cfg, self.device)
if self.is_main_process:
self.monitor = build_monitor(self.cfg)
self.monitor.load_info(self.cfg, self.model)
self.total_iter_nums = self.cfg.SOLVER.ITERATION_TOTAL - self.start_iter
self.total_time = 0
else:
self.update_checkpoint(checkpoint)
# build test-time augmentor and update output filename
self.augmentor = TestAugmentor.build_from_cfg(cfg, activation=True)
if not self.cfg.DATASET.DO_CHUNK_TITLE and not self.inference_singly:
self.test_filename = self.cfg.INFERENCE.OUTPUT_NAME
self.test_filename = self.augmentor.update_name(self.test_filename)
self.dataset, self.dataloader = None, None
if not self.cfg.DATASET.DO_CHUNK_TITLE and not self.inference_singly:
self.dataloader = build_dataloader(
self.cfg, self.augmentor, self.mode, rank=rank)
self.dataloader = iter(self.dataloader)
if self.mode == 'train' and cfg.DATASET.VAL_IMAGE_NAME is not None:
self.val_loader = build_dataloader(
self.cfg, None, mode='val', rank=rank)
def init_basics(self, *args):
# This function is used for classes that inherit Trainer but only
# need to initialize basic attributes in TrainerBase.
super().__init__(*args)
[docs] def train(self):
r"""Training function of the trainer class.
"""
self.model.train()
for i in range(self.total_iter_nums):
iter_total = self.start_iter + i
self.start_time = time.perf_counter()
self.optimizer.zero_grad()
# load data
sample = next(self.dataloader)
volume = sample.out_input
target, weight = sample.out_target_l, sample.out_weight_l
self.data_time = time.perf_counter() - self.start_time
# prediction
volume = volume.to(self.device, non_blocking=True)
with autocast(enabled=self.cfg.MODEL.MIXED_PRECESION):
pred = self.model(volume)
loss, losses_vis = self.criterion(pred, target, weight)
self._train_misc(loss, pred, volume, target, weight,
iter_total, losses_vis)
self.maybe_save_swa_model()
def _train_misc(self, loss, pred, volume, target, weight,
iter_total, losses_vis):
self.backward_pass(loss) # backward pass
# logging and update record
if hasattr(self, 'monitor'):
do_vis = self.monitor.update(iter_total, loss, losses_vis,
self.optimizer.param_groups[0]['lr'])
if do_vis:
self.monitor.visualize(
volume, target, pred, weight, iter_total)
if torch.cuda.is_available():
GPUtil.showUtilization(all=True)
# Save model
if (iter_total+1) % self.cfg.SOLVER.ITERATION_SAVE == 0:
self.save_checkpoint(iter_total)
if (iter_total+1) % self.cfg.SOLVER.ITERATION_VAL == 0:
self.validate(iter_total)
# update learning rate
self.maybe_update_swa_model(iter_total)
self.scheduler_step(iter_total, loss)
if self.is_main_process:
self.iter_time = time.perf_counter() - self.start_time
self.total_time += self.iter_time
avg_iter_time = self.total_time / (iter_total+1-self.start_iter)
est_time_left = avg_iter_time * \
(self.total_iter_nums+self.start_iter-iter_total-1) / 3600.0
info = [
'[Iteration %05d]' % iter_total, 'Data time: %.4fs,' % self.data_time,
'Iter time: %.4fs,' % self.iter_time, 'Avg iter time: %.4fs,' % avg_iter_time,
'Time Left %.2fh.' % est_time_left]
print(' '.join(info))
# Release some GPU memory and ensure same GPU usage in the consecutive iterations according to
# https://discuss.pytorch.org/t/gpu-memory-consumption-increases-while-training/2770
del volume, target, pred, weight, loss, losses_vis
[docs] def validate(self, iter_total):
r"""Validation function of the trainer class.
"""
if not hasattr(self, 'val_loader'):
return
self.model.eval()
with torch.no_grad():
val_loss = 0.0
for i, sample in enumerate(self.val_loader):
volume = sample.out_input
target, weight = sample.out_target_l, sample.out_weight_l
# prediction
volume = volume.to(self.device, non_blocking=True)
with autocast(enabled=self.cfg.MODEL.MIXED_PRECESION):
pred = self.model(volume)
loss, _ = self.criterion(pred, target, weight)
val_loss += loss.data
if hasattr(self, 'monitor'):
self.monitor.logger.log_tb.add_scalar(
'Validation_Loss', val_loss, iter_total)
self.monitor.visualize(volume, target, pred,
weight, iter_total, suffix='Val')
if not hasattr(self, 'best_val_loss'):
self.best_val_loss = val_loss
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint(iter_total, is_best=True)
# Release some GPU memory and ensure same GPU usage in the consecutive iterations according to
# https://discuss.pytorch.org/t/gpu-memory-consumption-increases-while-training/2770
del pred, loss, val_loss
# model.train() only called at the beginning of Trainer.train().
self.model.train()
[docs] def test(self):
r"""Inference function of the trainer class.
"""
# with batchnorm, train mode use the current batch statistics
self.model.eval() if self.cfg.INFERENCE.DO_EVAL else self.model.train()
output_scale = self.cfg.INFERENCE.OUTPUT_SCALE
spatial_size = list(np.ceil(
np.array(self.cfg.MODEL.OUTPUT_SIZE) *
np.array(output_scale)).astype(int))
channel_size = self.cfg.MODEL.OUT_PLANES
sz = tuple([channel_size] + spatial_size)
ww = build_blending_matrix(spatial_size, self.cfg.INFERENCE.BLENDING)
output_size = [tuple(np.ceil(np.array(x) * np.array(output_scale)).astype(int))
for x in self.dataloader._dataset.volume_size]
result = [np.stack([np.zeros(x, dtype=np.float32)
for _ in range(channel_size)]) for x in output_size]
weight = [np.zeros(x, dtype=np.float32) for x in output_size]
print("Total number of batches: ", len(self.dataloader))
start = time.perf_counter()
with torch.no_grad():
for i, sample in enumerate(self.dataloader):
print('progress: %d/%d batches, total time %.2fs' %
(i+1, len(self.dataloader), time.perf_counter()-start))
pos, volume = sample.pos, sample.out_input
volume = volume.to(self.device, non_blocking=True)
output = self.augmentor(self.model, volume)
if torch.cuda.is_available() and i % 50 == 0:
GPUtil.showUtilization(all=True)
for idx in range(output.shape[0]):
st = pos[idx]
st = (np.array(st) *
np.array([1]+output_scale)).astype(int).tolist()
out_block = output[idx]
if result[st[0]].ndim - out_block.ndim == 1: # 2d model
out_block = out_block[:, np.newaxis, :]
result[st[0]][:, st[1]:st[1]+sz[1], st[2]:st[2]+sz[2],
st[3]:st[3]+sz[3]] += out_block * ww[np.newaxis, :]
weight[st[0]][st[1]:st[1]+sz[1], st[2]:st[2]+sz[2],
st[3]:st[3]+sz[3]] += ww
end = time.perf_counter()
print("Prediction time: %.2fs" % (end-start))
for vol_id in range(len(result)):
if result[vol_id].ndim > weight[vol_id].ndim:
weight[vol_id] = np.expand_dims(weight[vol_id], axis=0)
result[vol_id] /= weight[vol_id] # in-place to save memory
result[vol_id] *= 255
result[vol_id] = result[vol_id].astype(np.uint8)
if self.cfg.INFERENCE.UNPAD:
pad_size = (np.array(self.cfg.DATASET.PAD_SIZE) *
np.array(output_scale)).astype(int).tolist()
if self.cfg.DATASET.DO_CHUNK_TITLE != 0:
# In chunk-based inference using TileDataset, padding is applied
# before resizing, while in normal inference using VolumeDataset,
# padding is after resizing. Thus we adjust pad_size accordingly.
pad_size = (np.array(self.cfg.DATASET.DATA_SCALE) *
np.array(pad_size)).astype(int).tolist()
pad_size = get_padsize(pad_size)
result[vol_id] = array_unpad(result[vol_id], pad_size)
if self.output_dir is None:
return result
else:
print('Final prediction shapes are:')
for k in range(len(result)):
print(result[k].shape)
save_path = os.path.join(self.output_dir, self.test_filename)
writeh5(save_path, result, ['vol%d' % (x) for x in range(len(result))])
print('Prediction saved as: ', save_path)
def test_singly(self):
dir_name = _get_file_list(self.cfg.DATASET.INPUT_PATH)
img_name = _get_file_list(self.cfg.DATASET.IMAGE_NAME, prefix=dir_name[0])
assert len(dir_name) == 1 # avoid ambiguity when DO_SINGLY is True
# save input image names for further reference
fw = open(os.path.join(self.output_dir, "images.txt"), "w")
fw.write('\n'.join(img_name))
fw.close()
num_file = len(img_name)
start_idx = self.cfg.INFERENCE.DO_SINGLY_START_INDEX
for i in range(start_idx, num_file):
dataset = get_dataset(
self.cfg, self.augmentor, self.mode, self.rank,
dir_name_init=dir_name, img_name_init=[img_name[i]])
self.dataloader = build_dataloader(
self.cfg, self.augmentor, self.mode, dataset, self.rank)
self.dataloader = iter(self.dataloader)
digits = int(math.log10(num_file))+1
self.test_filename = self.cfg.INFERENCE.OUTPUT_NAME + \
'_' + str(i).zfill(digits) + '.h5'
self.test_filename = self.augmentor.update_name(
self.test_filename)
self.test()
# -----------------------------------------------------------------------------
# Misc functions
# -----------------------------------------------------------------------------
def backward_pass(self, loss):
if self.cfg.MODEL.MIXED_PRECESION:
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
self.scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
self.scaler.step(self.optimizer)
# Updates the scale for next iteration.
self.scaler.update()
else: # standard backward pass
loss.backward()
self.optimizer.step()
[docs] def save_checkpoint(self, iteration: int, is_best: bool = False):
r"""Save the model checkpoint.
"""
if self.is_main_process:
print("Save model checkpoint at iteration ", iteration)
state = {'iteration': iteration + 1,
# Saving DataParallel or DistributedDataParallel models
'state_dict': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'lr_scheduler': self.lr_scheduler.state_dict()}
# Saves checkpoint to experiment directory
filename = 'checkpoint_%05d.pth.tar' % (iteration + 1)
if is_best:
filename = 'checkpoint_best.pth.tar'
filename = os.path.join(self.output_dir, filename)
torch.save(state, filename)
[docs] def update_checkpoint(self, checkpoint: Optional[str] = None):
r"""Update the model with the specified checkpoint file path.
"""
if checkpoint is None:
if self.mode == 'test':
warnings.warn("Test mode without specified checkpoint!")
return # nothing to load
# load pre-trained model
print('Load pretrained checkpoint: ', checkpoint)
checkpoint = torch.load(checkpoint, map_location=self.device)
print('checkpoints: ', checkpoint.keys())
# update model weights
if 'state_dict' in checkpoint.keys():
pretrained_dict = checkpoint['state_dict']
pretrained_dict = update_state_dict(
self.cfg, pretrained_dict, mode=self.mode)
model_dict = self.model.module.state_dict() # nn.DataParallel
# show model keys that do not match pretrained_dict
if not model_dict.keys() == pretrained_dict.keys():
warnings.warn("Module keys in model.state_dict() do not exactly "
"match the keys in pretrained_dict!")
for key in model_dict.keys():
if not key in pretrained_dict:
print(key)
# 1. filter out unnecessary keys by name
pretrained_dict = {k: v for k,
v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict (if size match)
for param_tensor in pretrained_dict:
if model_dict[param_tensor].size() == pretrained_dict[param_tensor].size():
model_dict[param_tensor] = pretrained_dict[param_tensor]
# 3. load the new state dict
self.model.module.load_state_dict(model_dict) # nn.DataParallel
if self.mode == 'train' and not self.cfg.SOLVER.ITERATION_RESTART:
if hasattr(self, 'optimizer') and 'optimizer' in checkpoint.keys():
self.optimizer.load_state_dict(checkpoint['optimizer'])
if hasattr(self, 'lr_scheduler') and 'lr_scheduler' in checkpoint.keys():
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if hasattr(self, 'start_iter') and 'iteration' in checkpoint.keys():
self.start_iter = checkpoint['iteration']
def maybe_save_swa_model(self):
if not hasattr(self, 'swa_model'):
return
if self.cfg.MODEL.NORM_MODE in ['bn', 'sync_bn']: # update bn statistics
for _ in range(self.cfg.SOLVER.SWA.BN_UPDATE_ITER):
sample = next(self.dataloader)
volume = sample.out_input
volume = volume.to(self.device, non_blocking=True)
with autocast(enabled=self.cfg.MODEL.MIXED_PRECESION):
pred = self.swa_model(volume)
# save swa model
if self.is_main_process:
print("Save SWA model checkpoint.")
state = {'state_dict': self.swa_model.module.state_dict()}
filename = 'checkpoint_swa.pth.tar'
filename = os.path.join(self.output_dir, filename)
torch.save(state, filename)
def maybe_update_swa_model(self, iter_total):
if not hasattr(self, 'swa_model'):
return
swa_start = self.cfg.SOLVER.SWA.START_ITER
swa_merge = self.cfg.SOLVER.SWA.MERGE_ITER
if iter_total >= swa_start and iter_total % swa_merge == 0:
self.swa_model.update_parameters(self.model)
def scheduler_step(self, iter_total, loss):
if hasattr(self, 'swa_scheduler') and iter_total >= self.cfg.SOLVER.SWA.START_ITER:
self.swa_scheduler.step()
return
if self.cfg.SOLVER.LR_SCHEDULER_NAME == 'ReduceLROnPlateau':
self.lr_scheduler.step(loss)
else:
self.lr_scheduler.step()
# -----------------------------------------------------------------------------
# Chunk processing for TileDataset
# -----------------------------------------------------------------------------
[docs] def run_chunk(self, mode: str):
r"""Run chunk-based training and inference for large-scale datasets.
"""
self.dataset = get_dataset(self.cfg, self.augmentor, mode)
if mode == 'train':
num_chunk = self.total_iter_nums // self.cfg.DATASET.DATA_CHUNK_ITER
self.total_iter_nums = self.cfg.DATASET.DATA_CHUNK_ITER
for chunk in range(num_chunk):
self.dataset.updatechunk()
self.dataloader = build_dataloader(self.cfg, self.augmentor, mode,
dataset=self.dataset.dataset)
self.dataloader = iter(self.dataloader)
print('start train for chunk %d' % chunk)
self.train()
print('finished train for chunk %d' % chunk)
self.start_iter += self.cfg.DATASET.DATA_CHUNK_ITER
del self.dataloader
return
# inference mode
num_chunk = len(self.dataset.chunk_ind)
print("Total number of chunks: ", num_chunk)
for chunk in range(num_chunk):
self.dataset.updatechunk(do_load=False)
self.test_filename = self.cfg.INFERENCE.OUTPUT_NAME + \
'_' + self.dataset.get_coord_name() + '.h5'
self.test_filename = self.augmentor.update_name(
self.test_filename)
if not os.path.exists(os.path.join(self.output_dir, self.test_filename)):
self.dataset.loadchunk()
self.dataloader = build_dataloader(self.cfg, self.augmentor, mode,
dataset=self.dataset.dataset)
self.dataloader = iter(self.dataloader)
self.test()