Esempio n. 1
0
def main(opts):
    pprint(vars(opts))
    device = torch.device(opts.device)
    print('using device: {}'.format(device))

    encoder = Convnet().to(device)
    encoder.load_state_dict(torch.load(opts.weights))
    encoder.eval()
    pipeline = protoPipe(encoder, opts.shot_k, opts.query_k)

    test_set = MNIST(mode='test')
    test_sampler = TaskSampler(test_set.label, 1000, opts.n_way,
                               opts.shot_k + opts.query_k)
    test_loader = DataLoader(dataset=test_set,
                             batch_sampler=test_sampler,
                             num_workers=8,
                             pin_memory=True)

    test_acc = []
    for episode, batch in enumerate(test_loader, 0):
        task = batch[0].view(opts.n_way * (opts.shot_k + opts.query_k), 3, 84,
                             84)
        loss, acc = pipeline(task.to(device), opts.n_way)
        test_acc.append(acc)

    m, h = mean_confidence_interval(test_acc)
    print('TEST set  acc: {:.4f}, h: {:.4f}'.format(m, h))
Esempio n. 2
0
def main(opts):
    pprint(vars(opts))
    device = torch.device(opts.device)
    print('using device: {}'.format(device))

    encoder = Convnet().to(device)
    encoder.load_state_dict(torch.load(opts.weights))
    encoder.eval()
    pipeline = protoPipe(encoder, opts.shot_k, opts.query_k)

    if opts.seed is not None:
        torch.manual_seed(opts.seed)
        torch.cuda.manual_seed_all(opts.seed)
        np.random.seed(opts.seed)
        print("\nrandom seed: {}".format(opts.seed))

    if opts.dataset == 'mini':
        test_set = MiniImageNet(mode='test')
    elif opts.dataset == 'MNIST':
        test_set = MNIST(mode='test')
    elif opts.dataset == 'CIFAR':
        test_set = CIFAR(mode='test')
    elif opts.dataset == 'FashionMNIST':
        test_set = Fashion_MNIST(mode='test')

    test_sampler = TaskSampler(test_set.label, 1000, opts.n_way,
                               opts.shot_k + opts.query_k)
    test_loader = DataLoader(dataset=test_set,
                             batch_sampler=test_sampler,
                             num_workers=8,
                             pin_memory=True)

    test_acc = []
    for episode, batch in enumerate(test_loader, 0):
        task = batch[0].view(opts.n_way * (opts.shot_k + opts.query_k), 3, 84,
                             84)
        loss, acc = pipeline(task.to(device), opts.n_way)
        test_acc.append(acc)

    m, h = mean_confidence_interval(test_acc)
    print('TEST set  acc: {:.4f}, h: {:.4f}'.format(m, h))
    parser.add_argument('--shot', type=int, default=1)
    parser.add_argument('--query', type=int, default=30)
    parser.add_argument('--folds', type=int, default=2)
    args = parser.parse_args()
    pprint(vars(args))

    set_gpu(args.gpu)

    dataset = MiniImageNet('test')
    sampler = CategoriesSampler(dataset.label,
                                args.batch, args.way, args.folds * args.shot + args.query)
    loader = DataLoader(dataset, batch_sampler=sampler,
                        num_workers=8, pin_memory=True)

    model = Convnet().cuda()
    model.load_state_dict(torch.load(args.load))
    model.eval()

    ave_acc = Averager()
    s_label = torch.arange(args.train_way).repeat(args.shot).view(args.shot * args.train_way)
    s_onehot = torch.zeros(s_label.size(0), 20)
    s_onehot = s_onehot.scatter_(1, s_label.unsqueeze(dim=1), 1).cuda()

    for i, batch in enumerate(loader, 1):
        data, _ = [_.cuda() for _ in batch]
        k = args.way * args.shot
        data_shot, meta_support, data_query = data[:k], data[k:2*k], data[2*k:]

        #p = inter_fold(model, args, data_shot)

        x = model(data_shot)
Esempio n. 4
0
                              pin_memory=True)

    #valset = MiniImageNet('test')
    valset = MiniImageNet('trainvaltest')
    val_sampler = CategoriesSampler_val_100way(valset.label, 400,
                                               args.test_way, args.shot,
                                               args.query_val)
    val_loader = DataLoader(dataset=valset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    model_cnn = Convnet().cuda()
    model_reg = Registrator().cuda()
    model_cnn.load_state_dict(
        torch.load(
            './iterative_G3_trainval_lr150epoch_dataaugumentation_2epoch-175_backbone.pth'
        ))

    noise_dim = 128
    model_gen = Hallucinator(noise_dim).cuda()
    model_gen.load_state_dict(
        torch.load(
            './iterative_G3_trainval_lr150epoch_dataaugumentation_2epoch-175_gen.pth'
        ))

    global_proto = torch.load('./global_proto_all_new.pth')
    global_base = global_proto[:args.n_base_class, :]
    global_novel = global_proto[args.n_base_class:, :]

    global_base = [Variable(global_base.cuda(), requires_grad=True)]
    global_novel = [Variable(global_novel.cuda(), requires_grad=True)]
    image_size = 84
    iter_num = 600

    few_shot_params = dict(n_way=params.test_n_way, n_support=params.n_shot)
    model = Convnet()
    model = model.to(device)

    checkpoint_dir = '%s/checkpoints/%s_%dway_%dshot' % (
        configs.save_dir, params.dataset, params.train_n_way, params.n_shot)
    if params.train_aug:
        checkpoint_dir += '_aug'

    modelfile = get_best_file(checkpoint_dir)
    if modelfile is not None:
        tmp = torch.load(modelfile)
        model.load_state_dict(tmp['state'])

    split = params.split
    if params.save_iter != -1:
        split_str = split + "_" + str(params.save_iter)
    else:
        split_str = split

    datamgr = SetDataManager(image_size,
                             n_eposide=iter_num,
                             n_query=15,
                             **few_shot_params)
    loadfile = configs.data_dir[params.dataset] + split + '.json'
    novel_loader = datamgr.get_data_loader(loadfile, aug=False)

    model.eval()
Esempio n. 6
0
    pprint(vars(args))

    set_gpu(args.gpu)

    valset = MiniImageNet2('trainvaltest')
    val_loader = DataLoader(dataset=valset, batch_size = 128,
                            num_workers=8, pin_memory=True)
    valset2 = MiniImageNet2('trainval')
    val_loader2 = DataLoader(dataset=valset2, batch_size = 128,
                            num_workers=8, pin_memory=True)
    valset3 = MiniImageNet2('test')
    val_loader3 = DataLoader(dataset=valset3, batch_size = 128,
                            num_workers=8, pin_memory=True)

    model_cnn = Convnet().cuda()
    model_cnn.load_state_dict(torch.load('./100way_pn_basenovel.pth'))
    global_proto = torch.load('./global_proto_basenovel_PN_5shot_500.pth')
    global_base =global_proto[:args.n_base_class,:]
    global_novel = global_proto[args.n_base_class:,:]
    global_base = [Variable(global_base.cuda(),requires_grad=True)]
    global_novel = [Variable(global_novel.cuda(),requires_grad=True)]

    def log(out_str):
        print(out_str)
        logfile.write(out_str+'\n')
        logfile.flush()

    model_cnn.eval()
    for epoch in range(1, args.max_epoch + 1):

        for i, batch in enumerate(val_loader, 1):