Shortcuts

Source code for connectomics.utils.process

from __future__ import print_function, division
from typing import Optional, Union, List
import numpy as np

from scipy import ndimage
from skimage.measure import label
from skimage.transform import resize
from skimage.morphology import dilation, binary_dilation
from skimage.segmentation import watershed
from skimage.morphology import remove_small_objects
from skimage.feature import peak_local_max

from connectomics.data.utils import getSegType, bbox_ND, crop_ND, replace_ND


__all__ = ['binary_connected',
           'binary_watershed',
           'bc_connected',
           'bc_watershed',
           'bcd_watershed',
           'polarity2instance']


# Post-processing functions of mitochondria instance segmentation model outputs
# as described in "MitoEM Dataset: Large-scale 3D Mitochondria Instance Segmentation
# from EM Images (MICCAI 2020, https://donglaiw.github.io/page/mitoEM/index.html)".
[docs]def binary_connected(volume, thres=0.8, thres_small=128, scale_factors=(1.0, 1.0, 1.0), remove_small_mode='background'): r"""Convert binary foreground probability maps to instance masks via connected-component labeling. Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres (float): threshold of foreground. Default: 0.8 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'`` """ semantic = volume[0] foreground = (semantic > int(255*thres)) segm = label(foreground) segm = remove_small_instances(segm, thres_small, remove_small_mode) if not all(x==1.0 for x in scale_factors): target_size = (int(semantic.shape[0]*scale_factors[0]), int(semantic.shape[1]*scale_factors[1]), int(semantic.shape[2]*scale_factors[2])) segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True) return cast2dtype(segm)
[docs]def binary_watershed(volume, thres1=0.98, thres2=0.85, thres_small=128, scale_factors=(1.0, 1.0, 1.0), remove_small_mode='background', seed_thres=32): r"""Convert binary foreground probability maps to instance masks via watershed segmentation algorithm. Note: This function uses the `skimage.segmentation.watershed <https://github.com/scikit-image/scikit-image/blob/master/skimage/segmentation/_watershed.py#L89>`_ function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays. Args: volume (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)`. thres1 (float): threshold of seeds. Default: 0.98 thres2 (float): threshold of foreground. Default: 0.85 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'`` """ semantic = volume[0] seed_map = semantic > int(255*thres1) foreground = semantic > int(255*thres2) seed = label(seed_map) seed = remove_small_objects(seed, seed_thres) segm = watershed(-semantic.astype(np.float64), seed, mask=foreground) segm = remove_small_instances(segm, thres_small, remove_small_mode) if not all(x==1.0 for x in scale_factors): target_size = (int(semantic.shape[0]*scale_factors[0]), int(semantic.shape[1]*scale_factors[1]), int(semantic.shape[2]*scale_factors[2])) segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True) return cast2dtype(segm)
[docs]def bc_connected(volume, thres1=0.8, thres2=0.5, thres_small=128, scale_factors=(1.0, 1.0, 1.0), dilation_struct=(1,5,5), remove_small_mode='background'): r"""Convert binary foreground probability maps and instance contours to instance masks via connected-component labeling. Note: The instance contour provides additional supervision to distinguish closely touching objects. However, the decoding algorithm only keep the intersection of foreground and non-contour regions, which will systematically result in imcomplete instance masks. Therefore we apply morphological dilation (check :attr:`dilation_struct`) to enlarge the object masks. Args: volume (numpy.ndarray): foreground and contour probability of shape :math:`(C, Z, Y, X)`. thres1 (float): threshold of foreground. Default: 0.8 thres2 (float): threshold of instance contours. Default: 0.5 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) dilation_struct (tuple): the shape of the structure for morphological dilation. Default: (1, 5, 5) remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'`` """ semantic = volume[0] boundary = volume[1] foreground = (semantic > int(255*thres1)) * (boundary < int(255*thres2)) segm = label(foreground) struct = np.ones(dilation_struct) segm = dilation(segm, struct) segm = remove_small_instances(segm, thres_small, remove_small_mode) if not all(x==1.0 for x in scale_factors): target_size = (int(semantic.shape[0]*scale_factors[0]), int(semantic.shape[1]*scale_factors[1]), int(semantic.shape[2]*scale_factors[2])) segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True) return cast2dtype(segm)
[docs]def bc_watershed(volume, thres1=0.9, thres2=0.8, thres3=0.85, thres_small=128, scale_factors=(1.0, 1.0, 1.0), remove_small_mode='background', seed_thres=32, return_seed=False, precomputed_seed=None): r"""Convert binary foreground probability maps and instance contours to instance masks via watershed segmentation algorithm. Note: This function uses the `skimage.segmentation.watershed <https://github.com/scikit-image/scikit-image/blob/master/skimage/segmentation/_watershed.py#L89>`_ function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays. Args: volume (numpy.ndarray): foreground and contour probability of shape :math:`(C, Z, Y, X)`. thres1 (float): threshold of seeds. Default: 0.9 thres2 (float): threshold of instance contours. Default: 0.8 thres3 (float): threshold of foreground. Default: 0.85 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'`` """ assert volume.shape[0] == 2 semantic = volume[0] boundary = volume[1] foreground = (semantic > int(255*thres3)) if precomputed_seed is not None: seed = precomputed_seed else: # compute the instance seeds seed_map = (semantic > int(255*thres1)) * (boundary < int(255*thres2)) seed = label(seed_map) seed = remove_small_objects(seed, seed_thres) segm = watershed(-semantic.astype(np.float64), seed, mask=foreground) segm = remove_small_instances(segm, thres_small, remove_small_mode) if not all(x==1.0 for x in scale_factors): target_size = (int(semantic.shape[0]*scale_factors[0]), int(semantic.shape[1]*scale_factors[1]), int(semantic.shape[2]*scale_factors[2])) segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True) if not return_seed: return cast2dtype(segm) return cast2dtype(segm), seed
[docs]def bcd_watershed(volume, thres1=0.9, thres2=0.8, thres3=0.85, thres4=0.5, thres5=0.0, thres_small=128, scale_factors=(1.0, 1.0, 1.0), remove_small_mode='background', seed_thres=32, return_seed=False, precomputed_seed=None): r"""Convert binary foreground probability maps, instance contours and signed distance transform to instance masks via watershed segmentation algorithm. Note: This function uses the `skimage.segmentation.watershed <https://github.com/scikit-image/scikit-image/blob/master/skimage/segmentation/_watershed.py#L89>`_ function that converts the input image into ``np.float64`` data type for processing. Therefore please make sure enough memory is allocated when handling large arrays. Args: volume (numpy.ndarray): foreground and contour probability of shape :math:`(C, Z, Y, X)`. thres1 (float): threshold of seeds. Default: 0.9 thres2 (float): threshold of instance contours. Default: 0.8 thres3 (float): threshold of foreground. Default: 0.85 thres4 (float): threshold of signed distance for locating seeds. Default: 0.5 thres5 (float): threshold of signed distance for foreground. Default: 0.0 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0) remove_small_mode (str): ``'background'``, ``'neighbor'`` or ``'none'``. Default: ``'background'`` """ assert volume.shape[0] == 3 semantic, boundary, distance = volume[0], volume[1], volume[2] distance = (distance / 255.0) * 2.0 - 1.0 foreground = (semantic > int(255*thres3)) * (distance > thres5) if precomputed_seed is not None: seed = precomputed_seed else: # compute the instance seeds seed_map = (semantic > int(255*thres1)) * (boundary < int(255*thres2)) * (distance > thres4) seed = label(seed_map) seed = remove_small_objects(seed, seed_thres) segm = watershed(-semantic.astype(np.float64), seed, mask=foreground) segm = remove_small_instances(segm, thres_small, remove_small_mode) if not all(x==1.0 for x in scale_factors): target_size = (int(semantic.shape[0]*scale_factors[0]), int(semantic.shape[1]*scale_factors[1]), int(semantic.shape[2]*scale_factors[2])) segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True) if not return_seed: return cast2dtype(segm) return cast2dtype(segm), seed
# Post-processing functions for synaptic polarity model outputs as described # in "Two-Stream Active Query Suggestion for Active Learning in Connectomics # (ECCV 2020, https://zudi-lin.github.io/projects/#two_stream_active)".
[docs]def polarity2instance( volume: np.ndarray, thres: float=0.5, thres_small: int=128, scale_factors: tuple=(1.0, 1.0, 1.0), semantic: bool=False, dilate_sz: int=5, exclusive: bool = False, ) -> np.ndarray: r"""From synaptic polarity prediction to instance masks via connected-component labeling. The input volume should be a 3-channel probability map of shape :math:`(C, Z, Y, X)` where :math:`C=3`, representing pre-synaptic region, post-synaptic region and their union, respectively. The function also handles the case where the pre- and post-synaptic masks are exclusive (applied a softmax function before post-processing). Note: For each pair of pre- and post-synaptic segmentation, the decoding function will annotate pre-synaptic region as :math:`2n-1` and post-synaptic region as :math:`2n`, for :math:`n>0`. If :attr:`semantic=True`, all pre-synaptic pixels are labeled with while all post-synaptic pixels are labeled with 2. Both kinds of annotation are compatible with the ``TARGET_OPT: ['1']`` configuration in training. Note: The number of pre- and post-synaptic segments will be reported when setting :attr:`semantic=False`. Note that the numbers can be different due to either incomplete syanpses touching the volume borders, or errors in the prediction. We thus make a conservative estimate of the total number of synapses by using the relatively small number among the two. Args: volume (numpy.ndarray): 3-channel probability map of shape :math:`(3, Z, Y, X)`. thres (float): probability threshold of foreground. Default: 0.5 thres_small (int): size threshold of small objects to remove. Default: 128 scale_factors (tuple): scale factors for resizing the output volume in :math:`(Z, Y, X)` order. Default: :math:`(1.0, 1.0, 1.0)` semantic (bool): return only the semantic mask of pre- and post-synaptic regions. Default: False dilate_sz (int): define a struct of size (1, dilate_sz, dilate_sz) to dilate the masks. Default: 5 exclusive (bool): whether the synaptic masks are exclusive (with softmax) or not. Default: False Examples:: >>> from connectomics.data.utils import readvol, savevol >>> from connectomics.utils.processing import polarity2instance >>> volume = readvol(input_name) >>> instances = polarity2instance(volume) >>> savevol(output_name, instances) """ if exclusive: idx_arr = np.argmax(volume, axis=0) temp = np.stack([ idx_arr == 1, idx_arr == 2, idx_arr != 0, # union of pre- and post-synaptic masks ], axis=0) else: thres = int(255.0 * thres) temp = (volume > thres) # boolean array syn_pre = np.logical_and(temp[0], temp[2]) syn_pre = remove_small_objects(syn_pre, min_size=thres_small, connectivity=1) syn_post = np.logical_and(temp[1], temp[2]) syn_post = remove_small_objects(syn_post, min_size=thres_small, connectivity=1) if semantic: # Generate only the semantic mask. The pre-synaptic region is labeled # with 1, while the post-synaptic region is labeled with 2. segm = np.maximum(syn_pre.astype(np.uint8), syn_post.astype(np.uint8) * 2) else:# Generate the instance mask. # The pre- and post-synaptic masks may not touch each other. Dilating the # union masks to define each synapse instance. foreground = binary_dilation(temp[2], np.ones((1,dilate_sz,dilate_sz), bool)) foreground = label(foreground) # Since non-zero pixels in seg_pos and seg_neg are subsets of temp[2], # they are naturally subsets of non-zero pixels in foreground. seg_pre = (foreground*2 - 1) * syn_pre.astype(foreground.dtype) seg_post = (foreground*2) * syn_post.astype(foreground.dtype) segm = np.maximum(seg_pre, seg_post) # Report the number of synapses num_pre = len(np.unique(seg_pre))-1 num_post = len(np.unique(seg_post))-1 num_syn = min(num_pre, num_post) # a conservative estimate print(f"Stats: found {num_pre} pre- and {num_post} post-synaptic segments.") print(f"There are {num_syn} synapses under a conservative estimate.") # resize the segmentation based on specified scale factors if not all(x==1.0 for x in scale_factors): target_size = (int(segm.shape[0]*scale_factors[0]), int(segm.shape[1]*scale_factors[1]), int(segm.shape[2]*scale_factors[2])) segm = resize(segm, target_size, order=0, anti_aliasing=False, preserve_range=True) return cast2dtype(segm)
# utils for post-processing def binarize_and_median(pred, size=(7,7,7), thres=0.8): """First binarize the prediction with a given threshold, and then conduct median filtering to reduce noise. pred (numpy.ndarray): predicted foreground probability within (0,1). size (tuple): kernal size of filtering. Default: (7,7,7) thres (float): threshold for binarizing the prediction. Default: 0.8 """ pred = (pred > thres).astype(np.uint8) pred = ndimage.median_filter(pred, size=size) return pred def remove_small_instances(segm: np.ndarray, thres_small: int = 128, mode: str = 'background'): """Remove small spurious instances. """ assert mode in ['none', 'background', 'background_2d', 'neighbor', 'neighbor_2d'] if mode == 'none': return segm # The function remove_small_objects expects ar to be an array with labeled objects, and # removes objects smaller than min_size. If ar is bool, the image is first labeled. This # leads to potentially different behavior for bool and 0-and-1 arrays. Reference: # https://scikit-image.org/docs/stable/api/skimage.morphology.html#remove-small-objects if mode == 'background': return remove_small_objects(segm, thres_small) elif mode == 'background_2d': temp = [remove_small_objects(segm[i], thres_small) for i in range(segm.shape[0])] return np.stack(temp, axis=0) if mode == 'neighbor': return merge_small_objects(segm, thres_small, do_3d=True) elif mode == 'neighbor_2d': temp = [merge_small_objects(segm[i], thres_small) for i in range(segm.shape[0])] return np.stack(temp, axis=0) def merge_small_objects(segm, thres_small, do_3d=False): struct = np.ones((1,3,3)) if do_3d else np.ones((3,3)) indices, counts = np.unique(segm, return_counts=True) for i in range(len(indices)): idx = indices[i] if counts[i] < thres_small: temp = (segm == idx).astype(np.uint8) coord = bbox_ND(temp, relax=2) cropped = crop_ND(temp, coord) diff = dilation(cropped, struct) - cropped diff_segm = crop_ND(segm, coord) diff_segm[np.where(diff==0)]=0 u, ct = np.unique(diff_segm, return_counts=True) if len(u) > 1 and u[0] == 0: u, ct = u[1:], ct[1:] segm[np.where(segm==idx)] = u[np.argmax(ct)] return segm def remove_large_instances(segm: np.ndarray, max_size: int = 2000): """Remove large instances given a maximum size threshold. """ out = np.copy(segm) component_sizes = np.bincount(segm.ravel()) too_large = component_sizes > max_size too_large_mask = too_large[segm] out[too_large_mask] = 0 return out def cast2dtype(segm): """Cast the segmentation mask to the best dtype to save storage. """ max_id = np.amax(np.unique(segm)) m_type = getSegType(int(max_id)) return segm.astype(m_type) def stitch_3d(masks, stitch_threshold=0.25): r""" Takes a volume stack of 2D annotations and stitches into 3D annotations using IOU. Args: mask (numpy.ndarray): 3D volume comprised of a 2D annotations stack of shape :math:`(Z, Y, X)`. stitch_threshold (float): threshold for joining 2D annotations via IOU. Default: 0.25 """ mmax = masks[0].max() empty = 0 for i in range(len(masks)-1): # retrive all intersecting pairs, discard background iou = intersection_over_union(masks[i+1], masks[i])[1:,1:] if not iou.size and empty == 0: mmax = masks[i+1].max() elif not iou.size and not empty == 0: icount = masks[i+1].max() istitch = np.arange(mmax+1, mmax + icount+1, 1, int) mmax += icount istitch = np.append(np.array(0), istitch) masks[i+1] = istitch[masks[i+1]] else: # set all iou value that did not breach the threshold to zero iou[iou < stitch_threshold] = 0.0 # we calculated the IoU for each possible masks pair # for each mask only consider the pairing with the greatest IoU iou[iou < iou.max(axis=0)] = 0.0 istitch = iou.argmax(axis=1) + 1 ino = np.nonzero(iou.max(axis=1)==0.0)[0] istitch[ino] = np.arange(mmax+1, mmax+len(ino)+1, 1, int) mmax += len(ino) istitch = np.append(np.array(0), istitch) masks[i+1] = istitch[masks[i+1]] empty = 1 return masks # Abducted from the cellpose repository (https://github.com/MouseLand/cellpose/blob/master/cellpose/metrics.py). def intersection_over_union(masks_true, masks_pred): """ Calculates the intersection over union for all mask pairs. Args: x (numpy.ndarray): 2D label array where 0=NO masks; 1,2... are mask labels, shape :math: `(Y, X)`. y (numpy.ndarray): 2D label array where 0=NO masks; 1,2... are mask labels, shape :math: `(Y, X)`. Return: A ND-array recording the IoU score (flaot) for each label pair, size [x.max()+1, y.max()+1] """ overlap = _label_overlap(masks_true, masks_pred) # index vise encoding of how often a predicted label coincides with true labels n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) # index vise encoding of how often a true label coincides with predicted labels n_pixels_true = np.sum(overlap, axis=1, keepdims=True) iou = overlap / (n_pixels_pred + n_pixels_true - overlap) iou[np.isnan(iou)] = 0.0 return iou def _label_overlap(x, y): """ Creates a look up table that records the pixel overlap between two 2D label arryes. Args: x (numpy.ndarray): 2D label array where 0=NO masks; 1,2... are mask labels, shape :math: `(Y, X)`. y (numpy.ndarray): 2D label array where 0=NO masks; 1,2... are mask labels, shape :math: `(Y, X)`. Returns A ND-array matrix recording the pixel overlaps, size :math: `[x.max()+1, y.max()+1]` """ # flatten the 2D label arryes x = x.ravel() y = y.ravel() assert len(x) == len(y), f"The label masks must have the same shape" # initialize the lookup tabel overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint) # loop over the labels in x and add to the corresponding # overlap entry. If label A in x and label B in y share P # pixels, then the resulting overlap is P for i in range(len(x)): overlap[x[i],y[i]] += 1 return overlap def remove_masks(vol: np.ndarray, indices: List[int]) -> np.ndarray: """Remove objects by indices from a segmentation volume. """ for idx in indices: vol[np.where(vol==idx)] = 0 return vol def add_masks(vol_base: np.ndarray, vol: np.ndarray, indices: List[int]) -> np.ndarray: """Add the instances in a new segmentation volume to the original one. A new instance can overwrite existing object pixels if the corresponding region contains non-background. """ max_idx = max(np.unique(vol_base)) for i, idx in enumerate(indices): vol_base[np.where(vol==idx)] = max_idx+i+1 return vol_base def merge_masks(vol: np.ndarray, indices: List[List[int]]) -> np.ndarray: """Merge two or more masks into a single one. """ for merges in indices: temp = np.zeros_like(vol) for i, idx in enumerate(merges): if i == 0: main_idx = idx temp = temp + (vol==idx).astype(temp.dtype) vol[np.where(temp!=0)] = main_idx return vol def watershed_split(vol: np.ndarray, index: int, show_id: bool = False, min_distance: int = 5) -> np.ndarray: """Apply watershed transform to split an 3D object into two or more parts based on the given index. """ assert vol.ndim == 3 # 3D label array max_idx = max(np.unique(vol)) binary = (vol == index) bbox = bbox_ND(binary, relax=1) # avoid cropped object touching borders cropped = crop_ND(binary, bbox, end_included=True) # see https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_watershed.html distance = ndimage.distance_transform_edt(cropped) coords = peak_local_max(distance, min_distance=min_distance, labels=cropped) mask = np.zeros(distance.shape, dtype=bool) mask[tuple(coords.T)] = True markers = label(mask) split_objects = watershed(-distance, markers, mask=cropped) seg_id = np.unique(split_objects) new_id = [] if seg_id[0] == 0: seg_id = seg_id[1:] # ignore background pixels for i, idx in enumerate(seg_id): split_objects[np.where(split_objects==idx)] = max_idx + i + 1 new_id.append(max_idx + i + 1) if show_id: print(new_id) vol = replace_ND(vol, split_objects, bbox, end_included=True) return vol

© Copyright 2019-2023, PyTorch Connectomics Contributors. Revision 9f6fcf8b.

Built with Sphinx using a theme provided by Read the Docs.