Пример #1
0
def test(config):
    log_dir = os.path.join(config.log_dir, config.name + '_stage_2')

    val_path = os.path.join(config.data, "*/test")

    val_dataset = MultiviewImgDataset(val_path,
                                      scale_aug=False,
                                      rot_aug=False,
                                      num_views=config.num_views)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config.stage2_batch_size,
        shuffle=False,
        num_workers=0)

    pretraining = not config.no_pretraining
    cnet = SVCNN(config.name,
                 nclasses=config.num_classes,
                 cnn_name=config.cnn_name,
                 pretraining=pretraining)

    cnet_2 = MVCNN(config.name,
                   cnet,
                   nclasses=config.num_classes,
                   cnn_name=config.cnn_name,
                   num_views=config.num_views)
    cnet_2.load(
        os.path.join(log_dir, config.snapshot_prefix + str(config.weights)))
    optimizer = optim.Adam(cnet_2.parameters(),
                           lr=config.learning_rate,
                           weight_decay=config.weight_decay,
                           betas=(0.9, 0.999))

    trainer = ModelNetTrainer(cnet_2,
                              None,
                              val_loader,
                              optimizer,
                              nn.CrossEntropyLoss(),
                              config,
                              log_dir,
                              num_views=config.num_views)

    labels, predictions = trainer.update_validation_accuracy(config.weights,
                                                             test=True)
    import Evaluation_tools as et
    eval_file = os.path.join(config.log_dir, '{}.txt'.format(config.name))
    et.write_eval_file(config.data, eval_file, predictions, labels,
                       config.name)
    et.make_matrix(config.data, eval_file, config.log_dir)
Пример #2
0
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batchSize,
                                             shuffle=False,
                                             num_workers=0)
    print('num_train_files: ' + str(len(train_dataset.filepaths)))
    print('num_val_files: ' + str(len(val_dataset.filepaths)))
    optimizer = None
    trainer = ModelNetTrainer(cnet_2,
                              train_loader,
                              val_loader,
                              optimizer,
                              nn.CrossEntropyLoss(),
                              'mvcnn',
                              log_dir,
                              num_views=args.num_views)
    loss, val_overall_acc, val_mean_class_acc = trainer.update_validation_accuracy(
        None)

    # total_seen = 0
    # total_correct = 0
    # cnet_2.eval()
    # for _, data in enumerate(val_loader,0):
    #
    #     N, V, C, H, W = data[1].size()
    #     total_seen = total_seen + N
    #     in_data = Variable(data[1]).view(-1, C, H, W).cuda()
    #     target = Variable(data[0]).cuda()
    #
    #     out_data = cnet_2(in_data)
    #     pred = torch.max(out_data, 1)[1]
    #     correct = (pred==target).float().sum()
    #     total_correct = total_correct+correct