예제 #1
0
    conf['in_channels'] = args.in_channels
    conf['inter_channels'] = args.inter_channels
    conf['num_iter'] = args.num_iter
    conf['blur'] = args.blur
    conf['num_iter'] = args.num_iter
    conf['trainable'] = args.trainable
    conf['manual_seed'] = args.manual_seed

    conf['loss'] = {}
    conf['loss']['name'] = args.loss

    conf['eval_metric'] = {}
    conf['eval_metric']['name'] = args.eval_metric
    conf['eval_metric']['skip_channels'] = args.skip_channels

    logger = get_logger('Trainer')

    # Load and log experiment configuration
    logger.info('The configurations: ')
    for k, v in conf.items():
        print('%s: %s' % (k, v))

    manual_seed = conf.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
예제 #2
0
def main():
    start_time = time.time()
    mean = 466.0
    std = 379.0
    parser = get_parser()
    args = parser.parse_args()
    foldInd = args.fold_ind
    dataDir = args.data_dir
    foldIndData = np.load(os.path.join(dataDir, 'split_ind_fold' + str(foldInd) + '.npz'))
    train_ind = foldIndData['train_ind']
    val_ind = foldIndData['val_ind']
    test_ind = foldIndData['test_ind']

    conf = get_default_conf()
    conf['manual_seed'] = args.seed
    conf['device'] = args.device

    conf['model']['name'] = args.model
    conf['model']['gcn_mode'] = args.gcn_mode

    conf['loss'] = {}
    conf['loss']['name'] = args.loss
    conf['loss']['lamda'] = args.lamda

    conf['eval_metric'] = {}
    conf['eval_metric']['name'] = args.eval_metric
    conf['eval_metric']['skip_channels'] = args.skip_channels

    conf['optimizer']['name'] = args.optimizer
    conf['optimizer']['learning_rate'] = args.learning_rate

    conf['lr_scheduler']['milestones'] = [args.epochs // 3, args.epochs // 1.5]

    conf['trainer']['resume'] = args.resume
    conf['trainer']['pre_trained'] = args.pre_trained
    conf['trainer']['validate_after_iters'] = args.validate_after_iters
    conf['trainer']['log_after_iters'] = args.log_after_iters
    conf['trainer']['epochs'] = args.epochs
    conf['trainer']['ds_weight'] = args.ds_weight

    train_foldDir = os.path.join(dataDir, 'model', 'fold' + str(foldInd))
    if not os.path.exists(train_foldDir):
        os.makedirs(train_foldDir)

    if args.loss == 'FPFNLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_lamda_' + str(
                    args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_loss_' + args.loss + '_lamda_' + str(
                    args.lamda) \
                             + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_lamda_' + str(
                args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    elif args.loss == 'PixelWiseCrossEntropyLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(
                    args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    else:
        return_weight = False
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(
                    args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)

    if args.seed != 0:
        identifier = identifier + '_seed_' + str(args.seed)

    if args.resume:
        identifier = identifier + '_resume'
    elif args.pre_trained:
        identifier = identifier + '_pretrained'

    out_dir = os.path.join(dataDir, 'out', 'fold' + str(foldInd), identifier)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    checkpoint_dir = os.path.join(train_foldDir, identifier)
    # model_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch')
    model_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch')
    conf['trainer']['checkpoint_dir'] = checkpoint_dir


    # Create main logger
    logger = get_logger('UNet3DTrainer')

    # Load and log experiment configuration
    logger.info('The configurations: ')
    for k, v in conf.items():
        print('%s: %s' % (k, v))

    manual_seed = conf.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    if args.model == 'DeepLabv3_plus_skipconnection_3d':
        model = DeepLabv3_plus_skipconnection_3d(nInputChannels=conf['model']['in_channels'],
                                                 n_classes=conf['model']['out_channels'],
                                                 os=16, pretrained=False, _print=True,
                                                 final_sigmoid=conf['model']['final_sigmoid'],
                                                 normalization='bn',
                                                 num_groups=8, devices=conf['device'])
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_3d':
        model = DeepLabv3_plus_gcn_skipconnection_3d(nInputChannels=conf['model']['in_channels'], n_classes=conf['model']['out_channels'],
                                  os=16, pretrained=False, _print=True, final_sigmoid=conf['model']['final_sigmoid'],
                                hidden_layers=conf['model']['hidden_layers'], devices=conf['device'],
                                                     gcn_mode=conf['model']['gcn_mode'])
    elif args.model == 'UNet3D':
        model = UNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                       final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, layer_order='cbr')
    elif args.model == 'ResidualUNet3D':
        model = ResidualUNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                               final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, conv_layer_order='cbr')

    utils.load_checkpoint(model_path, model)
    # put the model on GPUs
    logger.info(f"Sending the model to '{conf['device']}'")

    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(conf)
    model.eval()
    eval_scores = []
    for i in test_ind:
        raw_path = os.path.join(dataDir,'in', 'nii', 'original_mr', 'Case' + str(i) + '.nii.gz')
        label_path = os.path.join(dataDir, 'in', 'nii', 'mask', 'mask_case' + str(i) + '.nii.gz')
        raw_nii = nib.load(raw_path)
        raw = raw_nii.get_data()
        H, W, D = raw.shape # [H, W, D]
        transform_scale = [256./(H/2.), 512./W, 1.0]

        transform_raw, crop_h_start, crop_h_end = transform_volume(raw, is_crop=True, is_expand_dim=False,
                                                                   scale=transform_scale, interpolation='cubic')

        flag = False
        new_d = 18
        if D < new_d:
            transform_raw, start_d, end_d = pad_d(transform_raw, new_d=new_d)
            D = new_d
            flag = True
        else:
            start_d = 0
            end_d = 0
        original_shape = [H, W, D]

        transform_raw -= mean
        transform_raw /= std
        transform_raw = np.expand_dims(transform_raw, axis=0)
        transform_raw = np.expand_dims(transform_raw, axis=0)

        transform_raw = torch.from_numpy(transform_raw.astype(np.float32))

        label_nii = nib.load(label_path)
        label = label_nii.get_data() # numpy
        label = label.transpose((2, 0, 1)) # [D, H, W]

        label = np.expand_dims(label, axis=0) # [1, D, H, W]
        label_tensor = torch.from_numpy(label.astype(np.int64))

        with torch.no_grad():
            data = transform_raw.to(conf['device'][0])

            output = model(data) # batch, classes, d, h, w

            prob = output[0]

            prob_tensor = inv_transform_volume_tensor(input=prob, crop_h_start=crop_h_start, crop_h_end=crop_h_end,
                                                      original_shape=original_shape,
                                                      is_crop_d=flag, start_d=start_d, end_d=end_d)

            eval_score = eval_criterion(prob_tensor, label_tensor)

            if args.eval_metric == 'DiceCoefficient':
                print('Case %d, dice = %f' % (i, eval_score))
            elif args.eval_metric == 'MeanIoU':
                print('Case %d, IoU = %f' % (i, eval_score))

            eval_scores.append(eval_score)

            prob = prob_tensor.to('cpu').numpy()
            prob = np.squeeze(prob, axis=0) # numpy, [C, D, H, W]
            seg = np.argmax(prob, axis=0).astype(np.uint8) # numpy, [D, H, W]
            seg = seg.transpose((1,2,0)) # numpy, [H, W, D]

            path_split = raw_path.split(os.sep)
            seg_path = os.sep.join([out_dir, 'seg_' + path_split[-1]])


            segNii = nib.Nifti1Image(seg.astype(np.uint8), affine=raw_nii.affine)
            nib.save(segNii, seg_path)

    if args.eval_metric == 'DiceCoefficient':
        print('mean dice = %f' % np.mean(eval_scores))
        np.savez(os.path.join(out_dir, 'eval_scores.npz'), dice=eval_scores, mean_dice=np.mean(eval_scores))
    elif args.eval_metric == 'MeanIoU':
        print('mean IOU = %f' % np.mean(eval_scores))
        np.savez(os.path.join(out_dir, 'eval_scores.npz'), iou=eval_scores, mean_iou=np.mean(eval_scores))

    end_time = time.time()
    total_time = end_time - start_time
    mean_time = total_time / test_ind.size
    print('mean time for segmenting one volume: %.2f seconds' % mean_time)
예제 #3
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    foldInd = args.fold_ind
    dataDir = args.data_dir
    filePath = os.path.join(dataDir, 'in', 'h5py', 'data_fold' + str(foldInd) + '.h5')
    batch_size = args.batch_size

    conf = get_default_conf()
    conf['manual_seed'] = args.seed
    conf['device'] = args.device

    conf['model']['name'] = args.model
    conf['model']['gcn_mode'] = args.gcn_mode

    conf['loss'] = {}
    conf['loss']['name'] = args.loss
    conf['loss']['lamda'] = args.lamda

    conf['eval_metric'] = {}
    conf['eval_metric']['name'] = args.eval_metric
    conf['eval_metric']['skip_channels'] = args.skip_channels

    conf['optimizer']['name'] = args.optimizer
    conf['optimizer']['learning_rate'] = args.learning_rate

    conf['lr_scheduler']['milestones'] = [args.epochs // 3, args.epochs // 1.5]

    # conf['trainer']['resume'] = args.resume
    # conf['trainer']['pre_trained'] = args.pre_trained
    conf['trainer']['validate_after_iters'] = args.validate_after_iters
    conf['trainer']['log_after_iters'] = args.log_after_iters
    conf['trainer']['epochs'] = args.epochs
    conf['trainer']['ds_weight'] = args.ds_weight

    foldDir = os.path.join(dataDir, 'model', 'fold' + str(foldInd))
    if not os.path.exists(foldDir):
        os.makedirs(foldDir)

    if args.loss == 'FPFNLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_lamda_' + str(
                    args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_loss_' + args.loss + '_lamda_' + str(args.lamda)\
                             + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_lamda_' + str(
                args.lamda) + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    elif args.loss == 'PixelWiseCrossEntropyLoss':
        return_weight = True
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(
                    args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
    else:
        return_weight = False
        if 'gcn' in args.model:
            if args.gcn_mode == 2:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) + \
                             '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
            else:
                identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_loss_' + args.loss + '_' + args.optimizer + \
                             '_lr_' + str(args.learning_rate)
        else:
            identifier = args.model + '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)

    if args.seed != 0:
        identifier = identifier + '_seed_' + str(args.seed)

    if args.resume:
        identifier = identifier + '_resume'
        conf['trainer']['resume'] = os.path.join(foldDir, identifier,'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(foldDir, identifier)):
            if args.seed == 0:
                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate))
            else:
                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
                                   + '_seed_' + str(args.seed))
            shutil.copytree(src=src, dst=os.path.join(foldDir, identifier))
    elif args.pre_trained:
        identifier = identifier + '_pretrained'
        conf['trainer']['pre_trained'] = os.path.join(foldDir, identifier, 'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(foldDir, identifier)):
            if args.seed == 0:
                # src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                #                    '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate))

                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_0.001')
            else:
                # src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                #                    '_loss_' + args.loss + '_' + args.optimizer + '_lr_' + str(args.learning_rate)
                #                    + '_seed_' + str(args.seed))

                src = os.path.join(foldDir, 'DeepLabv3_plus_skipconnection_3d' +
                                   '_loss_' + args.loss + '_' + args.optimizer + '_lr_0.001'
                                   + '_seed_' + str(args.seed))
            shutil.copytree(src=src, dst=os.path.join(foldDir, identifier))


    checkpoint_dir = os.path.join(foldDir, identifier)

    conf['trainer']['checkpoint_dir'] = checkpoint_dir


    # Create main logger
    logger = get_logger('UNet3DTrainer')

    # Load and log experiment configuration
    logger.info('The configurations: ')
    for k, v in conf.items():
        print('%s: %s' % (k, v))

    manual_seed = conf.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    if args.model == 'DeepLabv3_plus_skipconnection_3d':
        model = DeepLabv3_plus_skipconnection_3d(nInputChannels=conf['model']['in_channels'],
                                                 n_classes=conf['model']['out_channels'],
                                                 os=16, pretrained=False, _print=True,
                                                 final_sigmoid=conf['model']['final_sigmoid'],
                                                 normalization='bn',
                                                 num_groups=8, devices=conf['device'])
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_3d':
        model = DeepLabv3_plus_gcn_skipconnection_3d(nInputChannels=conf['model']['in_channels'], n_classes=conf['model']['out_channels'],
                                  os=16, pretrained=False, _print=True, final_sigmoid=conf['model']['final_sigmoid'],
                                hidden_layers=conf['model']['hidden_layers'], devices=conf['device'],
                                                     gcn_mode=conf['model']['gcn_mode'])
    elif args.model == 'UNet3D':
        model = UNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                       final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, layer_order='cbr')
    elif args.model == 'ResidualUNet3D':
        model = ResidualUNet3D(in_channels=conf['model']['in_channels'], out_channels=conf['model']['out_channels'],
                               final_sigmoid=conf['model']['final_sigmoid'], f_maps=32, conv_layer_order='cbr')
    # put the model on GPUs
    logger.info(f"Sending the model to '{conf['device']}'")
    # model = model.to(conf['device'])
    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create loss criterion
    loss_criterion = get_loss_criterion(conf)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(conf)

    # Create data loaders

    train_data_loader, val_data_loader, test_data_loader, f = load_data(filePath=filePath, return_weight=return_weight,
                                                                        transformer_config=conf['transformer'],
                                                                        batch_size=batch_size)
    loaders = {'train': train_data_loader, 'val': val_data_loader}

    # Create the optimizer
    optimizer = _create_optimizer(conf, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(conf, optimizer)

    # Create model trainer
    trainer = _create_trainer(conf, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
                              loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders,
                              logger=logger)
    # Start training
    trainer.fit()
예제 #4
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    data_dir = args.data_dir
    fold_ind = args.fold_ind
    batch_size = args.batch_size

    if args.coarse_identifier == 'DeepLabv3_plus_gcn_skipconnection_3d_gcn_mode_2_ds_weight_0.3_loss_' \
                                 'CrossEntropyLoss_Adam_lr_0.001_pretrained':
        data_path = os.path.join(data_dir,
                                 'in/h5py/fold' + str(fold_ind) + '_data.h5')
    else:
        data_path = os.path.join(
            data_dir, 'in/h5py/fold' + str(fold_ind) + '_data_' +
            args.coarse_identifier + '.h5')

    if args.coarse_identifier == 'DeepLabv3_plus_gcn_skipconnection_3d_gcn_mode_2_ds_weight_0.3_loss_' \
                                 'CrossEntropyLoss_Adam_lr_0.001_pretrained':
        fold_dir = os.path.join(data_dir, 'model', 'fold' + str(fold_ind))
    else:
        fold_dir = os.path.join(
            data_dir, 'model',
            'fold' + str(fold_ind) + '_' + args.coarse_identifier)
    if not os.path.exists(fold_dir):
        os.makedirs(fold_dir)
    if args.model == 'non_local_crf':
        identifier = args.model + '_' + args.optimizer + '_lr_' + str(
            args.learning_rate) + '_num_iter_' + str(args.num_iter)
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_2d':
        if args.gcn_mode == 2:
            identifier = args.model + '_gcn_mode_' + str(args.gcn_mode) + '_ds_weight_' + str(args.ds_weight) +\
                         '_' + args.optimizer + '_lr_' + str(
                args.learning_rate) + '_weight_decay_' + str(
                args.weight_decay)
        else:
            identifier = args.model + '_gcn_mode_' + str(
                args.gcn_mode) + '_' + args.optimizer + '_lr_' + str(
                    args.learning_rate) + '_weight_decay_' + str(
                        args.weight_decay)
        if args.use_unary is False:
            identifier = identifier + '_noUnary'
            in_channels = 1
        else:
            in_channels = 21
    else:
        identifier = args.model + '_' + args.optimizer + '_lr_' + str(
            args.learning_rate) + '_weight_decay_' + str(args.weight_decay)
        if args.use_unary is False:
            identifier = identifier + '_noUnary'
            in_channels = 1
        else:
            in_channels = 21
    if args.augment:
        identifier = identifier + '_augment'
    if args.loss != 'CrossEntropyLoss':
        identifier = identifier + '_loss_' + args.loss

    conf = get_default_conf()
    conf['device'] = args.device
    conf['manual_seed'] = args.manual_seed

    conf['loss'] = {}
    conf['loss']['name'] = args.loss

    conf['eval_metric'] = {}
    conf['eval_metric']['name'] = args.eval_metric
    conf['eval_metric']['skip_channels'] = args.skip_channels

    conf['optimizer'] = {}
    conf['optimizer']['name'] = args.optimizer
    conf['optimizer']['learning_rate'] = args.learning_rate
    conf['optimizer']['weight_decay'] = args.weight_decay
    conf['optimizer']['momentum'] = args.momentum
    conf['optimizer']['nesterov'] = args.nesterov

    conf['lr_scheduler'] = {}
    conf['lr_scheduler']['name'] = 'MultiStepLR'
    conf['lr_scheduler']['milestones'] = [args.epochs // 3, args.epochs // 1.5]
    conf['lr_scheduler']['gamma'] = args.gamma

    conf['trainer'] = {}

    conf['trainer']['batch_size'] = batch_size
    conf['trainer']['epochs'] = args.epochs
    conf['trainer']['iters'] = args.iters
    conf['trainer'][
        'eval_score_higher_is_better'] = args.eval_score_higher_is_better
    conf['trainer']['ds_weight'] = args.ds_weight

    if args.loss == 'PixelWiseCrossEntropyLoss':
        return_weight = True
    else:
        return_weight = False

    if args.resume:
        identifier = identifier + '_resume'
        conf['trainer']['resume'] = os.path.join(fold_dir, identifier,
                                                 'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(fold_dir, identifier)):
            src = os.path.join(
                fold_dir, 'DeepLabv3_plus_skipconnection_2d' + '_Adam_lr_' +
                str(args.learning_rate) + '_weight_decay_' +
                str(args.weight_decay) + '_noUnary_augment')
            shutil.copytree(src=src, dst=os.path.join(fold_dir, identifier))
    elif args.pre_trained:
        identifier = identifier + '_pretrained'
        conf['trainer']['pre_trained'] = os.path.join(
            fold_dir, identifier, 'best_checkpoint.pytorch')
        if not os.path.exists(os.path.join(fold_dir, identifier)):
            src = os.path.join(
                fold_dir, 'DeepLabv3_plus_skipconnection_2d' + '_Adam_lr_' +
                str(args.learning_rate) + '_weight_decay_' +
                str(args.weight_decay) + '_noUnary_augment')
            shutil.copytree(src=src, dst=os.path.join(fold_dir, identifier))
    checkpoint_dir = os.path.join(fold_dir, identifier)
    conf['trainer']['checkpoint_dir'] = checkpoint_dir

    logger = get_logger('Trainer')

    # Load and log experiment configuration
    logger.info('The configurations: ')
    for k, v in conf.items():
        print('%s: %s' % (k, v))

    manual_seed = conf.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    if args.model == 'UNet2D':
        model = UNet2D(in_channels=in_channels,
                       out_channels=20,
                       final_sigmoid=False,
                       f_maps=32,
                       layer_order='cbr',
                       num_groups=8)
    elif args.model == 'ResidualUNet2D':
        model = ResidualUNet2D(in_channels=in_channels,
                               out_channels=20,
                               final_sigmoid=False,
                               f_maps=32,
                               conv_layer_order='cbr',
                               num_groups=8)
    elif args.model == 'DeepLabv3_plus_skipconnection_2d':
        model = DeepLabv3_plus_skipconnection_2d(nInputChannels=in_channels,
                                                 n_classes=20,
                                                 os=16,
                                                 pretrained=False,
                                                 _print=True,
                                                 final_sigmoid=False)
    elif args.model == 'DeepLabv3_plus_gcn_skipconnection_2d':
        model = DeepLabv3_plus_gcn_skipconnection_2d(
            nInputChannels=in_channels,
            n_classes=20,
            os=16,
            pretrained=False,
            _print=True,
            final_sigmoid=False,
            hidden_layers=128,
            gcn_mode=args.gcn_mode,
            device=conf['device'])

    # put the model on GPUs
    logger.info(f"Sending the model to '{conf['device']}'")
    model = model.to(conf['device'])
    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create loss criterion
    loss_criterion = get_loss_criterion(conf)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(conf)

    try:
        if args.augment:
            train_data_loader, val_data_loader, test_data_loader, f = load_data(
                data_path,
                batch_size=batch_size,
                transformer_config=conf['transformer'],
                return_weight=return_weight)
        else:
            train_data_loader, val_data_loader, test_data_loader, f = load_data(
                data_path, batch_size=batch_size, return_weight=return_weight)
        conf['trainer']['validate_after_iters'] = len(train_data_loader)
        conf['trainer']['log_after_iters'] = len(train_data_loader)
        # Create the optimizer
        optimizer = _create_optimizer(conf, model)

        # Create learning rate adjustment strategy
        lr_scheduler = _create_lr_scheduler(conf, optimizer)

        # Create model trainer
        trainer = _create_trainer(conf,
                                  model=model,
                                  optimizer=optimizer,
                                  lr_scheduler=lr_scheduler,
                                  loss_criterion=loss_criterion,
                                  eval_criterion=eval_criterion,
                                  train_loader=train_data_loader,
                                  val_loader=val_data_loader,
                                  logger=logger)
        # Start training
        trainer.fit()
    finally:
        f.close()
예제 #5
0
import importlib

import numpy as np
import torch
import torch.nn.functional as F
from skimage import measure

from networks.losses import compute_per_channel_dice, expand_as_one_hot
from networks.utils import get_logger, adapted_rand

LOGGER = get_logger('EvalMetric')

SUPPORTED_METRICS = [
    'dice', 'iou', 'boundary_ap', 'dt_ap', 'quantized_dt_ap', 'angle',
    'inverse_angular'
]


class DiceCoefficient:
    """Computes Dice Coefficient.
    Generalized to multiple channels by computing per-channel Dice Score
    (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average.
    Input is expected to be probabilities instead of logits.
    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
    """
    def __init__(self,
                 epsilon=1e-5,
                 ignore_index=None,
                 skip_channels=None,
                 **kwargs):