Ejemplo n.º 1
0
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    if n_gpus>1:
        maml = torch.nn.DataParallel(maml)
    state_dict = torch.load(Param.out_path+args.ckpt)
    print(state_dict.keys())
    pretrained_dict = OrderedDict()
    for k in state_dict.keys():
        if n_gpus==1:
            pretrained_dict[k[7:]] = deepcopy(state_dict[k])
        else:
            pretrained_dict[k[0:]] = deepcopy(state_dict[k])
    maml.load_state_dict(pretrained_dict)
    print("Load from ckpt:", Param.out_path+args.ckpt)
    
    #opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)  

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    #trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    #valset = MiniImagenet(Param.root, mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    #trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True)
    #valloader = DataLoader(valset, batch_size=4, shuffle=True, num_workers=4, drop_last=True)
    testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True)
    #train_data = inf_get(trainloader)
    #val_data = inf_get(valloader)
    test_data = inf_get(testloader)
    
    """Test for 600 epochs (each has 4 tasks)"""
    ans = None
    maml_clone = deepcopy(maml)
    for itr in range(600): # 600x4 test tasks
        support_x, support_y, qx, qy = test_data.__next__()
        support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device)
        temp = maml_clone(support_x, support_y, qx, qy, meta_train = False)
        if(ans is None):
            ans = temp
        else:
            ans = torch.cat([ans, temp], dim = 0)
        if itr%100==0:
            print(itr,ans.mean(dim = 0).tolist())
    ans = ans.mean(dim = 0).tolist()
    print('Acc: '+str(ans))
    with open(Param.out_path+'test.json','w') as f:
        json.dump(ans,f)
Ejemplo n.º 2
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    print(args)
    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    ckpt_dir = "./checkpoint_miniimage.pth"
    print("Load trained model")
    ckpt = torch.load(ckpt_dir)
    maml.load_state_dict(ckpt['model'])

    mini_test = MiniImagenet("F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\",
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=1,
                             resize=args.imgsz)

    db_test = DataLoader(mini_test,
                         1,
                         shuffle=True,
                         num_workers=1,
                         pin_memory=True)
    accs_all_test = []
    #count = 0
    #print("Test_loader",db_test)

    for x_spt, y_spt, x_qry, y_qry in db_test:

        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
        x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
        accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
Ejemplo n.º 3
0
def main():
    torch.manual_seed(222)  # 为cpu设置种子,为了使结果是确定的
    torch.cuda.manual_seed_all(222)  # 为GPU设置种子,为了使结果是确定的
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 1, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 7040])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)
    if os.path.exists(
            "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")):
        path = "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")
        maml.load_state_dict(path)
        print("load model success")

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000)
    mini_test = MiniImagenet("./miniimagenet", mode='val', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100)
    test_accuracy = []
    for epoch in range(10):
        # fetch meta_batchsz num of episode each time
        db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
        accs_all_test = []

        for x_spt, y_spt, x_qry, y_qry in db_test:
            x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                         x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

            accs, loss_t = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
            accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
        test_accuracy.append(accs[-1])
    average_accuracy = sum(test_accuracy) / len(test_accuracy)
    print("average accuracy:{}".format(average_accuracy))
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    if len(args.gpu.split(',')) > 1:
        maml = torch.nn.DataParallel(maml)
    state_dict = torch.load(Param.out_path + args.ckpt)
    print(state_dict.keys())
    pretrained_dict = OrderedDict()
    for k in state_dict.keys():
        if n_gpus == -1:
            pretrained_dict[k[7:]] = deepcopy(state_dict[k])
        else:
            pretrained_dict[k[0:]] = deepcopy(state_dict[k])
    maml.load_state_dict(pretrained_dict)
    print("Load from ckpt:", Param.out_path + args.ckpt)

    #opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    if args.loader in [0, 1]:  # default loader
        if args.loader == 1:
            #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet
            from MiniImagenet2 import MiniImagenet
        else:
            from MiniImagenet import MiniImagenet

        testset = MiniImagenet(Param.root,
                               mode='test',
                               n_way=args.n_way,
                               k_shot=args.k_spt,
                               k_query=args.k_qry,
                               resize=args.imgsz)
        testloader = DataLoader(testset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=4,
                                worker_init_fn=worker_init_fn,
                                drop_last=True)
        test_data = inf_get(testloader)

    elif args.loader == 2:  # pkl loader
        args_data = {}
        args_data['x_dim'] = "84,84,3"
        args_data['ratio'] = 1.0
        args_data['seed'] = 222
        loader_test = dataset_mini(600, 100, 'test', args_data)
        loader_test.load_data_pkl()
    """Test for 600 epochs (each has 4 tasks)"""
    ans = None
    maml_clone = deepcopy(maml)
    for itr in range(600):  # 600x4 test tasks
        if args.loader in [0, 1]:
            support_x, support_y, qx, qy = test_data.__next__()
            support_x, support_y, qx, qy = support_x.to(
                Param.device), support_y.to(Param.device), qx.to(
                    Param.device), qy.to(Param.device)
        elif args.loader == 2:
            support_x, support_y, qx, qy = get_data(loader_test)
            support_x, support_y, qx, qy = support_x.to(
                Param.device), support_y.to(Param.device), qx.to(
                    Param.device), qy.to(Param.device)

        temp = maml_clone(support_x, support_y, qx, qy, meta_train=False)
        if (ans is None):
            ans = temp
        else:
            ans = torch.cat([ans, temp], dim=0)
        if itr % 100 == 0:
            print(itr, ans.mean(dim=0).tolist())
    meanacc = np.array(ans.mean(dim=0).tolist())
    stdacc = np.array(ans.std(dim=0).tolist())
    ci95 = 1.96 * stdacc / np.sqrt(600)
    print(f'Acc: {meanacc[-1]:.4f}, ci95: {ci95[-1]:.4f}')
    with open(Param.out_path + 'test.txt', 'w') as f:
        print(f'Acc: {meanacc[-1]:.4f}, ci95: {ci95[-1]:.4f}', file=f)
Ejemplo n.º 5
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ## Task Learner Setup
    task_config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 2, 0]),
                   ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 2, 0]),
                   ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 2, 0]),
                   ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []),
                   ('linear', [args.n_way, 32 * 5 * 5])]
    last_epoch = 0
    suffix = "_v0"
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(
            args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt'
    while os.path.isfile(save_path):
        valid_epoch = last_epoch
        last_epoch += 500
        save_path = os.getcwd() + '/data/model_batchsz' + str(
            args.k_model) + '_stepsz' + str(
                args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt'
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(
            args.update_lr) + '_epoch' + str(valid_epoch) + suffix + '.pt'

    device = torch.device('cuda')
    task_mod = Meta(args, task_config).to(device)
    task_mod.load_state_dict(torch.load(save_path))
    task_mod.eval()

    ## AL Learner Setup
    print(args)

    al_config = [('linear', [1, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = AL_Learner(args, al_config, task_mod).to(device)
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('/home/tesca/data/miniimagenet/',
                        mode='train',
                        n_way=args.n_way,
                        k_shot=1,
                        k_query=args.k_qry,
                        batchsz=10000,
                        resize=args.imgsz)
    mini_test = MiniImagenet('/home/tesca/data/miniimagenet/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=1,
                             k_query=args.k_qry,
                             batchsz=100,
                             resize=args.imgsz)
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(args.update_lr) + '_epoch'

    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            al_accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\tAL acc:', al_accs)

            if step % 500 == 0:  # evaluation
                torch.save(maml.state_dict(),
                           save_path + str(step) + "_al_net.pt")
                '''db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
Ejemplo n.º 6
0
def main(args):
    step = args.step
    set_seed(args.seed)

    adj, features, labels = load_citation(args.dataset, args.normalization)

    features = sgc_precompute(features, adj, args.degree)

    if args.dataset == 'citeseer':
        node_num = 3327
        class_label = [0, 1, 2, 3, 4, 5]
        combination = list(combinations(class_label, 2))
    elif args.dataset == 'cora':
        node_num = 2708
        class_label = [0, 1, 2, 3, 4, 5, 6]
        combination = list(combinations(class_label, 2))

    config = [('linear', [args.hidden, features.size(1)]),
              ('linear', [args.n_way, args.hidden])]

    device = torch.device('cuda')

    for i in range(len(combination)):
        print("Cross Validation: {}".format((i + 1)))

        maml = Meta(args, config).to(device)

        test_label = list(combination[i])
        train_label = [n for n in class_label if n not in test_label]
        print('Cross Validation {} Train_Label_List: {} '.format(
            i + 1, train_label))
        print('Cross Validation {} Test_Label_List: {} '.format(
            i + 1, test_label))

        for j in range(args.epoch):
            x_spt, y_spt, x_qry, y_qry = sgc_data_generator(
                features, labels, node_num, train_label, args.task_num,
                args.n_way, args.k_spt, args.k_qry)
            accs = maml.forward(x_spt, y_spt, x_qry, y_qry)
            print('Step:', j, '\tMeta_Training_Accuracy:', accs)
            if j % 100 == 0:
                torch.save(maml.state_dict(), 'maml.pkl')
                meta_test_acc = []
                for k in range(step):
                    model_meta_trained = Meta(args, config).to(device)
                    model_meta_trained.load_state_dict(torch.load('maml.pkl'))
                    model_meta_trained.eval()
                    x_spt, y_spt, x_qry, y_qry = sgc_data_generator(
                        features, labels, node_num, test_label, args.task_num,
                        args.n_way, args.k_spt, args.k_qry)
                    accs = model_meta_trained.forward(x_spt, y_spt, x_qry,
                                                      y_qry)
                    meta_test_acc.append(accs)
                if args.dataset == 'citeseer':
                    with open('citeseer.txt', 'a') as f:
                        f.write(
                            'Cross Validation:{}, Step: {}, Meta-Test_Accuracy: {}'
                            .format(
                                i + 1, j,
                                np.array(meta_test_acc).mean(axis=0).astype(
                                    np.float16)))
                        f.write('\n')
                elif args.dataset == 'cora':
                    with open('cora.txt', 'a') as f:
                        f.write(
                            'Cross Validation:{}, Step: {}, Meta-Test_Accuracy: {}'
                            .format(
                                i + 1, j,
                                np.array(meta_test_acc).mean(axis=0).astype(
                                    np.float16)))
                        f.write('\n')
Ejemplo n.º 7
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = []

    if args.arch == "Unet":
        for block in range(args.NUM_DOWN_CONV):
            out_channels = (2**block) * args.HIDDEN_DIM
            if (block == 0):
                config += [(
                    'conv2d', [out_channels, args.imgc, 3, 3, 1, 1]
                )  # out_c, in_c, k_h, k_w, stride, padding, also only conv, without bias
                           ]
            else:
                config += [
                    ('conv2d', [out_channels, out_channels // 2, 3, 3, 1,
                                1]),  # out_c, in_c, k_h, k_w, stride, padding
                ]
            config += [
                ('leakyrelu',
                 [0.2, False]),  # alpha; if true then executes relu in place
                ('bn', [out_channels])
            ]

            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('max_pool2d', [2, 2,
                                       0])]  # kernel_size, stride, padding

        for block in range(args.NUM_DOWN_CONV - 1):
            out_channels = (2**(args.NUM_DOWN_CONV - block -
                                2)) * args.HIDDEN_DIM
            in_channels = out_channels * 3
            config += [('upsample', [2])]
            config += [('conv2d', [out_channels, in_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
        config += [
            ('conv2d_b', [args.outc, args.HIDDEN_DIM, 3, 3, 1, 1])
        ]  # all the conv2d before are without bias, and this conv_b is with bias
    else:
        raise ("architectures other than Unet hasn't been added!!")
    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    # print(maml)
    for name, param in maml.named_parameters():
        print(name, param.size())
    print('Total trainable tensors:', num)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = "/scratch/users/chenkaim/pytorch-models/pytorch_" + args.model_name + "_k_shot_" + str(
        args.k_spt) + "_task_num_" + str(args.task_num) + "_meta_lr_" + str(
            args.meta_lr) + "_inner_lr_" + str(
                args.update_lr) + "_num_inner_updates_" + str(args.update_step)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/epoch_" + str(args.continue_epoch) + ".pt")
        checkpoint = torch.load(model_path + "/epoch_" +
                                str(args.continue_epoch) + ".pt")
        print(checkpoint.keys())
        print(checkpoint.items())
        maml.load_state_dict(checkpoint['state_dict'])
        maml.lr_scheduler.load_state_dict(checkpoint['scheduler'])
        maml.meta_optim.load_state_dict(checkpoint['optimizer'])
        start_epoch = args.continue_epoch

    db = RCWA_data_loader(batchsz=args.task_num,
                          n_way=args.n_way,
                          k_shot=args.k_spt,
                          k_query=args.k_qry,
                          imgsz=args.imgsz,
                          data_folder=args.data_folder)

    for step in range(start_epoch, args.epoch):
        print("epoch: ", step)
        if step % EPOCH_SAVE_INTERVAL == 0:
            torch.save(maml.state_dict(),
                       model_path + "/epoch_" + str(step) + ".pt")
        for itr in range(
                int(0.7 * db.total_data_samples /
                    ((args.k_spt + args.k_qry) * args.task_num))):
            x_spt, y_spt, x_qry, y_qry = db.next()
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

            # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
            accs, loss_q, ave_trans, min_trans = maml.transference(
                x_spt, y_spt, x_qry, y_qry)

            if itr % SUMMARY_INTERVAL == 0:
                print_str = "Iteration %d: pre-inner-loop train accuracy: %.5f, post-iner-loop test accuracy: %.5f, train_loss: %.5f, ave_trans: %.2f, min_trans: %.2f" % (
                    itr, accs[0], accs[-1], loss_q, ave_trans, min_trans)
                print(print_str)

            if itr % TEST_PRINT_INTERVAL == 0:
                accs = []
                for _ in range(10):
                    # test
                    x_spt, y_spt, x_qry, y_qry = db.next('test')
                    x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                                 torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                    # split to single task each time
                    for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                            x_spt, y_spt, x_qry, y_qry):
                        test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                    x_qry_one, y_qry_one)
                        accs.append(test_acc)

                # [b, update_step+1]
                accs = np.array(accs).mean(axis=0).astype(np.float16)
                print(
                    'Meta-validation pre-inner-loop train accuracy: %.5f, meta-validation post-inner-loop test accuracy: %.5f'
                    % (accs[0], accs[-1]))

        maml.lr_scheduler.step()