Ejemplo n.º 1
0
    def prep_penalties(self):

        # initialize without considering background
        dice_weights_dict = {}  # average for "weighted averaging"
        dice_penalty_dict = {}  # penalty for misclassification
        for i in range(1, self.n_classes):
            dice_weights_dict[i] = 0
            dice_penalty_dict[i] = 0

        penalty_loader = self.data.get_penalty_loader()

        # get the weights for use for dice loss
        total_nonZeroVoxels = 0

        # dice penalty is calculated on the basis of the masks (processed here) and predicted labels
        # iterate through full data (may differ from training data by not being cropped for example)
        for subject in penalty_loader:
            # accumulate dice weights for each label
            mask = subject['label'][torchio.DATA]
            one_hot_mask = one_hot(mask, self.data.class_list)
            for i in range(1, self.n_classes):
                currentNumber = torch.nonzero(one_hot_mask[:, i, :, :, :],
                                              as_tuple=False).size(0)
                dice_weights_dict[i] = dice_weights_dict[
                    i] + currentNumber  # class-specific non-zero voxels
                total_nonZeroVoxels = total_nonZeroVoxels + currentNumber  # total number of non-zero voxels to be considered

        if total_nonZeroVoxels == 0:
            raise RuntimeError(
                'Trying to train on data where every label mask is background class only.'
            )

        # dice_weights_dict_temp = deepcopy(dice_weights_dict)
        dice_weights_dict = {
            k: (v / total_nonZeroVoxels)
            for k, v in dice_weights_dict.items()
        }  # divide each dice value by total nonzero
        dice_penalty_dict = deepcopy(
            dice_weights_dict)  # deep copy so that both values are preserved
        dice_penalty_dict = {k: 1 - v
                             for k, v in dice_weights_dict.items()
                             }  # subtract from 1 for penalty
        total = sum(dice_penalty_dict.values())
        dice_penalty_dict = {
            k: v / total
            for k, v in dice_penalty_dict.items()
        }  # normalize penalty to ensure sum of 1
        # dice_penalty_dict = get_class_imbalance_weights(trainingDataFromPickle, parameters, headers, is_regression, class_list) # this doesn't work because ImagesFromDataFrame gets import twice, causing a "'module' object is not callable" error

        return dice_weights_dict, dice_penalty_dict
Ejemplo n.º 2
0
def CCE_Generic(out, target, params, CCE_Type):
    """
    Generic function to calculate CCE loss

    Args:
        out (torch.tensor): The predicted output value for each pixel. dimension: [batch, class, x, y, z].
        target (torch.tensor): The ground truth label for each pixel. dimension: [batch, class, x, y, z] factorial_class_list.
        params (dict): The parameter dictionary.
        CCE_Type (torch.nn): The CE loss function type.

    Returns:
        torch.tensor: The final loss value after taking multiple classes into consideration
    """

    acc_ce_loss = 0
    target = one_hot(target, params["model"]["class_list"]).type(out.dtype)
    for i in range(0, len(params["model"]["class_list"])):
        curr_ce_loss = CCE_Type(out[:, i, ...], target[:, i, ...])
        if params["weights"] is not None:
            curr_ce_loss = curr_ce_loss * params["weights"][i]
        acc_ce_loss += curr_ce_loss
    if params["weights"] is None:
        acc_ce_loss /= len(params["model"]["class_list"])
    return acc_ce_loss
Ejemplo n.º 3
0
def main(data_path,
         plan_path,
         model_weights_path,
         output_pardir,
         model_output_tag,
         device,
         legacy_model_flag=False):

    # TODO: We do not currently make use of the ability for brainmage to infer by first cropping external
    #       zero planes, or inference by patching and fusing.

    flplan = parse_fl_plan(plan_path)

    # make sure the class list we are using is compatible with the hard-coded class_label_map above
    if flplan['data_object_init']['init_kwargs']['class_list'] != class_list:
        raise ValueError('We currently only support class_list=', class_list)

    # construct the data object
    data = create_data_object_with_explicit_data_path(flplan=flplan,
                                                      data_path=data_path)

    # code is written with assumption we are using the gandlf data object
    if not issubclass(data.__class__, GANDLFData):
        raise ValueError(
            'This script is currently assumed to be using a child of fets.data.pytorch.gandlf_data.GANDLFData, you are using: ',
            data.__class__.__name__)

    # construct the model object (requires cpu since we're passing [padded] whole brains)
    model = create_model_object(flplan=flplan,
                                data_object=data,
                                model_device=device)

    # code is written with assumption we are using the brainmage object
    if not issubclass(model.__class__, BrainMaGeModel):
        raise ValueError(
            'This script is currently assumed to be using a child of fets.models.pytorch.brainmage.BrainMaGeModel, you are using: ',
            data.__class__.__name__)

    # legacy models are defined in a single file, newer ones have a folder that holds per-layer files
    if legacy_model_flag:
        tensor_dict_from_proto = load_legacy_model_protobuf(model_weights_path)
    else:
        tensor_dict_from_proto = load_model(model_weights_path)

    # restore any tensors held out from the proto
    _, holdout_tensors = split_tensor_dict_for_holdouts(
        None, model.get_tensor_dict())
    tensor_dict = {**tensor_dict_from_proto, **holdout_tensors}
    model.set_tensor_dict(tensor_dict, with_opt_vars=False)

    print("\nWill be running inference on {} validation samples.\n".format(
        model.get_validation_data_size()))

    if not os.path.exists(output_pardir):
        os.mkdir(output_pardir)

    subdir_to_DICE = {}
    dice_outpath = None

    for subject in data.get_val_loader():
        first_mode_path = subject['1']['path'][
            0]  # using this because this is only one that's always defined
        subfolder = first_mode_path.split('/')[-2]

        #prep the path for the output files
        output_subdir = os.path.join(output_pardir, subfolder)
        if not os.path.exists(output_subdir):
            os.mkdir(output_subdir)
        inference_outpath = os.path.join(
            output_subdir, subfolder + model_output_tag + '_seg.nii.gz')
        if dice_outpath == None:
            dice_outpath = os.path.join(
                output_pardir, model_output_tag + '_subdirs_to_DICE.pkl')

        if not is_mask_present(subject):
            raise ValueError(
                'We are expecting to run this on subjects that have labels.')

        label_path = subject['label']['path'][0]
        label_file = label_path.split('/')[-1]
        subdir_name = label_path.split('/')[-2]

        # copy the label file over to the output subdir
        copy_label_path = os.path.join(output_subdir, label_file)
        shutil.copyfile(label_path, copy_label_path)

        features, ground_truth = subject_to_feature_and_label(subject=subject,
                                                              pad_z=pad_z)

        output = infer(model, features)

        # FIXME: Find a better solution
        # crop away the padding we put in
        output = output[:, :, :, :, :155]

        print(
            one_hot(segmask_array=ground_truth, class_list=class_list).shape,
            output.shape)

        # get the DICE score
        dice_dict = clinical_dice(output=output,
                                  target=one_hot(segmask_array=ground_truth,
                                                 class_list=class_list),
                                  class_list=class_list,
                                  to_scalar=True)

        subdir_to_DICE[subdir_name] = dice_dict

        output = np.squeeze(output.cpu().numpy())

        # GANDLFData loader produces transposed output from what sitk gets from file, so transposing here.
        output = np.transpose(output, [0, 3, 2, 1])

        # process float outputs (accros output channels), providing labels as defined in values of self.class_label_map
        output = new_labels_from_float_output(array=output,
                                              class_label_map=class_label_map,
                                              binary_classification=False)

        # convert array to SimpleITK image
        image = sitk.GetImageFromArray(output)

        image.CopyInformation(sitk.ReadImage(first_mode_path))

        print("\nWriting inference NIfTI image of shape {} to {}".format(
            output.shape, inference_outpath))
        sitk.WriteImage(image, inference_outpath)
        print("\nCorresponding DICE scores were: ")
        print("{}\n\n".format(dice_dict))

    print("Saving subdir_name_to_DICE at: ", dice_outpath)
    with open(dice_outpath, 'wb') as _file:
        pkl.dump(subdir_to_DICE, _file)
Ejemplo n.º 4
0
def get_loss_and_metrics(image, ground_truth, predicted, params):
    """
    image: torch.Tensor
        The input image stack according to requirements
    ground_truth : torch.Tensor
        The input ground truth for the corresponding image label
    predicted : torch.Tensor
        The input predicted label for the corresponding image label
    params : dict
        The parameters passed by the user yaml

    Returns
    -------
    loss : torch.Tensor
        The computed loss from the label and the output
    metric_output : torch.Tensor
        The computed metric from the label and the output
    """
    # this is currently only happening for mse_torch
    if isinstance(params["loss_function"], dict):
        # check for mse_torch
        loss_function = global_losses_dict["mse"]
    else:
        loss_str_lower = params["loss_function"].lower()
        if loss_str_lower in global_losses_dict:
            loss_function = global_losses_dict[loss_str_lower]
        else:
            sys.exit(
                "WARNING: Could not find the requested loss function '"
                + params["loss_function"]
            )

    loss = 0
    # specialized loss function for sdnet
    sdnet_check = (len(predicted) > 1) and (params["model"]["architecture"] == "sdnet")

    if params["problem_type"] == "segmentation":
        ground_truth = one_hot(ground_truth, params["model"]["class_list"])

    deep_supervision_model = False
    if (
        (len(predicted) > 1)
        and not (sdnet_check)
        and ("deep" in params["model"]["architecture"])
    ):
        deep_supervision_model = True
        # this case is for models that have deep-supervision - currently only used for segmentation models
        # these weights are taken from previous publication (https://arxiv.org/pdf/2103.03759.pdf)
        loss_weights = [0.5, 0.25, 0.175, 0.075]

        assert len(predicted) == len(
            loss_weights
        ), "Loss weights must be same length as number of outputs."

        ground_truth_resampled = []
        ground_truth_prev = ground_truth.detach()
        for i, _ in enumerate(predicted):
            if ground_truth_prev[0].shape != predicted[i][0].shape:

                # we get the expected shape of resampled ground truth
                expected_shape = reverse_one_hot(
                    predicted[i][0].detach(), params["model"]["class_list"]
                ).shape

                # linear interpolation is needed because we want "soft" images for resampled ground truth
                ground_truth_prev = nnf.interpolate(
                    ground_truth_prev,
                    size=expected_shape,
                    mode=get_linear_interpolation_mode(len(expected_shape)),
                    align_corners=False,
                )
            ground_truth_resampled.append(ground_truth_prev)

    if sdnet_check:
        # this is specific for sdnet-style archs
        loss_seg = loss_function(predicted[0], ground_truth.squeeze(-1), params)
        loss_reco = global_losses_dict["l1"](predicted[1], image[:, :1, ...], None)
        loss_kld = global_losses_dict["kld"](predicted[2], predicted[3])
        loss_cycle = global_losses_dict["mse"](predicted[2], predicted[4], None)
        loss = 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle
    else:
        if deep_supervision_model:
            # this is for models that have deep-supervision
            for i, _ in enumerate(predicted):
                # loss is calculated based on resampled "soft" labels using a pre-defined weights array
                loss += (
                    loss_function(predicted[i], ground_truth_resampled[i], params)
                    * loss_weights[i]
                )
        else:
            loss = loss_function(predicted, ground_truth, params)
    metric_output = {}

    # Metrics should be a list
    for metric in params["metrics"]:
        metric_lower = metric.lower()
        metric_output[metric] = 0
        if metric_lower in global_metrics_dict:
            metric_function = global_metrics_dict[metric_lower]
            if sdnet_check:
                metric_output[metric] = get_metric_output(
                    metric_function, predicted[0], ground_truth.squeeze(-1), params
                )
            else:
                if deep_supervision_model:
                    for i, _ in enumerate(predicted):
                        metric_output[metric] += get_metric_output(
                            metric_function,
                            predicted[i],
                            ground_truth_resampled[i],
                            params,
                        )

                else:
                    metric_output[metric] = get_metric_output(
                        metric_function, predicted, ground_truth, params
                    )
        else:
            print(
                "WARNING: Could not find the requested metric '" + metric,
                file=sys.stderr,
            )
    return loss, metric_output
Ejemplo n.º 5
0
    def validate(self, use_tqdm=False):

        total_dice = 0

        val_loader = self.data.get_val_loader()

        if val_loader == []:
            raise RuntimeError(
                "Attempting to run validation with an empty val loader.")

        if use_tqdm:
            val_loader = tqdm.tqdm(val_loader, desc="validate")

        for subject in val_loader:
            # this is when we are using pt_brainmagedata
            if ('features' in subject.keys()) and ('gt' in subject.keys()):
                features = subject['features']
                mask = subject['gt']

                output = self.infer_batch_with_no_numpy_conversion(
                    features=features)

            # using the gandlf loader
            else:
                features = torch.cat(
                    [subject[key][torchio.DATA] for key in self.channel_keys],
                    dim=1)
                mask = subject['label'][torchio.DATA]

                if self.infer_gandlf_images_with_cropping:
                    output = self.data.infer_with_crop(
                        model_inference_function=[
                            self.infer_batch_with_no_numpy_conversion
                        ],
                        features=features)
                else:
                    output = self.data.infer_with_crop_and_patches(
                        model_inference_function=[
                            self.infer_batch_with_no_numpy_conversion
                        ],
                        features=features)

            # one-hot encoding of ground truth
            mask = one_hot(mask, self.data.class_list)

            # sanity check that the output and mask have the same shape
            if output.shape != mask.shape:
                raise ValueError(
                    'Model output and ground truth mask are not the same shape.'
                )

            # curr_dice = average_dice_over_channels(output.float(), mask.float(), self.binary_classification).cpu().data.item()
            curr_dice = clinical_dice(
                output.float(), mask.float(),
                class_list=self.data.class_list).cpu().data.item()
            total_dice += curr_dice

        #Computing the average dice
        average_dice = total_dice / len(val_loader)

        return average_dice
Ejemplo n.º 6
0
    def train_batches(self, num_batches, use_tqdm=False):
        num_subjects = num_batches

        device = torch.device(self.device)

        ################################ PRINTING SOME STUFF ######################
        print("\nHostname   :" + str(os.getenv("HOSTNAME")))
        sys.stdout.flush()

        print("Training Data Samples: ", len(self.data.train_loader.dataset))
        sys.stdout.flush()

        print('Using device:', device)
        if device.type == 'cuda':
            print('Memory Usage:')
            print('Allocated:',
                  round(torch.cuda.memory_allocated(0) / 1024**3, 1), 'GB')
            print('Cached: ', round(torch.cuda.memory_cached(0) / 1024**3, 1),
                  'GB')

        sys.stdout.flush()

        train_loader = self.data.get_train_loader()

        if train_loader == []:
            raise RuntimeError(
                "Attempting to run training with an empty training loader.")

        if use_tqdm:
            train_loader = tqdm.tqdm(train_loader,
                                     desc="training for this round")

        total_loss = 0
        subject_num = 0
        num_nan_losses = 0

        # set to "training" mode
        self.train()
        while subject_num < num_subjects:

            for subject in train_loader:
                if subject_num >= num_subjects:
                    break
                else:
                    if device.type == 'cuda':
                        print(
                            '=== Memory (allocated; cached) : ',
                            round(torch.cuda.memory_allocated(0) / 1024**3, 1),
                            '; ',
                            round(torch.cuda.memory_reserved(0) / 1024**3, 1))
                    # Load the subject and its ground truth
                    # this is when we are using pt_brainmagedata
                    if ('features' in subject.keys()) and ('gt'
                                                           in subject.keys()):
                        features = subject['features']
                        mask = subject['gt']
                    # this is when we are using gandlf loader
                    else:
                        features = torch.cat([
                            subject[key][torchio.DATA]
                            for key in self.channel_keys
                        ],
                                             dim=1)
                        mask = subject['label'][torchio.DATA]

                    print("\n\nTrain features with shape: {}\n".format(
                        features.shape))

                    mask = one_hot(mask, self.data.class_list)

                    # Loading features into device
                    features, mask = features.float().to(
                        device), mask.float().to(device)
                    # TODO: Variable class is deprecated - parameters to be given are the tensor, whether it requires grad and the function that created it
                    # features, mask = Variable(features, requires_grad = True), Variable(mask, requires_grad = True)
                    # Making sure that the optimizer has been reset
                    self.optimizer.zero_grad()
                    # Forward Propagation to get the output from the models

                    # TODO: Not recommended? (https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/6)will try without
                    #torch.cuda.empty_cache()

                    output = self(features.float())
                    # Computing the loss
                    loss = self.loss_fn(output.float(),
                                        mask.float(),
                                        num_classes=self.label_channels,
                                        weights=self.dice_penalty_dict,
                                        class_list=self.data.class_list)
                    # Back Propagation for model to learn (unless loss is nan)
                    if torch.isnan(loss):
                        num_nan_losses += 1
                    else:
                        loss.backward()
                        #Updating the weight values
                        self.optimizer.step()
                        #Pushing the dice to the cpu and only taking its value
                        loss.cpu().data.item()
                        total_loss += loss
                    self.lr_scheduler.step()

                    # TODO: Not recommended? (https://discuss.pytorch.org/t/about-torch-cuda-empty-cache/34232/6)will try without
                    #torch.cuda.empty_cache()

                    subject_num += 1

        num_subject_grads = num_subjects - num_nan_losses

        # we return the average batch loss over all epochs trained this round (excluding the nan results)
        # we also return the number of samples that produced nan losses, as well as total samples used
        # FIXME: In a federation we may want the collaborators data size to be modified when backprop is skipped.
        return {
            "loss": total_loss / num_subject_grads,
            "num_nan_losses": num_nan_losses,
            "num_samples_used": num_subjects
        }
Ejemplo n.º 7
0
    def validate(self, use_tqdm=False):

        # dice results are dictionaries
        if self.validate_with_fine_grained_dice:
            # here keys will be: 'ET', 'WT', and 'TC'
            total_dice = {'ET': 0, 'WT': 0, 'TC': 0}
        else:
            # here we only have one key: 'AVG(ET,WT,TC)'
            total_dice = {'AVG(ET,WT,TC)': 0}

        val_loader = self.data.get_val_loader()

        if val_loader == []:
            raise RuntimeError(
                "Attempting to run validation with an empty val loader.")

        if use_tqdm:
            val_loader = tqdm.tqdm(val_loader, desc="validate")

        for subject in val_loader:
            # this is when we are using pt_brainmagedata
            if ('features' in subject.keys()) and ('gt' in subject.keys()):
                features = subject['features']
                mask = subject['gt']

                output = self.infer_batch_with_no_numpy_conversion(
                    features=features)

            # using the gandlf loader
            else:
                features = torch.cat(
                    [subject[key][torchio.DATA] for key in self.channel_keys],
                    dim=1)
                mask = subject['label'][torchio.DATA]

                if self.validate_without_patches:
                    output = self.data.infer_with_crop(
                        model_inference_function=[
                            self.infer_batch_with_no_numpy_conversion
                        ],
                        features=features)
                else:
                    output = self.data.infer_with_crop_and_patches(
                        model_inference_function=[
                            self.infer_batch_with_no_numpy_conversion
                        ],
                        features=features)

            # one-hot encoding of ground truth
            mask = one_hot(mask, self.data.class_list)

            # sanity check that the output and mask have the same shape
            if output.shape != mask.shape:
                raise ValueError(
                    'Model output and ground truth mask are not the same shape.'
                )

            # curr_dice = average_dice_over_channels(output.float(), mask.float(), self.binary_classification).cpu().data.item()
            current_dice = clinical_dice(
                output=output.float(),
                target=mask.float(),
                class_list=self.data.class_list,
                fine_grained=self.validate_with_fine_grained_dice,
                to_scalar=True)
            # the dice results here are dictionaries (sum up the totals)
            for key in total_dice:
                total_dice[key] = total_dice[key] + current_dice[key]

        #Computing the average dice for all values of total_dice dict
        average_dice = {
            key: value / len(val_loader)
            for key, value in total_dice.items()
        }

        return average_dice