Ejemplo n.º 1
0
Archivo: test.py Proyecto: zzdxjtu/BEAL
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model-file',
        type=str,
        default='./logs/train2/20181202_160326.365442/checkpoint_9.pth.tar',
        help='Model path')
    parser.add_argument('--dataset',
                        type=str,
                        default='Drishti-GS',
                        help='test folder id contain images ROIs to test')
    parser.add_argument('-g', '--gpu', type=int, default=0)

    parser.add_argument('--data-dir',
                        default='/home/sjwang/ssd1T/fundus/domain_adaptation/',
                        help='data root path')
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--save-root-ent',
        type=str,
        default='./results/ent/',
        help='path to save ent',
    )
    parser.add_argument(
        '--save-root-mask',
        type=str,
        default='./results/mask/',
        help='path to save mask',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=True,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )
    parser.add_argument('--test-prediction-save-path',
                        type=str,
                        default='./results/baseline/',
                        help='Path root for test image and mask')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    model_file = args.model_file

    # 1. dataset
    composed_transforms_test = transforms.Compose(
        [tr.Normalize_tf(), tr.ToTensor()])
    db_test = DL.FundusSegmentation(base_dir=args.data_dir,
                                    dataset=args.dataset,
                                    split='test',
                                    transform=composed_transforms_test)

    test_loader = DataLoader(db_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1)

    # 2. model
    model = DeepLab(num_classes=2,
                    backbone='mobilenet',
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn,
                    freeze_bn=args.freeze_bn).cuda()

    if torch.cuda.is_available():
        model = model.cuda()
    print('==> Loading %s model file: %s' %
          (model.__class__.__name__, model_file))
    checkpoint = torch.load(model_file)
    try:
        model.load_state_dict(model_data)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

    except Exception:
        model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print('==> Evaluating with %s' % (args.dataset))

    val_cup_dice = 0.0
    val_disc_dice = 0.0
    timestamp_start = \
        datetime.now(pytz.timezone('Asia/Hong_Kong'))

    for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader),
                                         total=len(test_loader),
                                         ncols=80,
                                         leave=False):
        data = sample['image']
        target = sample['map']
        img_name = sample['img_name']
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        prediction, boundary = model(data)
        prediction = torch.nn.functional.interpolate(prediction,
                                                     size=(target.size()[2],
                                                           target.size()[3]),
                                                     mode="bilinear")
        boundary = torch.nn.functional.interpolate(boundary,
                                                   size=(target.size()[2],
                                                         target.size()[3]),
                                                   mode="bilinear")
        data = torch.nn.functional.interpolate(data,
                                               size=(target.size()[2],
                                                     target.size()[3]),
                                               mode="bilinear")
        prediction = torch.sigmoid(prediction)
        boundary = torch.sigmoid(boundary)
        draw_ent(prediction.data.cpu()[0].numpy(),
                 os.path.join(args.save_root_ent, args.dataset), img_name[0])
        draw_mask(prediction.data.cpu()[0].numpy(),
                  os.path.join(args.save_root_mask, args.dataset), img_name[0])
        draw_boundary(boundary.data.cpu()[0].numpy(),
                      os.path.join(args.save_root_mask, args.dataset),
                      img_name[0])

        prediction = postprocessing(prediction.data.cpu()[0],
                                    dataset=args.dataset)
        target_numpy = target.data.cpu()
        cup_dice = dice_coefficient_numpy(prediction[0, ...],
                                          target_numpy[0, 0, ...])
        disc_dice = dice_coefficient_numpy(prediction[1, ...],
                                           target_numpy[0, 1, ...])

        val_cup_dice += cup_dice
        val_disc_dice += disc_dice

        imgs = data.data.cpu()

        for img, lt, lp in zip(imgs, target_numpy, [prediction]):
            img, lt = untransform(img, lt)
            save_per_img(img.numpy().transpose(1, 2, 0),
                         os.path.join(args.test_prediction_save_path,
                                      args.dataset),
                         img_name[0],
                         lp,
                         mask_path=None,
                         ext="bmp")

    val_cup_dice /= len(test_loader)
    val_disc_dice /= len(test_loader)

    print('''\n==>val_cup_dice : {0}'''.format(val_cup_dice))
    print('''\n==>val_disc_dice : {0}'''.format(val_disc_dice))
    with open(osp.join(args.test_prediction_save_path, 'test_log.csv'),
              'a') as f:
        elapsed_time = (datetime.now(pytz.timezone('Asia/Hong_Kong')) -
                        timestamp_start).total_seconds()
        log = [[args.model_file] + ['cup dice coefficence: '] + \
               [val_cup_dice] + ['disc dice coefficence: '] + \
               [val_disc_dice] + [elapsed_time]]
        log = map(str, log)
        f.write(','.join(log) + '\n')
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-file', type=str, default='./logs/test1/20190506_221021.177567/checkpoint_200.pth.tar', help='Model path')
    parser.add_argument('--datasetTest', type=list, default=[1], help='test folder id contain images ROIs to test')
    parser.add_argument('--dataset', type=str, default='test', help='test folder id contain images ROIs to test')
    parser.add_argument('-g', '--gpu', type=int, default=0)

    parser.add_argument('--data-dir', default='../../../../Dataset/Fundus/', help='data root path')
    parser.add_argument('--out-stride', type=int, default=16, help='out-stride of deeplabv3+',)
    parser.add_argument('--sync-bn', type=bool, default=False, help='sync-bn in deeplabv3+')
    parser.add_argument('--freeze-bn', type=bool, default=False, help='freeze batch normalization of deeplabv3+')
    parser.add_argument('--movingbn', type=bool, default=False, help='moving batch normalization of deeplabv3+ in the test phase',)
    parser.add_argument('--test-prediction-save-path', type=str, default='./results/rebuttle-0401/', help='Path root for test image and mask')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    model_file = args.model_file
    output_path = os.path.join(args.test_prediction_save_path, 'test' + str(args.datasetTest[0]), args.model_file.split('/')[-2])

    # 1. dataset
    composed_transforms_test = transforms.Compose([
        tr.Normalize_tf(),
        tr.ToTensor()
    ])
    db_test = DL.FundusSegmentation(base_dir=args.data_dir, phase='test', splitid=args.datasetTest,
                                    transform=composed_transforms_test, state='prediction')
    batch_size = 12
    test_loader = DataLoader(db_test, batch_size=batch_size, shuffle=False, num_workers=1)

    # 2. model
    model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride,
                    sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda()

    if torch.cuda.is_available():
        model = model.cuda()
    print('==> Loading %s model file: %s' %
          (model.__class__.__name__, model_file))
    # model_data = torch.load(model_file)

    checkpoint = torch.load(model_file)
    pretrained_dict = checkpoint['model_state_dict']
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    if args.movingbn:
        model.train()
    else:
        model.eval()

    val_cup_dice = 0.0
    val_disc_dice = 0.0
    total_hd_OC = 0.0
    total_hd_OD = 0.0
    total_asd_OC = 0.0
    total_asd_OD = 0.0
    timestamp_start = datetime.now(pytz.timezone('Asia/Hong_Kong'))
    total_num = 0
    OC = []
    OD = []

    for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader),total=len(test_loader),ncols=80, leave=False):
        data = sample['image']
        target = sample['label']
        img_name = sample['img_name']
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        prediction, dc, sel, _ = model(data)
        prediction = torch.nn.functional.interpolate(prediction, size=(target.size()[2], target.size()[3]), mode="bilinear")
        data = torch.nn.functional.interpolate(data, size=(target.size()[2], target.size()[3]), mode="bilinear")

        target_numpy = target.data.cpu()
        imgs = data.data.cpu()
        hd_OC = 100
        asd_OC = 100
        hd_OD = 100
        asd_OD = 100
        for i in range(prediction.shape[0]):
            prediction_post = postprocessing(prediction[i], dataset=args.dataset)
            cup_dice, disc_dice = dice_coeff_2label(prediction_post, target[i])
            OC.append(cup_dice)
            OD.append(disc_dice)
            if np.sum(prediction_post[0, ...]) < 1e-4:
                hd_OC = 100
                asd_OC = 100
            else:
                hd_OC = binary.hd95(np.asarray(prediction_post[0, ...], dtype=np.bool),
                                    np.asarray(target_numpy[i, 0, ...], dtype=np.bool))
                asd_OC = binary.asd(np.asarray(prediction_post[0, ...], dtype=np.bool),
                                    np.asarray(target_numpy[i, 0, ...], dtype=np.bool))
            if np.sum(prediction_post[0, ...]) < 1e-4:
                hd_OD = 100
                asd_OD = 100
            else:
                hd_OD = binary.hd95(np.asarray(prediction_post[1, ...], dtype=np.bool),
                                    np.asarray(target_numpy[i, 1, ...], dtype=np.bool))

                asd_OD = binary.asd(np.asarray(prediction_post[1, ...], dtype=np.bool),
                                    np.asarray(target_numpy[i, 1, ...], dtype=np.bool))
            val_cup_dice += cup_dice
            val_disc_dice += disc_dice
            total_hd_OC += hd_OC
            total_hd_OD += hd_OD
            total_asd_OC += asd_OC
            total_asd_OD += asd_OD
            total_num += 1
            for img, lt, lp in zip([imgs[i]], [target_numpy[i]], [prediction_post]):
                img, lt = utils.untransform(img, lt)
                save_per_img(img.numpy().transpose(1, 2, 0),
                             output_path,
                             img_name[i],
                             lp, lt, mask_path=None, ext="bmp")

    print('OC:', OC)
    print('OD:', OD)
    import csv
    with open('Dice_results.csv', 'a+') as result_file:
        wr = csv.writer(result_file, dialect='excel')
        for index in range(len(OC)):
            wr.writerow([OC[index], OD[index]])

    val_cup_dice /= total_num
    val_disc_dice /= total_num
    total_hd_OC /= total_num
    total_asd_OC /= total_num
    total_hd_OD /= total_num
    total_asd_OD /= total_num

    print('''\n==>val_cup_dice : {0}'''.format(val_cup_dice))
    print('''\n==>val_disc_dice : {0}'''.format(val_disc_dice))
    print('''\n==>average_hd_OC : {0}'''.format(total_hd_OC))
    print('''\n==>average_hd_OD : {0}'''.format(total_hd_OD))
    print('''\n==>ave_asd_OC : {0}'''.format(total_asd_OC))
    print('''\n==>average_asd_OD : {0}'''.format(total_asd_OD))
    with open(osp.join(output_path, '../test' + str(args.datasetTest[0]) + '_log.csv'), 'a') as f:
        elapsed_time = (
                datetime.now(pytz.timezone('Asia/Hong_Kong')) -
                timestamp_start).total_seconds()
        log = [['batch-size: '] + [batch_size] + [args.model_file] + ['cup dice coefficence: '] + \
               [val_cup_dice] + ['disc dice coefficence: '] + \
               [val_disc_dice] + ['average_hd_OC: '] + \
               [total_hd_OC] + ['average_hd_OD: '] + \
               [total_hd_OD] + ['ave_asd_OC: '] + \
               [total_asd_OC] + ['average_asd_OD: '] + \
               [total_asd_OD] + [elapsed_time]]
        log = map(str, log)
        f.write(','.join(log) + '\n')
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-file', type=str, default='./logs/refuge_weights.tar',
                        help='Model path')
    parser.add_argument(
        '--dataset', type=str, default='Drishti-GS', help='test folder id contain images ROIs to test'
    )
    parser.add_argument('-g', '--gpu', type=int, default=0)
    parser.add_argument(
        '--resize', type=int, default=800, help='image resize')

    parser.add_argument(
        '--data-dir',
        default='./fundus/',
        help='data root path'
    )
    parser.add_argument(
        '--mask-dir',
        required=True,
        default='./fundus/Drishti-GS/test/mask',
        help='mask image path'
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--save-root-ent',
        type=str,
        default='./results/ent/',
        help='path to save ent',
    )
    parser.add_argument(
        '--save-root-mask',
        type=str,
        default='./results/mask/',
        help='path to save mask',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=False,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )
    parser.add_argument('--test-prediction-save-path', type=str,
                        default='./results/baseline/',
                        help='Path root for test image and mask')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    model_file = args.model_file

    # 1. dataset
    composed_transforms_test = transforms.Compose([
        tr.Scale(args.resize),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])
    db_test = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test',
                                    transform=composed_transforms_test)

    test_loader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    # 2. model
    model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride,
                    sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda()

    if torch.cuda.is_available():
        model = model.cuda()
    print('==> Loading %s model file: %s' %
          (model.__class__.__name__, model_file))
    checkpoint = torch.load(model_file)
    # try:
    # model.load_state_dict(checkpoint)
    # pretrained_dict = checkpoint['model_state_dict']
    # model_dict = model.state_dict()
    # # 1. filter out unnecessary keys
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # # 2. overwrite entries in the existing state dict
    # model_dict.update(pretrained_dict)
    # # 3. load the new state dict
    # model.load_state_dict(model_dict)

    # except Exception:
    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        raise FileNotFoundError('No checkpoint file exist...')

    model.eval()
    print('==> Evaluating with %s' % args.dataset)

    test_cup_dice = 0.0
    test_disc_dice = 0.0
    timestamp_start = \
        datetime.now(pytz.timezone('Asia/Hong_Kong'))

    with torch.no_grad():
        for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader),
                                             total=len(test_loader),
                                             ncols=80, leave=False):
            data = sample['image']
            target = sample['map']
            img_name = sample['img_name']
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            # data, target = Variable(data), Variable(target)
            prediction, boundary = model(data)
            prediction = torch.nn.functional.interpolate(prediction, size=(target.size()[2], target.size()[3]),
                                                         mode="bilinear")
            # boundary = torch.nn.functional.interpolate(boundary, size=(target.size()[2], target.size()[3]),
            #                                            mode="bilinear")
            data = torch.nn.functional.interpolate(data, size=(target.size()[2], target.size()[3]), mode="bilinear")
            cup_dice, disc_dice = dice_coeff_2label(prediction, target)
            test_cup_dice += cup_dice
            test_disc_dice += disc_dice

            # boundary = torch.sigmoid(boundary)

            # # drawing figures
            # draw_ent(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_ent, args.dataset), img_name[0])
            # draw_mask(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0])
            # draw_boundary(boundary.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0])

            prediction, ROI_mask = postprocessing(torch.sigmoid(prediction).data.cpu()[0], dataset=args.dataset)
            imgs = data.data.cpu()
            target_numpy = target.cpu().numpy()

            for img, lt, lp in zip(imgs, target_numpy, [prediction]):
                img, lt = untransform(img, lt)
                save_per_img(img.numpy().transpose(1, 2, 0), os.path.join(args.test_prediction_save_path, args.dataset),
                             img_name[0], lp, lt, ROI_mask)

        test_cup_dice /= len(test_loader)
        test_disc_dice /= len(test_loader)

        print("test_cup_dice = ", test_cup_dice)
        print("test_disc_dice = ", test_disc_dice)

    # submit script
    _, _, mae_cdr = evaluate_segmentation_results(osp.join(args.test_prediction_save_path, args.dataset, 'pred_mask'),
                                                  args.mask_dir, output_path="./", export_table=True)

    with open(osp.join(args.test_prediction_save_path, 'test_log.csv'), 'a') as f:
        elapsed_time = (
                datetime.now(pytz.timezone('Asia/Hong_Kong')) -
                timestamp_start).total_seconds()
        log = [[args.model_file] + ['cup dice: '] + \
               [test_cup_dice] + ['disc dice: '] + \
               [test_disc_dice] + ['cdr: '] + \
               [mae_cdr] + [elapsed_time]]
        log = map(str, log)
        f.write(','.join(log) + '\n')
Ejemplo n.º 4
0
def main():
    # Add default values to all parameters
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')
    parser.add_argument(
        '--coefficient', type=float, default=0.01, help='balance coefficient'
    )
    parser.add_argument(
        '--boundary-exist', type=bool, default=True, help='whether or not using boundary branch'
    )
    parser.add_argument(
        '--dataset', type=str, default='refuge', help='folder id contain images ROIs to train or validation'
    )
    parser.add_argument(
        '--batch-size', type=int, default=12, help='batch size for training the model'
    )
    # parser.add_argument(
    #     '--group-num', type=int, default=1, help='group number for group normalization'
    # )
    parser.add_argument(
        '--max-epoch', type=int, default=300, help='max epoch'
    )
    parser.add_argument(
        '--stop-epoch', type=int, default=300, help='stop epoch'
    )
    parser.add_argument(
        '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN'
    )
    parser.add_argument(
        '--interval-validate', type=int, default=1, help='interval epoch number to valide the model'
    )
    parser.add_argument(
        '--lr-gen', type=float, default=1e-3, help='learning rate',
    )
    parser.add_argument(
        '--lr-dis', type=float, default=2.5e-5, help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate', type=float, default=0.2, help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay', type=float, default=0.0005, help='weight decay',
    )
    parser.add_argument(
        '--momentum', type=float, default=0.9, help='momentum',
    )
    parser.add_argument(
        '--data-dir',
        default='./fundus/',
        help='data root path'
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=False,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()
    args.model = 'MobileNetV2'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f'))
    os.makedirs(args.out)

    # save training hyperparameters or/and settings
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(2020)
    if cuda:
        torch.cuda.manual_seed(2020)
    
    import random
    import numpy as np
    random.seed(2020)
    np.random.seed(2020)

    # 1. loading data
    composed_transforms_train = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_val = transforms.Compose([
        tr.RandomCrop(512),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    data_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train',
                                       transform=composed_transforms_train)
    dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=4,
                                  pin_memory=True)
    data_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='testval',
                                     transform=composed_transforms_val)
    dataloader_val = DataLoader(data_val, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
    # domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train',
    #                                    transform=composed_transforms_ts)
    # domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2,
    #                                pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride,
                        sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda()

    model_bd = BoundaryDiscriminator().cuda()
    model_mask = MaskDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    optim_gen = torch.optim.Adam(
        model_gen.parameters(),
        lr=args.lr_gen,
        betas=(0.9, 0.99)
    )
    optim_bd = torch.optim.SGD(
        model_bd.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    optim_mask = torch.optim.SGD(
        model_mask.parameters(),
        lr=args.lr_dis,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # breakpoint recovery
    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_bd_state_dict']
        model_dict = model_bd.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_bd.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_mask_state_dict']
        model_dict = model_mask.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model_mask.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_bd.load_state_dict(checkpoint['optim_bd_state_dict'])
        optim_mask.load_state_dict(checkpoint['optim_mask_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_bd=model_bd,
        model_mask=model_mask,
        optimizer_gen=optim_gen,
        optim_bd=optim_bd,
        optim_mask=optim_mask,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        train_loader=dataloader_train,
        validation_loader=dataloader_val,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
        coefficient=args.coefficient,
        boundary_exist=args.boundary_exist
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')

    # configurations (same configuration as original work)
    # https://github.com/shelhamer/fcn.berkeleyvision.org
    parser.add_argument('--datasetS',
                        type=str,
                        default='refuge',
                        help='test folder id contain images ROIs to test')
    parser.add_argument('--datasetT',
                        type=str,
                        default='Drishti-GS',
                        help='refuge / Drishti-GS/ RIM-ONE_r3')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        help='batch size for training the model')
    parser.add_argument('--group-num',
                        type=int,
                        default=1,
                        help='group number for group normalization')
    parser.add_argument('--max-epoch', type=int, default=200, help='max epoch')
    parser.add_argument('--stop-epoch',
                        type=int,
                        default=200,
                        help='stop epoch')
    parser.add_argument('--warmup-epoch',
                        type=int,
                        default=-1,
                        help='warmup epoch begin train GAN')

    parser.add_argument('--interval-validate',
                        type=int,
                        default=10,
                        help='interval epoch number to valide the model')
    parser.add_argument(
        '--lr-gen',
        type=float,
        default=1e-3,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-dis',
        type=float,
        default=2.5e-5,
        help='learning rate',
    )
    parser.add_argument(
        '--lr-decrease-rate',
        type=float,
        default=0.1,
        help='ratio multiplied to initial lr',
    )
    parser.add_argument(
        '--weight-decay',
        type=float,
        default=0.0005,
        help='weight decay',
    )
    parser.add_argument(
        '--momentum',
        type=float,
        default=0.99,
        help='momentum',
    )
    parser.add_argument('--data-dir',
                        default='/home/sjwang/ssd1T/fundus/domain_adaptation/',
                        help='data root path')
    parser.add_argument(
        '--pretrained-model',
        default='../../../models/pytorch/fcn16s_from_caffe.pth',
        help='pretrained model of FCN16s',
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    parser.add_argument(
        '--sync-bn',
        type=bool,
        default=True,
        help='sync-bn in deeplabv3+',
    )
    parser.add_argument(
        '--freeze-bn',
        type=bool,
        default=False,
        help='freeze batch normalization of deeplabv3+',
    )

    args = parser.parse_args()

    args.model = 'FCN8s'

    now = datetime.now()
    args.out = osp.join(here, 'logs', args.datasetT,
                        now.strftime('%Y%m%d_%H%M%S.%f'))

    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset
    composed_transforms_tr = transforms.Compose([
        tr.RandomScaleCrop(512),
        tr.RandomRotate(),
        tr.RandomFlip(),
        tr.elastic_transform(),
        tr.add_salt_pepper_noise(),
        tr.adjust_light(),
        tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.RandomCrop(512),
         tr.Normalize_tf(),
         tr.ToTensor()])

    domain = DL.FundusSegmentation(base_dir=args.data_dir,
                                   dataset=args.datasetS,
                                   split='train',
                                   transform=composed_transforms_tr)
    domain_loaderS = DataLoader(domain,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=2,
                                pin_memory=True)
    domain_T = DL.FundusSegmentation(base_dir=args.data_dir,
                                     dataset=args.datasetT,
                                     split='train',
                                     transform=composed_transforms_tr)
    domain_loaderT = DataLoader(domain_T,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=2,
                                pin_memory=True)
    domain_val = DL.FundusSegmentation(base_dir=args.data_dir,
                                       dataset=args.datasetT,
                                       split='train',
                                       transform=composed_transforms_ts)
    domain_loader_val = DataLoader(domain_val,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=2,
                                   pin_memory=True)

    # 2. model
    model_gen = DeepLab(num_classes=2,
                        backbone='mobilenet',
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn).cuda()

    model_dis = BoundaryDiscriminator().cuda()
    model_dis2 = UncertaintyDiscriminator().cuda()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer

    optim_gen = torch.optim.Adam(model_gen.parameters(),
                                 lr=args.lr_gen,
                                 betas=(0.9, 0.99))
    optim_dis = torch.optim.SGD(model_dis.parameters(),
                                lr=args.lr_dis,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    optim_dis2 = torch.optim.SGD(model_dis2.parameters(),
                                 lr=args.lr_dis,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)

    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model_gen.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model_gen.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis_state_dict']
        model_dict = model_dis.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis.load_state_dict(model_dict)

        pretrained_dict = checkpoint['model_dis2_state_dict']
        model_dict = model_dis2.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model_dis2.load_state_dict(model_dict)

        start_epoch = checkpoint['epoch'] + 1
        start_iteration = checkpoint['iteration'] + 1
        optim_gen.load_state_dict(checkpoint['optim_state_dict'])
        optim_dis.load_state_dict(checkpoint['optim_dis_state_dict'])
        optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict'])
        optim_adv.load_state_dict(checkpoint['optim_adv_state_dict'])

    trainer = Trainer.Trainer(
        cuda=cuda,
        model_gen=model_gen,
        model_dis=model_dis,
        model_uncertainty_dis=model_dis2,
        optimizer_gen=optim_gen,
        optimizer_dis=optim_dis,
        optimizer_uncertainty_dis=optim_dis2,
        lr_gen=args.lr_gen,
        lr_dis=args.lr_dis,
        lr_decrease_rate=args.lr_decrease_rate,
        val_loader=domain_loader_val,
        domain_loaderS=domain_loaderS,
        domain_loaderT=domain_loaderT,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
        warmup_epoch=args.warmup_epoch,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
    parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--resume', default=None, help='checkpoint path')

    parser.add_argument(
        '--datasetTrain',
        nargs='+',
        type=int,
        default=1,
        help='train folder id contain images ROIs to train range from [1,2,3,4]'
    )
    parser.add_argument(
        '--datasetTest',
        nargs='+',
        type=int,
        default=1,
        help='test folder id contain images ROIs to test one of [1,2,3,4]')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        help='batch size for training the model')
    parser.add_argument('--group-num',
                        type=int,
                        default=1,
                        help='group number for group normalization')
    parser.add_argument('--max-epoch', type=int, default=120, help='max epoch')
    parser.add_argument('--stop-epoch',
                        type=int,
                        default=80,
                        help='stop epoch')
    parser.add_argument('--interval-validate',
                        type=int,
                        default=10,
                        help='interval epoch number to valide the model')
    parser.add_argument(
        '--lr',
        type=float,
        default=1e-3,
        help='learning rate',
    )
    parser.add_argument('--lr-decrease-rate',
                        type=float,
                        default=0.2,
                        help='ratio multiplied to initial lr')
    parser.add_argument(
        '--lam',
        type=float,
        default=0.9,
        help='momentum of memory update',
    )
    parser.add_argument('--data-dir',
                        default='../../../../Dataset/Fundus/',
                        help='data root path')
    parser.add_argument(
        '--pretrained-model',
        default='../../../models/pytorch/fcn16s_from_caffe.pth',
        help='pretrained model of FCN16s',
    )
    parser.add_argument(
        '--out-stride',
        type=int,
        default=16,
        help='out-stride of deeplabv3+',
    )
    args = parser.parse_args()

    now = datetime.now()
    args.out = osp.join(local_path, 'logs', 'test' + str(args.datasetTest[0]),
                        'lam' + str(args.lam),
                        now.strftime('%Y%m%d_%H%M%S.%f'))
    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()
    torch.cuda.manual_seed(1337)

    # 1. dataset
    composed_transforms_tr = transforms.Compose([
        tr.RandomScaleCrop(256),
        # tr.RandomCrop(512),
        # tr.RandomRotate(),
        # tr.RandomFlip(),
        # tr.elastic_transform(),
        # tr.add_salt_pepper_noise(),
        # tr.adjust_light(),
        # tr.eraser(),
        tr.Normalize_tf(),
        tr.ToTensor()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.RandomCrop(256),
         tr.Normalize_tf(),
         tr.ToTensor()])

    domain = DL.FundusSegmentation(base_dir=args.data_dir,
                                   phase='train',
                                   splitid=args.datasetTrain,
                                   transform=composed_transforms_tr)
    train_loader = DataLoader(domain,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

    domain_val = DL.FundusSegmentation(base_dir=args.data_dir,
                                       phase='test',
                                       splitid=args.datasetTest,
                                       transform=composed_transforms_ts)
    val_loader = DataLoader(domain_val,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=2,
                            pin_memory=True)

    # 2. model
    model = DeepLab(num_classes=2,
                    num_domain=3,
                    backbone='mobilenet',
                    output_stride=args.out_stride,
                    lam=args.lam).cuda()
    print('parameter numer:', sum([p.numel() for p in model.parameters()]))

    # load weights
    if args.resume:
        checkpoint = torch.load(args.resume)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)

        print('Before ', model.centroids.data)
        model.centroids.data = centroids_init(model, args.data_dir,
                                              args.datasetTrain,
                                              composed_transforms_ts)
        print('Before ', model.centroids.data)
        # model.freeze_para()

    start_epoch = 0
    start_iteration = 0

    # 3. optimizer
    optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99))

    trainer = Trainer.Trainer(
        cuda=cuda,
        model=model,
        lr=args.lr,
        lr_decrease_rate=args.lr_decrease_rate,
        train_loader=train_loader,
        val_loader=val_loader,
        optim=optim,
        out=args.out,
        max_epoch=args.max_epoch,
        stop_epoch=args.stop_epoch,
        interval_validate=args.interval_validate,
        batch_size=args.batch_size,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()