def main():
    """ Initialize model """
    model = init_model(args, arg_groups, use_cuda)
    """ Initialize dataloader """
    train_data, train_loader = init_train_data(args)

    eval_data, eval_loader = init_eval_data(args)
    """ Initialize optimizer """
    model_opt = init_model_optim(args, model)

    batch_tnf = BatchTensorToVars(use_cuda=use_cuda)
    """ Evaluate initial condition """
    eval_categories = np.array(range(20)) + 1
    eval_flag = np.zeros(len(eval_data))
    for i in range(len(eval_data)):
        eval_flag[i] = sum(eval_categories == eval_data.category[i])
    eval_idx = np.flatnonzero(eval_flag)

    model.eval()

    eval_stats = compute_metric(args.eval_metric, model, eval_data,
                                eval_loader, batch_tnf, args)
    best_eval_pck = np.mean(eval_stats['aff_tps'][args.eval_metric][eval_idx])

    best_epoch = 1
    """ Start training """
    for epoch in range(1, args.num_epochs + 1):

        model.eval()

        process_epoch(epoch, model, model_opt, train_loader, batch_tnf)

        model.eval()

        eval_stats = compute_metric(args.eval_metric, model, eval_data,
                                    eval_loader, batch_tnf, args)
        eval_pck = np.mean(eval_stats['aff_tps'][args.eval_metric][eval_idx])

        is_best = eval_pck > best_eval_pck

        if eval_pck > best_eval_pck:
            best_eval_pck = eval_pck
            best_epoch = epoch

        print('eval: {:.3f}'.format(eval_pck),
              'best eval: {:.3f}'.format(best_eval_pck),
              'best epoch: {}'.format(best_epoch))
        """ Early stopping """
        if eval_pck < (best_eval_pck - 0.05):
            break

        save_model(args, model, is_best)
Ejemplo n.º 2
0
dataset = Dataset(csv_file=os.path.join(args.csv_path, csv_file),
                  dataset_path=args.eval_dataset_path,
                  transform=NormalizeImageDict(['source_image','target_image']),
                  output_size=cnn_image_size)

if use_cuda:
    batch_size=8
else:
    batch_size=1

dataloader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=False, num_workers=4,
                        collate_fn=collate_fn)

batch_tnf = BatchTensorToVars(use_cuda=use_cuda)


if args.eval_dataset=='pf' or args.eval_dataset=='pf_pascal' or args.eval_dataset == 'pf_willow' or args.eval_dataset == 'tss-pck': 
    metric = 'pck'
elif args.eval_dataset=='caltech':
    metric = 'area'
elif args.eval_dataset=='pascal-parts':
    metric = 'pascal_parts'
elif args.eval_dataset=='tss':
    metric = 'flow'
    
model.eval()
    
stats=compute_metric(metric,model,dataset,dataloader,batch_tnf,batch_size,two_stage,do_aff,do_tps,args)
Ejemplo n.º 3
0
def main():

    args, arg_groups = ArgumentParser(mode='train').parse()
    print(args)

    use_cuda = torch.cuda.is_available()
    use_me = args.use_me
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    # Seed
    # torch.manual_seed(args.seed)
    # if use_cuda:
    # torch.cuda.manual_seed(args.seed)

    # CNN model and loss
    print('Creating CNN model...')
    if args.geometric_model == 'affine_simple':
        cnn_output_dim = 3
    elif args.geometric_model == 'affine_simple_4':
        cnn_output_dim = 4
    else:
        raise NotImplementedError('Specified geometric model is unsupported')

    model = CNNGeometric(use_cuda=use_cuda,
                         output_dim=cnn_output_dim,
                         **arg_groups['model'])

    if args.geometric_model == 'affine_simple':
        init_theta = torch.tensor([0.0, 1.0, 0.0], device=device)
    elif args.geometric_model == 'affine_simple_4':
        init_theta = torch.tensor([0.0, 1.0, 0.0, 0.0], device=device)

    try:
        model.FeatureRegression.linear.bias.data += init_theta
    except:
        model.FeatureRegression.resnet.fc.bias.data += init_theta

    args.load_images = False
    if args.loss == 'mse':
        print('Using MSE loss...')
        loss = nn.MSELoss()
    elif args.loss == 'weighted_mse':
        print('Using weighted MSE loss...')
        loss = WeightedMSELoss(use_cuda=use_cuda)
    elif args.loss == 'reconstruction':
        print('Using reconstruction loss...')
        loss = ReconstructionLoss(
            int(np.rint(args.input_width * (1 - args.crop_factor) / 16) * 16),
            int(np.rint(args.input_height * (1 - args.crop_factor) / 16) * 16),
            args.input_height,
            use_cuda=use_cuda)
        args.load_images = True
    elif args.loss == 'combined':
        print('Using combined loss...')
        loss = CombinedLoss(args, use_cuda=use_cuda)
        if args.use_reconstruction_loss:
            args.load_images = True
    elif args.loss == 'grid':
        print('Using grid loss...')
        loss = SequentialGridLoss(use_cuda=use_cuda)
    else:
        raise NotImplementedError('Specifyed loss %s is not supported' %
                                  args.loss)

    # Initialize Dataset objects
    if use_me:
        dataset = MEDataset(geometric_model=args.geometric_model,
                            dataset_csv_path=args.dataset_csv_path,
                            dataset_csv_file='train.csv',
                            dataset_image_path=args.dataset_image_path,
                            input_height=args.input_height,
                            input_width=args.input_width,
                            crop=args.crop_factor,
                            use_conf=args.use_conf,
                            use_random_patch=args.use_random_patch,
                            normalize_inputs=args.normalize_inputs,
                            random_sample=args.random_sample,
                            load_images=args.load_images)

        dataset_val = MEDataset(geometric_model=args.geometric_model,
                                dataset_csv_path=args.dataset_csv_path,
                                dataset_csv_file='val.csv',
                                dataset_image_path=args.dataset_image_path,
                                input_height=args.input_height,
                                input_width=args.input_width,
                                crop=args.crop_factor,
                                use_conf=args.use_conf,
                                use_random_patch=args.use_random_patch,
                                normalize_inputs=args.normalize_inputs,
                                random_sample=args.random_sample,
                                load_images=args.load_images)

    else:

        dataset = SynthDataset(geometric_model=args.geometric_model,
                               dataset_csv_path=args.dataset_csv_path,
                               dataset_csv_file='train.csv',
                               dataset_image_path=args.dataset_image_path,
                               transform=NormalizeImageDict(['image']),
                               random_sample=args.random_sample)

        dataset_val = SynthDataset(geometric_model=args.geometric_model,
                                   dataset_csv_path=args.dataset_csv_path,
                                   dataset_csv_file='val.csv',
                                   dataset_image_path=args.dataset_image_path,
                                   transform=NormalizeImageDict(['image']),
                                   random_sample=args.random_sample)

    # Set Tnf pair generation func
    if use_me:
        pair_generation_tnf = BatchTensorToVars(use_cuda=use_cuda)
    elif args.geometric_model == 'affine_simple' or args.geometric_model == 'affine_simple_4':
        pair_generation_tnf = SynthPairTnf(geometric_model='affine',
                                           use_cuda=use_cuda)
    else:
        raise NotImplementedError('Specified geometric model is unsupported')

    # Initialize DataLoaders
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=4)

    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)

    # Optimizer
    optimizer = optim.Adam(model.FeatureRegression.parameters(), lr=args.lr)

    # Train

    # Set up names for checkpoints
    ckpt = args.trained_model_fn + '_' + args.geometric_model + '_' + args.loss + '_loss_'
    checkpoint_path = os.path.join(args.trained_model_dir,
                                   args.trained_model_fn, ckpt + '.pth.tar')
    if not os.path.exists(args.trained_model_dir):
        os.mkdir(args.trained_model_dir)

    # Set up TensorBoard writer
    if not args.log_dir:
        tb_dir = os.path.join(args.trained_model_dir,
                              args.trained_model_fn + '_tb_logs')
    else:
        tb_dir = os.path.join(args.log_dir, args.trained_model_fn + '_tb_logs')

    logs_writer = SummaryWriter(tb_dir)
    # add graph, to do so we have to generate a dummy input to pass along with the graph
    if use_me:
        dummy_input = {
            'mv_L2R': torch.rand([args.batch_size, 2, 216, 384],
                                 device=device),
            'mv_R2L': torch.rand([args.batch_size, 2, 216, 384],
                                 device=device),
            'grid_L2R': torch.rand([args.batch_size, 2, 216, 384],
                                   device=device),
            'grid_R2L': torch.rand([args.batch_size, 2, 216, 384],
                                   device=device),
            'grid': torch.rand([args.batch_size, 2, 216, 384], device=device),
            'conf_L': torch.rand([args.batch_size, 1, 216, 384],
                                 device=device),
            'conf_R': torch.rand([args.batch_size, 1, 216, 384],
                                 device=device),
            'theta_GT': torch.rand([args.batch_size, 4], device=device),
        }
        if args.load_images:
            dummy_input['img_R_orig'] = torch.rand(
                [args.batch_size, 1, 216, 384], device=device)
            dummy_input['img_R'] = torch.rand([args.batch_size, 1, 216, 384],
                                              device=device)
    else:
        dummy_input = {
            'source_image':
            torch.rand([args.batch_size, 3, 240, 240], device=device),
            'target_image':
            torch.rand([args.batch_size, 3, 240, 240], device=device),
            'theta_GT':
            torch.rand([args.batch_size, 2, 3], device=device)
        }

    logs_writer.add_graph(model, dummy_input)

    # Start of training
    print('Starting training...')

    best_val_loss = float("inf")

    max_batch_iters = len(dataloader)
    print('Iterations for one epoch:', max_batch_iters)
    epoch_to_change_lr = int(args.lr_max_iter / max_batch_iters * 2 + 0.5)

    # Loading checkpoint
    model, optimizer, start_epoch, best_val_loss, last_epoch = load_checkpoint(
        checkpoint_path, model, optimizer, device)

    # Scheduler
    if args.lr_scheduler == 'cosine':
        is_cosine_scheduler = True
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.lr_max_iter,
            eta_min=1e-7,
            last_epoch=last_epoch)
    elif args.lr_scheduler == 'cosine_restarts':
        is_cosine_scheduler = True
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=args.lr_max_iter, T_mult=2, last_epoch=last_epoch)

    elif args.lr_scheduler == 'exp':
        is_cosine_scheduler = False
        if last_epoch > 0:
            last_epoch /= max_batch_iters
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=args.lr_decay, last_epoch=last_epoch)
    # elif args.lr_scheduler == 'step':
    # step_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1)
    # scheduler = False
    else:
        is_cosine_scheduler = False
        scheduler = False

    for epoch in range(1, start_epoch):
        if args.lr_scheduler == 'cosine' and (epoch % epoch_to_change_lr == 0):
            scheduler.state_dict()['base_lrs'][0] *= args.lr_decay

    torch.autograd.set_detect_anomaly(True)
    for epoch in range(start_epoch, args.num_epochs + 1):
        print('Current epoch: ', epoch)

        # we don't need the average epoch loss so we assign it to _
        _ = train(epoch,
                  model,
                  loss,
                  optimizer,
                  dataloader,
                  pair_generation_tnf,
                  log_interval=args.log_interval,
                  scheduler=scheduler,
                  is_cosine_scheduler=is_cosine_scheduler,
                  tb_writer=logs_writer)

        # Step non-cosine scheduler
        if scheduler and not is_cosine_scheduler:
            scheduler.step()

        val_loss = validate_model(model, loss, dataloader_val,
                                  pair_generation_tnf, epoch, logs_writer)

        # Change lr_max in cosine annealing
        if args.lr_scheduler == 'cosine' and (epoch % epoch_to_change_lr == 0):
            scheduler.state_dict()['base_lrs'][0] *= args.lr_decay

        if (epoch % epoch_to_change_lr
                == epoch_to_change_lr // 2) or epoch == 1:
            compute_metric('absdiff', model, args.geometric_model, None, None,
                           dataset_val, dataloader_val, pair_generation_tnf,
                           args.batch_size, args)

        # remember best loss
        is_best = val_loss < best_val_loss
        best_val_loss = min(val_loss, best_val_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'best_val_loss': best_val_loss,
                'optimizer': optimizer.state_dict(),
            }, is_best, checkpoint_path)

    logs_writer.close()
    print('Done!')
Ejemplo n.º 4
0
                  transform=NormalizeImageDict(
                      ['source_image', 'target_image']),
                  output_size=cnn_image_size)

if use_cuda:
    batch_size = 8
else:
    batch_size = 1

dataloader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=4,
                        collate_fn=collate_fn)

batch_tnf = BatchTensorToVars(use_cuda=use_cuda)

if args.eval_dataset == 'pf' or args.eval_dataset == 'pf-pascal':
    metric = 'pck'
elif args.eval_dataset == 'caltech':
    metric = 'area'
elif args.eval_dataset == 'pascal-parts':
    metric = 'pascal_parts'
elif args.eval_dataset == 'tss':
    metric = 'flow'

model.eval()

stats = compute_metric(metric, model, dataset, dataloader, batch_tnf,
                       batch_size, False, False, True, args)
Ejemplo n.º 5
0
two_stage = args.model!='' or (do_aff and do_tps)


if args.categories==0: 
    eval_categories = np.array(range(20))+1
else:
    eval_categories = np.array(args.categories)
    
eval_flag = np.zeros(len(dataset_eval))
for i in range(len(dataset_eval)):
    eval_flag[i]=sum(eval_categories==dataset_eval.category[i])
eval_idx = np.flatnonzero(eval_flag)

model.eval()

stats=compute_metric(metric,model,dataset_eval,dataloader_eval,batch_tnf,8,two_stage,do_aff,do_tps,args)
eval_value=np.mean(stats['aff_tps'][metric][eval_idx])

print(eval_value)

# train
best_test_loss = float("inf")

train_loss = np.zeros(args.num_epochs)
test_loss = np.zeros(args.num_epochs)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)

print('Starting training...')

for epoch in range(1, args.num_epochs+1):
Ejemplo n.º 6
0
do_tps = args.model_tps != ""
two_stage = args.model != '' or (do_aff and do_tps)

if args.categories == 0:
    eval_categories = np.array(range(20)) + 1
else:
    eval_categories = np.array(args.categories)

eval_flag = np.zeros(len(dataset_eval))
for i in range(len(dataset_eval)):
    eval_flag[i] = sum(eval_categories == dataset_eval.category[i])
eval_idx = np.flatnonzero(eval_flag)

model.eval()

stats = compute_metric(metric, model, dataset_eval, dataloader_eval, batch_tnf,
                       2, two_stage, do_aff, do_tps, args)
eval_value = np.mean(stats['aff_tps'][metric][eval_idx])

print(eval_value)

# train
best_test_loss = float("inf")

train_loss = np.zeros(args.num_epochs)
test_loss = np.zeros(args.num_epochs)

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=args.lr,
                       weight_decay=args.weight_decay)
print('Starting training...')
def main():

    # Argument parsing
    args, arg_groups = ArgumentParser(mode='eval').parse()
    print(args)

    # check provided models and deduce if single/two-stage model should be used
    two_stage = args.model_2 != ''

    if args.eval_dataset_path == '' and args.eval_dataset == 'pf':
        args.eval_dataset_path = 'datasets/proposal-flow-willow/'

    if args.eval_dataset_path == '' and args.eval_dataset == 'pf-pascal':
        args.eval_dataset_path = 'datasets/proposal-flow-pascal/'

    if args.eval_dataset_path == '' and args.eval_dataset == 'caltech':
        args.eval_dataset_path = 'datasets/caltech-101/'

    if args.eval_dataset_path == '' and args.eval_dataset == 'tss':
        args.eval_dataset_path = 'datasets/tss/'

    use_cuda = torch.cuda.is_available()

    # Download dataset if needed
    if args.eval_dataset == 'pf' and not exists(args.eval_dataset_path):
        download_PF_willow(args.eval_dataset_path)

    elif args.eval_dataset == 'pf-pascal' and not exists(
            args.eval_dataset_path):
        download_PF_pascal(args.eval_dataset_path)

    elif args.eval_dataset == 'caltech' and not exists(args.eval_dataset_path):
        download_caltech(args.eval_dataset_path)

    elif args.eval_dataset == 'tss' and not exists(args.eval_dataset_path):
        download_TSS(args.eval_dataset_path)

    print('Creating CNN model...')

    def create_model(model_filename):
        checkpoint = torch.load(model_filename,
                                map_location=lambda storage, loc: storage)
        checkpoint['state_dict'] = OrderedDict([
            (k.replace('vgg', 'model'), v)
            for k, v in checkpoint['state_dict'].items()
        ])
        output_size = checkpoint['state_dict'][
            'FeatureRegression.linear.bias'].size()[0]

        if output_size == 6:
            geometric_model = 'affine'

        elif output_size == 8 or output_size == 9:
            geometric_model = 'hom'
        else:
            geometric_model = 'tps'

        model = CNNGeometric(use_cuda=use_cuda,
                             output_dim=output_size,
                             **arg_groups['model'])

        for name, param in model.FeatureExtraction.state_dict().items():
            if not name.endswith('num_batches_tracked'):
                model.FeatureExtraction.state_dict()[name].copy_(
                    checkpoint['state_dict']['FeatureExtraction.' + name])

        for name, param in model.FeatureRegression.state_dict().items():
            if not name.endswith('num_batches_tracked'):
                model.FeatureRegression.state_dict()[name].copy_(
                    checkpoint['state_dict']['FeatureRegression.' + name])

        return (model, geometric_model)

    # Load model for stage 1
    model_1, geometric_model_1 = create_model(args.model_1)

    if two_stage:
        # Load model for stage 2
        model_2, geometric_model_2 = create_model(args.model_2)
    else:
        model_2, geometric_model_2 = None, None

    #import pdb; pdb.set_trace()

    print('Creating dataset and dataloader...')

    # Dataset and dataloader
    if args.eval_dataset == 'pf':
        Dataset = PFDataset
        collate_fn = default_collate
        csv_file = 'test_pairs_pf.csv'

    if args.eval_dataset == 'pf-pascal':
        Dataset = PFPascalDataset
        collate_fn = default_collate
        csv_file = 'all_pairs_pf_pascal.csv'

    elif args.eval_dataset == 'caltech':
        Dataset = CaltechDataset
        collate_fn = default_collate
        csv_file = 'test_pairs_caltech_with_category.csv'

    elif args.eval_dataset == 'tss':
        Dataset = TSSDataset
        collate_fn = default_collate
        csv_file = 'test_pairs_tss.csv'

    cnn_image_size = (args.image_size, args.image_size)

    dataset = Dataset(csv_file=os.path.join(args.eval_dataset_path, csv_file),
                      dataset_path=args.eval_dataset_path,
                      transform=NormalizeImageDict(
                          ['source_image', 'target_image']),
                      output_size=cnn_image_size)

    if use_cuda:
        batch_size = args.batch_size

    else:
        batch_size = 1

    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0,
                            collate_fn=collate_fn)

    batch_tnf = BatchTensorToVars(use_cuda=use_cuda)

    if args.eval_dataset == 'pf' or args.eval_dataset == 'pf-pascal':
        metric = 'pck'

    elif args.eval_dataset == 'caltech':
        metric = 'area'

    elif args.eval_dataset == 'tss':
        metric = 'flow'

    model_1.eval()

    if two_stage:
        model_2.eval()

    print('Starting evaluation...')

    stats = compute_metric(metric, model_1, geometric_model_1, model_2,
                           geometric_model_2, dataset, dataloader, batch_tnf,
                           batch_size, args)
Ejemplo n.º 8
0
def main(passed_arguments=None):

    # Argument parsing
    args,arg_groups = ArgumentParser(mode='eval').parse(passed_arguments)
    print(args)

    # check provided models and deduce if single/two-stage model should be used
    two_stage = args.model_2 != ''
     
    use_cuda = torch.cuda.is_available()
    use_me = args.use_me

    print('Creating CNN model...')

    def create_model(model_filename):
        checkpoint = torch.load(model_filename, map_location=lambda storage, loc: storage)
        checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
        try:
            output_size = checkpoint['state_dict']['FeatureRegression.linear.bias'].size()[0]
        except:
            output_size = checkpoint['state_dict']['FeatureRegression.resnet.fc.bias'].size()[0]
        if output_size == 4:
            geometric_model = 'affine_simple_4'
        elif output_size == 3:
            geometric_model = 'affine_simple'
        else:
            raise NotImplementedError('Geometric model deducted from output layer is unsupported')

        model = CNNGeometric(use_cuda=use_cuda,
                             output_dim=output_size,
                             **arg_groups['model'])

        if use_me is False:
            for name, param in model.FeatureExtraction.state_dict().items():
                if not name.endswith('num_batches_tracked'):
                    model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name])    

        for name, param in model.FeatureRegression.state_dict().items():
            if not name.endswith('num_batches_tracked'):
                model.FeatureRegression.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression.' + name])

        return (model,geometric_model)

    # Load model for stage 1
    model_1, geometric_model_1 = create_model(args.model_1)

    if two_stage:
        # Load model for stage 2
        model_2, geometric_model_2 = create_model(args.model_2)
    else:
        model_2,geometric_model_2 = None, None

    #import pdb; pdb.set_trace()

    print('Creating dataset and dataloader...')

    # Dataset and dataloader
    if args.eval_dataset == '3d' and use_me is False:
        cnn_image_size=(args.image_size,args.image_size)
        dataset = Dataset3D(csv_file = os.path.join(args.eval_dataset_path, 'all_pairs.csv'),
                      dataset_path = args.eval_dataset_path,
                      transform = NormalizeImageDict(['source_image','target_image']),
                      output_size = cnn_image_size)
        collate_fn = default_collate
    elif args.eval_dataset == '3d' and use_me is True:
        cnn_image_size=(args.input_height, args.input_width)
        dataset = MEDataset(dataset_csv_path=args.eval_dataset_path, 
                            dataset_csv_file='all_pairs_3d.csv', 
                            dataset_image_path=args.eval_dataset_path,
                            input_height=args.input_height, input_width=args.input_width, 
                            crop=args.crop_factor, 
                            use_conf=args.use_conf, 
                            use_random_patch=args.use_random_patch,
                            normalize_inputs=args.normalize_inputs,
                            geometric_model='EVAL', 
                            random_sample=False)
        collate_fn = default_collate
    else:
        raise NotImplementedError('Dataset is unsupported')

    if use_cuda:
        batch_size = args.batch_size
    else:
        batch_size = 1

    dataloader = DataLoader(dataset, 
                            batch_size = batch_size,
                            shuffle = False,
                            num_workers=0,
                            collate_fn = collate_fn)

    batch_tnf = BatchTensorToVars(use_cuda = use_cuda)

    if args.eval_dataset == '3d':
        metric = 'absdiff'
    else:
        raise NotImplementedError('Dataset is unsupported')
        
    model_1.eval()

    if two_stage:
        model_2.eval()

    print(os.path.basename(args.model_1))
    print('Starting evaluation...', flush=True)
        
    stats=compute_metric(metric,
                         model_1,
                         geometric_model_1,
                         model_2,
                         geometric_model_2,
                         dataset,
                         dataloader,
                         batch_tnf,
                         batch_size,
                         args)
    if args.eval_dataset_path.find('merged') >= 0:
        stats_fn = 'stats_merged.pkl'
    else:
        stats_fn = 'stats.pkl'
    stats_fn = os.path.join(os.path.dirname(args.model_1), stats_fn)
    save_dict(stats_fn, stats)
    return stats 
from util.util import init_model
from util.util import init_test_data

from util.eval_util import compute_metric
from util.torch_util import BatchTensorToVars

from parser.parser import ArgumentParser


args, arg_groups = ArgumentParser(mode='eval').parse()


#torch.cuda.set_device(args.gpu)
use_cuda = torch.cuda.is_available()


""" Initialize model """
model = init_model(args, arg_groups, use_cuda, mode='eval')



""" Initialize dataloader """
test_data, test_loader = init_test_data(args)

batch_tnf = BatchTensorToVars(use_cuda=use_cuda)


model.eval()
    
stats = compute_metric(args.eval_metric, model, test_data, test_loader, batch_tnf, args)