Ejemplo n.º 1
0
    def surface_distances(x, y):
        """ From
        https://github.com/BBillot/SynthSeg/blob/master/SynthSeg/evaluate.py
        Computes the average boundary distance of two masks.
        x and y should be boolean or 0/1 numpy arrays of the same size."""

        assert x.shape == y.shape, 'both inputs should have same size, ' \
                                   f'had {x.shape} and {y.shape}'

        x = to_numpy(x)
        y = to_numpy(y)

        # detect edge
        x_dist_int = distance_transform_edt(x * 1)
        x_edge = (x_dist_int == 1) * 1
        y_dist_int = distance_transform_edt(y * 1)
        y_edge = (y_dist_int == 1) * 1

        # calculate distance from edge
        x_dist = distance_transform_edt(np.logical_not(x_edge))
        y_dist = distance_transform_edt(np.logical_not(y_edge))

        # find distances from the 2 surfaces
        x_dists_to_y = y_dist[x_edge == 1]
        y_dists_to_x = x_dist[y_edge == 1]

        # find average distance between 2 surfaces
        x_mean_dist_to_y = np.mean(x_dists_to_y)
        y_mean_dist_to_x = np.mean(y_dists_to_x)

        return (x_mean_dist_to_y + y_mean_dist_to_x) / 2
Ejemplo n.º 2
0
    def save_volume(self,
                    sample,
                    volume,
                    idx=0,
                    batch_idx=0,
                    affine=None,
                    volume_name=None,
                    apply_activation=True):
        volume_name = volume_name or self.save_volume_name
        name = sample.get('name') or f'{idx:06d}'
        if affine is None:
            affine = self.get_affine(sample)

        if isinstance(name, list):
            name = name[batch_idx]
            name = name[0] if isinstance(name, list) else name
        if apply_activation:
            if self.criteria[0][
                    'criterion'].mixt_activation:  #softmax apply only on segmentation not regression
                skip_vol = self.criteria[0]['criterion'].mixt_activation
                vv = self.activation(volume[0, :-skip_vol, ...].unsqueeze(0))
                volume[0, :-skip_vol, ...] = vv[0]
            else:
                volume = self.activation(volume)

        resdir = f'{self.eval_results_dir}/{name}/'
        if not os.path.isdir(resdir):
            os.makedirs(resdir)
        if 'list_idx' in sample:
            volume_name = 'l{}_'.format(batch_idx) + volume_name

        if self.save_bin:
            bin_volume = torch.argmax(volume, dim=1)
            bin_volume = nib.Nifti1Image(
                to_numpy(bin_volume[0]).astype(np.uint8), affine)
            nib.save(bin_volume, f'{resdir}/bin_{volume_name}.nii.gz')

        volume[volume < self.save_threshold] = 0.

        if self.save_channels is not None:
            channels = [self.labels.index(c) for c in self.save_channels]
            channel_names = self.save_channels
            volume = volume[:, channels, ...]
        else:
            channel_names = self.labels

        if self.split_channels:
            for channel in range(volume.shape[1]):
                label = channel_names[channel]
                v = nib.Nifti1Image(to_numpy(volume[0, channel, ...]), affine)
                nib.save(v, f'{resdir}/{label}.nii.gz')

        else:
            volume = nib.Nifti1Image(
                to_numpy(volume.permute(0, 2, 3, 4, 1).squeeze()), affine)
            nib.save(volume, f'{resdir}/{volume_name}.nii.gz')
            self.debug('saving {}'.format(f'{resdir}/{volume_name}.nii.gz'))
Ejemplo n.º 3
0
def minimum_t_norm(prediction, target, background=False):
    device = prediction.device
    np_prediction = to_numpy(prediction)
    np_target = to_numpy(target)
    if background:
        res = np.maximum(np_prediction - np_target, 0)
    else:
        res = np.minimum(np_prediction, np_target)
    return to_var(res, device)
Ejemplo n.º 4
0
    if vol_crop_pad:
        tpad = torchio.CropOrPad(target_shape=vol_crop_pad)
        volume = tpad(volume)

    model_struct = {
        'module': args.model_module,
        'name': args.model_name,
        'last_one': False,
        'path': args.model,
        'device': args.device
    }
    config = Config(None, None, save_files=False)
    model_struct = config.parse_model_file(model_struct)
    model, device = config.load_model(model_struct)
    model.eval()

    with torch.no_grad():
        prediction = model(volume.data.unsqueeze(0).float().to(device))

    if nb_vol_exclude > 0:
        pp = F.softmax(prediction[0, :-nb_vol_exclude, ...].unsqueeze(0),
                       dim=1)
        prediction[0, :-nb_vol_exclude, ...] = pp[0]
    else:
        prediction = F.softmax(prediction, dim=1)

    image = nib.Nifti1Image(to_numpy(prediction[0].permute(1, 2, 3, 0)),
                            volume.affine)
    nib.save(image, args.filename)
Ejemplo n.º 5
0
    def record_segmentation_batch(self,
                                  df,
                                  sample,
                                  predictions,
                                  targets,
                                  batch_time,
                                  save=False,
                                  csv_name='eval',
                                  append_in_df=False):
        """
        Record information about the batches the model was trained or evaluated
        on during the segmentation task.
        At evaluation time, additional reporting metrics are recorded.
        """
        start = time.time()

        is_batch = not isinstance(sample, torchio.Subject)
        if self.model.training:
            mode = 'Train'
        else:
            if self.eval_results_dir != self.results_dir and csv_name == 'eval' and append_in_df is False:
                df = pd.DataFrame()
            mode = 'Val' if is_batch else 'Whole_image'
            if csv_name == 'patch_eval':
                mode = 'patch_eval'

        shape = targets.shape
        #size = np.product(shape[2:])
        location = sample.get(
            'index_ini') if 'index_ini' in sample else sample.get('location')
        if location is not None:
            location = location[:, :3]
        affine = self.get_affine(sample)
        # M is the product between a scaling and a rotation
        M = affine[:3, :3]
        voxel_size = np.diagonal(np.sqrt(M @ M.T)).prod()

        batch_size = shape[0]

        sample_time = batch_time / batch_size

        time_sum = 0

        for idx in range(batch_size):
            if is_batch:
                image_path = sample[self.image_key_name]['path'][idx]
            else:
                image_path = sample[self.image_key_name]['path']
            info = {
                'image_filename': image_path,
                'shape': to_numpy(shape[2:]),
                'sample_time': sample_time
            }
            if is_batch:
                info['batch_size'] = batch_size

            if 'name' in sample:
                info['name'] = sample['name'][idx]

            if is_batch:
                info['label_filename'] = sample[
                    self.label_key_name]['path'][idx]
            else:
                info['label_filename'] = sample[self.label_key_name]['path']

            if self.criteria[0]['criterion'].mixt_activation:
                max_chanel = shape[1] - self.criteria[0][
                    'criterion'].mixt_activation
            else:
                max_chanel = shape[1]
            for channel in list(range(max_chanel)):
                suffix = self.labels[channel]
                info[f'occupied_volume_{suffix}'] = to_numpy(
                    targets[idx, channel].sum() * voxel_size)
                info[f'predicted_occupied_volume_{suffix}'] = to_numpy(
                    self.activation(predictions)[idx, channel].sum() *
                    voxel_size)

            if location is not None:
                info['location'] = to_numpy(location[idx])

            loss = 0
            for criterion in self.criteria:
                one_loss = criterion['criterion'](
                    predictions[idx].unsqueeze(0), targets[idx].unsqueeze(0))
                if isinstance(
                        one_loss, tuple
                ):  #multiple task loss may return tuple to report each loss
                    for i in range(1, len(one_loss)):
                        aaa = one_loss[i]
                        if aaa.requires_grad:
                            aaa = aaa.detach()
                        info[f'loss_{i}'] = to_numpy(aaa)

                    one_loss = one_loss[0]
                loss += criterion['weight'] * one_loss

            info['loss'] = to_numpy(loss)

            if not self.model.training:
                for metric in self.metrics:
                    name = f'metric_{metric["name"]}'

                    info[name] = json.dumps(to_numpy(
                        metric['criterion'](predictions[idx].unsqueeze(0),
                                            targets[idx].unsqueeze(0))),
                                            cls=ArrayTensorJSONEncoder)

            reporting_time = time.time() - start
            time_sum += reporting_time
            info['reporting_time'] = reporting_time
            start = time.time()

            self.record_history(info, sample, idx)

            df = df.append(info, ignore_index=True)

        if save:
            self.save_info(mode, df, sample, csv_name)

        return df, time_sum
Ejemplo n.º 6
0
    def record_simple(self,
                      df,
                      sample,
                      predictions,
                      targets,
                      batch_time,
                      save=False):
        """
        Record information about the batches the model was trained or evaluated
        on during the segmentation task.
        At evaluation time, additional reporting metrics are recorded.
        """
        start = time.time()

        is_batch = not isinstance(sample, torchio.Subject)
        if self.model.training:
            mode = 'Train'
        else:
            if self.eval_results_dir != self.results_dir:
                df = pd.DataFrame()
            mode = 'Val' if is_batch else 'Whole_image'

        shape = targets.shape
        #size = np.product(shape[2:])
        location = sample.get('index_ini')

        batch_size = shape[0]

        sample_time = batch_time / batch_size

        time_sum = 0

        for idx in range(batch_size):
            if is_batch:
                image_path = sample[self.image_key_name]['path'][idx]
            else:
                image_path = sample[self.image_key_name]['path']
            info = {
                'image_filename': image_path,
                'shape': to_numpy(shape[2:]),
                'sample_time': sample_time
            }

            if is_batch:
                info['label_filename'] = sample[
                    self.label_key_name]['path'][idx]
            else:
                info['label_filename'] = sample[self.label_key_name]['path']

            if location is not None:
                info['location'] = to_numpy(location[idx])

            loss = 0
            for criterion in self.criteria:
                loss += criterion['weight'] * criterion['criterion'](
                    predictions[idx].unsqueeze(0), targets[idx].unsqueeze(0))
            info['loss'] = to_numpy(loss)

            if not self.model.training:
                for metric in self.metrics:
                    name = f'metric_{metric["name"]}'

                    info[name] = json.dumps(to_numpy(
                        metric['criterion'](predictions[idx].unsqueeze(0),
                                            targets[idx].unsqueeze(0))),
                                            cls=ArrayTensorJSONEncoder)
            if 'metrics' in sample[self.image_key_name]:
                dics = sample[self.image_key_name]['metrics']
                dicm = {}
                for key, val in dics.items():
                    dicm[key] = to_numpy(val[idx])
                    # if isinstance(val,dict): # hmm SSIM_wrapped still contains dict
                    #     for kkey, vval in val.items():
                    #         dicm[key + '_' + kkey] = to_numpy(vval[idx])
                    # else:
                    #     dicm[key] = to_numpy(val[idx])
                info.update(dicm)

            reporting_time = time.time() - start
            time_sum += reporting_time
            info['reporting_time'] = reporting_time
            start = time.time()

            self.record_history(info, sample, idx)

            df = df.append(info, ignore_index=True)

        if save:
            self.save_info(mode, df, sample)

        return df, time_sum
Ejemplo n.º 7
0
 def get_affine(self, sample):
     affine = sample[self.image_key_name]['affine']
     if affine.ndim == 3:
         affine = to_numpy(affine[0])
     return affine
Ejemplo n.º 8
0
    def record_regression_batch(self,
                                df,
                                sample,
                                predictions,
                                targets,
                                batch_time,
                                save=False):
        """
        Record information about the the model was trained or evaluated on during the regression task.
        At evaluation time, additional reporting metrics are recorded.
        """
        start = time.time()
        mode = 'Train' if self.model.training else 'Val'
        if self.eval_results_dir != self.results_dir:
            df = pd.DataFrame()
            save = True

        location = sample.get('index_ini')
        shape = sample[self.image_key_name]['data'].shape
        batch_size = shape[0]
        sample_time = batch_time / batch_size
        time_sum = 0
        is_batch = not isinstance(sample, torchio.Subject)

        for idx in range(batch_size):
            info = {
                'image_filename':
                sample[self.image_key_name]['path'][idx]
                if is_batch else sample[self.image_key_name]['path'],
                'shape':
                to_numpy(shape[2:]),
                'sample_time':
                sample_time,
                'batch_size':
                batch_size,
            }

            if location is not None:
                info['location'] = to_numpy(location[idx])

            if self.label_key_name in sample:
                info['label_filename'] = sample[
                    self.label_key_name]['path'][idx] if is_batch else sample[
                        self.label_key_name]['path']
            if 'name' in sample:
                info['subject_name'] = sample['name'][
                    idx] if is_batch else sample['name']

            with torch.no_grad():
                loss = 0
                for criterion in self.criteria:
                    loss += criterion['weight'] * criterion['criterion'](
                        predictions[idx].unsqueeze(0),
                        targets[idx].unsqueeze(0))
                info['loss'] = to_numpy(loss)
                info['prediction'] = to_numpy(predictions[idx])
                info['targets'] = to_numpy(targets[idx])
                if 'target_name' in self.__dict__.keys():
                    for i_target, tgn in enumerate(self.target_name):
                        info['tar_' + tgn] = to_numpy(targets[idx][i_target])
                        info['pred_' + tgn] = to_numpy(
                            predictions[idx][i_target])
                if 'scale_label' in self.__dict__.keys():
                    for i_target, tgn in enumerate(self.target_name):
                        info['scale_' + tgn] = self.scale_label[i_target]

            if 'simu_param' in sample[self.image_key_name]:
                #dicm = sample[self.image_key_name]['metrics']
                dics = sample[self.image_key_name]['simu_param']
                dicm = {}
                for key, val in dics.items():
                    dicm[key] = to_numpy(val[idx])
                info.update(dicm)

            if 'metrics' in sample[self.image_key_name]:
                dics = sample[self.image_key_name]['metrics']
                dicm = {}
                """
                for key, val in dics.items():
                    dicm[key] = to_numpy(val[idx])
                    # if isinstance(val,dict): # hmm SSIM_wrapped still contains dict
                    #     for kkey, vval in val.items():
                    #         dicm[key + '_' + kkey] = to_numpy(vval[idx])
                    # else:
                    #     dicm[key] = to_numpy(val[idx])
                """
                info.update({"metrics": {self.image_key_name: dics[idx]}})

            if not self.model.training:
                for metric in self.metrics:
                    name = f'metric_{metric["name"]}'

                    info[name] = json.dumps(to_numpy(
                        metric['criterion'](predictions[idx].unsqueeze(0),
                                            targets[idx].unsqueeze(0))),
                                            cls=ArrayTensorJSONEncoder)

            self.record_history(info, sample, idx)

            reporting_time = time.time() - start
            time_sum += reporting_time
            info['reporting_time'] = reporting_time
            start = time.time()

            df = df.append(info, ignore_index=True)

        if save:
            self.save_info(mode, df, sample)

        return df, time_sum
Ejemplo n.º 9
0
    def __call__(self, x, target):
        """
        rrr change to 2 input argument to avoid changing the generic call of the loss (loss(prediction,target)
        :param x: output segmentation, shape [*, C, *]
        :param sigma2 == sigma**2 or log(sigma**2) if apply_exp is set: variance map, shape [*, *]
        :param target: true segmentation, assuming that soft-labels are available, shape [*, C, *]
        :return: log-likelihood for logistic regression (classif)/ridge regression (regression) with uncertainty
        """
        #sigam2 is supposed to be the n last predicted output (n = self.sigma_prediction
        #sigma2 = x[:, -1, :] #withou (-1:) the dimension becomes batch,volume as classif loss
        if self.sigma_constrain == 'logsigmoid':
            sigma2 = torch.nn.functional.logsigmoid(x[:,
                                                      -self.sigma_prediction:,
                                                      ...])
        elif self.sigma_constrain == "softplus":
            sigma2 = torch.nn.functional.softplus(x[:, -self.sigma_prediction:,
                                                    ...])
        else:
            sigma2 = x[:, -self.sigma_prediction:, ...]
        x = x[:, :-self.sigma_prediction, ...]

        if self.fake:
            if isinstance(self.sup_loss, torch.nn.MSELoss):
                res_loss = self.sup_loss(x, target).sum(dim=1).mean()
            else:
                res_loss = self.sup_loss(x, target).mean()
            return res_loss

        #print(f'lamb is {self.lamb} shape is {x.shape}')
        if self.apply_exp:
            if self.sigma_constrain == "softplus":  #well do not apply ex
                sigma2 = sigma2.squeeze(dim=1) + 1e-6
                log_sigma2 = torch.log(sigma2)
            else:
                sigma2 = sigma2.squeeze(
                    dim=1)  #remove channel dim if only one sigma
                log_sigma2 = sigma2
                sigma2 = torch.exp(log_sigma2) + 1e-3

            if isinstance(self.sup_loss, torch.nn.MSELoss):
                mse_loss = self.sup_loss(x, target).sum(dim=1)
                print(
                    f'MSE/SIGMA min {mse_loss.min():.4f} | {sigma2.min():.4f} max {mse_loss.max():.4f} | {sigma2.max():.4f} mean {mse_loss.mean():.4f} |  {sigma2.mean():.4f}'
                )
                res_loss = (1. / sigma2.squeeze(dim=1) * mse_loss +
                            self.lamb * log_sigma2.squeeze(dim=1)).mean()
            else:
                #if x.isnan().any():
                #    qsdf
                the_loss = self.sup_loss(x, target)
                res_loss = (1. / sigma2 * the_loss +
                            self.lamb * log_sigma2).mean()

        else:
            the_loss = self.sup_loss(x, target)
            print(
                f'BCE/SIGMA  max {the_loss.max():.4f} | {sigma2.max():.4f} mean {the_loss.mean():.4f} |  {sigma2.mean():.4f}'
            )
            res_loss = (1. / sigma2 * self.sup_loss(x, target) +
                        self.lamb * torch.log(sigma2)).mean()

        if (
                self.return_loss_dict
        ):  #& (target.shape[0]==1) : #only need in record  batch for single iteration
            #NOT required in the main train_loop for training (with batch >1)

            shape_loss = the_loss.shape
            lThnorm = torch.linalg.norm(the_loss.reshape(
                [shape_loss[0],
                 shape_loss[1] * shape_loss[2] * shape_loss[3]]),
                                        ord=2,
                                        dim=1).mean()
            lSnorm = torch.linalg.norm(sigma2.reshape(
                [shape_loss[0],
                 shape_loss[1] * shape_loss[2] * shape_loss[3]]),
                                       ord=2,
                                       dim=1).mean()
            dict_loss = {
                'loss_kll_norm': to_numpy(lThnorm),
                'loss_sigma_norm': to_numpy(lSnorm),
                'loss_kll_mean': to_numpy(the_loss.mean()),
                'loss_sigma_mean': to_numpy(sigma2.mean()),
                'loss_kll_max': to_numpy(the_loss.max()),
                'loss_sigma_max': to_numpy(sigma2.max())
            }
            return res_loss, dict_loss
        else:
            return res_loss