Exemplo n.º 1
0
    def test_loss_wrapper_affinity_masking(self):
        from neurofire.criteria.loss_transforms import MaskTransitionToIgnoreLabel
        from neurofire.criteria.loss_transforms import RemoveSegmentationFromTarget
        from neurofire.criteria.loss_wrapper import LossWrapper
        from neurofire.transform.affinities import Segmentation2Affinities

        offsets = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (9, 0, 0), (0, 9, 0),
                   (0, 0, 9), (9, 4, 0), (4, 9, 0), (9, 0, 9)]

        trafos = Compose(MaskTransitionToIgnoreLabel(offsets, ignore_label=0),
                         RemoveSegmentationFromTarget())

        aff_trafo = Segmentation2Affinities(offsets, retain_segmentation=True)

        seg = self.make_segmentation_with_ignore(self.shape)
        ignore_mask = self.brute_force_transition_masking(seg, offsets)
        target = Variable(torch.Tensor(aff_trafo(seg.astype('float32'))[None]),
                          requires_grad=False)

        tshape = target.size()
        pshape = (tshape[0], tshape[1] - 1) + tshape[2:]
        prediction = Variable(torch.Tensor(*pshape).uniform_(0, 1),
                              requires_grad=True)

        # apply cross entropy loss
        criterion = BCELoss()
        # criterion = SorensenDiceLoss()
        wrapper = LossWrapper(criterion, trafos)
        loss = wrapper.forward(prediction, target)
        loss.backward()

        grads = prediction.grad.data.numpy().squeeze()
        self.assertEqual(grads.shape, ignore_mask.shape)
        self.assertTrue((grads[ignore_mask] == 0).all())
        self.assertFalse(np.sum(grads[np.logical_not(ignore_mask)]) == 0)
Exemplo n.º 2
0
def set_up_training(project_directory, config, data_config,
                    load_pretrained_model):
    # Get model
    if load_pretrained_model:
        model = Trainer().load(from_directory=project_directory,
                               filename='Weights/checkpoint.pytorch').model
    else:
        model_name = config.get('model_name')
        model = getattr(models, model_name)(**config.get('model_kwargs'))

    criterion = SorensenDiceLoss()
    loss_train = LossWrapper(criterion=criterion,
                             transforms=Compose(ApplyAndRemoveMask(),
                                                InvertTarget()))
    loss_val = LossWrapper(criterion=criterion,
                           transforms=Compose(RemoveSegmentationFromTarget(),
                                              ApplyAndRemoveMask(),
                                              InvertTarget()))

    # Build trainer and validation metric
    logger.info("Building trainer.")
    smoothness = 0.95

    offsets = data_config['volume_config']['segmentation']['affinity_config'][
        'offsets']
    metric = ArandErrorFromMulticut(average_slices=False,
                                    use_2d_ws=True,
                                    n_threads=8,
                                    weight_edges=True,
                                    offsets=offsets)

    trainer = Trainer(model)\
        .save_every((1000, 'iterations'),
                    to_directory=os.path.join(project_directory, 'Weights'))\
        .build_criterion(loss_train)\
        .build_validation_criterion(loss_val)\
        .build_optimizer(**config.get('training_optimizer_kwargs'))\
        .evaluate_metric_every('never')\
        .validate_every((100, 'iterations'), for_num_iterations=1)\
        .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\
        .build_metric(metric)\
        .register_callback(AutoLR(factor=0.98,
                                  patience='100 iterations',
                                  monitor_while='validating',
                                  monitor_momentum=smoothness,
                                  consider_improvement_with_respect_to='previous'))\
        .register_callback(GarbageCollection())

    logger.info("Building logger.")
    # Build logger
    tensorboard = TensorboardLogger(
        log_scalars_every=(1, 'iteration'),
        log_images_every=(100, 'iterations'),
        log_histograms_every='never').observe_states(
            ['validation_input', 'validation_prediction, validation_target'],
            observe_while='validating')

    trainer.build_logger(tensorboard,
                         log_directory=os.path.join(project_directory, 'Logs'))
    return trainer
Exemplo n.º 3
0
    def test_loss_wrapper_with_balancing(self):
        from neurofire.criteria.loss_transforms import RemoveSegmentationFromTarget
        from neurofire.criteria.loss_wrapper import LossWrapper, BalanceAffinities
        from neurofire.transform.segmentation import Segmentation2Affinities

        offsets = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (9, 0, 0), (0, 9, 0),
                   (0, 0, 9), (9, 4, 0), (4, 9, 0), (9, 0, 9)]

        trafos = RemoveSegmentationFromTarget()
        balance = BalanceAffinities(ignore_label=0, offsets=offsets)

        aff_trafo = Segmentation2Affinities(offsets, retain_segmentation=True)

        seg = self.make_segmentation_with_ignore(self.shape)
        target = Variable(torch.Tensor(aff_trafo(seg.astype('float32'))[None]),
                          requires_grad=False)

        tshape = target.size()
        pshape = (tshape[0], tshape[1] - 1) + tshape[2:]
        prediction = Variable(torch.Tensor(*pshape).uniform_(0, 1),
                              requires_grad=True)

        # apply cross entropy loss
        criterion = WeightedMSELoss()
        wrapper = LossWrapper(criterion, trafos, balance)
        loss = wrapper.forward(prediction, target)
        loss.backward()

        grads = prediction.grad.data
        # check for the correct gradient size
        self.assertEqual(grads.size(), prediction.size())
        # check that gradients are not trivial
        self.assertGreater(grads.sum(), 0)
Exemplo n.º 4
0
    def inferno_build_criterion(self):
        print("Building criterion")
        loss_config = self.get('trainer/criterion/losses')

        criterion = SorensenDiceLoss()
        loss_train = LossWrapper(criterion=criterion, transforms=None)
        loss_val = LossWrapper(criterion=criterion, transforms=None)
        self._trainer.build_criterion(loss_train)
        self._trainer.build_validation_criterion(loss_val)
Exemplo n.º 5
0
    def inferno_build_criterion(self):
        print("Building criterion")
        loss_config = self.get('trainer/criterion/losses')

        criterion = loss_config.get('loss')
        transforms = None  # loss_config.get('transforms', None)

        loss_train = LossWrapper(criterion=criterion, transforms=transforms)
        loss_val = LossWrapper(criterion=criterion, transforms=transforms)
        self._trainer.build_criterion(loss_train)
        self._trainer.build_validation_criterion(loss_val)
Exemplo n.º 6
0
def set_up_training(project_directory, config, data_config):

    # Get model
    model_name = config.get('model_name')
    model = getattr(models, model_name)(**config.get('model_kwargs'))

    criterion = SorensenDiceLoss()
    loss_train = LossWrapper(criterion=criterion, transforms=InvertTarget())
    loss_val = LossWrapper(criterion=criterion,
                           transforms=Compose(RemoveSegmentationFromTarget(),
                                              InvertTarget()))

    # Build trainer and validation metric
    logger.info("Building trainer.")
    smoothness = 0.75

    offsets = data_config['volume_config']['segmentation']['affinity_config'][
        'offsets']
    strides = [1, 10, 10]
    metric = ArandErrorFromMWS(average_slices=False,
                               offsets=offsets,
                               strides=strides,
                               randomize_strides=False)

    trainer = Trainer(model)\
        .save_every((1000, 'iterations'),
                    to_directory=os.path.join(project_directory, 'Weights'))\
        .build_criterion(loss_train)\
        .build_validation_criterion(loss_val)\
        .build_optimizer(**config.get('training_optimizer_kwargs'))\
        .evaluate_metric_every('never')\
        .validate_every((100, 'iterations'), for_num_iterations=1)\
        .register_callback(SaveAtBestValidationScore(smoothness=smoothness,
                                                     verbose=True))\
        .build_metric(metric)\
        .register_callback(AutoLR(factor=0.99,
                                  patience='100 iterations',
                                  monitor_while='validating',
                                  monitor_momentum=smoothness,
                                  consider_improvement_with_respect_to='previous'))\

    logger.info("Building logger.")
    # Build logger
    tensorboard = TensorboardLogger(
        log_scalars_every=(1, 'iteration'),
        log_images_every=(100, 'iterations'),
        log_histograms_every='never').observe_states(
            ['validation_input', 'validation_prediction, validation_target'],
            observe_while='validating')

    trainer.build_logger(tensorboard,
                         log_directory=os.path.join(project_directory, 'Logs'))
    return trainer
Exemplo n.º 7
0
    def inferno_build_criterion(self):
        print("Building criterion")
        loss_config = self.get('trainer/criterion/losses')

        criterion = SorensenDiceLoss()
        loss_train = LossWrapper(criterion=criterion,
                                 transforms=Compose(ApplyAndRemoveMask(), InvertTarget()))
        loss_val = LossWrapper(criterion=criterion,
                               transforms=Compose(RemoveSegmentationFromTarget(),
                                                  ApplyAndRemoveMask(), InvertTarget()))
        self._trainer.build_criterion(loss_train)
        self._trainer.build_validation_criterion(loss_val)
def set_up_training(project_directory, config, data_config, criterion, balance,
                    load_pretrained_model):
    # Get model
    if load_pretrained_model:
        model = Trainer().load(from_directory=project_directory,
                               filename='Weights/checkpoint.pytorch').model
    else:
        model_name = config.get('model_name')
        model = getattr(models, model_name)(**config.get('model_kwargs'))

    # TODO
    logger.info("Using criterion: %s" % criterion)

    # TODO this should go somewhere more prominent
    affinity_offsets = data_config['volume_config']['segmentation'][
        'affinity_offsets']

    # TODO implement affinities on gpu again ?!
    criterion = CRITERIA[criterion]
    loss = LossWrapper(
        criterion=criterion(),
        transforms=Compose(MaskTransitionToIgnoreLabel(affinity_offsets),
                           RemoveSegmentationFromTarget(), InvertTarget()),
        weight_function=BalanceAffinities(
            ignore_label=0, offsets=affinity_offsets) if balance else None)

    # Build trainer and validation metric
    logger.info("Building trainer.")
    smoothness = 0.95

    # use multicut pipeline for validation
    metric = ArandErrorFromSegmentationPipeline(
        local_affinity_multicut_from_wsdt2d(n_threads=10, time_limit=120))
    trainer = Trainer(model)\
        .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\
        .build_criterion(loss)\
        .build_optimizer(**config.get('training_optimizer_kwargs'))\
        .evaluate_metric_every('never')\
        .validate_every((100, 'iterations'), for_num_iterations=1)\
        .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\
        .build_metric(metric)\
        .register_callback(AutoLR(factor=0.98,
                                  patience='100 iterations',
                                  monitor_while='validating',
                                  monitor_momentum=smoothness,
                                  consider_improvement_with_respect_to='previous'))

    logger.info("Building logger.")
    # Build logger
    tensorboard = TensorboardLogger(
        log_scalars_every=(1, 'iteration'),
        log_images_every=(100, 'iterations')).observe_states(
            ['validation_input', 'validation_prediction, validation_target'],
            observe_while='validating')

    trainer.build_logger(tensorboard,
                         log_directory=os.path.join(project_directory, 'Logs'))
    return trainer
Exemplo n.º 9
0
def dice_loss():
    trafos = [
        SemanticTargetTrafo(class_ids=[1, 2, 3],
                            dtype=torch.float32,
                            ignore_label=-1),
        ApplyAndRemoveMask()
    ]
    trafos = Compose(*trafos)
    return LossWrapper(criterion=SorensenDiceLoss(), transforms=trafos)
Exemplo n.º 10
0
 def inferno_build_criterion(self):
     print("Building criterion")
     # path = self.get("autoencoder/path")
     # loss_kwargs = self.get("trainer/criterion/kwargs")
     # from vaeAffs.models.modified_unet import EncodingLoss, PatchLoss, AffLoss
     from vaeAffs.transforms import ApplyIgnoreMask
     # loss = AffLoss(**loss_kwargs)
     from neurofire.criteria.loss_wrapper import LossWrapper
     loss = LossWrapper(SorensenDiceLoss(), transforms=ApplyIgnoreMask())
     self._trainer.build_criterion(loss)
     self._trainer.build_validation_criterion(loss)
def set_up_training(project_directory, config, data_config,
                    load_pretrained_model):
    # Get model
    if load_pretrained_model:
        model = Trainer().load(from_directory=project_directory,
                               filename='Weights/checkpoint.pytorch').model
    else:
        model_name = config.get('model_name')
        model = getattr(models, model_name)(**config.get('model_kwargs'))

    affinity_offsets = data_config['volume_config']['segmentation'][
        'affinity_offsets']
    loss = LossWrapper(criterion=SorensenDiceLoss(),
                       transforms=Compose(
                           MaskTransitionToIgnoreLabel(affinity_offsets),
                           RemoveSegmentationFromTarget(), InvertTarget()))

    # Build trainer and validation metric
    logger.info("Building trainer.")
    smoothness = 0.95

    # use multicut pipeline for validation
    # metric = ArandErrorFromSegmentationPipeline(local_affinity_multicut_from_wsdt2d(n_threads=10,
    #                                                                                 time_limit=120))

    # use damws for validation
    stride = [2, 10, 10]
    metric = ArandErrorFromSegmentationPipeline(
        DamWatershed(affinity_offsets, stride, randomize_bounds=False))
    trainer = Trainer(model)\
        .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\
        .build_criterion(loss)\
        .build_optimizer(**config.get('training_optimizer_kwargs'))\
        .evaluate_metric_every('never')\
        .validate_every((100, 'iterations'), for_num_iterations=1)\
        .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\
        .build_metric(metric)\
        .register_callback(AutoLR(factor=0.98,
                                  patience='100 iterations',
                                  monitor_while='validating',
                                  monitor_momentum=smoothness,
                                  consider_improvement_with_respect_to='previous'))

    logger.info("Building logger.")
    # Build logger
    tensorboard = TensorboardLogger(
        log_scalars_every=(1, 'iteration'),
        log_images_every=(100, 'iterations')).observe_states(
            ['validation_input', 'validation_prediction, validation_target'],
            observe_while='validating')

    trainer.build_logger(tensorboard,
                         log_directory=os.path.join(project_directory, 'Logs'))
    return trainer
Exemplo n.º 12
0
def dice_loss(is_val=False):
    print("Build Dice loss")
    if is_val:
        trafos = [
            RemoveSegmentationFromTarget(),
            ApplyAndRemoveMask(),
            InvertTarget()
        ]
    else:
        trafos = [ApplyAndRemoveMask(), InvertTarget()]
    trafos = Compose(*trafos)
    return LossWrapper(criterion=SorensenDiceLoss(), transforms=trafos)
def set_up_training(project_directory, config, data_config,
                    load_pretrained_model, max_iters):
    # Get model
    if load_pretrained_model:
        model = Trainer().load(from_directory=project_directory,
                               filename='Weights/checkpoint.pytorch').model
    else:
        model_name = config.get('model_name')
        model = getattr(models, model_name)(**config.get('model_kwargs'))

    loss = LossWrapper(criterion=SorensenDiceLoss(),
                       transforms=Compose(MaskIgnoreLabel(),
                                          RemoveSegmentationFromTarget()))
    # TODO loss transforms:
    # - Invert Target ???

    # Build trainer and validation metric
    logger.info("Building trainer.")
    # smoothness = 0.95

    # TODO set up validation ?!
    trainer = Trainer(model)\
        .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\
        .build_criterion(loss)\
        .build_optimizer(**config.get('training_optimizer_kwargs'))\
        .evaluate_metric_every('never')\
        .register_callback(ManualLR(decay_specs=[((k * 100, 'iterations'), 0.99)
                                                 for k in range(1, max_iters // 100)]))
    # .validate_every((100, 'iterations'), for_num_iterations=1)\
    # .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\
    # .build_metric(metric)\
    # .register_callback(AutoLR(factor=0.98,
    #                           patience='100 iterations',
    #                           monitor_while='validating',
    #                           monitor_momentum=smoothness,
    #                           consider_improvement_with_respect_to='previous'))

    logger.info("Building logger.")
    # Build logger
    tensorboard = TensorboardLogger(
        log_scalars_every=(1, 'iteration'),
        log_images_every=(100, 'iterations'))  # .observe_states(
    #     ['validation_input', 'validation_prediction, validation_target'],
    #     observe_while='validating'
    # )

    trainer.build_logger(tensorboard,
                         log_directory=os.path.join(project_directory, 'Logs'))
    return trainer
Exemplo n.º 14
0
 def parse_and_wrap_losses(self, config, transforms, losses, weights,
                           loss_names):
     default_weight = config.pop('weight', 1)
     for class_name, kwargs in config.items():
         loss_names.append(kwargs.pop('name', class_name))
         weights.append(kwargs.pop('weight', default_weight))
         print(f'Adding {loss_names[-1]} with weight {weights[-1]}')
         loss_class = locate(class_name, [
             'embeddingutils.loss', 'SegTags.loss',
             'inferno.extensions.criteria.set_similarity_measures',
             'torch.nn'
         ])
         if issubclass(loss_class, WeightedLoss):
             kwargs['trainer'] = self.trainer
         losses.append(
             LossWrapper(criterion=loss_class(**kwargs),
                         transforms=transforms))
Exemplo n.º 15
0
    def _test_maxpool_loss_retain_segmentation(self):
        from neurofire.criteria.loss_wrapper import LossWrapper
        from neurofire.criteria.multi_scale_loss import MultiScaleLossMaxPool
        from neurofire.transform.segmentation import Segmentation2AffinitiesFromOffsets
        from neurofire.criteria.loss_transforms import MaskTransitionToIgnoreLabel
        from neurofire.criteria.loss_transforms import RemoveSegmentationFromTarget

        offsets = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (9, 0, 0), (0, 9, 0),
                   (0, 0, 9)]

        shape = (128, 128, 128)
        aff_trafo = Segmentation2AffinitiesFromOffsets(
            3,
            offsets,
            retain_segmentation=True,
            add_singleton_channel_dimension=True)
        seg = self.make_segmentation_with_ignore(shape)

        target = Variable(torch.Tensor(aff_trafo(seg.astype('float32'))[None]),
                          requires_grad=False)

        tshape = target.size()
        # make all scale predictions
        predictions = []
        for scale in range(4):
            pshape = (tshape[0], tshape[1] - 1) + shape
            predictions.append(
                Variable(torch.Tensor(*pshape).uniform_(0, 1),
                         requires_grad=True))
            shape = tuple(sh // 2 for sh in shape)

        trafos = Compose(MaskTransitionToIgnoreLabel(offsets, ignore_label=0),
                         RemoveSegmentationFromTarget())
        criterion = LossWrapper(SorensenDiceLoss(), trafos)
        ms_loss = MultiScaleLossMaxPool(criterion, 2, retain_segmentation=True)
        loss = ms_loss.forward(predictions, target)
        loss.backward()

        for prediction in predictions:
            grads = prediction.grad.data
            # check for the correct gradient size
            self.assertEqual(grads.size(), prediction.size())
            # check that gradients are not trivial
            self.assertNotEqual(grads.sum(), 0)
    def inferno_build_criterion(self):
        print("Building criterion")
        loss_kwargs = self.get("trainer/criterion/kwargs", {})
        # from vaeAffs.models.losses import EncodingLoss, PatchLoss, PatchBasedLoss, StackedAffinityLoss
        loss_name = self.get("trainer/criterion/loss_name",
                             "inferno.extensions.criteria.set_similarity_measures.SorensenDiceLoss")
        loss_config = {loss_name: loss_kwargs}

        criterion = create_instance(loss_config, self.CRITERION_LOCATIONS)
        transforms = self.get("trainer/criterion/transforms")
        if transforms is not None:
            assert isinstance(transforms, list)
            transforms_instances = []
            # Build transforms:
            for transf in transforms:
                transforms_instances.append(create_instance(transf, []))
            # Wrap criterion:
            criterion = LossWrapper(criterion, transforms=Compose(*transforms_instances))

        self._trainer.build_criterion(criterion)
        self._trainer.build_validation_criterion(criterion)
Exemplo n.º 17
0
    def test_maxpool_loss(self):
        from neurofire.criteria.loss_wrapper import LossWrapper
        from neurofire.criteria.multi_scale_loss import MultiScaleLossMaxPool
        from neurofire.transform.segmentation import Segmentation2Affinities

        offsets = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (9, 0, 0), (0, 9, 0),
                   (0, 0, 9)]

        shape = (128, 128, 128)
        aff_trafo = Segmentation2Affinities(offsets, retain_segmentation=False)
        seg = self.make_segmentation_with_ignore(shape)

        target = Variable(torch.Tensor(aff_trafo(seg.astype('float32'))[None]),
                          requires_grad=False)

        tshape = target.size()
        # make all scale predictions
        predictions = []
        for scale in range(4):
            pshape = tuple(tshape[:2], ) + shape
            predictions.append(
                Variable(torch.Tensor(*pshape).uniform_(0, 1),
                         requires_grad=True))
            shape = tuple(sh // 2 for sh in shape)

        criterion = LossWrapper(SorensenDiceLoss())
        ms_loss = MultiScaleLossMaxPool(criterion, 2)
        loss = ms_loss.forward(predictions, target)
        loss.backward()

        for prediction in predictions:
            grads = prediction.grad.data
            # check for the correct gradient size
            self.assertEqual(grads.size(), prediction.size())
            # check that gradients are not trivial
            self.assertNotEqual(grads.sum(), 0)
Exemplo n.º 18
0
train_labels = HDF5VolumeLoader(path='labeled_segmentation.h5', path_in_h5_dataset='data',
                                transforms=tosignedint, **yaml2dict('config_train.yml')['slicing_config'])
trainset = Zip(train_images, train_labels)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCHSIZE,
                                          shuffle=True, num_workers=2)


val_images = HDF5VolumeLoader(path='./val-volume.h5', path_in_h5_dataset='data',
                              **yaml2dict('config_val.yml')['slicing_config'])
val_labels = HDF5VolumeLoader(path='labeled_segmentation_validation.h5', path_in_h5_dataset='data',
                              transforms=tosignedint, **yaml2dict('config_val.yml')['slicing_config'])
valset = Zip(val_images, val_labels)
valloader = torch.utils.data.DataLoader(trainset, batch_size=BATCHSIZE,
                                        shuffle=True, num_workers=2)

criterion = LossWrapper(criterion=nn.L1Loss,
                        transforms=LabelToTarget())


net = torch.nn.Sequential(
    ConvReLU2D(in_channels=1, out_channels=3, kernel_size=3),
    UNet(in_channels=3, out_channels=N_DIRECTIONS, dim=2, final_activation='ReLU')
    )

trainer = Trainer(net)

trainer.bind_loader('train', trainloader)
trainer.bind_loader('validate', valloader)

trainer.save_to_directory('./checkpoints')
trainer.save_every((200, 'iterations'))
trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
Exemplo n.º 19
0
def dice_loss():
    print("Build Dice loss")
    trafos = [ApplyAndRemoveMask(), InvertTarget()]
    trafos = Compose(*trafos)
    return LossWrapper(criterion=SorensenDiceLoss(),
                       transforms=trafos)
Exemplo n.º 20
0
def set_up_training(project_directory, config):

    # Load the model to train from the configuratuib file ('./config/train_config.yml')
    model_name = config.get('model_name')
    model = getattr(models, model_name)(**config.get('model_kwargs'))

    # Initialize the loss: we use the SorensenDiceLoss, which has the nice property
    # of being fairly robust for un-balanced targets
    criterion = SorensenDiceLoss()
    # Wrap the loss to apply additional transformations before the actual
    # loss is applied. Here, we apply the mask to the target
    # and invert the target (necessary for sorensen dice) during training.
    # In addition, we need to remove the segmentation from the target
    # during validation (we only keep the segmentation in the target during validation)
    loss_train = LossWrapper(criterion=criterion,
                             transforms=Compose(ApplyAndRemoveMask(),
                                                InvertTarget()))
    loss_val = LossWrapper(criterion=criterion,
                           transforms=Compose(RemoveSegmentationFromTarget(),
                                              ApplyAndRemoveMask(),
                                              InvertTarget()))

    # Build the validation metric: we validate by running connected components on
    # the affinities for several thresholds
    # metric = ArandErrorFromConnectedComponentsOnAffinities(thresholds=[.5, .6, .7, .8, .9],
    #                                                        invert_affinities=True)
    metric = ArandErrorFromConnectedComponents(thresholds=[.5, .6, .7, .8, .9],
                                               invert_input=True,
                                               average_input=True)

    logger.info("Building trainer.")
    smoothness = 0.95
    # Build the trainer object
    trainer = Trainer(model)\
        .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\
        .build_criterion(loss_train)\
        .build_validation_criterion(loss_val)\
        .build_optimizer(**config.get('training_optimizer_kwargs'))\
        .evaluate_metric_every('never')\
        .validate_every((100, 'iterations'), for_num_iterations=1)\
        .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\
        .build_metric(metric)\
        .register_callback(AutoLR(factor=0.98,
                                  patience='100 iterations',
                                  monitor_while='validating',
                                  monitor_momentum=smoothness,
                                  consider_improvement_with_respect_to='previous'))
    # .register_callback(DumpHDF5Every(frequency='99 iterations',
    #                                  to_directory=os.path.join(project_directory, 'debug')))

    logger.info("Building logger.")
    # Build tensorboard logger
    tensorboard = TensorboardLogger(
        log_scalars_every=(1, 'iteration'),
        log_images_every=(100, 'iterations')).observe_states(
            ['validation_input', 'validation_prediction, validation_target'],
            observe_while='validating')

    trainer.build_logger(tensorboard,
                         log_directory=os.path.join(project_directory, 'Logs'))
    return trainer
Exemplo n.º 21
0
 def build_fgbg_metric(self):
     self.trainer.register_callback(
         ExtraMetric(LossWrapper(SorensenDiceLoss(channelwise=True),
                                 transforms=self.to_fgbg_loss_input),
                     frequency=self.get('trainer/metric/evaluate_every'),
                     name='error_semantic_dice'))