def prep_penalties(self): # initialize without considering background dice_weights_dict = {} # average for "weighted averaging" dice_penalty_dict = {} # penalty for misclassification for i in range(1, self.n_classes): dice_weights_dict[i] = 0 dice_penalty_dict[i] = 0 penalty_loader = self.data.get_penalty_loader() # get the weights for use for dice loss total_nonZeroVoxels = 0 # dice penalty is calculated on the basis of the masks (processed here) and predicted labels # iterate through full data (may differ from training data by not being cropped for example) for subject in penalty_loader: # accumulate dice weights for each label mask = subject['label'][torchio.DATA] one_hot_mask = one_hot(mask, self.data.class_list) for i in range(1, self.n_classes): currentNumber = torch.nonzero(one_hot_mask[:, i, :, :, :], as_tuple=False).size(0) dice_weights_dict[i] = dice_weights_dict[ i] + currentNumber # class-specific non-zero voxels total_nonZeroVoxels = total_nonZeroVoxels + currentNumber # total number of non-zero voxels to be considered if total_nonZeroVoxels == 0: raise RuntimeError( 'Trying to train on data where every label mask is background class only.' ) # dice_weights_dict_temp = deepcopy(dice_weights_dict) dice_weights_dict = { k: (v / total_nonZeroVoxels) for k, v in dice_weights_dict.items() } # divide each dice value by total nonzero dice_penalty_dict = deepcopy( dice_weights_dict) # deep copy so that both values are preserved dice_penalty_dict = {k: 1 - v for k, v in dice_weights_dict.items() } # subtract from 1 for penalty total = sum(dice_penalty_dict.values()) dice_penalty_dict = { k: v / total for k, v in dice_penalty_dict.items() } # normalize penalty to ensure sum of 1 # dice_penalty_dict = get_class_imbalance_weights(trainingDataFromPickle, parameters, headers, is_regression, class_list) # this doesn't work because ImagesFromDataFrame gets import twice, causing a "'module' object is not callable" error return dice_weights_dict, dice_penalty_dict
def CCE_Generic(out, target, params, CCE_Type): """ Generic function to calculate CCE loss Args: out (torch.tensor): The predicted output value for each pixel. dimension: [batch, class, x, y, z]. target (torch.tensor): The ground truth label for each pixel. dimension: [batch, class, x, y, z] factorial_class_list. params (dict): The parameter dictionary. CCE_Type (torch.nn): The CE loss function type. Returns: torch.tensor: The final loss value after taking multiple classes into consideration """ acc_ce_loss = 0 target = one_hot(target, params["model"]["class_list"]).type(out.dtype) for i in range(0, len(params["model"]["class_list"])): curr_ce_loss = CCE_Type(out[:, i, ...], target[:, i, ...]) if params["weights"] is not None: curr_ce_loss = curr_ce_loss * params["weights"][i] acc_ce_loss += curr_ce_loss if params["weights"] is None: acc_ce_loss /= len(params["model"]["class_list"]) return acc_ce_loss
def main(data_path, plan_path, model_weights_path, output_pardir, model_output_tag, device, legacy_model_flag=False): # TODO: We do not currently make use of the ability for brainmage to infer by first cropping external # zero planes, or inference by patching and fusing. flplan = parse_fl_plan(plan_path) # make sure the class list we are using is compatible with the hard-coded class_label_map above if flplan['data_object_init']['init_kwargs']['class_list'] != class_list: raise ValueError('We currently only support class_list=', class_list) # construct the data object data = create_data_object_with_explicit_data_path(flplan=flplan, data_path=data_path) # code is written with assumption we are using the gandlf data object if not issubclass(data.__class__, GANDLFData): raise ValueError( 'This script is currently assumed to be using a child of fets.data.pytorch.gandlf_data.GANDLFData, you are using: ', data.__class__.__name__) # construct the model object (requires cpu since we're passing [padded] whole brains) model = create_model_object(flplan=flplan, data_object=data, model_device=device) # code is written with assumption we are using the brainmage object if not issubclass(model.__class__, BrainMaGeModel): raise ValueError( 'This script is currently assumed to be using a child of fets.models.pytorch.brainmage.BrainMaGeModel, you are using: ', data.__class__.__name__) # legacy models are defined in a single file, newer ones have a folder that holds per-layer files if legacy_model_flag: tensor_dict_from_proto = load_legacy_model_protobuf(model_weights_path) else: tensor_dict_from_proto = load_model(model_weights_path) # restore any tensors held out from the proto _, holdout_tensors = split_tensor_dict_for_holdouts( None, model.get_tensor_dict()) tensor_dict = {**tensor_dict_from_proto, **holdout_tensors} model.set_tensor_dict(tensor_dict, with_opt_vars=False) print("\nWill be running inference on {} validation samples.\n".format( model.get_validation_data_size())) if not os.path.exists(output_pardir): os.mkdir(output_pardir) subdir_to_DICE = {} dice_outpath = None for subject in data.get_val_loader(): first_mode_path = subject['1']['path'][ 0] # using this because this is only one that's always defined subfolder = first_mode_path.split('/')[-2] #prep the path for the output files output_subdir = os.path.join(output_pardir, subfolder) if not os.path.exists(output_subdir): os.mkdir(output_subdir) inference_outpath = os.path.join( output_subdir, subfolder + model_output_tag + '_seg.nii.gz') if dice_outpath == None: dice_outpath = os.path.join( output_pardir, model_output_tag + '_subdirs_to_DICE.pkl') if not is_mask_present(subject): raise ValueError( 'We are expecting to run this on subjects that have labels.') label_path = subject['label']['path'][0] label_file = label_path.split('/')[-1] subdir_name = label_path.split('/')[-2] # copy the label file over to the output subdir copy_label_path = os.path.join(output_subdir, label_file) shutil.copyfile(label_path, copy_label_path) features, ground_truth = subject_to_feature_and_label(subject=subject, pad_z=pad_z) output = infer(model, features) # FIXME: Find a better solution # crop away the padding we put in output = output[:, :, :, :, :155] print( one_hot(segmask_array=ground_truth, class_list=class_list).shape, output.shape) # get the DICE score dice_dict = clinical_dice(output=output, target=one_hot(segmask_array=ground_truth, class_list=class_list), class_list=class_list, to_scalar=True) subdir_to_DICE[subdir_name] = dice_dict output = np.squeeze(output.cpu().numpy()) # GANDLFData loader produces transposed output from what sitk gets from file, so transposing here. output = np.transpose(output, [0, 3, 2, 1]) # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map output = new_labels_from_float_output(array=output, class_label_map=class_label_map, binary_classification=False) # convert array to SimpleITK image image = sitk.GetImageFromArray(output) image.CopyInformation(sitk.ReadImage(first_mode_path)) print("\nWriting inference NIfTI image of shape {} to {}".format( output.shape, inference_outpath)) sitk.WriteImage(image, inference_outpath) print("\nCorresponding DICE scores were: ") print("{}\n\n".format(dice_dict)) print("Saving subdir_name_to_DICE at: ", dice_outpath) with open(dice_outpath, 'wb') as _file: pkl.dump(subdir_to_DICE, _file)
def get_loss_and_metrics(image, ground_truth, predicted, params): """ image: torch.Tensor The input image stack according to requirements ground_truth : torch.Tensor The input ground truth for the corresponding image label predicted : torch.Tensor The input predicted label for the corresponding image label params : dict The parameters passed by the user yaml Returns ------- loss : torch.Tensor The computed loss from the label and the output metric_output : torch.Tensor The computed metric from the label and the output """ # this is currently only happening for mse_torch if isinstance(params["loss_function"], dict): # check for mse_torch loss_function = global_losses_dict["mse"] else: loss_str_lower = params["loss_function"].lower() if loss_str_lower in global_losses_dict: loss_function = global_losses_dict[loss_str_lower] else: sys.exit( "WARNING: Could not find the requested loss function '" + params["loss_function"] ) loss = 0 # specialized loss function for sdnet sdnet_check = (len(predicted) > 1) and (params["model"]["architecture"] == "sdnet") if params["problem_type"] == "segmentation": ground_truth = one_hot(ground_truth, params["model"]["class_list"]) deep_supervision_model = False if ( (len(predicted) > 1) and not (sdnet_check) and ("deep" in params["model"]["architecture"]) ): deep_supervision_model = True # this case is for models that have deep-supervision - currently only used for segmentation models # these weights are taken from previous publication (https://arxiv.org/pdf/2103.03759.pdf) loss_weights = [0.5, 0.25, 0.175, 0.075] assert len(predicted) == len( loss_weights ), "Loss weights must be same length as number of outputs." ground_truth_resampled = [] ground_truth_prev = ground_truth.detach() for i, _ in enumerate(predicted): if ground_truth_prev[0].shape != predicted[i][0].shape: # we get the expected shape of resampled ground truth expected_shape = reverse_one_hot( predicted[i][0].detach(), params["model"]["class_list"] ).shape # linear interpolation is needed because we want "soft" images for resampled ground truth ground_truth_prev = nnf.interpolate( ground_truth_prev, size=expected_shape, mode=get_linear_interpolation_mode(len(expected_shape)), align_corners=False, ) ground_truth_resampled.append(ground_truth_prev) if sdnet_check: # this is specific for sdnet-style archs loss_seg = loss_function(predicted[0], ground_truth.squeeze(-1), params) loss_reco = global_losses_dict["l1"](predicted[1], image[:, :1, ...], None) loss_kld = global_losses_dict["kld"](predicted[2], predicted[3]) loss_cycle = global_losses_dict["mse"](predicted[2], predicted[4], None) loss = 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle else: if deep_supervision_model: # this is for models that have deep-supervision for i, _ in enumerate(predicted): # loss is calculated based on resampled "soft" labels using a pre-defined weights array loss += ( loss_function(predicted[i], ground_truth_resampled[i], params) * loss_weights[i] ) else: loss = loss_function(predicted, ground_truth, params) metric_output = {} # Metrics should be a list for metric in params["metrics"]: metric_lower = metric.lower() metric_output[metric] = 0 if metric_lower in global_metrics_dict: metric_function = global_metrics_dict[metric_lower] if sdnet_check: metric_output[metric] = get_metric_output( metric_function, predicted[0], ground_truth.squeeze(-1), params ) else: if deep_supervision_model: for i, _ in enumerate(predicted): metric_output[metric] += get_metric_output( metric_function, predicted[i], ground_truth_resampled[i], params, ) else: metric_output[metric] = get_metric_output( metric_function, predicted, ground_truth, params ) else: print( "WARNING: Could not find the requested metric '" + metric, file=sys.stderr, ) return loss, metric_output
def validate(self, use_tqdm=False): total_dice = 0 val_loader = self.data.get_val_loader() if val_loader == []: raise RuntimeError( "Attempting to run validation with an empty val loader.") if use_tqdm: val_loader = tqdm.tqdm(val_loader, desc="validate") for subject in val_loader: # this is when we are using pt_brainmagedata if ('features' in subject.keys()) and ('gt' in subject.keys()): features = subject['features'] mask = subject['gt'] output = self.infer_batch_with_no_numpy_conversion( features=features) # using the gandlf loader else: features = torch.cat( [subject[key][torchio.DATA] for key in self.channel_keys], dim=1) mask = subject['label'][torchio.DATA] if self.infer_gandlf_images_with_cropping: output = self.data.infer_with_crop( model_inference_function=[ self.infer_batch_with_no_numpy_conversion ], features=features) else: output = self.data.infer_with_crop_and_patches( model_inference_function=[ self.infer_batch_with_no_numpy_conversion ], features=features) # one-hot encoding of ground truth mask = one_hot(mask, self.data.class_list) # sanity check that the output and mask have the same shape if output.shape != mask.shape: raise ValueError( 'Model output and ground truth mask are not the same shape.' ) # curr_dice = average_dice_over_channels(output.float(), mask.float(), self.binary_classification).cpu().data.item() curr_dice = clinical_dice( output.float(), mask.float(), class_list=self.data.class_list).cpu().data.item() total_dice += curr_dice #Computing the average dice average_dice = total_dice / len(val_loader) return average_dice
def train_batches(self, num_batches, use_tqdm=False): num_subjects = num_batches device = torch.device(self.device) ################################ PRINTING SOME STUFF ###################### print("\nHostname :" + str(os.getenv("HOSTNAME"))) sys.stdout.flush() print("Training Data Samples: ", len(self.data.train_loader.dataset)) sys.stdout.flush() print('Using device:', device) if device.type == 'cuda': print('Memory Usage:') print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024**3, 1), 'GB') print('Cached: ', round(torch.cuda.memory_cached(0) / 1024**3, 1), 'GB') sys.stdout.flush() train_loader = self.data.get_train_loader() if train_loader == []: raise RuntimeError( "Attempting to run training with an empty training loader.") if use_tqdm: train_loader = tqdm.tqdm(train_loader, desc="training for this round") total_loss = 0 subject_num = 0 num_nan_losses = 0 # set to "training" mode self.train() while subject_num < num_subjects: for subject in train_loader: if subject_num >= num_subjects: break else: if device.type == 'cuda': print( '=== Memory (allocated; cached) : ', round(torch.cuda.memory_allocated(0) / 1024**3, 1), '; ', round(torch.cuda.memory_reserved(0) / 1024**3, 1)) # Load the subject and its ground truth # this is when we are using pt_brainmagedata if ('features' in subject.keys()) and ('gt' in subject.keys()): features = subject['features'] mask = subject['gt'] # this is when we are using gandlf loader else: features = torch.cat([ subject[key][torchio.DATA] for key in self.channel_keys ], dim=1) mask = subject['label'][torchio.DATA] print("\n\nTrain features with shape: {}\n".format( features.shape)) mask = one_hot(mask, self.data.class_list) # Loading features into device features, mask = features.float().to( device), mask.float().to(device) # TODO: Variable class is deprecated - parameters to be given are the tensor, whether it requires grad and the function that created it # features, mask = Variable(features, requires_grad = True), Variable(mask, requires_grad = True) # Making sure that the optimizer has been reset self.optimizer.zero_grad() # Forward Propagation to get the output from the models # TODO: Not recommended? (https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/6)will try without #torch.cuda.empty_cache() output = self(features.float()) # Computing the loss loss = self.loss_fn(output.float(), mask.float(), num_classes=self.label_channels, weights=self.dice_penalty_dict, class_list=self.data.class_list) # Back Propagation for model to learn (unless loss is nan) if torch.isnan(loss): num_nan_losses += 1 else: loss.backward() #Updating the weight values self.optimizer.step() #Pushing the dice to the cpu and only taking its value loss.cpu().data.item() total_loss += loss self.lr_scheduler.step() # TODO: Not recommended? (https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/6)will try without #torch.cuda.empty_cache() subject_num += 1 num_subject_grads = num_subjects - num_nan_losses # we return the average batch loss over all epochs trained this round (excluding the nan results) # we also return the number of samples that produced nan losses, as well as total samples used # FIXME: In a federation we may want the collaborators data size to be modified when backprop is skipped. return { "loss": total_loss / num_subject_grads, "num_nan_losses": num_nan_losses, "num_samples_used": num_subjects }
def validate(self, use_tqdm=False): # dice results are dictionaries if self.validate_with_fine_grained_dice: # here keys will be: 'ET', 'WT', and 'TC' total_dice = {'ET': 0, 'WT': 0, 'TC': 0} else: # here we only have one key: 'AVG(ET,WT,TC)' total_dice = {'AVG(ET,WT,TC)': 0} val_loader = self.data.get_val_loader() if val_loader == []: raise RuntimeError( "Attempting to run validation with an empty val loader.") if use_tqdm: val_loader = tqdm.tqdm(val_loader, desc="validate") for subject in val_loader: # this is when we are using pt_brainmagedata if ('features' in subject.keys()) and ('gt' in subject.keys()): features = subject['features'] mask = subject['gt'] output = self.infer_batch_with_no_numpy_conversion( features=features) # using the gandlf loader else: features = torch.cat( [subject[key][torchio.DATA] for key in self.channel_keys], dim=1) mask = subject['label'][torchio.DATA] if self.validate_without_patches: output = self.data.infer_with_crop( model_inference_function=[ self.infer_batch_with_no_numpy_conversion ], features=features) else: output = self.data.infer_with_crop_and_patches( model_inference_function=[ self.infer_batch_with_no_numpy_conversion ], features=features) # one-hot encoding of ground truth mask = one_hot(mask, self.data.class_list) # sanity check that the output and mask have the same shape if output.shape != mask.shape: raise ValueError( 'Model output and ground truth mask are not the same shape.' ) # curr_dice = average_dice_over_channels(output.float(), mask.float(), self.binary_classification).cpu().data.item() current_dice = clinical_dice( output=output.float(), target=mask.float(), class_list=self.data.class_list, fine_grained=self.validate_with_fine_grained_dice, to_scalar=True) # the dice results here are dictionaries (sum up the totals) for key in total_dice: total_dice[key] = total_dice[key] + current_dice[key] #Computing the average dice for all values of total_dice dict average_dice = { key: value / len(val_loader) for key, value in total_dice.items() } return average_dice