Пример #1
0
    def __init__(self, model, loss_type="Dice", model_kwargs=None, devices=(0,1),
                 predictions_specs=None,
                 train_glia_mask=False,
                 target_has_label_segm=False,
                 target_has_various_masks=False,
                 precrop_pred=None):
        super(MultiLevelSparseAffinityLoss, self).__init__()
        if loss_type == "Dice":
            self.loss = SorensenDiceLoss()
        elif loss_type == "MSE":
            self.loss = nn.MSELoss()
        elif loss_type == "BCE":
            self.loss = nn.BCELoss()
        else:
            raise ValueError

        self.devices = devices
        self.model_kwargs = model_kwargs
        self.MSE_loss = nn.MSELoss()
        self.smoothL1_loss = nn.SmoothL1Loss()
        # TODO: use nn.BCEWithLogitsLoss()
        self.BCE = nn.BCELoss()
        self.soresen_loss = SorensenDiceLoss()

        self.model = model
        assert predictions_specs is not None, "A dictionary should be passed"
        self.predictions_specs = predictions_specs
        self.precrop_pred = precrop_pred
        self.target_has_label_segm = target_has_label_segm
        self.target_has_various_masks = target_has_various_masks
        self.train_glia_mask = train_glia_mask
Пример #2
0
    def __init__(self, model, apply_checkerboard=False, loss_type="Dice",
                 ignore_label=0,
                 train_glia_mask=False,
                 boundary_label=None,
                 glia_label=None,
                 train_patches_on_glia=False,
                 fix_bug_multiscale_patches=False,
                 defected_label=None,
                 IoU_loss_kwargs=None,
                 sparse_affs_loss_kwargs=None,
                 indx_trained_patchNets=None,
                 model_kwargs=None, devices=(0, 1)):
        super(LatentMaskLoss, self).__init__()
        if loss_type == "Dice":
            self.loss = SorensenDiceLoss()
        elif loss_type == "MSE":
            self.loss = nn.MSELoss()
        elif loss_type == "BCE":
            self.loss = nn.BCELoss()
        else:
            raise ValueError

        self.apply_checkerboard = apply_checkerboard
        self.fix_bug_multiscale_patches = fix_bug_multiscale_patches
        self.ignore_label = ignore_label
        self.boundary_label = boundary_label
        self.glia_label = glia_label
        self.defected_label = defected_label
        self.train_glia_mask = train_glia_mask
        self.train_patches_on_glia = train_patches_on_glia
        self.indx_trained_patchNets = indx_trained_patchNets
        self.add_IoU_loss = False
        if IoU_loss_kwargs is not None:
            raise NotImplementedError()
            # self.add_IoU_loss = True
            # from .compute_IoU import IoULoss
            # self.IoU_loss = IoULoss(model, model_kwargs=model_kwargs, devices=devices, **IoU_loss_kwargs)

        self.devices = devices
        # TODO: get rid of kwargs
        self.model_kwargs = model_kwargs
        self.MSE_loss = nn.MSELoss()
        self.smoothL1_loss = nn.SmoothL1Loss()
        # TODO: use nn.BCEWithLogitsLoss()
        self.BCE = nn.BCELoss()
        self.soresen_loss = SorensenDiceLoss()

        self.model = model

        self.train_sparse_loss = False
        self.sparse_multilevelDiceLoss = None
        if sparse_affs_loss_kwargs is not None:
            self.train_sparse_loss = True
            self.sparse_multilevelDiceLoss = MultiLevelSparseAffinityLoss(model, model_kwargs=model_kwargs,
                                                                          devices=devices,
                                                                          **sparse_affs_loss_kwargs)
Пример #3
0
 def test_channelwise(self):
     from inferno.extensions.criteria.set_similarity_measures import SorensenDiceLoss
     x, y = self.get_dummy_variables()
     channelwise = SorensenDiceLoss(channelwise=True)
     not_channelwise = SorensenDiceLoss(channelwise=False)
     # Compute expected channelwise loss
     expected_channelwise_loss = \
         not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \
         not_channelwise(x[:, 1, ...], y[:, 1, ...])
     # Compute channelwise
     channelwise_loss = channelwise(x, y)
     # Compare
     self.assertAlmostEqual(expected_channelwise_loss.data[0], channelwise_loss.data[0])
Пример #4
0
    def test_mask_ignore_label(self):
        from neurofire.criteria.loss_transforms import MaskIgnoreLabel
        from neurofire.transform.segmentation import Segmentation2Membranes
        trafo = MaskIgnoreLabel(-1)
        seg_trafo = Segmentation2Membranes()

        seg = self.make_segmentation_with_ignore(self.shape)
        ignore_mask = seg == 0

        # make membrane target
        target = seg_trafo(seg)
        target[ignore_mask] = -1
        target = torch.from_numpy(target)
        target.requires_grad = False
        self.assertEqual(target.shape, seg.shape)

        # make dummy torch prediction
        prediction = torch.Tensor(*self.shape).uniform_(0, 1)
        prediction.requires_grad = True
        masked_prediction, _ = trafo(prediction, target)

        self.assertEqual(prediction.size(), masked_prediction.size())

        # apply a loss to the prediction and check that the
        # masked parts are actually zero

        # apply cross entropy loss
        criterium = SorensenDiceLoss()
        loss = criterium(masked_prediction, target)
        loss.backward()

        grads = prediction.grad.data.numpy()
        self.assertTrue((grads[ignore_mask] == 0).all())
        self.assertTrue((grads[np.logical_not(ignore_mask)] != 0).all())
Пример #5
0
    def inferno_build_criterion(self):
        print("Building criterion")
        loss_config = self.get('trainer/criterion/losses')

        criterion = SorensenDiceLoss()
        loss_train = LossWrapper(criterion=criterion, transforms=None)
        loss_val = LossWrapper(criterion=criterion, transforms=None)
        self._trainer.build_criterion(loss_train)
        self._trainer.build_validation_criterion(loss_val)
Пример #6
0
    def __init__(self, offsets='default-3D', loss_weights=None, ignore_label=None, margin=0,
                 use_cosine_distance=False, pull_weight=0, push_weight=0, affinity_weight=1,
                 affinities_direct=False, **super_kwargs):
        if callable(offsets):
            self.offset_sampler = offsets
            self.dynamic_offsets = True
            offsets = self.offset_sampler()
            loss_names = None
            loss_weights = 'average'
        else:
            self.offsets = get_offsets(offsets)
            self.dynamic_offsets = False
            if loss_weights is None:
                loss_weights = (1 / len(self.offsets),) * len(self.offsets)
            assert len(loss_weights) == len(self.offsets)
            loss_names = ["offset_" + '_'.join(str(o) for o in off) for off in self.offsets]
            print(loss_names)

        super(LossAffinitiesFromEmbedding, self).__init__(
            loss_weights=loss_weights,
            loss_names=loss_names,
            enable_logging=not self.dynamic_offsets,
            **super_kwargs)
        self.ignore_label = ignore_label
        self.use_cosine_distance = use_cosine_distance
        self.ignore_label = ignore_label
        self.margin = margin
        self.push_weight = push_weight
        self.pull_weight = pull_weight
        self.affinity_weight = affinity_weight

        # initialize distance/affinity generating functions
        self.seg_to_aff = EmbeddingToAffinities(offsets=offsets,
                                                affinity_measure=label_equal_similarity,
                                                pass_offset=False)

        if self.affinity_weight is not 0:
            self.emb_to_aff = EmbeddingToAffinities(offsets=offsets,
                                                    affinity_measure=self.affinity_measure,
                                                    pass_offset=True)
            self.aff_loss = SorensenDiceLoss(channelwise=False)

        if self.ignore_label is not None:
            self.seg_to_mask = EmbeddingToAffinities(offsets=offsets,
                                                     affinity_measure=ignore_label_mask_similarity,
                                                     pass_offset=False)

        if self.push_weight != 0 or self.pull_weight != 0:
            self.emb_to_dist = EmbeddingToAffinities(offsets=offsets,
                                                     affinity_measure=self.distance_measure,
                                                     pass_offset=False)

        self.affinities_direct = affinities_direct
        if affinities_direct:
            assert self.pull_weight == self.push_weight == 0 and self.affinity_weight != 0
        self.relu = torch.nn.ReLU()
Пример #7
0
 def inferno_build_criterion(self):
     print("Building criterion")
     # path = self.get("autoencoder/path")
     # loss_kwargs = self.get("trainer/criterion/kwargs")
     # from vaeAffs.models.modified_unet import EncodingLoss, PatchLoss, AffLoss
     from vaeAffs.transforms import ApplyIgnoreMask
     # loss = AffLoss(**loss_kwargs)
     from neurofire.criteria.loss_wrapper import LossWrapper
     loss = LossWrapper(SorensenDiceLoss(), transforms=ApplyIgnoreMask())
     self._trainer.build_criterion(loss)
     self._trainer.build_validation_criterion(loss)
Пример #8
0
    def inferno_build_criterion(self):
        print("Building criterion")
        loss_config = self.get('trainer/criterion/losses')

        criterion = SorensenDiceLoss()
        loss_train = LossWrapper(criterion=criterion,
                                 transforms=Compose(ApplyAndRemoveMask(), InvertTarget()))
        loss_val = LossWrapper(criterion=criterion,
                               transforms=Compose(RemoveSegmentationFromTarget(),
                                                  ApplyAndRemoveMask(), InvertTarget()))
        self._trainer.build_criterion(loss_train)
        self._trainer.build_validation_criterion(loss_val)
Пример #9
0
    def __init__(self):
        # Hyper Parameters
        self.input_size = 48
        self.output_size = 48
        # self.learning_rate = 0.001
        self.learning_rate = 0.005

        # training_ratio = 1.
        #
        # all_images_paths = get_GMIS_dataset(partial=False, type="train")
        # print("Number of ROIs: ", len(all_images_paths))
        # nb_images_in_training = int(len(all_images_paths) * training_ratio)
        # print("Training ROIs: ", nb_images_in_training)

        model_path = os.path.join(
            get_trendytukan_drive_path(),
            "GMIS_predictions/logistic_regression_model/pyT_model_train_2.pkl")
        if os.path.exists(model_path):
            print("Model loaded from file!")
            model = torch.load(model_path)
        else:
            model = LogisticRegression(self.input_size, self.output_size)

        # Loss and Optimizer
        # Softmax is internally computed.
        # Set parameters to be updated.
        # criterion = nn.BCEWithLogitsLoss(reduction='none')

        criterion = SorensenDiceLoss()

        # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=self.learning_rate,
                                     weight_decay=0.0005)

        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")

        model.to(self.device)

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
Пример #10
0
    def __init__(self,
                 SD_weight,
                 topo_weight,
                 topo_SD_weight,
                 enable_logging=True,
                 pretraining=False,
                 weight=None,
                 channelwise=True,
                 eps=1e-6):
        """
        Parameters
        ----------
        :param SD_weight: float
            This controls the weighting of the affinity channel Sorensen-Dice vs. boundary branch losses
        :param topo_weight: float
            This controls the weighting of topological loss on the boundary branch.
        :param topo_SD_weight: float
            This controls the weighting of Sorensen-Dice loss on the boundary branch.
        :param pretraining: bool
            If you want to pretrain your model on the affinity branch with a SD only, set this falg to true.
        :param weight: torch.FloatTensor or torch.cuda.FloatTensor
            Class weights: Applies only if `channelwise = True`.
        :param channelwise: bool
            Whether to apply the loss channelwise and sum the results (True)
            or to apply it on all channels jointly (False).
        """
        super(TopologicalLoss, self).__init__()
        self.register_buffer('weight', weight)
        self.channelwise = channelwise
        self.eps = eps

        self.SD_weight = SD_weight
        self.topo_weight = topo_weight
        self.topo_SD_weight = topo_SD_weight
        self.pretraining = pretraining

        self.SDLoss = SorensenDiceLoss()
        self.TopoLoss = TopologicalLossFunction.apply

        # initialise logging
        if enable_logging:
            self.log = logging.getLogger()
            self.log.setLevel(logging.INFO)
Пример #11
0
    def test_transition_mask(self):
        from neurofire.criteria.loss_transforms import MaskTransitionToIgnoreLabel
        from neurofire.transform.affinities import Segmentation2Affinities

        offsets = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (9, 0, 0), (0, 9, 0),
                   (0, 0, 9), (9, 4, 0), (4, 9, 0), (9, 0, 9)]
        trafo = MaskTransitionToIgnoreLabel(offsets, ignore_label=0)
        aff_trafo = Segmentation2Affinities(offsets=offsets,
                                            retain_segmentation=True,
                                            retain_mask=False)

        seg = self.make_segmentation_with_ignore(self.shape)
        ignore_mask = self.brute_force_transition_masking(seg, offsets)
        target = torch.from_numpy(aff_trafo(seg.astype('float32'))[None])
        target.requires_grad = False

        tshape = target.size()
        pshape = (tshape[0], tshape[1] - 1) + tshape[2:]
        prediction = torch.Tensor(*pshape).uniform_(0, 1)
        prediction.require_grad = True
        masked_prediction, target_ = trafo(prediction, target)

        self.assertEqual(masked_prediction.size(), prediction.size())
        self.assertEqual(target.size(), target_.size())
        self.assertTrue(np.allclose(target.data.numpy(), target_.data.numpy()))

        # apply cross entropy loss
        criterium = SorensenDiceLoss()
        target = target[:, 1:]
        self.assertEqual(target.size(), masked_prediction.size())
        loss = criterium(masked_prediction, target)
        loss.backward()

        grads = prediction.grad.data.numpy().squeeze()
        self.assertEqual(grads.shape, ignore_mask.shape)
        self.assertTrue((grads[ignore_mask] == 0).all())
        self.assertFalse(np.sum(grads[np.logical_not(ignore_mask)]) == 0)
Пример #12
0
    model_path = os.path.join(
        get_trendytukan_drive_path(),
        "GMIS_predictions/logistic_regression_model/pyT_model_train_2.pkl")
    if os.path.exists(model_path):
        print("Model loaded from file!")
        model = torch.load(model_path)
    else:
        model = GMIS_utils.LogisticRegression(input_size, output_size)

    # Loss and Optimizer
    # Softmax is internally computed.
    # Set parameters to be updated.
    # criterion = nn.BCEWithLogitsLoss(reduction='none')

    criterion = SorensenDiceLoss()

    # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=0.0005)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model.to(device)

    # TODO: increase batch size! (if it is possible, sizewise...)

    # TODO: improving options
    #  - apply semantic combination to the inputed affs
    #  - re-create same two-layers network
Пример #13
0
class NevenMaskLoss(WeightedLoss):
    """
    Generalization of loss from
    "Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth" by Neven et al.

    Notes on the original code at https://github.com/davyneven/SpatialEmbeddings
        - Model is a multi branched variant with 3 independent decoders for embeddings, sigmas, seeds.
        - Last layer of model is initialized differently: sigmas to 0 weight, 1 bias, seeds to 0.
        - Sigma is made positive by sigma = exp(10 * sigma). Why 10?
        - Clustering:
            - They only cluster all but 128 (seems arbitrary) pixels are assigned.

    Ideas for generalization:
        - Free, non-offset embeddings
        - learned sigma -> learned, position dependent similarity measure
            - general gaussians
            - weights of a couple of 1x1 convolutions
    """
    BINARY_LOSSES = {
        'BCE':
        torch.nn.BCELoss(),
        'Dice':
        SorensenDiceLoss(channelwise=True),
        'LovaszHinge':
        lambda pred, label: lovasz_hinge(2 * pred - 1, label, per_image=True),
    }

    def __init__(self,
                 loss_weights=(1, 1, 1),
                 ignore_label=None,
                 log_masks=0,
                 binary_loss='BCE',
                 spatial_dim=3,
                 **super_kwargs):

        super_kwargs = dict(
            loss_weights=loss_weights,
            loss_names=['instance-loss', 'variance-loss', 'seed-loss'],
            **super_kwargs)

        super(NevenMaskLoss, self).__init__(**super_kwargs)
        assert binary_loss in self.BINARY_LOSSES, f'Please select one of {self.BINARY_LOSSES}. Got {binary_loss}.'
        self.mask_comparison_loss = self.BINARY_LOSSES[binary_loss]

        self.ignore_label = ignore_label
        self.log_masks = log_masks
        self.predicted_masks = None
        self.gt_masks = None
        self.spatial_dim = spatial_dim

    def get_losses(self, preds, labels):
        embeddings, sigmas, seed_maps = preds
        gt_segs = labels[0]

        # for logging purposes only
        self.predicted_masks = []
        self.gt_masks = []

        losses = [
            self.get_losses_single_sample(*inputs)
            for inputs in zip(embeddings, sigmas, seed_maps, gt_segs)
        ]
        if self.log_masks:
            assert log_image is not None, f'need speedrun anywhere logging to log masks'
            log_image('pred_instance_masks',
                      torch.stack(self.predicted_masks, dim=1))
            log_image('gt_instance_masks', torch.stack(self.gt_masks, dim=1))
        return torch.stack(losses).mean(0)

    def get_losses_single_sample(self, embedding, sigmas, seed_map, gt_seg):
        instance_ids = gt_seg.unique()
        instance_ids = instance_ids[instance_ids != self.ignore_label]
        losses = [
            self.get_loss_single_instance(instance_id, embedding, sigmas,
                                          seed_map, gt_seg)
            for instance_id in instance_ids
        ]
        losses = torch.stack([torch.stack(loss) for loss in losses])

        # regress seeds to 0 on background
        if self.ignore_label is not None:
            losses[-1] += ((seed_map[gt_seg == self.ignore_label])**2).mean()
        return losses.mean(0)

    def predicted_instance_mask(self, embedding, target_embedding, sigma):
        return torch.exp(-(
            (embedding - target_embedding[(slice(None), ) + self.spatial_dim *
                                          (None, )])**2).sum(0, keepdim=True) /
                         sigma**2)

    def get_loss_single_instance(self, instance_id, embedding, sigmas,
                                 seed_maps, gt_seg):
        instance_mask = (gt_seg == instance_id)[0]  # (1, W, H)

        # the target embedding is the mean of embedding vectors of pixels in instance
        target_embedding = embedding[:, instance_mask].view(
            embedding.shape[0], -1).mean(1)  # (E)

        # we take as sigma the mean predicted sigma over the instance
        sigma = sigmas[:, instance_mask].view(sigmas.shape[0], -1).mean(1)

        predicted_mask = self.predicted_instance_mask(embedding,
                                                      target_embedding, sigma)
        if len(self.predicted_masks) < self.log_masks:
            self.predicted_masks.append(
                predicted_mask.detach()[None])  # B I D H W
            self.gt_masks.append(instance_mask.detach()[None,
                                                        None])  # B I D H W

        instance_loss = self.mask_comparison_loss(predicted_mask,
                                                  instance_mask.float()[None])

        sigma_loss = ((sigmas[:, instance_mask] -
                       sigma.detach().item())**2).sum(0).mean()

        seed_maps_loss = ((seed_maps[:, instance_mask] -
                           predicted_mask[:, instance_mask])**2).sum(0).mean()

        return instance_loss, sigma_loss, seed_maps_loss
Пример #14
0
    def __init__(self,
                 model,
                 apply_checkerboard=False,
                 loss_type="Dice",
                 ignore_label=0,
                 train_glia_mask=False,
                 boundary_label=None,
                 glia_label=None,
                 train_patches_on_glia=False,
                 fix_bug_multiscale_patches=False,
                 defected_label=None,
                 IoU_loss_kwargs=None,
                 sparse_affs_loss_kwargs=None,
                 indx_trained_patchNets=None,
                 model_kwargs=None,
                 devices=(0, 1)):
        super(PatchBasedLoss, self).__init__()
        if loss_type == "Dice":
            self.loss = SorensenDiceLoss()
        elif loss_type == "MSE":
            self.loss = nn.MSELoss()
        elif loss_type == "BCE":
            self.loss = nn.BCELoss()
        else:
            raise ValueError

        self.apply_checkerboard = apply_checkerboard
        self.fix_bug_multiscale_patches = fix_bug_multiscale_patches
        self.ignore_label = ignore_label
        self.boundary_label = boundary_label
        self.glia_label = glia_label
        self.defected_label = defected_label
        self.train_glia_mask = train_glia_mask
        self.train_patches_on_glia = train_patches_on_glia
        self.indx_trained_patchNets = indx_trained_patchNets
        self.add_IoU_loss = False
        if IoU_loss_kwargs is not None:
            self.add_IoU_loss = True
            from .compute_IoU import IoULoss
            self.IoU_loss = IoULoss(model,
                                    model_kwargs=model_kwargs,
                                    devices=devices,
                                    **IoU_loss_kwargs)

        self.devices = devices
        self.model_kwargs = model_kwargs
        self.MSE_loss = nn.MSELoss()
        self.smoothL1_loss = nn.SmoothL1Loss()
        # TODO: use nn.BCEWithLogitsLoss()
        self.BCE = nn.BCELoss()
        self.soresen_loss = SorensenDiceLoss()

        from vaeAffs.models.vanilla_vae import VAE_loss
        self.VAE_loss = VAE_loss()

        self.model = model

        self.train_sparse_loss = False
        self.sparse_multilevelDiceLoss = None
        if sparse_affs_loss_kwargs is not None:
            self.train_sparse_loss = True
            self.sparse_multilevelDiceLoss = MultiLevelAffinityLoss(
                model,
                model_kwargs=model_kwargs,
                devices=devices,
                **sparse_affs_loss_kwargs)

        # TODO: hack to adapt to stacked model:
        self.downscale_and_crop_targets = {}
        if hasattr(self.model, "collected_patchNet_kwargs"):
            self.model_kwargs["patchNet_kwargs"] = [
                kwargs for i, kwargs in enumerate(
                    self.model.collected_patchNet_kwargs)
                if i in self.model.trained_patchNets
            ]

            # FIXME: generalize to the non-stacked model (there I also have global in the keys...)
            for nb, kwargs in enumerate(self.model_kwargs["patchNet_kwargs"]):
                if "downscale_and_crop_target" in kwargs:
                    self.downscale_and_crop_targets[nb] = DownsampleAndCrop3D(
                        **kwargs["downscale_and_crop_target"])
Пример #15
0
 def build_fgbg_metric(self):
     self.trainer.register_callback(
         ExtraMetric(LossWrapper(SorensenDiceLoss(channelwise=True),
                                 transforms=self.to_fgbg_loss_input),
                     frequency=self.get('trainer/metric/evaluate_every'),
                     name='error_semantic_dice'))