def iterDataset(dataloader,
                pair_generator,
                ntg_model,
                vis,
                threshold=10,
                use_cuda=True):
    '''
    迭代数据集中的批次数据,进行处理
    :param dataloader:
    :param pair_generator:
    :param ntg_model:
    :param use_cuda:
    :return:
    '''

    grid_loss_hist = []
    grid_loss_traditional_hist = []

    loss_fn = NTGLoss()
    gridGen = AffineGridGen()

    grid_loss = GridLoss(use_cuda=use_cuda)
    grid_loss_list = []
    grid_loss_ntg_list = []
    grid_loss_comb_list = []

    ntg_loss_total = 0

    # batch {image.shape = }
    for batch_idx, batch in enumerate(dataloader):
        #print("batch_id",batch_idx,'/',len(dataloader))

        # if batch_idx == 2:
        #     break

        if batch_idx % 5 == 0:
            print('test batch: [{}/{} ({:.0f}%)]'.format(
                batch_idx, len(dataloader),
                100. * batch_idx / len(dataloader)))

        pair_batch = pair_generator(
            batch)  # image[batch_size,1,w,h] theta_GT[batch_size,2,3]

        theta_estimate_batch = ntg_model(pair_batch)  # theta [batch_size,6]

        source_image_batch = pair_batch['source_image']
        target_image_batch = pair_batch['target_image']
        theta_GT_batch = pair_batch['theta_GT']

        sampling_grid = gridGen(theta_estimate_batch.view(-1, 2, 3))
        warped_image_batch = F.grid_sample(source_image_batch, sampling_grid)

        loss, g1xy, g2xy = loss_fn(target_image_batch, warped_image_batch)
        #print("one batch ntg:",loss.item())
        ntg_loss_total += loss.item()

        # 显示CNN配准结果
        # print("显示图片")
        visualize_cnn_result(source_image_batch, target_image_batch,
                             theta_estimate_batch, vis)
        # #
        # time.sleep(10)
        # 显示一个epoch的对比结果
        #visualize_compare_result(source_image_batch,target_image_batch,theta_GT_batch,theta_estimate_batch,use_cuda=use_cuda)

        # 显示多个epoch的折线图
        #visualize_iter_result(source_image_batch,target_image_batch,theta_GT_batch,theta_estimate_batch,use_cuda=use_cuda)

        ## 计算网格点损失配准误差
        # 将pytorch的变换参数转为opencv的变换参数
        #theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda)

        # P5使用传统NTG方法进行优化cnn的结果
        #ntg_param = estimate_param_batch(source_image_batch,target_image_batch,None,itermax=600)
        #ntg_param_pytorch = param2theta(ntg_param,240,240,use_cuda=use_cuda)
        #cnn_ntg_param_batch = estimate_param_batch(source_image_batch, target_image_batch, theta_opencv,itermax=800)
        #cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch, 240, 240, use_cuda=use_cuda)

        loss_cnn = grid_loss.compute_grid_loss(theta_estimate_batch,
                                               theta_GT_batch)
        #loss_ntg = grid_loss.compute_grid_loss(ntg_param_pytorch,theta_GT_batch)
        #loss_cnn_ntg = grid_loss.compute_grid_loss(cnn_ntg_param_pytorch_batch,theta_GT_batch)

        grid_loss_list.append(loss_cnn.detach().cpu())
        #grid_loss_ntg_list.append(loss_ntg)
        #grid_loss_comb_list.append(loss_cnn_ntg)
        ##

        # 显示特定epoch的gridloss的直方图
        # g_loss,g_trad_loss = visualize_spec_epoch_result(source_image_batch, target_image_batch, theta_GT_batch, theta_estimate_batch,
        #                             use_cuda=use_cuda)
        # grid_loss_hist.append(g_loss)
        # grid_loss_traditional_hist.append(g_trad_loss)

        # loss_cnn = grid_loss.compute_grid_loss(theta_estimate_batch,theta_GT_list)
        #
        # loss_cnn_ntg = grid_loss.compute_grid_loss(cnn_ntg_param,theta_GT_list)
    print("计算平均网格点损失")
    compute_average_grid_loss(grid_loss_list)
    print("计算平均NTG值", ntg_loss_total / len(dataloader))

    print("计算正确率")
    compute_correct_rate(grid_loss_list, threshold=threshold)
def main(args):

    # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_paper_NTG_resnet101.pth.tar"
    # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/coco2017_multi_gpu/checkpoint_coco2017_multi_gpu_paper30_NTG_resnet101.pth.tar"
    #args.training_image_path = '/home/zale/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages'
    # args.training_image_path = '/media/disk2/zale/datasets/coco2017/train2017'

    checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_three_channel_paper_origin_NTG_resnet101.pth.tar"
    args.training_image_path = '/home/zlk/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages'

    random_seed = 10021
    init_seeds(random_seed + random.randint(0, 10000))
    mixed_precision = True

    utils.init_distributed_mode(args)
    print(args)

    #device,local_rank = torch_util.select_device(multi_process =True,apex=mixed_precision)

    device = torch.device(args.device)
    use_cuda = True
    # Data loading code
    print("Loading data")
    RandomTnsDataset = RandomTnsData(args.training_image_path,
                                     cache_images=False,
                                     paper_affine_generator=True,
                                     transform=NormalizeImageDict(["image"]))
    # train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
    #                               pin_memory=True)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            RandomTnsDataset)
        # test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(RandomTnsDataset)
        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    # train_batch_sampler = torch.utils.data.BatchSampler(
    #     train_sampler, args.batch_size, drop_last=True)

    data_loader = DataLoader(RandomTnsDataset,
                             sampler=train_sampler,
                             num_workers=4,
                             shuffle=(train_sampler is None),
                             pin_memory=False,
                             batch_size=args.batch_size)

    # data_loader_test = torch.utils.data.DataLoader(
    #     dataset_test, batch_size=1,
    #     sampler=test_sampler, num_workers=args.workers,
    #     collate_fn=utils.collate_fn)

    print("Creating model")
    model = CNNRegistration(use_cuda=use_cuda)

    model.to(device)

    # 优化器 和scheduler
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=args.lr)

    # 学习率小于1e-6 ntg损失下降很慢,所以最小设置为1e-6
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.lr_max_iter, eta_min=1e-6)

    # if mixed_precision:
    #     model,optimizer = amp.initialize(model,optimizer,opt_level='O1',verbosity=0)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    minium_loss, saved_epoch = load_checkpoint(model_without_ddp, optimizer,
                                               lr_scheduler, checkpoint_path,
                                               args.rank)

    vis_env = "multi_gpu_rgb_train_paper_30"
    loss = NTGLoss()
    pair_generator = RandomTnsPair(use_cuda=use_cuda)
    gridGen = AffineGridGen()
    vis = VisdomHelper(env_name=vis_env)

    print('Starting training...')
    start_time = time.time()
    draw_test_loss = False
    log_interval = 20
    for epoch in range(saved_epoch, args.num_epochs):
        start_time = time.time()

        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_loss = train(epoch,
                           model,
                           loss,
                           optimizer,
                           data_loader,
                           pair_generator,
                           gridGen,
                           vis,
                           use_cuda=use_cuda,
                           log_interval=log_interval,
                           lr_scheduler=lr_scheduler,
                           rank=args.rank)

        if draw_test_loss:
            #test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda)
            #vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table')
            pass
        else:
            vis.drawLoss(epoch, train_loss)

        end_time = time.time()
        print("epoch:", str(end_time - start_time), '秒')

        is_best = train_loss < minium_loss
        minium_loss = min(train_loss, minium_loss)

        state_dict = model_without_ddp.state_dict()
        if is_main_process():
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'args': args,
                    # 'state_dict': model.state_dict(),
                    'state_dict': state_dict,
                    'minium_loss': minium_loss,
                    'model_loss': train_loss,
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                },
                is_best,
                checkpoint_path)
Example #3
0
def start_train(training_path,test_image_path,load_from,out_path,vis_env,paper_affine_generator = False,
                random_seed=666,log_interval=100,multi_gpu=True,use_cuda=True):

    init_seeds(random_seed+random.randint(0,10000))

    device,local_rank = torch_util.select_device(multi_process =multi_gpu,apex=mixed_precision)

    # args.batch_size = args.batch_size * torch.cuda.device_count()
    args.batch_size = 16
    args.lr_scheduler = True
    draw_test_loss = False
    print(args.batch_size)


    print("创建模型中")
    model = CNNRegistration(use_cuda=use_cuda)

    model = model.to(device)

    # 优化器 和scheduler
    optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr)

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=args.lr_max_iter,
                                                               eta_min=1e-7)
    else:
        scheduler = False

    print("加载权重")
    minium_loss,saved_epoch= load_checkpoint(model,optimizer,load_from,0)

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        model,optimizer = amp.initialize(model,optimizer,opt_level='01',verbosity=0)

    if multi_gpu:
        model = nn.DataParallel(model)

    loss = NTGLoss()
    pair_generator = RandomTnsPair(use_cuda=use_cuda)
    gridGen = AffineGridGen()
    vis = VisdomHelper(env_name=vis_env)

    print("创建dataloader")
    RandomTnsDataset = RandomTnsData(training_path, cache_images=False,paper_affine_generator = paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
    train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    if draw_test_loss:
        testDataset = RandomTnsData(test_image_path, cache_images=False, paper_affine_generator=paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
        test_dataloader = DataLoader(testDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)

    print('Starting training...')

    for epoch in range(saved_epoch, args.num_epochs):
        start_time = time.time()

        train_loss = train(epoch, model, loss, optimizer, train_dataloader, pair_generator, gridGen, vis,
                           use_cuda=use_cuda, log_interval=log_interval,scheduler = scheduler)

        if draw_test_loss:
            test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda)
            vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table')
        else:
            vis.drawLoss(epoch,train_loss)

        end_time = time.time()
        print("epoch:", str(end_time - start_time),'秒')

        is_best = train_loss < minium_loss
        minium_loss = min(train_loss, minium_loss)

        state_dict = model.module.state_dict() if multi_gpu else model.state_dict()
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            #'state_dict': model.state_dict(),
            'state_dict': state_dict,
            'minium_loss': minium_loss,
            'model_loss':train_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best, out_path)