Shortcuts

Source code for connectomics.data.dataset.dataset_tile

from __future__ import print_function, division
from typing import Optional, List, Union
import numpy as np
import json
import random

import torch
import torch.utils.data
from scipy.ndimage import zoom

from . import VolumeDataset
from ..utils import reduce_label, tile2volume


[docs]class TileDataset(torch.utils.data.Dataset): r"""Dataset class for large-scale tile-based datasets. Large-scale volumetric datasets are usually stored as individual tiles. Directly loading them as a single array for training and inference is infeasible. This class reads the paths of the tiles and construct smaller chunks for processing. Args: chunk_num (list): volume spliting parameters in :math:`(z, y, x)` order. Default: :math:`[2, 2, 2]` chunk_ind (list): predefined list of chunks. Default: `None` chunk_ind_split (list): rank and world_size for spliting chunk_ind in multi-processing. Default: `None` chunk_iter (int): number of iterations on each chunk. Default: -1 chunk_stride (bool): allow overlap between chunks. Default: `True` volume_json (str): json file for input image. Default: ``'path/to/image'`` label_json (str, optional): json file for label. Default: `None` valid_mask_json (str, optional): json file for valid mask. Default: `None` mode (str): ``'train'``, ``'val'`` or ``'test'``. Default: ``'train'`` pad_size (list): padding parameters in :math:`(z, y, x)` order. Default: :math:`[0, 0, 0]` data_scale (list): volume scaling factors in :math:`(z, y, x)` order. Default: :math:`[1.0, 1.0, 1.0]` coord_range (list): the valid coordinate range of volumes. Default: `None` do_relabel (bool): reduce the the mask indicies in a sampled label volume. This option be set to False for semantic segmentation, otherwise the classes can shift. Default: True Note: To run inference using multiple nodes in an asynchronous manner, ``chunk_ind_split`` specifies the number of parts to split the total number of chunks in inference, and which part should the current node/process see. For example, ``chunk_ind_split = "0-5"`` means the chunks are split into 5 parts (thus can be processed asynchronously using 5 nodes), and the current node/process is handling the first (0-base) part of the chunks. Note: The ``coord_range`` option specify the region of a volume to use. Suppose the fisrt input volume has a voxel size of (1000, 10000, 10000), and only the center subvolume of size (400, 2000, 2000) needs to be used for training or inference, then set ``coord_range=[[300, 700, 4000, 6000, 4000, 6000]]``. """ def __init__(self, chunk_num: List[int] = [2, 2, 2], chunk_ind: Optional[list] = None, chunk_ind_split: Optional[Union[List[int], str]] = None, chunk_iter: int = -1, chunk_stride: bool = True, volume_json: List[str] = ['path/to/image.json'], label_json: Optional[List[str]] = None, valid_mask_json: Optional[List[str]] = None, mode: str = 'train', pad_size: List[int] = [0, 0, 0], data_scale: List[float] = [1.0, 1.0, 1.0], coord_range: Optional[List[List[int]]] = None, do_relabel: bool = True, **kwargs): self.kwargs = kwargs self.mode = mode self.chunk_iter = chunk_iter self.pad_size = pad_size self.data_scale = data_scale self.do_relabel = do_relabel self.chunk_step = 1 if chunk_stride and self.mode == 'train': # 50% overlap between volumes during training self.chunk_step = 2 self.chunk_num = chunk_num self.chunk_ind = self.get_chunk_ind(chunk_ind, chunk_ind_split) self.chunk_id_done = [] self.num_volumes = len(volume_json) # number of volumes specified by json files if self.mode == 'test': assert self.num_volumes == 1, "Only one json file should be given in inference!" self.json_volume = [json.load(open(volume_json[i])) for i in range(self.num_volumes)] self.json_label = [json.load(open(label_json[i])) for i in range(self.num_volumes)] if ( label_json is not None) else None self.json_valid = [json.load(open(valid_mask_json[i])) for i in range(self.num_volumes)] if ( valid_mask_json is not None) else None self.json_size = [[ self.json_volume[i]['depth'], self.json_volume[i]['height'], self.json_volume[i]['width']] for i in range(self.num_volumes)] self.coord_m = np.array([[ 0, self.json_volume[i]['depth'], 0, self.json_volume[i]['height'], 0, self.json_volume[i]['width']] for i in range(self.num_volumes)], int) # specify the coordintate range of data to use self.get_coord_range(coord_range) def get_chunk_ind(self, chunk_ind, split_rule): if chunk_ind is None: chunk_ind = list( range(np.prod(self.chunk_num))) if split_rule is not None: # keep only the chunk indicies for current node/process if isinstance(split_rule, str): split_rule = split_rule.split('-') assert len(split_rule) == 2 rank, world_size = split_rule rank, world_size = int(rank), int(world_size) assert rank < world_size # rank is 0-base x = len(chunk_ind) // world_size low, high = rank * x, (rank + 1) * x # Last split needs to cover remaining chunks. if rank == world_size - 1: high = len(chunk_ind) chunk_ind = chunk_ind[low: high] return chunk_ind
[docs] def get_coord_name(self): r"""Return the filename suffix based on the chunk coordinates. """ # this function should only be called in test mode assert self.mode == 'test' and len(self.coord) == 1 return '-'.join([str(x) for x in self.coord[0]])
def get_coord_range(self, coord_range): if coord_range is not None: if isinstance(coord_range[0], int): assert len(coord_range) == 6 self.coord_range = [coord_range for _ in range(self.num_volumes)] elif isinstance(coord_range[0], list): self.coord_range = coord_range else: self.coord_range = self.coord_m # use all data assert len(self.coord_range) == self.num_volumes self.coord_range_l, self.coord_range_r = [], [] for temp in self.coord_range: # lower and higher boundaries self.coord_range_l.append([temp[2*i] for i in range(3)]) self.coord_range_r.append([temp[2*i+1] for i in range(3)]) def get_range_axis(self, axis_id, vol_id, axis: str='z'): axis_map = {'z': 0, 'y': 1, 'x': 2} l_bd = self.coord_range_l[vol_id][axis_map[axis]] r_bd = self.coord_range_r[vol_id][axis_map[axis]] assert r_bd > l_bd length = r_bd - l_bd steps = np.array([axis_id, axis_id + self.chunk_step]) axis_range = np.floor(steps / ( self.chunk_num[axis_map[axis]] + self.chunk_step-1) * length).astype(int) return axis_range + l_bd
[docs] def updatechunk(self, do_load=True): r"""Update the coordinates to a new chunk in the large volume. """ if len(self.chunk_id_done) == len(self.chunk_ind): self.chunk_id_done = [] id_rest = list(set(self.chunk_ind)-set(self.chunk_id_done)) if self.mode == 'train': id_sample = id_rest[int(np.floor(random.random()*len(id_rest)))] elif self.mode == 'test': id_sample = id_rest[0] self.chunk_id_done += [id_sample] zid = float(id_sample//(self.chunk_num[1]*self.chunk_num[2])) yid = float((id_sample//self.chunk_num[2]) % (self.chunk_num[1])) xid = float(id_sample % self.chunk_num[2]) self.coord = [] for i in range(self.num_volumes): z0, z1 = self.get_range_axis(zid, vol_id=i, axis='z') y0, y1 = self.get_range_axis(yid, vol_id=i, axis='y') x0, x1 = self.get_range_axis(xid, vol_id=i, axis='x') self.coord.append(np.array([z0, z1, y0, y1, x0, x1], int)) if do_load: self.loadchunk()
[docs] def loadchunk(self): r"""Load the chunk based on current coordinates and construct a VolumeDataset for processing. """ # assuming same padding for the list of volumes given padding = np.array([ -self.pad_size[0], self.pad_size[0], -self.pad_size[1], self.pad_size[1], -self.pad_size[2], self.pad_size[2]]) coord_p = [self.coord[i] + padding for i in range(self.num_volumes)] print('load chunk: ', coord_p) volume = [ tile2volume(self.json_volume[i]['image'], coord_p[i], self.coord_m[i], tile_sz=self.json_volume[i]['tile_size'], tile_st=self.json_volume[i]['tile_st'], tile_ratio=self.json_volume[i]['tile_ratio']) for i in range(self.num_volumes) ] volume = self.maybe_scale(volume, order=3) label = None if self.json_label is not None: dt = {'uint8': np.uint8, 'uint16': np.uint16, 'uint32': np.uint32, 'uint64': np.uint64} label = [ tile2volume(self.json_label[i]['image'], coord_p[i], self.coord_m[i], tile_sz=self.json_label[i]['tile_size'], tile_st=self.json_label[i]['tile_st'], tile_ratio=self.json_label[i]['tile_ratio'], dt=dt[self.json_label[i]['dtype']], do_im=False) for i in range(self.num_volumes) ] # float32 may misrepresent large uint32/uint64 numbers -> reduce the label indices if self.do_relabel: label = [reduce_label(x, do_type=True) for x in label] label = self.maybe_scale(label, order=0) valid_mask = None if self.json_valid is not None: valid_mask = [ tile2volume(self.json_valid[i]['image'], coord_p[i], self.coord_m[i], tile_sz=self.json_valid[i]['tile_size'], tile_st=self.json_valid[i]['tile_st'], tile_ratio=self.json_valid[i]['tile_ratio'], do_im=False) for i in range(self.num_volumes) ] valid_mask = self.maybe_scale(valid_mask, order=0) self.dataset = VolumeDataset(volume, label, valid_mask, mode=self.mode, do_relabel=self.do_relabel, # specify chunk iteration number for training (-1 for inference) iter_num=self.chunk_iter if self.mode == 'train' else -1, **self.kwargs)
def maybe_scale(self, data, order=0): if (np.array(self.data_scale) != 1).any(): for i in range(len(data)): dt = data[i].dtype data[i] = zoom(data[i], self.data_scale, order=order).astype(dt) return data

© Copyright 2019-2023, PyTorch Connectomics Contributors. Revision 9f6fcf8b.

Built with Sphinx using a theme provided by Read the Docs.