def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger): assert 'trainer' in config, 'Could not find trainer configuration' trainer_config = config['trainer'] resume = trainer_config.get('resume', None) pre_trained = trainer_config.get('pre_trained', None) if resume is not None: # continue training from a given checkpoint return UNet3DTrainer.from_checkpoint(resume, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger=logger) elif pre_trained is not None: # fine-tune a given pre-trained model return UNet3DTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, device=config['device'], loaders=loaders, max_num_epochs=trainer_config['epochs'], max_num_iterations=trainer_config['iters'], validate_after_iters=trainer_config['validate_after_iters'], log_after_iters=trainer_config['log_after_iters'], eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'], logger=logger) else: # start training from scratch return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion, config['device'], loaders, trainer_config['checkpoint_dir'], max_num_epochs=trainer_config['epochs'], max_num_iterations=trainer_config['iters'], validate_after_iters=trainer_config['validate_after_iters'], log_after_iters=trainer_config['log_after_iters'], eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'], logger=logger)
def _train_save_load(self, tmpdir, loss, val_metric, model='UNet3D', max_num_epochs=1, log_after_iters=2, validate_after_iters=2, max_num_iterations=4, weight_map=False): binary_loss = loss in ['BCEWithLogitsLoss', 'DiceLoss', 'GeneralizedDiceLoss'] device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') test_config = copy.deepcopy(CONFIG_BASE) test_config['model']['name'] = model test_config.update({ # get device to train on 'device': device, 'loss': {'name': loss, 'weight': np.random.rand(2).astype(np.float32)}, 'eval_metric': {'name': val_metric} }) test_config['model']['final_sigmoid'] = binary_loss if weight_map: test_config['loaders']['weight_internal_path'] = 'weight_map' loss_criterion = get_loss_criterion(test_config) eval_criterion = get_evaluation_metric(test_config) model = get_model(test_config) model = model.to(device) if loss in ['BCEWithLogitsLoss']: label_dtype = 'float32' else: label_dtype = 'long' test_config['loaders']['transformer']['train']['label'][0]['dtype'] = label_dtype test_config['loaders']['transformer']['test']['label'][0]['dtype'] = label_dtype train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), binary_loss) test_config['loaders']['train_path'] = [train] test_config['loaders']['val_path'] = [val] loaders = get_train_loaders(test_config) optimizer = _create_optimizer(test_config, model) test_config['lr_scheduler']['name'] = 'MultiStepLR' lr_scheduler = _create_lr_scheduler(test_config, optimizer) logger = get_logger('UNet3DTrainer', logging.DEBUG) formatter = DefaultTensorboardFormatter() trainer = UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion, device, loaders, tmpdir, max_num_epochs=max_num_epochs, log_after_iters=log_after_iters, validate_after_iters=validate_after_iters, max_num_iterations=max_num_iterations, logger=logger, tensorboard_formatter=formatter) trainer.fit() # test loading the trainer from the checkpoint trainer = UNet3DTrainer.from_checkpoint(os.path.join(tmpdir, 'last_checkpoint.pytorch'), model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger=logger, tensorboard_formatter=formatter) return trainer
def main(): parser = _arg_parser() logger = get_logger('UNet3DTrainer') # Get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') args = parser.parse_args() logger.info(args) # Create loss criterion loss_criterion, final_sigmoid = _get_loss_criterion(args.loss) model = _create_model(args.in_channels, args.out_channels, layer_order=args.layer_order, interpolate=args.interpolate, final_sigmoid=final_sigmoid) model = model.to(device) # Log the number of learnable parameters logger.info( f'Number of learnable params {get_number_of_learnable_parameters(model)}' ) # Create error criterion error_criterion = DiceCoefficient() # Get data loaders loaders = _get_loaders(args.config_dir, logger) # Create the optimizer optimizer = _create_optimizer(args, model) if args.resume: trainer = UNet3DTrainer.from_checkpoint( args.resume, model, optimizer, loss_criterion, error_criterion, loaders, validate_after_iters=args.validate_after_iters, log_after_iters=args.log_after_iters, logger=logger) else: trainer = UNet3DTrainer(model, optimizer, loss_criterion, error_criterion, device, loaders, args.checkpoint_dir, validate_after_iters=args.validate_after_iters, log_after_iters=args.log_after_iters, logger=logger) trainer.fit()
def _train_save_load(self, tmpdir, loss, max_num_epochs=1, log_after_iters=2, validate_after_iters=2, max_num_iterations=4): # get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') # conv-relu-groupnorm conv_layer_order = 'crg' final_sigmoid = loss == 'bce' loss_criterion = get_loss_criterion(loss, final_sigmoid, weight=torch.rand(2).to(device)) model = self._create_model(final_sigmoid, conv_layer_order) accuracy_criterion = DiceCoefficient() channel_per_class = loss == 'bce' if loss in ['bce', 'dice']: label_dtype = 'float32' else: label_dtype = 'long' pixel_wise_weight = loss == 'pce' loaders = self._get_loaders(channel_per_class=channel_per_class, label_dtype=label_dtype, pixel_wise_weight=pixel_wise_weight) learning_rate = 2e-4 weight_decay = 0.0001 optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) logger = get_logger('UNet3DTrainer', logging.DEBUG) trainer = UNet3DTrainer(model, optimizer, loss_criterion, accuracy_criterion, device, loaders, tmpdir, max_num_epochs=max_num_epochs, log_after_iters=log_after_iters, validate_after_iters=validate_after_iters, max_num_iterations=max_num_iterations, logger=logger) trainer.fit() # test loading the trainer from the checkpoint trainer = UNet3DTrainer.from_checkpoint(os.path.join( tmpdir, 'last_checkpoint.pytorch'), model, optimizer, loss_criterion, accuracy_criterion, loaders, logger=logger) return trainer
def main(): logger = get_logger('UNet3DTrainer') config = load_config() logger.info(config) # Create loss criterion loss_criterion = get_loss_criterion(config) # Create the model model = UNet3D(config['in_channels'], config['out_channels'], final_sigmoid=config['final_sigmoid'], init_channel_number=config['init_channel_number'], conv_layer_order=config['layer_order'], interpolate=config['interpolate']) model = model.to(config['device']) # Log the number of learnable parameters logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') # Create evaluation metric eval_criterion = get_evaluation_metric(config) loaders = get_train_loaders(config) # Create the optimizer optimizer = _create_optimizer(config, model) # Create learning rate adjustment strategy lr_scheduler = _create_lr_scheduler(config, optimizer) if config['resume'] is not None: trainer = UNet3DTrainer.from_checkpoint(config['resume'], model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger=logger) else: trainer = UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion, config['device'], loaders, config['checkpoint_dir'], max_num_epochs=config['epochs'], max_num_iterations=config['iters'], validate_after_iters=config['validate_after_iters'], log_after_iters=config['log_after_iters'], logger=logger) trainer.fit()
def test_single_epoch(self, tmpdir, capsys): with capsys.disabled(): # get device to train on device = torch.device( "cuda:0" if torch.cuda.is_available() else 'cpu') conv_layer_order = 'crg' loss_criterion, final_sigmoid = DiceLoss(), True model = self._load_model(final_sigmoid, conv_layer_order) error_criterion = DiceCoefficient() loaders = self._get_loaders() learning_rate = 1e-4 weight_decay = 0.0005 optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) logger = get_logger('UNet3DTrainer', logging.DEBUG) trainer = UNet3DTrainer(model, optimizer, loss_criterion, error_criterion, device, loaders, tmpdir, max_num_epochs=1, log_after_iters=2, validate_after_iters=2, logger=logger) trainer.fit() # test loading the trainer from the checkpoint UNet3DTrainer.from_checkpoint(os.path.join( tmpdir, 'last_checkpoint.pytorch'), model, optimizer, loss_criterion, error_criterion, loaders, logger=logger)
def _train_save_load(self, tmpdir, loss, val_metric, max_num_epochs=1, log_after_iters=2, validate_after_iters=2, max_num_iterations=4): # get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') # conv-relu-groupnorm conv_layer_order = 'crg' final_sigmoid = loss == 'bce' loss_criterion = get_loss_criterion(loss, weight=torch.rand(2).to(device)) eval_criterion = get_evaluation_metric(val_metric) model = self._create_model(final_sigmoid, conv_layer_order) channel_per_class = loss == 'bce' if loss in ['bce']: label_dtype = 'float32' else: label_dtype = 'long' pixel_wise_weight = loss == 'pce' patch = (32, 64, 64) stride = (32, 64, 64) train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), channel_per_class) loaders = get_loaders([train], [val], 'raw', 'label', label_dtype=label_dtype, train_patch=patch, train_stride=stride, val_patch=patch, val_stride=stride, transformer='BaseTransformer', pixel_wise_weight=pixel_wise_weight) learning_rate = 2e-4 weight_decay = 0.0001 optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) logger = get_logger('UNet3DTrainer', logging.DEBUG) trainer = UNet3DTrainer(model, optimizer, loss_criterion, eval_criterion, device, loaders, tmpdir, max_num_epochs=max_num_epochs, log_after_iters=log_after_iters, validate_after_iters=validate_after_iters, max_num_iterations=max_num_iterations, logger=logger) trainer.fit() # test loading the trainer from the checkpoint trainer = UNet3DTrainer.from_checkpoint( os.path.join(tmpdir, 'last_checkpoint.pytorch'), model, optimizer, loss_criterion, eval_criterion, loaders, logger=logger) return trainer
def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger): assert 'trainer' in config, 'Could not find trainer configuration' trainer_config = config['trainer'] skip_train_validation = trainer_config.get('skip_train_validation', False) # get tensorboard formatter tensorboard_formatter = get_tensorboard_formatter(trainer_config.get('tensorboard_formatter', None)) # start training from scratch return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion, config['device'], loaders, trainer_config['checkpoint_dir'], max_num_epochs=trainer_config['epochs'], max_num_iterations=trainer_config['iters'], validate_after_iters=trainer_config['validate_after_iters'], log_after_iters=trainer_config['log_after_iters'], eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'], logger=logger, tensorboard_formatter=tensorboard_formatter, skip_train_validation=skip_train_validation)
def main(): logger = get_logger('UNet3DTrainer') # Get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') config = parse_train_config() logger.info(config) # Create loss criterion if config.loss_weight is not None: loss_weight = torch.tensor(config.loss_weight) loss_weight = loss_weight.to(device) else: loss_weight = None loss_criterion = get_loss_criterion(config.loss, loss_weight, config.ignore_index) model = UNet3D(config.in_channels, config.out_channels, init_channel_number=config.init_channel_number, conv_layer_order=config.layer_order, interpolate=config.interpolate, final_sigmoid=config.final_sigmoid) model = model.to(device) # Log the number of learnable parameters logger.info( f'Number of learnable params {get_number_of_learnable_parameters(model)}' ) # Create evaluation metric eval_criterion = get_evaluation_metric(config.eval_metric, ignore_index=config.ignore_index) # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float train_path, val_path = config.train_path, config.val_path if config.loss in ['bce']: label_dtype = 'float32' else: label_dtype = 'long' train_patch = tuple(config.train_patch) train_stride = tuple(config.train_stride) val_patch = tuple(config.val_patch) val_stride = tuple(config.val_stride) logger.info(f'Train patch/stride: {train_patch}/{train_stride}') logger.info(f'Val patch/stride: {val_patch}/{val_stride}') pixel_wise_weight = config.loss == 'pce' loaders = get_loaders(train_path, val_path, label_dtype=label_dtype, raw_internal_path=config.raw_internal_path, label_internal_path=config.label_internal_path, train_patch=train_patch, train_stride=train_stride, val_patch=val_patch, val_stride=val_stride, transformer=config.transformer, pixel_wise_weight=pixel_wise_weight, curriculum_learning=config.curriculum, ignore_index=config.ignore_index) # Create the optimizer optimizer = _create_optimizer(config, model) if config.resume: trainer = UNet3DTrainer.from_checkpoint(config.resume, model, optimizer, loss_criterion, eval_criterion, loaders, logger=logger) else: trainer = UNet3DTrainer( model, optimizer, loss_criterion, eval_criterion, device, loaders, config.checkpoint_dir, max_num_epochs=config.epochs, max_num_iterations=config.iters, max_patience=config.patience, validate_after_iters=config.validate_after_iters, log_after_iters=config.log_after_iters, logger=logger) trainer.fit()
def main(): parser = _arg_parser() logger = get_logger('UNet3DTrainer') # Get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') args = parser.parse_args() logger.info(args) # Create loss criterion if args.loss_weight is not None: loss_weight = torch.tensor(args.loss_weight) loss_weight = loss_weight.to(device) else: loss_weight = None loss_criterion, final_sigmoid = _get_loss_criterion(args.loss, loss_weight) model = _create_model(args.in_channels, args.out_channels, layer_order=args.layer_order, interpolate=args.interpolate, final_sigmoid=final_sigmoid) model = model.to(device) # Log the number of learnable parameters #logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') # Create accuracy metric accuracy_criterion = _get_accuracy_criterion(not final_sigmoid) # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float train_path, val_path = args.train_path, args.val_path if args.loss in ['bce', 'dice']: label_dtype = 'float32' else: label_dtype = 'long' train_patch = tuple(args.train_patch) train_stride = tuple(args.train_stride) val_patch = tuple(args.val_patch) val_stride = tuple(args.val_stride) #logger.info(f'Train patch/stride: {train_patch}/{train_stride}') #logger.info(f'Val patch/stride: {val_patch}/{val_stride}') loaders = _get_loaders(train_path, val_path, label_dtype=label_dtype, train_patch=train_patch, train_stride=train_stride, val_patch=val_patch, val_stride=val_stride) # Create the optimizer optimizer = _create_optimizer(args, model) if args.resume: trainer = UNet3DTrainer.from_checkpoint(args.resume, model, optimizer, loss_criterion, accuracy_criterion, loaders, logger=logger) else: trainer = UNet3DTrainer(model, optimizer, loss_criterion, accuracy_criterion, device, loaders, args.checkpoint_dir, max_num_epochs=args.epochs, max_num_iterations=args.iters, max_patience=args.patience, validate_after_iters=args.validate_after_iters, log_after_iters=args.log_after_iters, logger=logger) trainer.fit()
def _train_save_load(self, tmpdir, loss, val_metric, max_num_epochs=1, log_after_iters=2, validate_after_iters=2, max_num_iterations=4): # conv-relu-groupnorm conv_layer_order = 'crg' final_sigmoid = loss in ['bce', 'dice'] device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') test_config = dict(CONFIG_BASE) test_config.update({ # get device to train on 'device': device, 'loss': { 'name': loss, 'weight': np.random.rand(2).astype(np.float32) }, 'eval_metric': { 'name': val_metric } }) loss_criterion = get_loss_criterion(test_config) eval_criterion = get_evaluation_metric(test_config) model = self._create_model(final_sigmoid, conv_layer_order) channel_per_class = loss in ['bce', 'dice', 'gdl'] if loss in ['bce']: label_dtype = 'float32' else: label_dtype = 'long' test_config['loaders']['transformer']['train']['label'][0][ 'dtype'] = label_dtype test_config['loaders']['transformer']['test']['label'][0][ 'dtype'] = label_dtype train, val = TestUNet3DTrainer._create_random_dataset( (128, 128, 128), (64, 64, 64), channel_per_class) test_config['loaders']['train_path'] = [train] test_config['loaders']['val_path'] = [val] loaders = get_train_loaders(test_config) learning_rate = 2e-4 weight_decay = 0.0001 optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) lr_scheduler = MultiStepLR(optimizer, milestones=[2, 3], gamma=0.5) logger = get_logger('UNet3DTrainer', logging.DEBUG) trainer = UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion, device, loaders, tmpdir, max_num_epochs=max_num_epochs, log_after_iters=log_after_iters, validate_after_iters=validate_after_iters, max_num_iterations=max_num_iterations, logger=logger) trainer.fit() # test loading the trainer from the checkpoint trainer = UNet3DTrainer.from_checkpoint(os.path.join( tmpdir, 'last_checkpoint.pytorch'), model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger=logger) return trainer
def main(): parser = _arg_parser() logger = get_logger('Trainer') # Get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') args = parser.parse_args() if args.loss_weight is not None: loss_weight = torch.tensor(args.loss_weight) loss_weight = loss_weight.to(device) else: loss_weight = None if args.network == 'cd': args.loss = 'mse' loss_criterion = get_loss_criterion('mse', loss_weight, args.ignore_index) model = CoorNet(args.in_channels) model = model.to(device) accuracy_criterion = PrecisionBasedAccuracy(30) elif args.network == 'seg': if not args.loss: raise ValueError("Invalid loss assigned.") loss_criterion = get_loss_criterion(args.loss, loss_weight, args.ignore_index) model = UNet3D(args.in_channels, args.out_channels, init_channel_number=args.init_channel_number, conv_layer_order=args.layer_order, interpolate=True, final_sigmoid=args.final_sigmoid) model = model.to(device) accuracy_criterion = DiceCoefficient(ignore_index=args.ignore_index) else: raise ValueError( "Incorrect network type defined by the --network argument, either cd or seg." ) # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float train_path = args.train_path if args.loss in ['bce', 'mse']: label_dtype = 'float32' else: label_dtype = 'long' train_patch = tuple(args.train_patch) train_stride = tuple(args.train_stride) pixel_wise_weight = args.loss == 'pce' loaders = get_loaders(train_path, label_dtype=label_dtype, raw_internal_path=args.raw_internal_path, label_internal_path=args.label_internal_path, train_patch=train_patch, train_stride=train_stride, transformer=args.transformer, pixel_wise_weight=pixel_wise_weight, curriculum_learning=args.curriculum, ignore_index=args.ignore_index) # Create the optimizer optimizer = _create_optimizer(args, model) if args.resume: trainer = UNet3DTrainer.from_checkpoint(args.resume, model, optimizer, loss_criterion, accuracy_criterion, loaders, logger=logger) else: trainer = UNet3DTrainer(model, optimizer, loss_criterion, accuracy_criterion, device, loaders, args.checkpoint_dir, max_num_epochs=args.epochs, max_num_iterations=args.iters, max_patience=args.patience, validate_after_iters=args.validate_after_iters, log_after_iters=args.log_after_iters, logger=logger) trainer.fit()