Ejemplo n.º 1
0
def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger):
    assert 'trainer' in config, 'Could not find trainer configuration'
    trainer_config = config['trainer']

    resume = trainer_config.get('resume', None)
    pre_trained = trainer_config.get('pre_trained', None)

    if resume is not None:
        # continue training from a given checkpoint
        return UNet3DTrainer.from_checkpoint(resume, model,
                                             optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, loaders,
                                             logger=logger)
    elif pre_trained is not None:
        # fine-tune a given pre-trained model
        return UNet3DTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, device=config['device'], loaders=loaders,
                                             max_num_epochs=trainer_config['epochs'],
                                             max_num_iterations=trainer_config['iters'],
                                             validate_after_iters=trainer_config['validate_after_iters'],
                                             log_after_iters=trainer_config['log_after_iters'],
                                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                                             logger=logger)
    else:
        # start training from scratch
        return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                             config['device'], loaders, trainer_config['checkpoint_dir'],
                             max_num_epochs=trainer_config['epochs'],
                             max_num_iterations=trainer_config['iters'],
                             validate_after_iters=trainer_config['validate_after_iters'],
                             log_after_iters=trainer_config['log_after_iters'],
                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                             logger=logger)
Ejemplo n.º 2
0
    def _train_save_load(self, tmpdir, loss, val_metric, model='UNet3D', max_num_epochs=1, log_after_iters=2,
                         validate_after_iters=2, max_num_iterations=4, weight_map=False):
        binary_loss = loss in ['BCEWithLogitsLoss', 'DiceLoss', 'GeneralizedDiceLoss']

        device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

        test_config = copy.deepcopy(CONFIG_BASE)
        test_config['model']['name'] = model
        test_config.update({
            # get device to train on
            'device': device,
            'loss': {'name': loss, 'weight': np.random.rand(2).astype(np.float32)},
            'eval_metric': {'name': val_metric}
        })
        test_config['model']['final_sigmoid'] = binary_loss

        if weight_map:
            test_config['loaders']['weight_internal_path'] = 'weight_map'

        loss_criterion = get_loss_criterion(test_config)
        eval_criterion = get_evaluation_metric(test_config)
        model = get_model(test_config)
        model = model.to(device)

        if loss in ['BCEWithLogitsLoss']:
            label_dtype = 'float32'
        else:
            label_dtype = 'long'
        test_config['loaders']['transformer']['train']['label'][0]['dtype'] = label_dtype
        test_config['loaders']['transformer']['test']['label'][0]['dtype'] = label_dtype

        train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), binary_loss)
        test_config['loaders']['train_path'] = [train]
        test_config['loaders']['val_path'] = [val]

        loaders = get_train_loaders(test_config)

        optimizer = _create_optimizer(test_config, model)

        test_config['lr_scheduler']['name'] = 'MultiStepLR'
        lr_scheduler = _create_lr_scheduler(test_config, optimizer)

        logger = get_logger('UNet3DTrainer', logging.DEBUG)

        formatter = DefaultTensorboardFormatter()
        trainer = UNet3DTrainer(model, optimizer, lr_scheduler,
                                loss_criterion, eval_criterion,
                                device, loaders, tmpdir,
                                max_num_epochs=max_num_epochs,
                                log_after_iters=log_after_iters,
                                validate_after_iters=validate_after_iters,
                                max_num_iterations=max_num_iterations,
                                logger=logger, tensorboard_formatter=formatter)
        trainer.fit()
        # test loading the trainer from the checkpoint
        trainer = UNet3DTrainer.from_checkpoint(os.path.join(tmpdir, 'last_checkpoint.pytorch'),
                                                model, optimizer, lr_scheduler,
                                                loss_criterion, eval_criterion,
                                                loaders, logger=logger, tensorboard_formatter=formatter)
        return trainer
Ejemplo n.º 3
0
def main():
    parser = _arg_parser()
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    logger.info(args)

    # Create loss criterion
    loss_criterion, final_sigmoid = _get_loss_criterion(args.loss)

    model = _create_model(args.in_channels,
                          args.out_channels,
                          layer_order=args.layer_order,
                          interpolate=args.interpolate,
                          final_sigmoid=final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create error criterion
    error_criterion = DiceCoefficient()

    # Get data loaders
    loaders = _get_loaders(args.config_dir, logger)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(
            args.resume,
            model,
            optimizer,
            loss_criterion,
            error_criterion,
            loaders,
            validate_after_iters=args.validate_after_iters,
            log_after_iters=args.log_after_iters,
            logger=logger)
    else:
        trainer = UNet3DTrainer(model,
                                optimizer,
                                loss_criterion,
                                error_criterion,
                                device,
                                loaders,
                                args.checkpoint_dir,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()
Ejemplo n.º 4
0
 def _train_save_load(self,
                      tmpdir,
                      loss,
                      max_num_epochs=1,
                      log_after_iters=2,
                      validate_after_iters=2,
                      max_num_iterations=4):
     # get device to train on
     device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
     # conv-relu-groupnorm
     conv_layer_order = 'crg'
     final_sigmoid = loss == 'bce'
     loss_criterion = get_loss_criterion(loss,
                                         final_sigmoid,
                                         weight=torch.rand(2).to(device))
     model = self._create_model(final_sigmoid, conv_layer_order)
     accuracy_criterion = DiceCoefficient()
     channel_per_class = loss == 'bce'
     if loss in ['bce', 'dice']:
         label_dtype = 'float32'
     else:
         label_dtype = 'long'
     pixel_wise_weight = loss == 'pce'
     loaders = self._get_loaders(channel_per_class=channel_per_class,
                                 label_dtype=label_dtype,
                                 pixel_wise_weight=pixel_wise_weight)
     learning_rate = 2e-4
     weight_decay = 0.0001
     optimizer = optim.Adam(model.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay)
     logger = get_logger('UNet3DTrainer', logging.DEBUG)
     trainer = UNet3DTrainer(model,
                             optimizer,
                             loss_criterion,
                             accuracy_criterion,
                             device,
                             loaders,
                             tmpdir,
                             max_num_epochs=max_num_epochs,
                             log_after_iters=log_after_iters,
                             validate_after_iters=validate_after_iters,
                             max_num_iterations=max_num_iterations,
                             logger=logger)
     trainer.fit()
     # test loading the trainer from the checkpoint
     trainer = UNet3DTrainer.from_checkpoint(os.path.join(
         tmpdir, 'last_checkpoint.pytorch'),
                                             model,
                                             optimizer,
                                             loss_criterion,
                                             accuracy_criterion,
                                             loaders,
                                             logger=logger)
     return trainer
Ejemplo n.º 5
0
def main():
    logger = get_logger('UNet3DTrainer')

    config = load_config()

    logger.info(config)

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)

    # Create the model
    model = UNet3D(config['in_channels'], config['out_channels'],
                   final_sigmoid=config['final_sigmoid'],
                   init_channel_number=config['init_channel_number'],
                   conv_layer_order=config['layer_order'],
                   interpolate=config['interpolate'])

    model = model.to(config['device'])

    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)

    loaders = get_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    if config['resume'] is not None:
        trainer = UNet3DTrainer.from_checkpoint(config['resume'], model,
                                                optimizer, lr_scheduler, loss_criterion,
                                                eval_criterion, loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                                config['device'], loaders, config['checkpoint_dir'],
                                max_num_epochs=config['epochs'],
                                max_num_iterations=config['iters'],
                                validate_after_iters=config['validate_after_iters'],
                                log_after_iters=config['log_after_iters'],
                                logger=logger)

    trainer.fit()
Ejemplo n.º 6
0
    def test_single_epoch(self, tmpdir, capsys):
        with capsys.disabled():
            # get device to train on
            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else 'cpu')

            conv_layer_order = 'crg'

            loss_criterion, final_sigmoid = DiceLoss(), True

            model = self._load_model(final_sigmoid, conv_layer_order)

            error_criterion = DiceCoefficient()

            loaders = self._get_loaders()

            learning_rate = 1e-4
            weight_decay = 0.0005
            optimizer = optim.Adam(model.parameters(),
                                   lr=learning_rate,
                                   weight_decay=weight_decay)

            logger = get_logger('UNet3DTrainer', logging.DEBUG)
            trainer = UNet3DTrainer(model,
                                    optimizer,
                                    loss_criterion,
                                    error_criterion,
                                    device,
                                    loaders,
                                    tmpdir,
                                    max_num_epochs=1,
                                    log_after_iters=2,
                                    validate_after_iters=2,
                                    logger=logger)

            trainer.fit()

            # test loading the trainer from the checkpoint
            UNet3DTrainer.from_checkpoint(os.path.join(
                tmpdir, 'last_checkpoint.pytorch'),
                                          model,
                                          optimizer,
                                          loss_criterion,
                                          error_criterion,
                                          loaders,
                                          logger=logger)
Ejemplo n.º 7
0
    def _train_save_load(self, tmpdir, loss, val_metric, max_num_epochs=1, log_after_iters=2, validate_after_iters=2,
                         max_num_iterations=4):
        # get device to train on
        device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
        # conv-relu-groupnorm
        conv_layer_order = 'crg'
        final_sigmoid = loss == 'bce'
        loss_criterion = get_loss_criterion(loss, weight=torch.rand(2).to(device))
        eval_criterion = get_evaluation_metric(val_metric)
        model = self._create_model(final_sigmoid, conv_layer_order)
        channel_per_class = loss == 'bce'
        if loss in ['bce']:
            label_dtype = 'float32'
        else:
            label_dtype = 'long'
        pixel_wise_weight = loss == 'pce'

        patch = (32, 64, 64)
        stride = (32, 64, 64)
        train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), channel_per_class)
        loaders = get_loaders([train], [val], 'raw', 'label', label_dtype=label_dtype, train_patch=patch,
                              train_stride=stride, val_patch=patch, val_stride=stride, transformer='BaseTransformer',
                              pixel_wise_weight=pixel_wise_weight)

        learning_rate = 2e-4
        weight_decay = 0.0001
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        logger = get_logger('UNet3DTrainer', logging.DEBUG)
        trainer = UNet3DTrainer(model, optimizer, loss_criterion,
                                eval_criterion,
                                device, loaders, tmpdir,
                                max_num_epochs=max_num_epochs,
                                log_after_iters=log_after_iters,
                                validate_after_iters=validate_after_iters,
                                max_num_iterations=max_num_iterations,
                                logger=logger)
        trainer.fit()
        # test loading the trainer from the checkpoint
        trainer = UNet3DTrainer.from_checkpoint(
            os.path.join(tmpdir, 'last_checkpoint.pytorch'),
            model, optimizer, loss_criterion, eval_criterion, loaders,
            logger=logger)
        return trainer
Ejemplo n.º 8
0
def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger):
    assert 'trainer' in config, 'Could not find trainer configuration'
    trainer_config = config['trainer']

    skip_train_validation = trainer_config.get('skip_train_validation', False)

    # get tensorboard formatter
    tensorboard_formatter = get_tensorboard_formatter(trainer_config.get('tensorboard_formatter', None))

    # start training from scratch
    return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                            config['device'], loaders, trainer_config['checkpoint_dir'],
                            max_num_epochs=trainer_config['epochs'],
                            max_num_iterations=trainer_config['iters'],
                            validate_after_iters=trainer_config['validate_after_iters'],
                            log_after_iters=trainer_config['log_after_iters'],
                            eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                            logger=logger, tensorboard_formatter=tensorboard_formatter,
                            skip_train_validation=skip_train_validation)
Ejemplo n.º 9
0
def main():
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    config = parse_train_config()

    logger.info(config)

    # Create loss criterion
    if config.loss_weight is not None:
        loss_weight = torch.tensor(config.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    loss_criterion = get_loss_criterion(config.loss, loss_weight,
                                        config.ignore_index)

    model = UNet3D(config.in_channels,
                   config.out_channels,
                   init_channel_number=config.init_channel_number,
                   conv_layer_order=config.layer_order,
                   interpolate=config.interpolate,
                   final_sigmoid=config.final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config.eval_metric,
                                           ignore_index=config.ignore_index)

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path, val_path = config.train_path, config.val_path
    if config.loss in ['bce']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(config.train_patch)
    train_stride = tuple(config.train_stride)
    val_patch = tuple(config.val_patch)
    val_stride = tuple(config.val_stride)

    logger.info(f'Train patch/stride: {train_patch}/{train_stride}')
    logger.info(f'Val patch/stride: {val_patch}/{val_stride}')

    pixel_wise_weight = config.loss == 'pce'
    loaders = get_loaders(train_path,
                          val_path,
                          label_dtype=label_dtype,
                          raw_internal_path=config.raw_internal_path,
                          label_internal_path=config.label_internal_path,
                          train_patch=train_patch,
                          train_stride=train_stride,
                          val_patch=val_patch,
                          val_stride=val_stride,
                          transformer=config.transformer,
                          pixel_wise_weight=pixel_wise_weight,
                          curriculum_learning=config.curriculum,
                          ignore_index=config.ignore_index)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    if config.resume:
        trainer = UNet3DTrainer.from_checkpoint(config.resume,
                                                model,
                                                optimizer,
                                                loss_criterion,
                                                eval_criterion,
                                                loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(
            model,
            optimizer,
            loss_criterion,
            eval_criterion,
            device,
            loaders,
            config.checkpoint_dir,
            max_num_epochs=config.epochs,
            max_num_iterations=config.iters,
            max_patience=config.patience,
            validate_after_iters=config.validate_after_iters,
            log_after_iters=config.log_after_iters,
            logger=logger)

    trainer.fit()
Ejemplo n.º 10
0
def main():
    parser = _arg_parser()
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    logger.info(args)

    # Create loss criterion
    if args.loss_weight is not None:
        loss_weight = torch.tensor(args.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    loss_criterion, final_sigmoid = _get_loss_criterion(args.loss, loss_weight)

    model = _create_model(args.in_channels, args.out_channels,
                          layer_order=args.layer_order,
                          interpolate=args.interpolate,
                          final_sigmoid=final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    #logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create accuracy metric
    accuracy_criterion = _get_accuracy_criterion(not final_sigmoid)

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path, val_path = args.train_path, args.val_path
    if args.loss in ['bce', 'dice']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(args.train_patch)
    train_stride = tuple(args.train_stride)
    val_patch = tuple(args.val_patch)
    val_stride = tuple(args.val_stride)

    #logger.info(f'Train patch/stride: {train_patch}/{train_stride}')
    #logger.info(f'Val patch/stride: {val_patch}/{val_stride}')

    loaders = _get_loaders(train_path, val_path, label_dtype=label_dtype, train_patch=train_patch,
                           train_stride=train_stride, val_patch=val_patch, val_stride=val_stride)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(args.resume, model,
                                                optimizer, loss_criterion,
                                                accuracy_criterion, loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model, optimizer, loss_criterion,
                                accuracy_criterion,
                                device, loaders, args.checkpoint_dir,
                                max_num_epochs=args.epochs,
                                max_num_iterations=args.iters,
                                max_patience=args.patience,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()
Ejemplo n.º 11
0
    def _train_save_load(self,
                         tmpdir,
                         loss,
                         val_metric,
                         max_num_epochs=1,
                         log_after_iters=2,
                         validate_after_iters=2,
                         max_num_iterations=4):
        # conv-relu-groupnorm
        conv_layer_order = 'crg'
        final_sigmoid = loss in ['bce', 'dice']
        device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
        test_config = dict(CONFIG_BASE)
        test_config.update({
            # get device to train on
            'device': device,
            'loss': {
                'name': loss,
                'weight': np.random.rand(2).astype(np.float32)
            },
            'eval_metric': {
                'name': val_metric
            }
        })
        loss_criterion = get_loss_criterion(test_config)
        eval_criterion = get_evaluation_metric(test_config)
        model = self._create_model(final_sigmoid, conv_layer_order)
        channel_per_class = loss in ['bce', 'dice', 'gdl']
        if loss in ['bce']:
            label_dtype = 'float32'
        else:
            label_dtype = 'long'
        test_config['loaders']['transformer']['train']['label'][0][
            'dtype'] = label_dtype
        test_config['loaders']['transformer']['test']['label'][0][
            'dtype'] = label_dtype

        train, val = TestUNet3DTrainer._create_random_dataset(
            (128, 128, 128), (64, 64, 64), channel_per_class)
        test_config['loaders']['train_path'] = [train]
        test_config['loaders']['val_path'] = [val]

        loaders = get_train_loaders(test_config)

        learning_rate = 2e-4
        weight_decay = 0.0001
        optimizer = optim.Adam(model.parameters(),
                               lr=learning_rate,
                               weight_decay=weight_decay)
        lr_scheduler = MultiStepLR(optimizer, milestones=[2, 3], gamma=0.5)
        logger = get_logger('UNet3DTrainer', logging.DEBUG)
        trainer = UNet3DTrainer(model,
                                optimizer,
                                lr_scheduler,
                                loss_criterion,
                                eval_criterion,
                                device,
                                loaders,
                                tmpdir,
                                max_num_epochs=max_num_epochs,
                                log_after_iters=log_after_iters,
                                validate_after_iters=validate_after_iters,
                                max_num_iterations=max_num_iterations,
                                logger=logger)
        trainer.fit()
        # test loading the trainer from the checkpoint
        trainer = UNet3DTrainer.from_checkpoint(os.path.join(
            tmpdir, 'last_checkpoint.pytorch'),
                                                model,
                                                optimizer,
                                                lr_scheduler,
                                                loss_criterion,
                                                eval_criterion,
                                                loaders,
                                                logger=logger)
        return trainer
Ejemplo n.º 12
0
def main():
    parser = _arg_parser()
    logger = get_logger('Trainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    if args.loss_weight is not None:
        loss_weight = torch.tensor(args.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    if args.network == 'cd':
        args.loss = 'mse'
        loss_criterion = get_loss_criterion('mse', loss_weight,
                                            args.ignore_index)

        model = CoorNet(args.in_channels)

        model = model.to(device)

        accuracy_criterion = PrecisionBasedAccuracy(30)

    elif args.network == 'seg':
        if not args.loss:
            raise ValueError("Invalid loss assigned.")
        loss_criterion = get_loss_criterion(args.loss, loss_weight,
                                            args.ignore_index)

        model = UNet3D(args.in_channels,
                       args.out_channels,
                       init_channel_number=args.init_channel_number,
                       conv_layer_order=args.layer_order,
                       interpolate=True,
                       final_sigmoid=args.final_sigmoid)

        model = model.to(device)

        accuracy_criterion = DiceCoefficient(ignore_index=args.ignore_index)

    else:
        raise ValueError(
            "Incorrect network type defined by the --network argument, either cd or seg."
        )

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path = args.train_path
    if args.loss in ['bce', 'mse']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(args.train_patch)
    train_stride = tuple(args.train_stride)

    pixel_wise_weight = args.loss == 'pce'

    loaders = get_loaders(train_path,
                          label_dtype=label_dtype,
                          raw_internal_path=args.raw_internal_path,
                          label_internal_path=args.label_internal_path,
                          train_patch=train_patch,
                          train_stride=train_stride,
                          transformer=args.transformer,
                          pixel_wise_weight=pixel_wise_weight,
                          curriculum_learning=args.curriculum,
                          ignore_index=args.ignore_index)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(args.resume,
                                                model,
                                                optimizer,
                                                loss_criterion,
                                                accuracy_criterion,
                                                loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model,
                                optimizer,
                                loss_criterion,
                                accuracy_criterion,
                                device,
                                loaders,
                                args.checkpoint_dir,
                                max_num_epochs=args.epochs,
                                max_num_iterations=args.iters,
                                max_patience=args.patience,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()