Shortcuts

Source code for connectomics.utils.visualizer

from __future__ import print_function, division
from typing import Optional, List, Union, Tuple

import torch
import torchvision.utils as vutils
import numpy as np
from ..data.utils import decode_quantize, dx_to_circ
from connectomics.model.utils import SplitActivation

__all__ = [
    'Visualizer'
]


[docs]class Visualizer(object): """TensorboardX visualizer for displaying loss, learning rate and predictions at training time. """ def __init__(self, cfg, vis_opt=0, N=16): self.cfg = cfg self.act = SplitActivation.build_from_cfg(cfg, do_cat=False) self.vis_opt = vis_opt self.N = N # default maximum number of sections to show self.N_ind = None self.semantic_colors = {} for topt in self.cfg.MODEL.TARGET_OPT: if topt[0] == '9': channels = int(topt.split('-')[1]) colors = [torch.rand(3) for _ in range(channels)] colors[0] = torch.zeros(3) # make background black self.semantic_colors[topt] = torch.stack(colors, dim=0) if topt[0] == '1' and len(topt) != 1: # synaptic cleft (exclusive) _, exclusive = topt.split('-') assert int(exclusive), f"Option {topt} is not expected!" colors = torch.stack([ torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 0.0, 1.0]), torch.tensor([0.0, 1.0, 1.0])], dim=0) self.semantic_colors[topt] = colors def visualize(self, volume, label, output, weight, iter_total, writer, suffix: Optional[str] = None, additional_image_groups: Optional[dict] = None): self.visualize_image_groups(writer, iter_total, additional_image_groups) volume = self._denormalize(volume) # split the prediction into chunks along the channel dimension output = self.act(output) assert len(output) == len(label) for idx in range(len(self.cfg.MODEL.TARGET_OPT)): topt = self.cfg.MODEL.TARGET_OPT[idx] if topt[0] == '9': # semantic segmentation output[idx] = self.get_semantic_map(output[idx], topt) label[idx] = self.get_semantic_map(label[idx], topt, argmax=False) if topt[0] == '1' and len(topt) != 1: # synaptic cleft segmentation _, exclusive = topt.split('-') assert int(exclusive), f"Option {topt} is not expected!" output[idx] = self.get_semantic_map(output[idx], topt) label[idx] = self.get_semantic_map(label[idx], topt, argmax=False) if topt[0] == '5': # distance transform if len(topt) == 1: topt = topt + '-2d-0-0-5.0' # default _, mode, padding, quant, z_res = topt.split('-') if bool(int(quant)): # only the quantized version needs decoding output[idx] = decode_quantize( output[idx], mode='max').unsqueeze(1) temp_label = label[idx].clone().float()[ :, np.newaxis] label[idx] = temp_label / temp_label.max() + 1e-6 if topt[0]=='7': # diffusion gradient output[idx] = dx_to_circ(output[idx]) label[idx] = dx_to_circ(label[idx]) RGB = (topt[0] in ['1', '2', '7', '9']) vis_name = self.cfg.MODEL.TARGET_OPT[idx] + '_' + str(idx) if suffix is not None: vis_name = vis_name + '_' + suffix if isinstance(label[idx], (np.ndarray, np.generic)): label[idx] = torch.from_numpy(label[idx]) weight_maps = {} for j, wopt in enumerate(self.cfg.MODEL.WEIGHT_OPT[idx]): if wopt != '0': w_name = vis_name + '_' + wopt weight_maps[w_name] = weight[idx][j] else: # The weight map can be the binary valid mask. if weight[idx][j].shape[-1] != 1: weight_maps['valid_mask'] = weight[idx][j] self.visualize_consecutive(volume, label[idx], output[idx], weight_maps, iter_total, writer, RGB=RGB, vis_name=vis_name) def visualize_image_groups(self, writer, iteration, image_groups: Optional[dict] = None, is_3d: bool = True) -> None: if image_groups is None: return for name in image_groups.keys(): image_list = image_groups[name] image_list = [self._denormalize(x) for x in image_list] image_list = [self.permute_truncate(x, is_3d=is_3d) for x in image_list] sz = image_list[0].size() canvas = [x.detach().cpu().expand(sz[0], 3, sz[2], sz[3]) for x in image_list] canvas_merge = torch.cat(canvas, 0) canvas_show = vutils.make_grid( canvas_merge, nrow=8, normalize=True, scale_each=True) writer.add_image('Image_Group_%s' % name, canvas_show, iteration) def visualize_consecutive(self, volume, label, output, weight_maps, iteration, writer, RGB=False, vis_name='0_0'): volume, label, output, weight_maps = self.prepare_data( volume, label, output, weight_maps) sz = volume.size() # z,c,y,x canvas = [] volume_visual = volume.detach().cpu().expand(sz[0], 3, sz[2], sz[3]) canvas.append(volume_visual) def maybe2rgb(temp): if temp.shape[1] == 2: # 2d affinity map has two channels temp = torch.cat([temp, torch.zeros( sz[0], 1, sz[2], sz[3]).type(temp.dtype)], dim=1) return temp if RGB: output_visual = [maybe2rgb(output.detach().cpu())] label_visual = [maybe2rgb(label.detach().cpu())] else: output_visual = [self.vol_reshape( output[:, i], sz) for i in range(sz[1])] label_visual = [self.vol_reshape( label[:, i], sz) for i in range(sz[1])] weight_visual = [] for key in weight_maps.keys(): weight_visual.append(maybe2rgb(weight_maps[key]).detach().cpu().expand( sz[0], 3, sz[2], sz[3])) canvas = canvas + output_visual + label_visual + weight_visual canvas_merge = torch.cat(canvas, 0) canvas_show = vutils.make_grid( canvas_merge, nrow=8, normalize=True, scale_each=True) writer.add_image('Consecutive_%s' % vis_name, canvas_show, iteration) def prepare_data(self, volume, label, output, weight_maps): ndim = volume.ndim assert ndim in [4, 5] is_3d = (ndim == 5) volume = self.permute_truncate(volume, is_3d) label = self.permute_truncate(label, is_3d) output = self.permute_truncate(output, is_3d) for key in weight_maps.keys(): weight_maps[key] = self.permute_truncate(weight_maps[key], is_3d) return volume, label, output, weight_maps def permute_truncate(self, data, is_3d=False): if is_3d: # show consecutive slices of a 3d volume data = data[0].permute(1, 0, 2, 3) high = min(data.size()[0], self.N) return data[:high] def get_semantic_map(self, output, topt, argmax=True): if isinstance(output, (np.ndarray, np.generic)): output = torch.from_numpy(output) # output shape: BCDHW or BCHW if argmax: output = torch.argmax(output, 1) pred = self.semantic_colors[topt][output.cpu()] if len(pred.size()) == 4: # 2D Inputs pred = pred.permute(0, 3, 1, 2) elif len(pred.size()) == 5: # 3D Inputs pred = pred.permute(0, 4, 1, 2, 3) return pred def vol_reshape(self, vol, sz): vol = vol.detach().cpu().unsqueeze(1) return vol.expand(sz[0], 3, sz[2], sz[3]) def _denormalize(self, volume): match_act = self.cfg.DATASET.MATCH_ACT if match_act == 'none': volume = (volume * self.cfg.DATASET.STD) + self.cfg.DATASET.MEAN elif match_act == 'tanh': volume = (volume + 1.0) * 0.5 return volume

© Copyright 2019-2023, PyTorch Connectomics Contributors. Revision 9f6fcf8b.

Built with Sphinx using a theme provided by Read the Docs.