Shortcuts

Source code for connectomics.data.augmentation.augmentor

from __future__ import print_function, division
from abc import ABCMeta, abstractmethod
from typing import Optional
import numpy as np

[docs]class DataAugment(object, metaclass=ABCMeta): r""" DataAugment interface. A data augmentor needs to conduct the following steps: 1. Set :attr:`sample_params` at initialization to compute required sample size. 2. Randomly generate augmentation parameters for the current transform. 3. Apply the transform to a pair of images and corresponding labels. All the real data augmentations (except mix-up augmentor and test-time augmentor) should be a subclass of this class. Args: p (float): probability of applying the augmentation. Default: 0.5 additional_targets(dict, optional): additional targets to augment. Default: None """ def __init__(self, p: float = 0.5, additional_targets: Optional[dict] = None, skip_targets: list = []): super().__init__() assert p >= 0.0 and p <=1.0 self.p = p self.sample_params = { 'ratio': np.array([1.0, 1.0, 1.0]), 'add': np.array([0, 0, 0])} if additional_targets is not None: self.additional_targets = additional_targets else: # initialize as an empty dictionary self.additional_targets = {} self.skip_targets = skip_targets
[docs] @abstractmethod def set_params(self): r""" Calculate the appropriate sample size with data augmentation. Some data augmentations (wrap, misalignment, etc.) require a larger sample size than the original, depending on the augmentation parameters that are randomly chosen. This function takes the data augmentation parameters and returns an updated data sampling size accordingly. """ raise NotImplementedError
@abstractmethod def __call__(self, sample, random_state=None): r""" Apply the data augmentation. For a multi-CPU dataloader, one may need to use a unique index to generate the random seed (:attr:`random_state`), otherwise different workers may generate the same pseudo-random number for augmentation and sampling. The only required key in :attr:`sample` is ``'image'``. The keys that are not specified in :attr:`additional_targets` will be ignored. """ raise NotImplementedError

© Copyright 2019-2023, PyTorch Connectomics Contributors. Revision 9f6fcf8b.

Built with Sphinx using a theme provided by Read the Docs.