示例#1
0
def createNirDataloader(nir_path,
                        rgb_path,
                        label_path,
                        batch_size=16,
                        use_cuda=True):
    '''
    创建dataloader
    :param image_path:
    :param label_path:
    :param batch_size:
    :param use_cuda:
    :return:
    '''
    dataset = NirRgbData(nir_path,
                         rgb_path,
                         label_path,
                         transform=NormalizeImageDict(
                             ["nir_image", "rgb_image"]))
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)
    pair_generator = NirRgbTnsPair(use_cuda=use_cuda)

    return dataloader, pair_generator
def createDataloader(image_path, label_path, batch_size=16, use_cuda=True):
    '''
    创建dataloader
    :param image_path:
    :param label_path:
    :param batch_size:
    :param use_cuda:
    :return:
    '''
    #dataset = SinglechannelData(image_path,label_path,transform=NormalizeImage(normalize_range=True, normalize_img=False))
    dataset = TestDataset(image_path,
                          label_path,
                          transform=NormalizeImageDict(["image"]))
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)
    #pair_generator = SingleChannelPairTnf(use_cuda=use_cuda)
    pair_generator = RandomTnsPair(use_cuda=use_cuda)

    return dataloader, pair_generator
def main(args):

    # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_paper_NTG_resnet101.pth.tar"
    # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/coco2017_multi_gpu/checkpoint_coco2017_multi_gpu_paper30_NTG_resnet101.pth.tar"
    #args.training_image_path = '/home/zale/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages'
    # args.training_image_path = '/media/disk2/zale/datasets/coco2017/train2017'

    checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_three_channel_paper_origin_NTG_resnet101.pth.tar"
    args.training_image_path = '/home/zlk/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages'

    random_seed = 10021
    init_seeds(random_seed + random.randint(0, 10000))
    mixed_precision = True

    utils.init_distributed_mode(args)
    print(args)

    #device,local_rank = torch_util.select_device(multi_process =True,apex=mixed_precision)

    device = torch.device(args.device)
    use_cuda = True
    # Data loading code
    print("Loading data")
    RandomTnsDataset = RandomTnsData(args.training_image_path,
                                     cache_images=False,
                                     paper_affine_generator=True,
                                     transform=NormalizeImageDict(["image"]))
    # train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
    #                               pin_memory=True)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            RandomTnsDataset)
        # test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(RandomTnsDataset)
        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    # train_batch_sampler = torch.utils.data.BatchSampler(
    #     train_sampler, args.batch_size, drop_last=True)

    data_loader = DataLoader(RandomTnsDataset,
                             sampler=train_sampler,
                             num_workers=4,
                             shuffle=(train_sampler is None),
                             pin_memory=False,
                             batch_size=args.batch_size)

    # data_loader_test = torch.utils.data.DataLoader(
    #     dataset_test, batch_size=1,
    #     sampler=test_sampler, num_workers=args.workers,
    #     collate_fn=utils.collate_fn)

    print("Creating model")
    model = CNNRegistration(use_cuda=use_cuda)

    model.to(device)

    # 优化器 和scheduler
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=args.lr)

    # 学习率小于1e-6 ntg损失下降很慢,所以最小设置为1e-6
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.lr_max_iter, eta_min=1e-6)

    # if mixed_precision:
    #     model,optimizer = amp.initialize(model,optimizer,opt_level='O1',verbosity=0)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    minium_loss, saved_epoch = load_checkpoint(model_without_ddp, optimizer,
                                               lr_scheduler, checkpoint_path,
                                               args.rank)

    vis_env = "multi_gpu_rgb_train_paper_30"
    loss = NTGLoss()
    pair_generator = RandomTnsPair(use_cuda=use_cuda)
    gridGen = AffineGridGen()
    vis = VisdomHelper(env_name=vis_env)

    print('Starting training...')
    start_time = time.time()
    draw_test_loss = False
    log_interval = 20
    for epoch in range(saved_epoch, args.num_epochs):
        start_time = time.time()

        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_loss = train(epoch,
                           model,
                           loss,
                           optimizer,
                           data_loader,
                           pair_generator,
                           gridGen,
                           vis,
                           use_cuda=use_cuda,
                           log_interval=log_interval,
                           lr_scheduler=lr_scheduler,
                           rank=args.rank)

        if draw_test_loss:
            #test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda)
            #vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table')
            pass
        else:
            vis.drawLoss(epoch, train_loss)

        end_time = time.time()
        print("epoch:", str(end_time - start_time), '秒')

        is_best = train_loss < minium_loss
        minium_loss = min(train_loss, minium_loss)

        state_dict = model_without_ddp.state_dict()
        if is_main_process():
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'args': args,
                    # 'state_dict': model.state_dict(),
                    'state_dict': state_dict,
                    'minium_loss': minium_loss,
                    'model_loss': train_loss,
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                },
                is_best,
                checkpoint_path)
示例#4
0
def start_train(training_path,test_image_path,load_from,out_path,vis_env,paper_affine_generator = False,
                random_seed=666,log_interval=100,multi_gpu=True,use_cuda=True):

    init_seeds(random_seed+random.randint(0,10000))

    device,local_rank = torch_util.select_device(multi_process =multi_gpu,apex=mixed_precision)

    # args.batch_size = args.batch_size * torch.cuda.device_count()
    args.batch_size = 16
    args.lr_scheduler = True
    draw_test_loss = False
    print(args.batch_size)


    print("创建模型中")
    model = CNNRegistration(use_cuda=use_cuda)

    model = model.to(device)

    # 优化器 和scheduler
    optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr)

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=args.lr_max_iter,
                                                               eta_min=1e-7)
    else:
        scheduler = False

    print("加载权重")
    minium_loss,saved_epoch= load_checkpoint(model,optimizer,load_from,0)

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        model,optimizer = amp.initialize(model,optimizer,opt_level='01',verbosity=0)

    if multi_gpu:
        model = nn.DataParallel(model)

    loss = NTGLoss()
    pair_generator = RandomTnsPair(use_cuda=use_cuda)
    gridGen = AffineGridGen()
    vis = VisdomHelper(env_name=vis_env)

    print("创建dataloader")
    RandomTnsDataset = RandomTnsData(training_path, cache_images=False,paper_affine_generator = paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
    train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    if draw_test_loss:
        testDataset = RandomTnsData(test_image_path, cache_images=False, paper_affine_generator=paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
        test_dataloader = DataLoader(testDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)

    print('Starting training...')

    for epoch in range(saved_epoch, args.num_epochs):
        start_time = time.time()

        train_loss = train(epoch, model, loss, optimizer, train_dataloader, pair_generator, gridGen, vis,
                           use_cuda=use_cuda, log_interval=log_interval,scheduler = scheduler)

        if draw_test_loss:
            test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda)
            vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table')
        else:
            vis.drawLoss(epoch,train_loss)

        end_time = time.time()
        print("epoch:", str(end_time - start_time),'秒')

        is_best = train_loss < minium_loss
        minium_loss = min(train_loss, minium_loss)

        state_dict = model.module.state_dict() if multi_gpu else model.state_dict()
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            #'state_dict': model.state_dict(),
            'state_dict': state_dict,
            'minium_loss': minium_loss,
            'model_loss':train_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best, out_path)
示例#5
0
def main():
    print("eval pf dataset")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    # ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/output/voc2012_coco2014_NTG_resnet101.pth.tar"
    # ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/checkpoint_voc2011_NTG_resnet101.pth.tar"
    # ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/checkpoint_voc2011_20r_NTG_resnet101.pth.tar"
    # ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/three_channel/checkpoint_NTG_resnet101.pth.tar'
    small_aff_ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/three_channel/coco2014_small_aff_checkpoint_NTG_resnet101.pth.tar'
    ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/best_checkpoint_voc2011_three_channel_paper_NTG_resnet101.pth.tar'
    # ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011_paper_affine/best_checkpoint_voc2011_NTG_resnet101.pth.tar'

    #ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/checkpoint_voc2011_30r_NTG_resnet101.pth.tar"
    # image_path = '../datasets/row_data/VOC/3
    # label_path = '../datasets/row_data/label_file/aff_param2.csv'
    #image_path = '../datasets/row_data/COCO/'
    #label_path = '../datasets/row_data/label_file/aff_param_coco.csv'

    pf_data_path = 'datasets/row_data/pf_data'

    batch_size = 128
    # 加载模型
    use_cuda = torch.cuda.is_available()

    ntg_model = createModel(ntg_checkpoint_path, use_cuda=use_cuda)
    small_aff_ntg_model = createModel(small_aff_ntg_checkpoint_path,
                                      use_cuda=use_cuda)

    dataset = PFDataset(
        csv_file=os.path.join(pf_data_path, 'test_pairs_pf.csv'),
        training_image_path=pf_data_path,
        transform=NormalizeImageDict(['source_image', 'target_image']))

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=4)

    batchTensorToVars = BatchTensorToVars(use_cuda=use_cuda)

    pt = PointTnf(use_cuda=use_cuda)

    print('Computing PCK...')
    total_correct_points_aff = 0
    ntg_total_correct_points_aff = 0
    cnn_ntg_total_correct_points_aff = 0
    total_correct_points_tps = 0
    total_correct_points_aff_tps = 0
    total_points = 0
    ntg_total_points = 0
    cnn_ntg_total_points = 0

    for i, batch in enumerate(dataloader):
        batch = batchTensorToVars(batch)
        source_im_size = batch['source_im_size']
        target_im_size = batch['target_im_size']

        source_points = batch['source_points']
        target_points = batch['target_points']

        source_image_batch = batch['source_image']
        target_image_batch = batch['target_image']

        # warp points with estimated transformations
        target_points_norm = PointsToUnitCoords(target_points, target_im_size)

        theta_estimate_batch = ntg_model(batch)

        #warped_image_batch = affine_transform_pytorch(source_image_batch, theta_estimate_batch)
        #batch['source_image'] = warped_image_batch
        #theta_estimate_batch = small_aff_ntg_model(batch)

        # 将pytorch的变换参数转为opencv的变换参数
        #theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda)

        # P5使用传统NTG方法进行优化cnn的结果
        #cnn_ntg_param_batch = estimate_param_batch(source_image_batch, target_image_batch, theta_opencv,itermax = 600)
        #theta_pytorch = param2theta(cnn_ntg_param_batch.view(-1, 2, 3),240,240,use_cuda=use_cuda)

        # theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda)
        # with torch.no_grad():
        #     ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:, 0, :, :].unsqueeze(1),
        #                                                   target_image_batch[:, 0, :, :].unsqueeze(1),
        #                                                   None, use_cuda=use_cuda, itermax=600)
        #
        #     cnn_ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:, 0, :, :].unsqueeze(1),
        #                                                       target_image_batch[:, 0, :, :].unsqueeze(1),
        #                                                       theta_opencv, use_cuda=use_cuda, itermax=600)
        #
        #     ntg_param_pytorch_batch = param2theta(ntg_param_batch,240, 240, use_cuda=use_cuda)
        #     cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch,240, 240, use_cuda=use_cuda)

        warped_points_aff_norm = pt.affPointTnf(theta_estimate_batch,
                                                target_points_norm)
        warped_points_aff = PointsToPixelCoords(warped_points_aff_norm,
                                                source_im_size)

        # ntg_warped_points_aff_norm = pt.affPointTnf(ntg_param_pytorch_batch, target_points_norm)
        # ntg_warped_points_aff = PointsToPixelCoords(ntg_warped_points_aff_norm, source_im_size)
        #
        # cnn_ntg_warped_points_aff_norm = pt.affPointTnf(cnn_ntg_param_pytorch_batch, target_points_norm)
        # cnn_ntg_warped_points_aff = PointsToPixelCoords(cnn_ntg_warped_points_aff_norm, source_im_size)

        L_pck = batch['L_pck'].data

        correct_points_aff, num_points = correct_keypoints(
            source_points.data, warped_points_aff.data, L_pck)
        # ntg_correct_points_aff, ntg_num_points = correct_keypoints(source_points.data,
        #                                                    ntg_warped_points_aff.data, L_pck)
        # cnn_ntg_correct_points_aff, cnn_ntg_num_points = correct_keypoints(source_points.data,
        #                                                    cnn_ntg_warped_points_aff.data, L_pck)

        total_correct_points_aff += correct_points_aff
        total_points += num_points

        # ntg_total_correct_points_aff += ntg_correct_points_aff
        # ntg_total_points += ntg_num_points
        #
        # cnn_ntg_total_correct_points_aff += cnn_ntg_correct_points_aff
        # cnn_ntg_total_points += cnn_ntg_num_points

        print('Batch: [{}/{} ({:.0f}%)]'.format(i, len(dataloader),
                                                100. * i / len(dataloader)))

    total_correct_points_aff = total_correct_points_aff.__float__()
    # ntg_total_correct_points_aff = ntg_total_correct_points_aff.__float__()
    # cnn_ntg_total_correct_points_aff = cnn_ntg_total_correct_points_aff.__float__()

    PCK_aff = total_correct_points_aff / total_points
    # ntg_PCK_aff=ntg_total_correct_points_aff/ntg_total_points
    # cnn_ntg_PCK_aff=cnn_ntg_total_correct_points_aff/cnn_ntg_total_points
    print('PCK affine:', PCK_aff)
    # print('ntg_PCK affine:',ntg_PCK_aff)
    # print('cnn_ntg_PCK affine:',cnn_ntg_PCK_aff)
    print('Done!')