Esempio n. 1
0
                                              model,
                                              loss,
                                              optimizer,
                                              dataloader,
                                              pair_generation_tnf,
                                              log_interval=100)
        model.FeatureRegression.eval()
        test_loss[epoch - 1] = process_epoch('test',
                                             epoch,
                                             model,
                                             loss,
                                             optimizer,
                                             dataloader_test,
                                             pair_generation_tnf,
                                             log_interval=100)

        # remember best loss
        is_best = test_loss[epoch - 1] < best_test_loss
        best_test_loss = min(test_loss[epoch - 1], best_test_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'best_test_loss': best_test_loss,
                'optimizer': optimizer.state_dict(),
                'train_loss': train_loss,
                'test_loss': test_loss,
            }, is_best, checkpoint_name)

    print('Done!')
Esempio n. 2
0
def main():

    args, arg_groups = ArgumentParser(mode='train').parse()
    print(args)

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    # Seed
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    # Download dataset if needed and set paths
    if args.training_dataset == 'pascal':

        if args.dataset_image_path == '' and not os.path.exists(
                'datasets/pascal-voc11/TrainVal'):
            download_pascal('datasets/pascal-voc11/')

        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/pascal-voc11/'

        args.dataset_csv_path = 'training_data/pascal-random'

    # CNN model and loss
    print('Creating CNN model...')
    if args.geometric_model == 'affine':
        cnn_output_dim = 6
    elif args.geometric_model == 'hom' and args.four_point_hom:
        cnn_output_dim = 8
    elif args.geometric_model == 'hom' and not args.four_point_hom:
        cnn_output_dim = 9
    elif args.geometric_model == 'tps':
        cnn_output_dim = 18

    model = CNNGeometric(use_cuda=use_cuda,
                         output_dim=cnn_output_dim,
                         **arg_groups['model'])

    if args.geometric_model == 'hom' and not args.four_point_hom:
        init_theta = torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1], device=device)
        model.FeatureRegression.linear.bias.data += init_theta

    if args.geometric_model == 'hom' and args.four_point_hom:
        init_theta = torch.tensor([-1, -1, 1, 1, -1, 1, -1, 1], device=device)
        model.FeatureRegression.linear.bias.data += init_theta

    if args.use_mse_loss:
        print('Using MSE loss...')
        loss = nn.MSELoss()
    else:
        print('Using grid loss...')
        loss = TransformedGridLoss(use_cuda=use_cuda,
                                   geometric_model=args.geometric_model)

    # Initialize Dataset objects
    dataset = SynthDataset(geometric_model=args.geometric_model,
                           dataset_csv_path=args.dataset_csv_path,
                           dataset_csv_file='train.csv',
                           dataset_image_path=args.dataset_image_path,
                           transform=NormalizeImageDict(['image']),
                           random_sample=args.random_sample)

    dataset_val = SynthDataset(geometric_model=args.geometric_model,
                               dataset_csv_path=args.dataset_csv_path,
                               dataset_csv_file='val.csv',
                               dataset_image_path=args.dataset_image_path,
                               transform=NormalizeImageDict(['image']),
                               random_sample=args.random_sample)

    # Set Tnf pair generation func
    pair_generation_tnf = SynthPairTnf(geometric_model=args.geometric_model,
                                       use_cuda=use_cuda)

    # Initialize DataLoaders
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)

    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)

    # Optimizer and eventual scheduler
    optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr)

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.lr_max_iter, eta_min=1e-6)
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    else:
        scheduler = False

    # Train

    # Set up names for checkpoints
    if args.use_mse_loss:
        ckpt = args.trained_model_fn + '_' + args.geometric_model + '_mse_loss' + args.feature_extraction_cnn
        checkpoint_path = os.path.join(args.trained_model_dir,
                                       args.trained_model_fn,
                                       ckpt + '.pth.tar')
    else:
        ckpt = args.trained_model_fn + '_' + args.geometric_model + '_grid_loss' + args.feature_extraction_cnn
        checkpoint_path = os.path.join(args.trained_model_dir,
                                       args.trained_model_fn,
                                       ckpt + '.pth.tar')
    if not os.path.exists(args.trained_model_dir):
        os.mkdir(args.trained_model_dir)

    # Set up TensorBoard writer
    if not args.log_dir:
        tb_dir = os.path.join(args.trained_model_dir,
                              args.trained_model_fn + '_tb_logs')
    else:
        tb_dir = os.path.join(args.log_dir, args.trained_model_fn + '_tb_logs')

    logs_writer = SummaryWriter(tb_dir)
    # add graph, to do so we have to generate a dummy input to pass along with the graph
    dummy_input = {
        'source_image': torch.rand([args.batch_size, 3, 240, 240],
                                   device=device),
        'target_image': torch.rand([args.batch_size, 3, 240, 240],
                                   device=device),
        'theta_GT': torch.rand([16, 2, 3], device=device)
    }

    logs_writer.add_graph(model, dummy_input)

    # Start of training
    print('Starting training...')

    best_val_loss = float("inf")

    for epoch in range(1, args.num_epochs + 1):

        # we don't need the average epoch loss so we assign it to _
        _ = train(epoch,
                  model,
                  loss,
                  optimizer,
                  dataloader,
                  pair_generation_tnf,
                  log_interval=args.log_interval,
                  scheduler=scheduler,
                  tb_writer=logs_writer)

        val_loss = validate_model(model, loss, dataloader_val,
                                  pair_generation_tnf, epoch, logs_writer)

        # remember best loss
        is_best = val_loss < best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'best_val_loss': best_val_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best, checkpoint_path)

    logs_writer.close()
    print('Done!')
    if args.load_model:
        load_dir = 'trained_models/checkpoint_seresnext101.pth.tar'
        checkpoint = torch.load(load_dir, map_location=lambda storage, loc: storage)  # Load trained model

        # Load parameters of FeatureExtraction
        for name, param in model.FeatureExtraction.state_dict().items():
            model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name])
        # Load parameters of FeatureRegression (Affine)
        for name, param in model.FeatureRegression.state_dict().items():
            model.FeatureRegression.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression.' + name])
        print("Reloading from--[%s]" % load_dir)

    for epoch in range(1, args.num_epochs+1):
        # Call train, test function
        train_loss = train(epoch,model,loss,optimizer,dataloader_train,pair_generation_tnf,log_interval=100)
        # test_loss = test(model,loss,dataloader_test,pair_generation_tnf)

        if args.use_mse_loss:
            checkpoint_name = os.path.join(args.trained_models_dir,args.geometric_model+'_mse_loss_'+args.feature_extraction_cnn+'_'+args.training_dataset+'_epoch_'+str(epoch)+'.pth.tar')
        else:
            checkpoint_name = os.path.join(args.trained_models_dir,args.geometric_model+'_grid_loss_'+args.feature_extraction_cnn+'_'+args.training_dataset+'_epoch_'+str(epoch)+'.pth.tar')
        # Save checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
        },checkpoint_name)

    print('Done!')
def main(args):

    # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_paper_NTG_resnet101.pth.tar"
    # checkpoint_path = "/home/zale/project/registration_cnn_ntg/trained_weight/coco2017_multi_gpu/checkpoint_coco2017_multi_gpu_paper30_NTG_resnet101.pth.tar"
    #args.training_image_path = '/home/zale/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages'
    # args.training_image_path = '/media/disk2/zale/datasets/coco2017/train2017'

    checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011_multi_gpu/checkpoint_voc2011_multi_gpu_three_channel_paper_origin_NTG_resnet101.pth.tar"
    args.training_image_path = '/home/zlk/datasets/vocdata/VOC_train_2011/VOCdevkit/VOC2011/JPEGImages'

    random_seed = 10021
    init_seeds(random_seed + random.randint(0, 10000))
    mixed_precision = True

    utils.init_distributed_mode(args)
    print(args)

    #device,local_rank = torch_util.select_device(multi_process =True,apex=mixed_precision)

    device = torch.device(args.device)
    use_cuda = True
    # Data loading code
    print("Loading data")
    RandomTnsDataset = RandomTnsData(args.training_image_path,
                                     cache_images=False,
                                     paper_affine_generator=True,
                                     transform=NormalizeImageDict(["image"]))
    # train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4,
    #                               pin_memory=True)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            RandomTnsDataset)
        # test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(RandomTnsDataset)
        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    # train_batch_sampler = torch.utils.data.BatchSampler(
    #     train_sampler, args.batch_size, drop_last=True)

    data_loader = DataLoader(RandomTnsDataset,
                             sampler=train_sampler,
                             num_workers=4,
                             shuffle=(train_sampler is None),
                             pin_memory=False,
                             batch_size=args.batch_size)

    # data_loader_test = torch.utils.data.DataLoader(
    #     dataset_test, batch_size=1,
    #     sampler=test_sampler, num_workers=args.workers,
    #     collate_fn=utils.collate_fn)

    print("Creating model")
    model = CNNRegistration(use_cuda=use_cuda)

    model.to(device)

    # 优化器 和scheduler
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=args.lr)

    # 学习率小于1e-6 ntg损失下降很慢,所以最小设置为1e-6
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.lr_max_iter, eta_min=1e-6)

    # if mixed_precision:
    #     model,optimizer = amp.initialize(model,optimizer,opt_level='O1',verbosity=0)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    minium_loss, saved_epoch = load_checkpoint(model_without_ddp, optimizer,
                                               lr_scheduler, checkpoint_path,
                                               args.rank)

    vis_env = "multi_gpu_rgb_train_paper_30"
    loss = NTGLoss()
    pair_generator = RandomTnsPair(use_cuda=use_cuda)
    gridGen = AffineGridGen()
    vis = VisdomHelper(env_name=vis_env)

    print('Starting training...')
    start_time = time.time()
    draw_test_loss = False
    log_interval = 20
    for epoch in range(saved_epoch, args.num_epochs):
        start_time = time.time()

        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_loss = train(epoch,
                           model,
                           loss,
                           optimizer,
                           data_loader,
                           pair_generator,
                           gridGen,
                           vis,
                           use_cuda=use_cuda,
                           log_interval=log_interval,
                           lr_scheduler=lr_scheduler,
                           rank=args.rank)

        if draw_test_loss:
            #test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda)
            #vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table')
            pass
        else:
            vis.drawLoss(epoch, train_loss)

        end_time = time.time()
        print("epoch:", str(end_time - start_time), '秒')

        is_best = train_loss < minium_loss
        minium_loss = min(train_loss, minium_loss)

        state_dict = model_without_ddp.state_dict()
        if is_main_process():
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'args': args,
                    # 'state_dict': model.state_dict(),
                    'state_dict': state_dict,
                    'minium_loss': minium_loss,
                    'model_loss': train_loss,
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                },
                is_best,
                checkpoint_path)
Esempio n. 5
0
def start_train(training_path,test_image_path,load_from,out_path,vis_env,paper_affine_generator = False,
                random_seed=666,log_interval=100,multi_gpu=True,use_cuda=True):

    init_seeds(random_seed+random.randint(0,10000))

    device,local_rank = torch_util.select_device(multi_process =multi_gpu,apex=mixed_precision)

    # args.batch_size = args.batch_size * torch.cuda.device_count()
    args.batch_size = 16
    args.lr_scheduler = True
    draw_test_loss = False
    print(args.batch_size)


    print("创建模型中")
    model = CNNRegistration(use_cuda=use_cuda)

    model = model.to(device)

    # 优化器 和scheduler
    optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr)

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=args.lr_max_iter,
                                                               eta_min=1e-7)
    else:
        scheduler = False

    print("加载权重")
    minium_loss,saved_epoch= load_checkpoint(model,optimizer,load_from,0)

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        model,optimizer = amp.initialize(model,optimizer,opt_level='01',verbosity=0)

    if multi_gpu:
        model = nn.DataParallel(model)

    loss = NTGLoss()
    pair_generator = RandomTnsPair(use_cuda=use_cuda)
    gridGen = AffineGridGen()
    vis = VisdomHelper(env_name=vis_env)

    print("创建dataloader")
    RandomTnsDataset = RandomTnsData(training_path, cache_images=False,paper_affine_generator = paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
    train_dataloader = DataLoader(RandomTnsDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    if draw_test_loss:
        testDataset = RandomTnsData(test_image_path, cache_images=False, paper_affine_generator=paper_affine_generator,
                                     transform=NormalizeImageDict(["image"]))
        test_dataloader = DataLoader(testDataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)

    print('Starting training...')

    for epoch in range(saved_epoch, args.num_epochs):
        start_time = time.time()

        train_loss = train(epoch, model, loss, optimizer, train_dataloader, pair_generator, gridGen, vis,
                           use_cuda=use_cuda, log_interval=log_interval,scheduler = scheduler)

        if draw_test_loss:
            test_loss = test(model,loss,test_dataloader,pair_generator,gridGen,use_cuda=use_cuda)
            vis.drawBothLoss(epoch,train_loss,test_loss,'loss_table')
        else:
            vis.drawLoss(epoch,train_loss)

        end_time = time.time()
        print("epoch:", str(end_time - start_time),'秒')

        is_best = train_loss < minium_loss
        minium_loss = min(train_loss, minium_loss)

        state_dict = model.module.state_dict() if multi_gpu else model.state_dict()
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            #'state_dict': model.state_dict(),
            'state_dict': state_dict,
            'minium_loss': minium_loss,
            'model_loss':train_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best, out_path)
Esempio n. 6
0
def main():

    args, arg_groups = ArgumentParser(mode='train').parse()
    print(args)

    use_cuda = torch.cuda.is_available()
    use_me = args.use_me
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    # Seed
    # torch.manual_seed(args.seed)
    # if use_cuda:
    # torch.cuda.manual_seed(args.seed)

    # CNN model and loss
    print('Creating CNN model...')
    if args.geometric_model == 'affine_simple':
        cnn_output_dim = 3
    elif args.geometric_model == 'affine_simple_4':
        cnn_output_dim = 4
    else:
        raise NotImplementedError('Specified geometric model is unsupported')

    model = CNNGeometric(use_cuda=use_cuda,
                         output_dim=cnn_output_dim,
                         **arg_groups['model'])

    if args.geometric_model == 'affine_simple':
        init_theta = torch.tensor([0.0, 1.0, 0.0], device=device)
    elif args.geometric_model == 'affine_simple_4':
        init_theta = torch.tensor([0.0, 1.0, 0.0, 0.0], device=device)

    try:
        model.FeatureRegression.linear.bias.data += init_theta
    except:
        model.FeatureRegression.resnet.fc.bias.data += init_theta

    args.load_images = False
    if args.loss == 'mse':
        print('Using MSE loss...')
        loss = nn.MSELoss()
    elif args.loss == 'weighted_mse':
        print('Using weighted MSE loss...')
        loss = WeightedMSELoss(use_cuda=use_cuda)
    elif args.loss == 'reconstruction':
        print('Using reconstruction loss...')
        loss = ReconstructionLoss(
            int(np.rint(args.input_width * (1 - args.crop_factor) / 16) * 16),
            int(np.rint(args.input_height * (1 - args.crop_factor) / 16) * 16),
            args.input_height,
            use_cuda=use_cuda)
        args.load_images = True
    elif args.loss == 'combined':
        print('Using combined loss...')
        loss = CombinedLoss(args, use_cuda=use_cuda)
        if args.use_reconstruction_loss:
            args.load_images = True
    elif args.loss == 'grid':
        print('Using grid loss...')
        loss = SequentialGridLoss(use_cuda=use_cuda)
    else:
        raise NotImplementedError('Specifyed loss %s is not supported' %
                                  args.loss)

    # Initialize Dataset objects
    if use_me:
        dataset = MEDataset(geometric_model=args.geometric_model,
                            dataset_csv_path=args.dataset_csv_path,
                            dataset_csv_file='train.csv',
                            dataset_image_path=args.dataset_image_path,
                            input_height=args.input_height,
                            input_width=args.input_width,
                            crop=args.crop_factor,
                            use_conf=args.use_conf,
                            use_random_patch=args.use_random_patch,
                            normalize_inputs=args.normalize_inputs,
                            random_sample=args.random_sample,
                            load_images=args.load_images)

        dataset_val = MEDataset(geometric_model=args.geometric_model,
                                dataset_csv_path=args.dataset_csv_path,
                                dataset_csv_file='val.csv',
                                dataset_image_path=args.dataset_image_path,
                                input_height=args.input_height,
                                input_width=args.input_width,
                                crop=args.crop_factor,
                                use_conf=args.use_conf,
                                use_random_patch=args.use_random_patch,
                                normalize_inputs=args.normalize_inputs,
                                random_sample=args.random_sample,
                                load_images=args.load_images)

    else:

        dataset = SynthDataset(geometric_model=args.geometric_model,
                               dataset_csv_path=args.dataset_csv_path,
                               dataset_csv_file='train.csv',
                               dataset_image_path=args.dataset_image_path,
                               transform=NormalizeImageDict(['image']),
                               random_sample=args.random_sample)

        dataset_val = SynthDataset(geometric_model=args.geometric_model,
                                   dataset_csv_path=args.dataset_csv_path,
                                   dataset_csv_file='val.csv',
                                   dataset_image_path=args.dataset_image_path,
                                   transform=NormalizeImageDict(['image']),
                                   random_sample=args.random_sample)

    # Set Tnf pair generation func
    if use_me:
        pair_generation_tnf = BatchTensorToVars(use_cuda=use_cuda)
    elif args.geometric_model == 'affine_simple' or args.geometric_model == 'affine_simple_4':
        pair_generation_tnf = SynthPairTnf(geometric_model='affine',
                                           use_cuda=use_cuda)
    else:
        raise NotImplementedError('Specified geometric model is unsupported')

    # Initialize DataLoaders
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)

    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)

    # Optimizer
    optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr)

    # Train

    # Set up names for checkpoints
    ckpt = args.trained_model_fn + '_' + args.geometric_model + '_' + args.loss + '_loss_'
    checkpoint_path = os.path.join(args.trained_model_dir,
                                   args.trained_model_fn, ckpt + '.pth.tar')
    if not os.path.exists(args.trained_model_dir):
        os.mkdir(args.trained_model_dir)

    # Set up TensorBoard writer
    if not args.log_dir:
        tb_dir = os.path.join(args.trained_model_dir,
                              args.trained_model_fn + '_tb_logs')
    else:
        tb_dir = os.path.join(args.log_dir, args.trained_model_fn + '_tb_logs')

    logs_writer = SummaryWriter(tb_dir)
    # add graph, to do so we have to generate a dummy input to pass along with the graph
    if use_me:
        dummy_input = {
            'mv_L2R': torch.rand([args.batch_size, 2, 216, 384],
                                 device=device),
            'mv_R2L': torch.rand([args.batch_size, 2, 216, 384],
                                 device=device),
            'grid_L2R': torch.rand([args.batch_size, 2, 216, 384],
                                   device=device),
            'grid_R2L': torch.rand([args.batch_size, 2, 216, 384],
                                   device=device),
            'grid': torch.rand([args.batch_size, 2, 216, 384], device=device),
            'conf_L': torch.rand([args.batch_size, 1, 216, 384],
                                 device=device),
            'conf_R': torch.rand([args.batch_size, 1, 216, 384],
                                 device=device),
            'theta_GT': torch.rand([args.batch_size, 4], device=device),
        }
        if args.load_images:
            dummy_input['img_R_orig'] = torch.rand(
                [args.batch_size, 1, 216, 384], device=device)
            dummy_input['img_R'] = torch.rand([args.batch_size, 1, 216, 384],
                                              device=device)
    else:
        dummy_input = {
            'source_image':
            torch.rand([args.batch_size, 3, 240, 240], device=device),
            'target_image':
            torch.rand([args.batch_size, 3, 240, 240], device=device),
            'theta_GT':
            torch.rand([args.batch_size, 2, 3], device=device)
        }

    logs_writer.add_graph(model, dummy_input)

    # Start of training
    print('Starting training...')

    best_val_loss = float("inf")

    max_batch_iters = len(dataloader)
    print('Iterations for one epoch:', max_batch_iters)
    epoch_to_change_lr = int(args.lr_max_iter / max_batch_iters * 2 + 0.5)

    # Loading checkpoint
    model, optimizer, start_epoch, best_val_loss, last_epoch = load_checkpoint(
        checkpoint_path, model, optimizer, device)

    # Scheduler
    if args.lr_scheduler == 'cosine':
        is_cosine_scheduler = True
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.lr_max_iter,
            eta_min=1e-7,
            last_epoch=last_epoch)
    elif args.lr_scheduler == 'cosine_restarts':
        is_cosine_scheduler = True
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=args.lr_max_iter, T_mult=2, last_epoch=last_epoch)

    elif args.lr_scheduler == 'exp':
        is_cosine_scheduler = False
        if last_epoch > 0:
            last_epoch /= max_batch_iters
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=args.lr_decay, last_epoch=last_epoch)
    # elif args.lr_scheduler == 'step':
    # step_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1)
    # scheduler = False
    else:
        is_cosine_scheduler = False
        scheduler = False

    for epoch in range(1, start_epoch):
        if args.lr_scheduler == 'cosine' and (epoch % epoch_to_change_lr == 0):
            scheduler.state_dict()['base_lrs'][0] *= args.lr_decay

    torch.autograd.set_detect_anomaly(True)
    for epoch in range(start_epoch, args.num_epochs + 1):
        print('Current epoch: ', epoch)

        # we don't need the average epoch loss so we assign it to _
        _ = train(epoch,
                  model,
                  loss,
                  optimizer,
                  dataloader,
                  pair_generation_tnf,
                  log_interval=args.log_interval,
                  scheduler=scheduler,
                  is_cosine_scheduler=is_cosine_scheduler,
                  tb_writer=logs_writer)

        # Step non-cosine scheduler
        if scheduler and not is_cosine_scheduler:
            scheduler.step()

        val_loss = validate_model(model, loss, dataloader_val,
                                  pair_generation_tnf, epoch, logs_writer)

        # Change lr_max in cosine annealing
        if args.lr_scheduler == 'cosine' and (epoch % epoch_to_change_lr == 0):
            scheduler.state_dict()['base_lrs'][0] *= args.lr_decay

        if (epoch % epoch_to_change_lr
                == epoch_to_change_lr // 2) or epoch == 1:
            compute_metric('absdiff', model, args.geometric_model, None, None,
                           dataset_val, dataloader_val, pair_generation_tnf,
                           args.batch_size, args)

        # remember best loss
        is_best = val_loss < best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'best_val_loss': best_val_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best, checkpoint_path)

    logs_writer.close()
    print('Done!')
Esempio n. 7
0
def main():

    args = parse_flags()

    use_cuda = torch.cuda.is_available()

    # Seed
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    # Download dataset if needed and set paths
    if args.training_dataset == 'pascal':

        if args.training_image_path == '':

            download_pascal('datasets/pascal-voc11/')
            args.training_image_path = 'datasets/pascal-voc11/'

        if args.training_tnf_csv == '' and args.geometric_model == 'affine':

            args.training_tnf_csv = 'training_data/pascal-synth-aff'

        elif args.training_tnf_csv == '' and args.geometric_model == 'tps':

            args.training_tnf_csv = 'training_data/pascal-synth-tps'

    # CNN model and loss
    if not args.pretrained:
        if args.light_model:
            print('Creating light CNN model...')
            model = LightCNN(use_cuda=use_cuda,
                             geometric_model=args.geometric_model)
        else:
            print('Creating CNN model...')
            model = CNNGeometric(
                use_cuda=use_cuda,
                geometric_model=args.geometric_model,
                feature_extraction_cnn=args.feature_extraction_cnn)
    else:
        model = load_torch_model(args, use_cuda)

    if args.loss == 'mse':
        print('Using MSE loss...')
        loss = MSELoss()

    elif args.loss == 'sum':
        print('Using the sum of MSE and grid loss...')
        loss = GridLossWithMSE(use_cuda=use_cuda,
                               geometric_model=args.geometric_model)

    else:
        print('Using grid loss...')
        loss = TransformedGridLoss(use_cuda=use_cuda,
                                   geometric_model=args.geometric_model)

    # Initialize csv paths
    train_csv_path_list = glob(
        os.path.join(args.training_tnf_csv, '*train.csv'))
    if len(train_csv_path_list) > 1:
        print(
            "!!!!WARNING!!!! multiple train csv files found, using first in glob order"
        )
    elif not len(train_csv_path_list):
        raise FileNotFoundError(
            "No training csv where found in the specified path!!!")

    train_csv_path = train_csv_path_list[0]

    val_csv_path_list = glob(os.path.join(args.training_tnf_csv, '*val.csv'))
    if len(val_csv_path_list) > 1:
        print(
            "!!!!WARNING!!!! multiple train csv files found, using first in glob order"
        )
    elif not len(val_csv_path_list):
        raise FileNotFoundError(
            "No training csv where found in the specified path!!!")

    val_csv_path = val_csv_path_list[0]

    # Initialize Dataset objects
    if args.coupled_dataset:
        # Dataset  for train and val if dataset is already coupled
        dataset = CoupledDataset(geometric_model=args.geometric_model,
                                 csv_file=train_csv_path,
                                 training_image_path=args.training_image_path,
                                 transform=NormalizeImageDict(
                                     ['image_a', 'image_b']))

        dataset_val = CoupledDataset(
            geometric_model=args.geometric_model,
            csv_file=val_csv_path,
            training_image_path=args.training_image_path,
            transform=NormalizeImageDict(['image_a', 'image_b']))

        # Set Tnf pair generation func
        pair_generation_tnf = CoupledPairTnf(use_cuda=use_cuda)

    else:
        # Standard Dataset for train and val
        dataset = SynthDataset(geometric_model=args.geometric_model,
                               csv_file=train_csv_path,
                               training_image_path=args.training_image_path,
                               transform=NormalizeImageDict(['image']),
                               random_sample=args.random_sample)

        dataset_val = SynthDataset(
            geometric_model=args.geometric_model,
            csv_file=val_csv_path,
            training_image_path=args.training_image_path,
            transform=NormalizeImageDict(['image']),
            random_sample=args.random_sample)

        # Set Tnf pair generation func
        pair_generation_tnf = SynthPairTnf(
            geometric_model=args.geometric_model, use_cuda=use_cuda)

    # Initialize DataLoaders
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)

    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)

    # Optimizer and eventual scheduler
    optimizer = Adam(model.FeatureRegression.parameters(), lr=args.lr)

    if args.lr_scheduler:

        if args.scheduler_type == 'cosine':
            print('Using cosine learning rate scheduler')
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=args.lr_max_iter,
                                          eta_min=args.lr_min)

        elif args.scheduler_type == 'decay':
            print('Using decay learning rate scheduler')
            scheduler = ReduceLROnPlateau(optimizer, 'min')

        else:
            print(
                'Using truncated cosine with decay learning rate scheduler...')
            scheduler = TruncateCosineScheduler(optimizer, len(dataloader),
                                                args.num_epochs - 1)
    else:
        scheduler = False

    # Train

    # Set up names for checkpoints
    if args.loss == 'mse':
        ckpt = args.trained_models_fn + '_' + args.geometric_model + '_mse_loss' + args.feature_extraction_cnn
        checkpoint_path = os.path.join(args.trained_models_dir,
                                       args.trained_models_fn,
                                       ckpt + '.pth.tar')
    elif args.loss == 'sum':
        ckpt = args.trained_models_fn + '_' + args.geometric_model + '_sum_loss' + args.feature_extraction_cnn
        checkpoint_path = os.path.join(args.trained_models_dir,
                                       args.trained_models_fn,
                                       ckpt + '.pth.tar')
    else:
        ckpt = args.trained_models_fn + '_' + args.geometric_model + '_grid_loss' + args.feature_extraction_cnn
        checkpoint_path = os.path.join(args.trained_models_dir,
                                       args.trained_models_fn,
                                       ckpt + '.pth.tar')
    if not os.path.exists(args.trained_models_dir):
        os.mkdir(args.trained_models_dir)

    # Set up TensorBoard writer
    if not args.log_dir:
        tb_dir = os.path.join(args.trained_models_dir,
                              args.trained_models_fn + '_tb_logs')
    else:
        tb_dir = os.path.join(args.log_dir,
                              args.trained_models_fn + '_tb_logs')

    logs_writer = SummaryWriter(tb_dir)
    # add graph, to do so we have to generate a dummy input to pass along with the graph
    dummy_input = {
        'source_image': torch.rand([args.batch_size, 3, 240, 240]),
        'target_image': torch.rand([args.batch_size, 3, 240, 240]),
        'theta_GT': torch.rand([16, 2, 3])
    }

    logs_writer.add_graph(model, dummy_input)

    #                START OF TRAINING                 #
    print('Starting training...')

    best_val_loss = float("inf")

    for epoch in range(1, args.num_epochs + 1):

        # we don't need the average epoch loss so we assign it to _
        _ = train(epoch,
                  model,
                  loss,
                  optimizer,
                  dataloader,
                  pair_generation_tnf,
                  log_interval=args.log_interval,
                  scheduler=scheduler,
                  tb_writer=logs_writer)

        val_loss = validate_model(model,
                                  loss,
                                  dataloader_val,
                                  pair_generation_tnf,
                                  epoch,
                                  logs_writer,
                                  coupled=args.coupled_dataset)

        # remember best loss
        is_best = val_loss < best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'best_val_loss': best_val_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best, checkpoint_path)

    logs_writer.close()
    print('Done!')
Esempio n. 8
0
for epoch in range(1, args.num_epochs+1):
    if args.update_bn_buffers==False:
        model.eval()
    else:
        model.train()
    train_loss[epoch-1] = process_epoch('train',epoch,model,loss_fun,optimizer,dataloader,batch_tnf,log_interval=1)
    model.eval()
    stats=compute_metric(metric,model,dataset_eval,dataloader_eval,batch_tnf,8,two_stage,do_aff,do_tps,args)
    eval_value=np.mean(stats['aff_tps'][metric][eval_idx])
    print(eval_value)
    
    if args.eval_metric=='pck':
        test_loss[epoch-1] = -eval_value
    else:
        test_loss[epoch-1] = eval_value
        
    # remember best loss
    is_best = test_loss[epoch-1] < best_test_loss
    best_test_loss = min(test_loss[epoch-1], best_test_loss)
    save_checkpoint({
        'epoch': epoch + 1,
        'args': args,
        'state_dict': model.state_dict(),
        'best_test_loss': best_test_loss,
        'optimizer' : optimizer.state_dict(),
        'train_loss': train_loss,
        'test_loss': test_loss,
    }, is_best,checkpoint_name)

print('Done!')