def set_up_training(project_directory, config, data_config, load_pretrained_model): # Get model if load_pretrained_model: model = Trainer().load(from_directory=project_directory, filename='Weights/checkpoint.pytorch').model else: model_name = config.get('model_name') model = getattr(models, model_name)(**config.get('model_kwargs')) criterion = SorensenDiceLoss() loss_train = LossWrapper(criterion=criterion, transforms=Compose(ApplyAndRemoveMask(), InvertTarget())) loss_val = LossWrapper(criterion=criterion, transforms=Compose(RemoveSegmentationFromTarget(), ApplyAndRemoveMask(), InvertTarget())) # Build trainer and validation metric logger.info("Building trainer.") smoothness = 0.95 offsets = data_config['volume_config']['segmentation']['affinity_config'][ 'offsets'] metric = ArandErrorFromMulticut(average_slices=False, use_2d_ws=True, n_threads=8, weight_edges=True, offsets=offsets) trainer = Trainer(model)\ .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\ .build_criterion(loss_train)\ .build_validation_criterion(loss_val)\ .build_optimizer(**config.get('training_optimizer_kwargs'))\ .evaluate_metric_every('never')\ .validate_every((100, 'iterations'), for_num_iterations=1)\ .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\ .build_metric(metric)\ .register_callback(AutoLR(factor=0.98, patience='100 iterations', monitor_while='validating', monitor_momentum=smoothness, consider_improvement_with_respect_to='previous'))\ .register_callback(GarbageCollection()) logger.info("Building logger.") # Build logger tensorboard = TensorboardLogger( log_scalars_every=(1, 'iteration'), log_images_every=(100, 'iterations'), log_histograms_every='never').observe_states( ['validation_input', 'validation_prediction, validation_target'], observe_while='validating') trainer.build_logger(tensorboard, log_directory=os.path.join(project_directory, 'Logs')) return trainer
def dice_loss(): trafos = [ SemanticTargetTrafo(class_ids=[1, 2, 3], dtype=torch.float32, ignore_label=-1), ApplyAndRemoveMask() ] trafos = Compose(*trafos) return LossWrapper(criterion=SorensenDiceLoss(), transforms=trafos)
def set_up_training(project_directory, config, data_config, load_pretrained_model): # Get model if load_pretrained_model: model = Trainer().load(from_directory=project_directory, filename='Weights/checkpoint.pytorch').model else: model_name = config.get('model_name') model = getattr(models, model_name)(**config.get('model_kwargs')) affinity_offsets = data_config['volume_config']['segmentation'][ 'affinity_offsets'] loss = MultiOutputLossWrapper( criterion=SorensenDiceLoss(), transforms=Compose(MaskTransitionToIgnoreLabel(affinity_offsets), RemoveSegmentationFromTarget(), InvertTarget())) # Build trainer and validation metric logger.info("Building trainer.") smoothness = 0.95 # use multicut pipeline for validation # metric = ArandErrorFromSegmentationPipeline(local_affinity_multicut_from_wsdt2d(n_threads=10, # time_limit=120)) # use damws for validation stride = [2, 10, 10] metric = ArandErrorFromSegmentationPipeline( DamWatershed(affinity_offsets, stride, randomize_bounds=False)) trainer = Trainer(model)\ .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\ .build_criterion(loss)\ .build_optimizer(**config.get('training_optimizer_kwargs'))\ .evaluate_metric_every('never')\ .validate_every((100, 'iterations'), for_num_iterations=1)\ .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\ .build_metric(metric)\ .register_callback(AutoLR(factor=0.98, patience='100 iterations', monitor_while='validating', monitor_momentum=smoothness, consider_improvement_with_respect_to='previous')) # FIXME some issues with conda tf for torch0.3 env # logger.info("Building logger.") # # Build logger # tensorboard = TensorboardLogger(log_scalars_every=(1, 'iteration'), # log_images_every=(100, 'iterations')).observe_states( # ['validation_input', 'validation_prediction, validation_target'], # observe_while='validating' # ) # trainer.build_logger(tensorboard, log_directory=os.path.join(project_directory, 'Logs')) return trainer
def set_up_training(project_directory, config, data_config): # Get model model_name = config.get('model_name') model = getattr(models, model_name)(**config.get('model_kwargs')) criterion = SorensenDiceLoss() loss_train = LossWrapper(criterion=criterion, transforms=InvertTarget()) loss_val = LossWrapper(criterion=criterion, transforms=Compose(RemoveSegmentationFromTarget(), InvertTarget())) # Build trainer and validation metric logger.info("Building trainer.") smoothness = 0.75 offsets = data_config['volume_config']['segmentation']['affinity_config'][ 'offsets'] strides = [1, 10, 10] metric = ArandErrorFromMWS(average_slices=False, offsets=offsets, strides=strides, randomize_strides=False) trainer = Trainer(model)\ .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\ .build_criterion(loss_train)\ .build_validation_criterion(loss_val)\ .build_optimizer(**config.get('training_optimizer_kwargs'))\ .evaluate_metric_every('never')\ .validate_every((100, 'iterations'), for_num_iterations=1)\ .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\ .build_metric(metric)\ .register_callback(AutoLR(factor=0.99, patience='100 iterations', monitor_while='validating', monitor_momentum=smoothness, consider_improvement_with_respect_to='previous'))\ logger.info("Building logger.") # Build logger tensorboard = TensorboardLogger( log_scalars_every=(1, 'iteration'), log_images_every=(100, 'iterations'), log_histograms_every='never').observe_states( ['validation_input', 'validation_prediction, validation_target'], observe_while='validating') trainer.build_logger(tensorboard, log_directory=os.path.join(project_directory, 'Logs')) return trainer
def dice_loss(is_val=False): print("Build Dice loss") if is_val: trafos = [ RemoveSegmentationFromTarget(), ApplyAndRemoveMask(), InvertTarget() ] else: trafos = [ApplyAndRemoveMask(), InvertTarget()] trafos = Compose(*trafos) return LossWrapper(criterion=SorensenDiceLoss(), transforms=trafos)
def set_up_training(project_directory, config, data_config, load_pretrained_model, max_iters): # Get model if load_pretrained_model: model = Trainer().load(from_directory=project_directory, filename='Weights/checkpoint.pytorch').model else: model_name = config.get('model_name') model = getattr(models, model_name)(**config.get('model_kwargs')) loss = LossWrapper(criterion=SorensenDiceLoss(), transforms=Compose(MaskIgnoreLabel(), RemoveSegmentationFromTarget())) # TODO loss transforms: # - Invert Target ??? # Build trainer and validation metric logger.info("Building trainer.") # smoothness = 0.95 # TODO set up validation ?! trainer = Trainer(model)\ .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\ .build_criterion(loss)\ .build_optimizer(**config.get('training_optimizer_kwargs'))\ .evaluate_metric_every('never')\ .register_callback(ManualLR(decay_specs=[((k * 100, 'iterations'), 0.99) for k in range(1, max_iters // 100)])) # .validate_every((100, 'iterations'), for_num_iterations=1)\ # .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\ # .build_metric(metric)\ # .register_callback(AutoLR(factor=0.98, # patience='100 iterations', # monitor_while='validating', # monitor_momentum=smoothness, # consider_improvement_with_respect_to='previous')) logger.info("Building logger.") # Build logger tensorboard = TensorboardLogger( log_scalars_every=(1, 'iteration'), log_images_every=(100, 'iterations')) # .observe_states( # ['validation_input', 'validation_prediction, validation_target'], # observe_while='validating' # ) trainer.build_logger(tensorboard, log_directory=os.path.join(project_directory, 'Logs')) return trainer
def _test_maxpool_loss_retain_segmentation(self): from neurofire.criteria.loss_wrapper import LossWrapper from neurofire.criteria.multi_scale_loss import MultiScaleLossMaxPool from neurofire.transform.segmentation import Segmentation2AffinitiesFromOffsets from neurofire.criteria.loss_transforms import MaskTransitionToIgnoreLabel from neurofire.criteria.loss_transforms import RemoveSegmentationFromTarget offsets = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (9, 0, 0), (0, 9, 0), (0, 0, 9)] shape = (128, 128, 128) aff_trafo = Segmentation2AffinitiesFromOffsets( 3, offsets, retain_segmentation=True, add_singleton_channel_dimension=True) seg = self.make_segmentation_with_ignore(shape) target = Variable(torch.Tensor(aff_trafo(seg.astype('float32'))[None]), requires_grad=False) tshape = target.size() # make all scale predictions predictions = [] for scale in range(4): pshape = (tshape[0], tshape[1] - 1) + shape predictions.append( Variable(torch.Tensor(*pshape).uniform_(0, 1), requires_grad=True)) shape = tuple(sh // 2 for sh in shape) trafos = Compose(MaskTransitionToIgnoreLabel(offsets, ignore_label=0), RemoveSegmentationFromTarget()) criterion = LossWrapper(SorensenDiceLoss(), trafos) ms_loss = MultiScaleLossMaxPool(criterion, 2, retain_segmentation=True) loss = ms_loss.forward(predictions, target) loss.backward() for prediction in predictions: grads = prediction.grad.data # check for the correct gradient size self.assertEqual(grads.size(), prediction.size()) # check that gradients are not trivial self.assertNotEqual(grads.sum(), 0)
def set_up_training(project_directory, config, data_config, n_iters): model_name = config.get('model_name') model = getattr(models, model_name)(**config.get('model_kwargs')) loss = SorensenDiceLoss() # Build trainer and validation metric logger.info("Building trainer.") log_file = os.path.join(project_directory, 'tmp_log.txt') trainer = Trainer(model)\ .save_every((n_iters, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\ .build_criterion(loss)\ .build_optimizer(**config.get('training_optimizer_kwargs'))\ .evaluate_metric_every('never')\ .register_callback(TimeTrainingIters(log_file)) return trainer
def __init__(self, n_channels, n_directions, weights=[1., 1.], log=True, exclude_borders=[0, 0, 0], max_dist=100): """ :param n_channels: :param n_directions: :param weights: First gives weight for Sorensen Dice loss, second gives L1-Loss :param log: """ super().__init__() self.n_channels = n_channels self.n_directions = n_directions self.l1 = nn.L1Loss() self.sd = SorensenDiceLoss() self.weights = weights self.log = log self.exclude_borders = exclude_borders self.max_dist = max_dist self.log_counter = 0
def test_maxpool_loss(self): from neurofire.criteria.loss_wrapper import LossWrapper from neurofire.criteria.multi_scale_loss import MultiScaleLossMaxPool from neurofire.transform.segmentation import Segmentation2Affinities offsets = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (9, 0, 0), (0, 9, 0), (0, 0, 9)] shape = (128, 128, 128) aff_trafo = Segmentation2Affinities(offsets, retain_segmentation=False) seg = self.make_segmentation_with_ignore(shape) target = Variable(torch.Tensor(aff_trafo(seg.astype('float32'))[None]), requires_grad=False) tshape = target.size() # make all scale predictions predictions = [] for scale in range(4): pshape = tuple(tshape[:2], ) + shape predictions.append( Variable(torch.Tensor(*pshape).uniform_(0, 1), requires_grad=True)) shape = tuple(sh // 2 for sh in shape) criterion = LossWrapper(SorensenDiceLoss()) ms_loss = MultiScaleLossMaxPool(criterion, 2) loss = ms_loss.forward(predictions, target) loss.backward() for prediction in predictions: grads = prediction.grad.data # check for the correct gradient size self.assertEqual(grads.size(), prediction.size()) # check that gradients are not trivial self.assertNotEqual(grads.sum(), 0)
def __init__(self, alpha=1., beta=1.): super().__init__() self.alpha = alpha self.beta = beta self.bce = nn.BCELoss() self.dice = SorensenDiceLoss()
def set_up_training(project_directory, config, data_config, load_pretrained_model): # Get model if load_pretrained_model: model = Trainer().load(from_directory=project_directory, filename='Weights/checkpoint.pytorch').model else: model_name = config.get('model_name') model = getattr(models, model_name)(**config.get('model_kwargs')) affinity_offsets = data_config['volume_config']['segmentation'][ 'affinity_offsets'] # NOTE invert target is done in the multiscale loss loss = LossWrapper(criterion=SorensenDiceLoss(), transforms=Compose( MaskTransitionToIgnoreLabel(affinity_offsets), RemoveSegmentationFromTarget())) scaling_factors = 3 * [(1, 3, 3)] multiscale_loss = MultiScaleLossMaxPool(loss, scaling_factors, invert_target=True, retain_segmentation=True) # Build trainer and validation metric logger.info("Building trainer.") smoothness = 0.95 # use multicut pipeline for validation # TODO fix nifty weighting schemes metric = ArandErrorFromSegmentationPipeline( local_affinity_multicut_from_wsdt2d(n_threads=10, weighting_scheme=None, time_limit=120), is_multiscale=True) trainer = Trainer(model)\ .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\ .build_criterion(multiscale_loss)\ .build_optimizer(**config.get('training_optimizer_kwargs'))\ .evaluate_metric_every('never')\ .validate_every((100, 'iterations'), for_num_iterations=1)\ .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\ .build_metric(metric)\ .register_callback(AutoLR(factor=0.98, patience='100 iterations', monitor_while='validating', monitor_momentum=smoothness, consider_improvement_with_respect_to='previous')) logger.info("Building logger.") # Build logger tensorboard = TensorboardLogger( log_scalars_every=(1, 'iteration'), log_images_every=(100, 'iterations')).observe_states( ['validation_input', 'validation_prediction, validation_target'], observe_while='validating') trainer.build_logger(tensorboard, log_directory=os.path.join(project_directory, 'Logs')) return trainer
def __init__(self): super().__init__() self.criterion = SorensenDiceLoss()
def set_up_training(project_directory, config): # Load the model to train from the configuratuib file ('./config/train_config.yml') model_name = config.get('model_name') model = getattr(models, model_name)(**config.get('model_kwargs')) # Initialize the loss: we use the SorensenDiceLoss, which has the nice property # of being fairly robust for un-balanced targets criterion = SorensenDiceLoss() # Wrap the loss to apply additional transformations before the actual # loss is applied. Here, we apply the mask to the target # and invert the target (necessary for sorensen dice) during training. # In addition, we need to remove the segmentation from the target # during validation (we only keep the segmentation in the target during validation) loss_train = LossWrapper(criterion=criterion, transforms=Compose(ApplyAndRemoveMask(), InvertTarget())) loss_val = LossWrapper(criterion=criterion, transforms=Compose(RemoveSegmentationFromTarget(), ApplyAndRemoveMask(), InvertTarget())) # Build the validation metric: we validate by running connected components on # the affinities for several thresholds # metric = ArandErrorFromConnectedComponentsOnAffinities(thresholds=[.5, .6, .7, .8, .9], # invert_affinities=True) metric = ArandErrorFromConnectedComponents(thresholds=[.5, .6, .7, .8, .9], invert_input=True, average_input=True) logger.info("Building trainer.") smoothness = 0.95 # Build the trainer object trainer = Trainer(model)\ .save_every((1000, 'iterations'), to_directory=os.path.join(project_directory, 'Weights'))\ .build_criterion(loss_train)\ .build_validation_criterion(loss_val)\ .build_optimizer(**config.get('training_optimizer_kwargs'))\ .evaluate_metric_every('never')\ .validate_every((100, 'iterations'), for_num_iterations=1)\ .register_callback(SaveAtBestValidationScore(smoothness=smoothness, verbose=True))\ .build_metric(metric)\ .register_callback(AutoLR(factor=0.98, patience='100 iterations', monitor_while='validating', monitor_momentum=smoothness, consider_improvement_with_respect_to='previous')) # .register_callback(DumpHDF5Every(frequency='99 iterations', # to_directory=os.path.join(project_directory, 'debug'))) logger.info("Building logger.") # Build tensorboard logger tensorboard = TensorboardLogger( log_scalars_every=(1, 'iteration'), log_images_every=(100, 'iterations')).observe_states( ['validation_input', 'validation_prediction, validation_target'], observe_while='validating') trainer.build_logger(tensorboard, log_directory=os.path.join(project_directory, 'Logs')) return trainer