def get_segmentation_data(self, sample): volumes = sample[self.image_key_name] volumes = to_var(volumes[torchio.DATA].float(), self.device) targets = None if self.label_key_name in sample: targets = sample[self.label_key_name] targets = to_var(targets[torchio.DATA].float(), self.device) return volumes, targets
def get_segmentation_data_and_regress_key(self, sample, regress_key): volumes = sample[self.image_key_name] volumes = to_var(volumes[torchio.DATA].float(), self.device) targets = None if self.label_key_name in sample: targets = sample[self.label_key_name] targets = to_var(targets[torchio.DATA].float(), self.device) if regress_key in sample: targets_to_regress = to_var( sample[regress_key][torchio.DATA].float(), self.device) targets = torch.cat((targets, targets_to_regress), dim=1) return volumes, targets
def make_prediction_on_whole_volume(self, sample, df): patch_size = self.eval_patch_size or self.patch_size grid_sampler = torchio.inference.GridSampler(sample, patch_size, self.patch_overlap, padding_mode='reflect') patch_loader = DataLoader(grid_sampler, batch_size=self.batch_size) aggregator = torchio.inference.GridAggregator(grid_sampler) if self.results_dir != self.eval_results_dir: df = pd.DataFrame() for patches_batch in patch_loader: # Take variables and make sure they are tensors on the right device volumes, targets = self.data_getter(patches_batch) locations = patches_batch[torchio.LOCATION] # Compute output predictions = self.model(volumes) aggregator.add_batch(predictions, locations) if self.dense_patch_eval and targets is not None: df, _ = self.batch_recorder(df, patches_batch, predictions, targets, 0, True, csv_name='patch_eval') # Aggregate predictions for the whole image predictions = to_var(aggregator.get_output_tensor(), self.device) return predictions, df
def apply_post_transforms(self, tensors, sample): affine = self.get_affine(sample) if not self.post_transforms: return tensors, affine if len(self.post_transforms) == 0: return tensors, affine # Transforms apply on TorchIO subjects and TorchIO images require # 4D tensors transformed_tensors = [] for i, tensor in enumerate(tensors): subject = torchio.Subject(pred=torchio.ScalarImage( tensor=to_var(tensor, 'cpu'), affine=affine)) transformed = self.post_transforms(subject) tensor = transformed['pred']['data'] transformed_tensors.append(tensor) new_affine = transformed['pred']['affine'] transformed_tensors = torch.stack(transformed_tensors) return to_var(transformed_tensors, self.device), new_affine
def minimum_t_norm(prediction, target, background=False): device = prediction.device np_prediction = to_numpy(prediction) np_target = to_numpy(target) if background: res = np.maximum(np_prediction - np_target, 0) else: res = np.minimum(np_prediction, np_target) return to_var(res, device)
def get_regression_data(self, data, target=None, scale_label=[1], default_missing_label=0): if isinstance(data, list): # case where callate_fn is used #inputs = torch.cat([sample[self.image_key_name]['data'].unsqueeze(0) for sample in data]) #this was wen lamba collate x:x #inputs = torch.cat([sample[self.image_key_name]['data'] for sample in data]) # #this happen when ListOf transform input_list, labels_list = [], [] for dd in data: ii, ll = self.get_regression_data(dd, target) input_list.append(ii) labels_list.append(ll) inputs = torch.cat(input_list) labels = torch.cat(labels_list) return inputs, labels else: inputs = data[self.image_key_name]['data'] targets = [target] if isinstance(target, str) else target default_missing_label = [default_missing_label] if not isinstance( default_missing_label, list) else default_missing_label self.target_name = targets self.scale_label = [ scale_label ] if not isinstance(scale_label, list) else scale_label #default values for missing label labels = torch.cat([ torch.ones(inputs.shape[0], 1) * default_lab for default_lab in default_missing_label ], dim=1) for target_idx, target in enumerate(targets): if target == 'random_noise': histo = data['history'] for batch_idx, hh in enumerate(histo): #length = batch size for hhh in hh: #length: number of transfo that lead history info if isinstance( hhh, tio.transforms.augmentation.intensity. random_noise.Noise): labels[batch_idx, target_idx] = hhh.std[ self.image_key_name] * scale_label[target_idx] else: histo = data['history'] for batch_idx, hh in enumerate(histo): #length = batch size for hhh in hh: #length: number of transfo that lead history info #if '_metrics' in hhh[1].keys(): if isinstance( hhh, dict ): #hhh.name == 'RandomMotionFromTimeCourse': #dict_metrics = hhh[1]["_metrics"][self.image_key_name] if '_metrics' in hhh: dict_metrics = hhh['_metrics'][ self.image_key_name] labels[batch_idx, target_idx] = dict_metrics[ target] * scale_label[target_idx] inputs = to_var(inputs.float(), self.device) labels = to_var(labels.float(), self.device) return inputs, labels