Пример #1
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    data_dir = args.data_dir
    fold_ind = args.fold_ind
    batch_size = args.batch_size

    if args.coarse_identifier == 'DeepLabv3_plus_gcn_skipconnection_3d_gcn_mode_2_ds_weight_0.3_loss_' \
                                 'CrossEntropyLoss_Adam_lr_0.001_pretrained':
        data_path = os.path.join(data_dir,
                                 'in/h5py/fold' + str(fold_ind) + '_data.h5')
    else:
        data_path = os.path.join(
            data_dir, 'in/h5py/fold' + str(fold_ind) + '_data_' +
            args.coarse_identifier + '.h5')

    if args.coarse_identifier == 'DeepLabv3_plus_gcn_skipconnection_3d_gcn_mode_2_ds_weight_0.3_loss_' \
                                 'CrossEntropyLoss_Adam_lr_0.001_pretrained':
        fold_dir = os.path.join(data_dir, 'model', 'fold' + str(fold_ind))
    else:
        fold_dir = os.path.join(
            data_dir, 'model',
            'fold' + str(fold_ind) + '_' + args.coarse_identifier)
    if not os.path.exists(fold_dir):
        os.makedirs(fold_dir)
    if args.model == 'non_local_crf':
        identifier = args.model + '_' + args.optimizer + '_lr_' + str(
            args.learning_rate) + '_num_iter_' + str(args.num_iter)
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_2d':
        if args.gcn_mode == 2:
            identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) +\
                         '_' + args.optimizer + '_lr_' + str(
                args.learning_rate) + '_weight_decay_' + str(
                args.weight_decay)
        else:
            identifier = args.model + '_gcn_mode_' + str(
                args.gcn_mode) + '_' + args.optimizer + '_lr_' + str(
                    args.learning_rate) + '_weight_decay_' + str(
                        args.weight_decay)
        if args.use_unary is False:
            identifier = identifier + '_noUnary'
            in_channels = 1
        else:
            in_channels = 21
    else:
        identifier = args.model + '_' + args.optimizer + '_lr_' + str(
            args.learning_rate) + '_weight_decay_' + str(args.weight_decay)
        if args.use_unary is False:
            identifier = identifier + '_noUnary'
            in_channels = 1
        else:
            in_channels = 21
    if args.augment:
        identifier = identifier + '_augment'
    if args.loss != 'CrossEntropyLoss':
        identifier = identifier + '_loss_' + args.loss

    conf = get_default_conf()
    conf['device'] = args.device
    conf['manual_seed'] = args.manual_seed

    conf['loss'] = {}
    conf['loss']['name'] = args.loss

    conf['eval_metric'] = {}
    conf['eval_metric']['name'] = args.eval_metric
    conf['eval_metric']['skip_channels'] = args.skip_channels

    conf['optimizer'] = {}
    conf['optimizer']['name'] = args.optimizer
    conf['optimizer']['learning_rate'] = args.learning_rate
    conf['optimizer']['weight_decay'] = args.weight_decay
    conf['optimizer']['momentum'] = args.momentum
    conf['optimizer']['nesterov'] = args.nesterov

    conf['lr_scheduler'] = {}
    conf['lr_scheduler']['name'] = 'MultiStepLR'
    conf['lr_scheduler']['milestones'] = [args.epochs // 3, args.epochs // 1.5]
    conf['lr_scheduler']['gamma'] = args.gamma

    conf['trainer'] = {}

    conf['trainer']['batch_size'] = batch_size
    conf['trainer']['epochs'] = args.epochs
    conf['trainer']['iters'] = args.iters
    conf['trainer'][
        'eval_score_higher_is_better'] = args.eval_score_higher_is_better
    conf['trainer']['ds_weight'] = args.ds_weight

    if args.loss == 'PixelWiseCrossEntropyLoss':
        return_weight = True
    else:
        return_weight = False

    if args.resume:
        identifier = identifier + '_resume'
        conf['trainer']['resume'] = os.path.join(fold_dir, identifier,
                                                 'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(fold_dir, identifier)):
            src = os.path.join(
                fold_dir, 'DeepLabv3_plus_skipconnection_2d' + '_Adam_lr_' +
                str(args.learning_rate) + '_weight_decay_' +
                str(args.weight_decay) + '_noUnary_augment')
            shutil.copytree(src=src, dst=os.path.join(fold_dir, identifier))
    elif args.pre_trained:
        identifier = identifier + '_pretrained'
        conf['trainer']['pre_trained'] = os.path.join(
            fold_dir, identifier, 'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(fold_dir, identifier)):
            src = os.path.join(
                fold_dir, 'DeepLabv3_plus_skipconnection_2d' + '_Adam_lr_' +
                str(args.learning_rate) + '_weight_decay_' +
                str(args.weight_decay) + '_noUnary_augment')
            shutil.copytree(src=src, dst=os.path.join(fold_dir, identifier))
    checkpoint_dir = os.path.join(fold_dir, identifier)
    conf['trainer']['checkpoint_dir'] = checkpoint_dir

    logger = get_logger('Trainer')

    # Load and log experiment configuration
    logger.info('The configurations: ')
    for k, v in conf.items():
        print('%s: %s' % (k, v))

    manual_seed = conf.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    if args.model == 'UNet2D':
        model = UNet2D(in_channels=in_channels,
                       out_channels=20,
                       final_sigmoid=False,
                       f_maps=32,
                       layer_order='cbr',
                       num_groups=8)
    elif args.model == 'ResidualUNet2D':
        model = ResidualUNet2D(in_channels=in_channels,
                               out_channels=20,
                               final_sigmoid=False,
                               f_maps=32,
                               conv_layer_order='cbr',
                               num_groups=8)
    elif args.model == 'DeepLabv3_plus_skipconnection_2d':
        model = DeepLabv3_plus_skipconnection_2d(nInputChannels=in_channels,
                                                 n_classes=20,
                                                 os=16,
                                                 pretrained=False,
                                                 _print=True,
                                                 final_sigmoid=False)
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_2d':
        model = DeepLabv3_plus_gcn_skipconnection_2d(
            nInputChannels=in_channels,
            n_classes=20,
            os=16,
            pretrained=False,
            _print=True,
            final_sigmoid=False,
            hidden_layers=128,
            gcn_mode=args.gcn_mode,
            device=conf['device'])

    # put the model on GPUs
    logger.info(f"Sending the model to '{conf['device']}'")
    model = model.to(conf['device'])
    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create loss criterion
    loss_criterion = get_loss_criterion(conf)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(conf)

    try:
        if args.augment:
            train_data_loader, val_data_loader, test_data_loader, f = load_data(
                data_path,
                batch_size=batch_size,
                transformer_config=conf['transformer'],
                return_weight=return_weight)
        else:
            train_data_loader, val_data_loader, test_data_loader, f = load_data(
                data_path, batch_size=batch_size, return_weight=return_weight)
        conf['trainer']['validate_after_iters'] = len(train_data_loader)
        conf['trainer']['log_after_iters'] = len(train_data_loader)
        # Create the optimizer
        optimizer = _create_optimizer(conf, model)

        # Create learning rate adjustment strategy
        lr_scheduler = _create_lr_scheduler(conf, optimizer)

        # Create model trainer
        trainer = _create_trainer(conf,
                                  model=model,
                                  optimizer=optimizer,
                                  lr_scheduler=lr_scheduler,
                                  loss_criterion=loss_criterion,
                                  eval_criterion=eval_criterion,
                                  train_loader=train_data_loader,
                                  val_loader=val_data_loader,
                                  logger=logger)
        # Start training
        trainer.fit()
    finally:
        f.close()
Пример #2
0
def main():
    start_time = time.time()
    mean = 466.0
    std = 379.0
    parser = get_parser()
    args = parser.parse_args()
    foldInd = args.fold_ind
    dataDir = args.data_dir
    foldIndData = np.load(os.path.join(dataDir, 'split_ind_fold' + str(foldInd) + '.npz'))
    train_ind = foldIndData['train_ind']
    val_ind = foldIndData['val_ind']
    test_ind = foldIndData['test_ind']

    conf = get_default_conf()
    conf['manual_seed'] = args.seed
    conf['device'] = args.device

    conf['model']['name'] = args.model
    conf['model']['gcn_mode'] = args.gcn_mode

    conf['loss'] = {}
    conf['loss']['name'] = args.loss
    conf['loss']['lamda'] = args.lamda

    conf['eval_metric'] = {}
    conf['eval_metric']['name'] = args.eval_metric
    conf['eval_metric']['skip_channels'] = args.skip_channels

    conf['optimizer']['name'] = args.optimizer
    conf['optimizer']['learning_rate'] = args.learning_rate

    conf['lr_scheduler']['milestones'] = [args.epochs // 3, args.epochs // 1.5]

    conf['trainer']['resume'] = args.resume
    conf['trainer']['pre_trained'] = args.pre_trained
    conf['trainer']['validate_after_iters'] = args.validate_after_iters
    conf['trainer']['log_after_iters'] = args.log_after_iters
    conf['trainer']['epochs'] = args.epochs
    conf['trainer']['ds_weight'] = args.ds_weight

    train_foldDir = os.path.join(dataDir, 'model', 'fold' + str(foldInd))
    if not os.path.exists(train_foldDir):
        os.makedirs(train_foldDir)

    if args.loss == 'FPFNLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_lamda_' + str(
                    args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_loss_' + args.loss + '_lamda_' + str(
                    args.lamda) \
                             + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_lamda_' + str(
                args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    elif args.loss == 'PixelWiseCrossEntropyLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(
                    args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    else:
        return_weight = False
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(
                    args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)

    if args.seed != 0:
        identifier = identifier + '_seed_' + str(args.seed)

    if args.resume:
        identifier = identifier + '_resume'
    elif args.pre_trained:
        identifier = identifier + '_pretrained'

    out_dir = os.path.join(dataDir, 'out', 'fold' + str(foldInd), identifier)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    checkpoint_dir = os.path.join(train_foldDir, identifier)
    # model_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch')
    model_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch')
    conf['trainer']['checkpoint_dir'] = checkpoint_dir


    # Create main logger
    logger = get_logger('UNet3DTrainer')

    # Load and log experiment configuration
    logger.info('The configurations: ')
    for k, v in conf.items():
        print('%s: %s' % (k, v))

    manual_seed = conf.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    if args.model == 'DeepLabv3_plus_skipconnection_3d':
        model = DeepLabv3_plus_skipconnection_3d(nInputChannels=conf['model']['in_channels'],
                                                 n_classes=conf['model']['out_channels'],
                                                 os=16, pretrained=False, _print=True,
                                                 final_sigmoid=conf['model']['final_sigmoid'],
                                                 normalization='bn',
                                                 num_groups=8, devices=conf['device'])
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_3d':
        model = DeepLabv3_plus_gcn_skipconnection_3d(nInputChannels=conf['model']['in_channels'], n_classes=conf['model']['out_channels'],
                                  os=16, pretrained=False, _print=True, final_sigmoid=conf['model']['final_sigmoid'],
                                hidden_layers=conf['model']['hidden_layers'], devices=conf['device'],
                                                     gcn_mode=conf['model']['gcn_mode'])
    elif args.model == 'UNet3D':
        model = UNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                       final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, layer_order='cbr')
    elif args.model == 'ResidualUNet3D':
        model = ResidualUNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                               final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, conv_layer_order='cbr')

    utils.load_checkpoint(model_path, model)
    # put the model on GPUs
    logger.info(f"Sending the model to '{conf['device']}'")

    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(conf)
    model.eval()
    eval_scores = []
    for i in test_ind:
        raw_path = os.path.join(dataDir,'in', 'nii', 'original_mr', 'Case' + str(i) + '.nii.gz')
        label_path = os.path.join(dataDir, 'in', 'nii', 'mask', 'mask_case' + str(i) + '.nii.gz')
        raw_nii = nib.load(raw_path)
        raw = raw_nii.get_data()
        H, W, D = raw.shape # [H, W, D]
        transform_scale = [256./(H/2.), 512./W, 1.0]

        transform_raw, crop_h_start, crop_h_end = transform_volume(raw, is_crop=True, is_expand_dim=False,
                                                                   scale=transform_scale, interpolation='cubic')

        flag = False
        new_d = 18
        if D < new_d:
            transform_raw, start_d, end_d = pad_d(transform_raw, new_d=new_d)
            D = new_d
            flag = True
        else:
            start_d = 0
            end_d = 0
        original_shape = [H, W, D]

        transform_raw -= mean
        transform_raw /= std
        transform_raw = np.expand_dims(transform_raw, axis=0)
        transform_raw = np.expand_dims(transform_raw, axis=0)

        transform_raw = torch.from_numpy(transform_raw.astype(np.float32))

        label_nii = nib.load(label_path)
        label = label_nii.get_data() # numpy
        label = label.transpose((2, 0, 1)) # [D, H, W]

        label = np.expand_dims(label, axis=0) # [1, D, H, W]
        label_tensor = torch.from_numpy(label.astype(np.int64))

        with torch.no_grad():
            data = transform_raw.to(conf['device'][0])

            output = model(data) # batch, classes, d, h, w

            prob = output[0]

            prob_tensor = inv_transform_volume_tensor(input=prob, crop_h_start=crop_h_start, crop_h_end=crop_h_end,
                                                      original_shape=original_shape,
                                                      is_crop_d=flag, start_d=start_d, end_d=end_d)

            eval_score = eval_criterion(prob_tensor, label_tensor)

            if args.eval_metric == 'DiceCoefficient':
                print('Case %d, dice = %f' % (i, eval_score))
            elif args.eval_metric == 'MeanIoU':
                print('Case %d, IoU = %f' % (i, eval_score))

            eval_scores.append(eval_score)

            prob = prob_tensor.to('cpu').numpy()
            prob = np.squeeze(prob, axis=0) # numpy, [C, D, H, W]
            seg = np.argmax(prob, axis=0).astype(np.uint8) # numpy, [D, H, W]
            seg = seg.transpose((1,2,0)) # numpy, [H, W, D]

            path_split = raw_path.split(os.sep)
            seg_path = os.sep.join([out_dir, 'seg_' + path_split[-1]])


            segNii = nib.Nifti1Image(seg.astype(np.uint8), affine=raw_nii.affine)
            nib.save(segNii, seg_path)

    if args.eval_metric == 'DiceCoefficient':
        print('mean dice = %f' % np.mean(eval_scores))
        np.savez(os.path.join(out_dir, 'eval_scores.npz'), dice=eval_scores, mean_dice=np.mean(eval_scores))
    elif args.eval_metric == 'MeanIoU':
        print('mean IOU = %f' % np.mean(eval_scores))
        np.savez(os.path.join(out_dir, 'eval_scores.npz'), iou=eval_scores, mean_iou=np.mean(eval_scores))

    end_time = time.time()
    total_time = end_time - start_time
    mean_time = total_time / test_ind.size
    print('mean time for segmenting one volume: %.2f seconds' % mean_time)
Пример #3
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    foldInd = args.fold_ind
    dataDir = args.data_dir
    filePath = os.path.join(dataDir, 'in', 'h5py', 'data_fold' + str(foldInd) + '.h5')
    batch_size = args.batch_size

    conf = get_default_conf()
    conf['manual_seed'] = args.seed
    conf['device'] = args.device

    conf['model']['name'] = args.model
    conf['model']['gcn_mode'] = args.gcn_mode

    conf['loss'] = {}
    conf['loss']['name'] = args.loss
    conf['loss']['lamda'] = args.lamda

    conf['eval_metric'] = {}
    conf['eval_metric']['name'] = args.eval_metric
    conf['eval_metric']['skip_channels'] = args.skip_channels

    conf['optimizer']['name'] = args.optimizer
    conf['optimizer']['learning_rate'] = args.learning_rate

    conf['lr_scheduler']['milestones'] = [args.epochs // 3, args.epochs // 1.5]

    # conf['trainer']['resume'] = args.resume
    # conf['trainer']['pre_trained'] = args.pre_trained
    conf['trainer']['validate_after_iters'] = args.validate_after_iters
    conf['trainer']['log_after_iters'] = args.log_after_iters
    conf['trainer']['epochs'] = args.epochs
    conf['trainer']['ds_weight'] = args.ds_weight

    foldDir = os.path.join(dataDir, 'model', 'fold' + str(foldInd))
    if not os.path.exists(foldDir):
        os.makedirs(foldDir)

    if args.loss == 'FPFNLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_lamda_' + str(
                    args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_loss_' + args.loss + '_lamda_' + str(args.lamda)\
                             + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_lamda_' + str(
                args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    elif args.loss == 'PixelWiseCrossEntropyLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(
                    args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    else:
        return_weight = False
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)

    if args.seed != 0:
        identifier = identifier + '_seed_' + str(args.seed)

    if args.resume:
        identifier = identifier + '_resume'
        conf['trainer']['resume'] = os.path.join(foldDir, identifier,'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(foldDir, identifier)):
            if args.seed == 0:
                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate))
            else:
                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
                                   + '_seed_' + str(args.seed))
            shutil.copytree(src=src, dst=os.path.join(foldDir, identifier))
    elif args.pre_trained:
        identifier = identifier + '_pretrained'
        conf['trainer']['pre_trained'] = os.path.join(foldDir, identifier, 'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(foldDir, identifier)):
            if args.seed == 0:
                # src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                #                    '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate))

                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_0.001')
            else:
                # src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                #                    '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
                #                    + '_seed_' + str(args.seed))

                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_0.001'
                                   + '_seed_' + str(args.seed))
            shutil.copytree(src=src, dst=os.path.join(foldDir, identifier))


    checkpoint_dir = os.path.join(foldDir, identifier)

    conf['trainer']['checkpoint_dir'] = checkpoint_dir


    # Create main logger
    logger = get_logger('UNet3DTrainer')

    # Load and log experiment configuration
    logger.info('The configurations: ')
    for k, v in conf.items():
        print('%s: %s' % (k, v))

    manual_seed = conf.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    if args.model == 'DeepLabv3_plus_skipconnection_3d':
        model = DeepLabv3_plus_skipconnection_3d(nInputChannels=conf['model']['in_channels'],
                                                 n_classes=conf['model']['out_channels'],
                                                 os=16, pretrained=False, _print=True,
                                                 final_sigmoid=conf['model']['final_sigmoid'],
                                                 normalization='bn',
                                                 num_groups=8, devices=conf['device'])
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_3d':
        model = DeepLabv3_plus_gcn_skipconnection_3d(nInputChannels=conf['model']['in_channels'], n_classes=conf['model']['out_channels'],
                                  os=16, pretrained=False, _print=True, final_sigmoid=conf['model']['final_sigmoid'],
                                hidden_layers=conf['model']['hidden_layers'], devices=conf['device'],
                                                     gcn_mode=conf['model']['gcn_mode'])
    elif args.model == 'UNet3D':
        model = UNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                       final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, layer_order='cbr')
    elif args.model == 'ResidualUNet3D':
        model = ResidualUNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                               final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, conv_layer_order='cbr')
    # put the model on GPUs
    logger.info(f"Sending the model to '{conf['device']}'")
    # model = model.to(conf['device'])
    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create loss criterion
    loss_criterion = get_loss_criterion(conf)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(conf)

    # Create data loaders

    train_data_loader, val_data_loader, test_data_loader, f = load_data(filePath=filePath, return_weight=return_weight,
                                                                        transformer_config=conf['transformer'],
                                                                        batch_size=batch_size)
    loaders = {'train': train_data_loader, 'val': val_data_loader}

    # Create the optimizer
    optimizer = _create_optimizer(conf, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(conf, optimizer)

    # Create model trainer
    trainer = _create_trainer(conf, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
                              loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders,
                              logger=logger)
    # Start training
    trainer.fit()