Shortcuts

Source code for connectomics.data.augmentation.warp

from __future__ import print_function, division
from typing import Optional

import cv2
import numpy as np
from scipy.ndimage.filters import gaussian_filter

from .augmentor import DataAugment

[docs]class Elastic(DataAugment): r"""Elastic deformation of images as described in [Simard2003]_ (with modifications). The implementation is based on https://gist.github.com/erniejunior/601cdf56d2b424757de5. This augmentation is applied to both images and masks. .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for Convolutional Neural Networks applied to Visual Document Analysis", in Proc. of the International Conference on Document Analysis and Recognition, 2003. Args: alpha (float): maximum pixel-moving distance of elastic deformation. Default: 10.0 sigma (float): standard deviation of the Gaussian filter. Default: 4.0 p (float): probability of applying the augmentation. Default: 0.5 additional_targets(dict, optional): additional targets to augment. Default: None """ interpolation = {'img': cv2.INTER_LINEAR, 'mask': cv2.INTER_NEAREST} border_mode = cv2.BORDER_CONSTANT def __init__(self, alpha: float = 16.0, sigma: float = 4.0, p: float = 0.5, additional_targets: Optional[dict] = None, skip_targets: list = []): super(Elastic, self).__init__(p, additional_targets, skip_targets) self.alpha = alpha self.sigma = sigma self.set_params()
[docs] def set_params(self): r"""The rescale augmentation is only applied to the `xy`-plane. The required sample size before transformation need to be larger as decided by the maximum pixel-moving distance (:attr:`self.alpha`). """ max_margin = int(self.alpha) + 1 self.sample_params['add'] = [0, max_margin, max_margin]
def elastic_wrap(self, images, mapx, mapy, target_type='img'): transformed_images = [] assert images.ndim in [3, 4] for i in range(images.shape[-3]): if images.ndim == 3: transformed_images.append(cv2.remap(images[i], mapx, mapy, self.interpolation[target_type], borderMode=self.border_mode)) else: # multi-channel images in (c,z,y,x) format temp = [cv2.remap(images[channel, i], mapx, mapy, self.interpolation[target_type], borderMode=self.border_mode) for channel in range(images.shape[0])] transformed_images.append(np.stack(temp, 0)) axis = 0 if images.ndim == 3 else 1 transformed_images = np.stack(transformed_images, axis) return transformed_images def get_random_params(self, images, random_state): height, width = images.shape[-2:] # (c, z, y, x) or (z, y, x) dx = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), self.sigma) * self.alpha) dy = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), self.sigma) * self.alpha) x, y = np.meshgrid(np.arange(width), np.arange(height)) mapx, mapy = np.float32(x + dx), np.float32(y + dy) return mapx, mapy def __call__(self, sample, random_state=np.random.RandomState()): images = sample['image'].copy() mapx, mapy = self.get_random_params(images, random_state) sample['image'] = self.elastic_wrap(images, mapx, mapy, 'img') for key in self.additional_targets.keys(): if key not in self.skip_targets: sample[key] = self.elastic_wrap(sample[key].copy(), mapx, mapy, target_type = self.additional_targets[key]) return sample

© Copyright 2019-2023, PyTorch Connectomics Contributors. Revision 9f6fcf8b.

Built with Sphinx using a theme provided by Read the Docs.