コード例 #1
0
ファイル: train_strong.py プロジェクト: nyummvc/Arbicon-Net
                            dataset_csv_file='test.csv',
                            **arg_groups['dataset'])

dataloader_test = DataLoader(dataset_test,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=4)

cnn_image_size = (args.image_size, args.image_size)

pair_generation_tnf = SynthPairTnf(geometric_model=args.geometric_model,
                                   output_size=cnn_image_size,
                                   use_cuda=use_cuda)

# Optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr)

# Define checkpoint name
checkpoint_suffix = '_strong'
checkpoint_suffix += '_' + str(args.num_epochs)
checkpoint_suffix += '_' + args.training_dataset
checkpoint_suffix += '_' + args.geometric_model
checkpoint_suffix += '_' + args.feature_extraction_cnn
if args.use_mse_loss:
    checkpoint_suffix += '_mse_loss'
else:
    checkpoint_suffix += '_grid_loss'

checkpoint_name = os.path.join(
    args.result_model_dir,
コード例 #2
0
def main():
    warnings.simplefilter(action='ignore', category=FutureWarning)

    args,arg_groups = ArgumentParser(mode='train').parse()
    print(args)

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    # Seed
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)

    # Download dataset if needed and set paths
    if args.training_dataset == 'pascal':

        if args.dataset_image_path == '' and not os.path.exists('datasets/pascal-voc11/TrainVal'):
            download_pascal('datasets/pascal-voc11/')

        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/pascal-voc11/'

        args.dataset_csv_path = 'training_data/pascal-random'

#------------- RGB 512
    if args.training_dataset == 'rgb512_aug':

        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/rgb512_augmented/'

        args.dataset_csv_path = 'training_data/rgb512_augmented-random' 
        
#######################
    if args.training_dataset == 'rgb240_aug':
        
        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/rgb240_augmented/'

        args.dataset_csv_path = 'training_data/rgb240_augmented-random' 
#######################
    if args.training_dataset == 'red240_aug':
        
        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/red240_augmented/'

        args.dataset_csv_path = 'training_data/red240_augmented-random' 

    if args.training_dataset == 'smallset':
        if args.dataset_image_path == '':
            args.dataset_image_path = 'datasets/smallset/'

        args.dataset_csv_path = 'training_data/smallset-random' 

    # CNN model and loss
    print('Creating CNN model...')
    if args.geometric_model=='affine':
        cnn_output_dim = 6
    elif args.geometric_model=='hom' and args.four_point_hom:
        cnn_output_dim = 8
    elif args.geometric_model=='hom' and not args.four_point_hom:
        cnn_output_dim = 9
    elif args.geometric_model=='tps':
        cnn_output_dim = 18
    
##############################
    model = CNNGeometric(use_cuda=use_cuda,
                         output_dim=cnn_output_dim,
                         **arg_groups['model'])

#######################
    if args.geometric_model=='hom' and not args.four_point_hom:
        init_theta = torch.tensor([1,0,0,0,1,0,0,0,1], device = device)
        model.FeatureRegression.linear.bias.data+=init_theta
        
    if args.geometric_model=='hom' and args.four_point_hom:
        init_theta = torch.tensor([-1, -1, 1, 1, -1, 1, -1, 1], device = device)
        model.FeatureRegression.linear.bias.data+=init_theta

    if args.use_mse_loss:
        print('Using MSE loss...')
        loss = nn.MSELoss()
    else:
        print('Using grid loss...')
        loss = TransformedGridLoss(use_cuda=use_cuda,
                                   geometric_model=args.geometric_model)

    # Initialize Dataset objects
    dataset = SynthDataset(geometric_model=args.geometric_model,
               dataset_csv_path=args.dataset_csv_path,
               dataset_csv_file='train.csv',
			   dataset_image_path=args.dataset_image_path,
			   transform=NormalizeImageDict(['image']),
			   random_sample=args.random_sample)

    dataset_val = SynthDataset(geometric_model=args.geometric_model,
                   dataset_csv_path=args.dataset_csv_path,
                   dataset_csv_file='val.csv',
			       dataset_image_path=args.dataset_image_path,
			       transform=NormalizeImageDict(['image']),
			       random_sample=args.random_sample)

    # Set Tnf pair generation func
    pair_generation_tnf = SynthPairTnf(geometric_model=args.geometric_model,
				       use_cuda=use_cuda)

    # Initialize DataLoaders
    dataloader = DataLoader(dataset, batch_size=args.batch_size,
                            shuffle=True, num_workers=4)

    dataloader_val = DataLoader(dataset_val, batch_size=args.batch_size,
                                shuffle=True, num_workers=4)

    # Optimizer and eventual scheduler
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.lr_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=args.lr_max_iter,
                                                               eta_min=5e-7)
    else:
        scheduler = False

    # Train

    # Set up names for checkpoints
    if args.use_mse_loss:
        ckpt = args.training_dataset + '_' + args.trained_model_fn + '_' + args.geometric_model + '_mse_loss' + args.feature_extraction_cnn
        checkpoint_path = os.path.join(args.trained_model_dir,
                                       args.trained_model_fn,
                                       ckpt + '.pth.tar')
    else:
        ckpt = args.trained_model_fn + '_' + args.geometric_model+args.feature_regression+ args.feature_extraction_cnn+args.training_dataset + '_' + '_grid_loss'
        checkpoint_path = os.path.join(args.trained_model_dir,
                                       args.trained_model_fn,
                                       ckpt + '.pth.tar')
    if not os.path.exists(args.trained_model_dir):
        os.mkdir(args.trained_model_dir)

    # Set up TensorBoard writer
    if not args.log_dir:
        tb_dir = os.path.join(args.trained_model_dir, args.trained_model_fn + '_tb_logs')
    else:
        tb_dir = os.path.join(args.log_dir, args.trained_model_fn + '_tb_logs')

    logs_writer = SummaryWriter(tb_dir)
    # add graph, to do so we have to generate a dummy input to pass along with the graph
    dummy_input = {'source_image': torch.rand([args.batch_size, 3, 240, 240], device = device),
                   'target_image': torch.rand([args.batch_size, 3, 240, 240], device = device),
                   'theta_GT': torch.rand([16, 2, 3], device = device)}

    logs_writer.add_graph(model, dummy_input)

    # Start of training
    print('Starting training...')

    best_val_loss = float("inf")
    df = pd.DataFrame()
    for epoch in range(1, args.num_epochs+1):

        train_loss = train(epoch, model, loss, optimizer,
                  dataloader, pair_generation_tnf,
                  log_interval=args.log_interval,
                  scheduler=scheduler,
                  tb_writer=logs_writer)

        val_loss = validate_model(model, loss,
                                  dataloader_val, pair_generation_tnf,
                                  epoch, logs_writer)
        
        #Logging losses to .csv so we can re-use them later for graphs
        df = df.append({'epoch': epoch, 'train_loss': train_loss,'val_loss' : val_loss}, ignore_index=True)
        # remember best loss
        is_best = val_loss < best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        save_checkpoint({
                         'epoch': epoch + 1,
                         'args': args,
                         'state_dict': model.state_dict(),
                         'best_val_loss': best_val_loss,
                         'optimizer': optimizer.state_dict(),
                         },
                        is_best, checkpoint_path)
        
    name = args.geometric_model+'_'+args.feature_extraction_cnn+'_'+args.feature_regression+'dropout_'+str(args.fr_dropout)+'.csv'
    csv_path = os.path.join(args.trained_model_dir,
                            args.trained_model_fn,
                            name)
    df.to_csv(csv_path)
    logs_writer.close()
    print('Done!')