def __call__(self, input, target): if isinstance(input, torch.Tensor): assert input.dim() == 5 # convert to numpy array input = input[0].detach().cpu().numpy() # 4D if isinstance(target, torch.Tensor): if not self.use_last_target: assert target.dim() == 4 # convert to numpy array target = target[0].detach().cpu().numpy() # 3D else: # if use_last_target == True the target must be 5D (NxCxDxHxW) assert target.dim() == 5 target = target[0, -1].detach().cpu().numpy() # 3D if isinstance(input, np.ndarray): assert input.ndim == 4 if isinstance(target, np.ndarray): assert target.ndim == 3 if self.use_first_input: # compute only on the first input channel n_channels = 1 else: n_channels = input.shape[0] per_channel_arand = [] for c in range(n_channels): predictions = input[c] # threshold probability maps predictions = predictions > self.threshold if self.invert_pmaps: # for connected component analysis we need to treat boundary signal as background # assign 0-label to boundary mask predictions = np.logical_not(predictions) predictions = predictions.astype(np.uint8) # run connected components on the predicted mask; consider only 1-connectivity predicted = measure.label(predictions, background=0, connectivity=1) # make sure that target is 'int' type as well target = target.astype(np.int64) # compute AdaptedRand error arand = adapted_rand(predicted, target) per_channel_arand.append(arand) # get minimum AdaptedRand error across channels min_arand, c_index = np.min(per_channel_arand), np.argmin( per_channel_arand) LOGGER.info(f'Min AdaptedRand error: {min_arand}, channel: {c_index}') return min_arand
def __call__(self, input, target): """ Compute ARand Error for each input, target pair in the batch and return the mean value. Args: input (torch.tensor): 5D (NCDHW) output from the network target (torch.tensor): 4D (NDHW) ground truth segmentation Returns: average ARand Error across the batch """ # converts input and target to numpy arrays input, target = self._convert_to_numpy(input, target) # ensure target is of integer type target = target.astype(np.int) per_batch_arand = [] _batch_inst = 0 for _input, _target in zip(input, target): LOGGER.info( f'Number of ground truth clusters: {len(np.unique(_target))}') # convert _input to segmentation segm = self.input_to_segm(_input) # run connected components if necessary if self.run_target_cc: _target = measure.label(_target, connectivity=1) if self.save_plots: # save predicted and ground truth segmentation plot_segm(segm, _target, self.plots_dir) assert segm.ndim == 4 # compute per channel arand and return the minimum value per_channel_arand = [] for channel_segm in segm: per_channel_arand.append(adapted_rand(channel_segm, _target)) # get the min arand across channels min_arand, c_index = np.min(per_channel_arand), np.argmin( per_channel_arand) LOGGER.info( f'Batch: {_batch_inst}. Min AdaptedRand error: {min_arand}, channel: {c_index}' ) per_batch_arand.append(min_arand) _batch_inst += 1 # return mean arand error return torch.mean(torch.tensor(per_batch_arand))
def test_embeddings_predictor(self, tmpdir): config = {'model': {'output_heads': 1}, 'device': torch.device('cpu')} slice_builder_config = { 'name': 'SliceBuilder', 'patch_shape': (100, 200, 200), 'stride_shape': (60, 150, 150) } transformer_config = { 'raw': [{ 'name': 'ToTensor', 'expand_dims': False, 'dtype': 'long' }] } gt_file = 'resources/sample_cells.h5' output_file = os.path.join(tmpdir, 'output_segmentation.h5') dataset = HDF5Dataset(gt_file, phase='test', slice_builder_config=slice_builder_config, transformer_config=transformer_config, raw_internal_path='label') loader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False, collate_fn=prediction_collate) predictor = FakePredictor(FakeModel(), loader, output_file, config, clustering='meanshift', bandwidth=0.5) predictor.predict() with h5py.File(gt_file, 'r') as f: with h5py.File(output_file, 'r') as g: gt = f['label'][...] segm = g['segmentation/meanshift'][...] arand_error = adapted_rand(segm, gt) assert arand_error < 0.1
def __call__(self, input, target): return adapted_rand(input, target, all_stats=self.all_stats)