def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)  

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True)
    testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True)
    train_data = inf_get(trainloader)
    test_data = inf_get(testloader)

    for epoch in range(args.epoch):
        support_x, support_y, meta_x, meta_y = train_data.__next__()
        support_x, support_y, meta_x, meta_y = support_x.to(Param.device), support_y.to(Param.device), meta_x.to(Param.device), meta_y.to(Param.device)
        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value = 10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if(epoch % 2000 == 999):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                support_x, support_y, qx, qy = test_data.__next__()
                support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device)
                temp = maml_clone(support_x, support_y, qx, qy, meta_train = False)
                if(ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim = 0)
            ans = ans.mean(dim = 0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(maml.state_dict(), Param.out_path + 'net_'+ str(epoch) + '_' + str(best_acc) + '.pkl') 
            del maml_clone
            print(str(epoch) + ': '+str(ans))
            with open(Param.out_path+'test.json','w') as f:
                json.dump(test_result,f)
        if (epoch < 5) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()
Ejemplo n.º 2
0
def main():
    print(args)
    device = torch.device('cuda')
    maml = Meta(args).to(device)
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)
    trainset = Gen(args.task_num, args.k_spt, args.k_qry)
    testset = Gen(args.task_num, args.k_spt, args.k_qry * 10)

    for epoch in range(args.epoch):
        ind = [i for i in range(trainset.xs.shape[0])]
        np.random.shuffle(ind)
        xs, ys = torch.Tensor(trainset.xs[ind]).to(device), torch.Tensor(
            trainset.ys[ind]).to(device)
        xq, yq = torch.Tensor(trainset.xq[ind]).to(device), torch.Tensor(
            trainset.yq[ind]).to(device)
        maml.train()
        loss = maml(xs, ys, xq, yq, epoch)
        print('Epoch: {} Initial loss: {} Train loss: {}'.format(
            epoch, loss[0] / args.task_num, loss[-1] / args.task_num))
        if (epoch + 1) % 50 == 0:
            print("Evaling the model...")
            torch.save(maml.state_dict(), 'save.pt')
            # del(maml)
            # maml = Meta(args).to(device)
            # maml.load_state_dict(torch.load('save.pt'))
            maml.eval()
            i = random.randint(0, testset.xs.shape[0] - 1)
            xs, ys = torch.Tensor(testset.xs[i]).to(device), torch.Tensor(
                testset.ys[i]).to(device)
            xq, yq = torch.Tensor(testset.xq[i]).to(device), torch.Tensor(
                testset.yq[i]).to(device)
            losses, losses_q, logits_q, _ = maml.finetunning(xs, ys, xq, yq)
            print('Epoch: {} Initial loss: {} Test loss: {}'.format(
                epoch, losses_q[0], losses_q[-1]))
Ejemplo n.º 3
0
                                                 query_y)
            optimizer.zero_grad()
            euc_loss.backward()
            optimizer.step()

            if step % 100 == 0:
                val_acc = eval(db_val, meta)

                tb.add_scalar('accuracy', val_acc)
                print('accuracy:', val_acc, 'best accuracy:', best_val_acc)
                # update learning rate per epoch
                # scheduler.step(total_val_loss)

                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save(meta.state_dict(), mdl_file)
                    print('saved to checkpoint:', mdl_file)

                    if val_acc > 0.4:
                        print('now conduct test performance...')
                        mini_test = MiniImagenet('../mini-imagenet/',
                                                 mode='test',
                                                 n_way=n_way,
                                                 k_shot=k_shot,
                                                 k_query=1,
                                                 batchsz=200,
                                                 resize=resize)
                        db_test = DataLoader(mini_test, batchsz, shuffle=True)
                        test_acc, _ = eval(db_test, meta)
                        print('>>>>>>>>>>>> test accuracy:', test_acc,
                              '<<<<<<<<<<<<<<')
Ejemplo n.º 4
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    mini_test = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)


    ckpt_dir = "./model/"

    for epoch in range(args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)


        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)

                # save checkpoints
                os.makedirs(ckpt_dir, exist_ok=True)
                print('Saving the model as a checkpoint...')
                torch.save({'epoch': epoch, 'Steps': step, 'model': maml.state_dict()}, os.path.join(ckpt_dir, 'checkpoint.pth'))
Ejemplo n.º 5
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('/home/tesca/data/miniimagenet/',
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000,
                        resize=args.imgsz)
    mini_test = MiniImagenet('/home/tesca/data/miniimagenet/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100,
                             resize=args.imgsz)
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_spt) + '_stepsz' + str(args.update_lr) + '_epoch'

    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(mini_test,
                                     1,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, _ = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
                torch.save(maml.state_dict(), save_path + str(step) + "_og.pt")
Ejemplo n.º 6
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='train',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=10000,
        resize=args.imgsz)
    mini_val = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='val',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=600,
        resize=args.imgsz)
    mini_test = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='test',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=600,
        resize=args.imgsz)

    best_acc = 0.0
    if not os.path.exists('ckpt/{}'.format(args.exp)):
        os.mkdir('ckpt/{}'.format(args.exp))
    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 500 == 0:
                print('step:', step, '\ttraining acc:', accs)
            if step % 1000 == 0:  # evaluation
                db_val = DataLoader(mini_val,
                                    1,
                                    shuffle=True,
                                    num_workers=1,
                                    pin_memory=True)
                accs_all_val = []
                for x_spt, y_spt, x_qry, y_qry in db_val:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_val.append(accs)
                mean, std, ci95 = cal_conf(np.array(accs_all_val))
                print('Val acc:{}, std:{}. ci95:{}'.format(
                    mean[-1], std[-1], ci95[-1]))
                if mean[-1] > best_acc or step % 5000 == 0:
                    best_acc = mean[-1]
                    torch.save(
                        maml.state_dict(),
                        'ckpt/{}/model_e{}s{}_{:.4f}.pkl'.format(
                            args.exp, epoch, step, best_acc))
                    with open('ckpt/' + args.exp + '/val.txt', 'a') as f:
                        print(
                            'val epoch {}, step {}: acc_val:{:.4f}, ci95:{:.4f}'
                            .format(epoch, step, best_acc, ci95[-1]),
                            file=f)

                    ## Test
                    db_test = DataLoader(mini_test,
                                         1,
                                         shuffle=True,
                                         num_workers=1,
                                         pin_memory=True)
                    accs_all_test = []
                    for x_spt, y_spt, x_qry, y_qry in db_test:
                        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                     x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                        accs_all_test.append(accs)
                    mean, std, ci95 = cal_conf(np.array(accs_all_test))
                    print('Test acc:{}, std:{}, ci95:{}'.format(
                        mean[-1], std[-1], ci95[-1]))
                    with open('ckpt/' + args.exp + '/test.txt', 'a') as f:
                        print(
                            'test epoch {}, step {}: acc_test:{:.4f}, ci95:{:.4f}'
                            .format(epoch, step, mean[-1], ci95[-1]),
                            file=f)
Ejemplo n.º 7
0
def main(args):
    step = args.step
    set_seed(args.seed)

    adj, features, labels = load_citation(args.dataset, args.normalization)

    features = sgc_precompute(features, adj, args.degree)

    if args.dataset == 'citeseer':
        node_num = 3327
        class_label = [0, 1, 2, 3, 4, 5]
        combination = list(combinations(class_label, 2))
    elif args.dataset == 'cora':
        node_num = 2708
        class_label = [0, 1, 2, 3, 4, 5, 6]
        combination = list(combinations(class_label, 2))

    config = [('linear', [args.hidden, features.size(1)]),
              ('linear', [args.n_way, args.hidden])]

    device = torch.device('cuda')

    for i in range(len(combination)):
        print("Cross Validation: {}".format((i + 1)))

        maml = Meta(args, config).to(device)

        test_label = list(combination[i])
        train_label = [n for n in class_label if n not in test_label]
        print('Cross Validation {} Train_Label_List: {} '.format(
            i + 1, train_label))
        print('Cross Validation {} Test_Label_List: {} '.format(
            i + 1, test_label))

        for j in range(args.epoch):
            x_spt, y_spt, x_qry, y_qry = sgc_data_generator(
                features, labels, node_num, train_label, args.task_num,
                args.n_way, args.k_spt, args.k_qry)
            accs = maml.forward(x_spt, y_spt, x_qry, y_qry)
            print('Step:', j, '\tMeta_Training_Accuracy:', accs)
            if j % 100 == 0:
                torch.save(maml.state_dict(), 'maml.pkl')
                meta_test_acc = []
                for k in range(step):
                    model_meta_trained = Meta(args, config).to(device)
                    model_meta_trained.load_state_dict(torch.load('maml.pkl'))
                    model_meta_trained.eval()
                    x_spt, y_spt, x_qry, y_qry = sgc_data_generator(
                        features, labels, node_num, test_label, args.task_num,
                        args.n_way, args.k_spt, args.k_qry)
                    accs = model_meta_trained.forward(x_spt, y_spt, x_qry,
                                                      y_qry)
                    meta_test_acc.append(accs)
                if args.dataset == 'citeseer':
                    with open('citeseer.txt', 'a') as f:
                        f.write(
                            'Cross Validation:{}, Step: {}, Meta-Test_Accuracy: {}'
                            .format(
                                i + 1, j,
                                np.array(meta_test_acc).mean(axis=0).astype(
                                    np.float16)))
                        f.write('\n')
                elif args.dataset == 'cora':
                    with open('cora.txt', 'a') as f:
                        f.write(
                            'Cross Validation:{}, Step: {}, Meta-Test_Accuracy: {}'
                            .format(
                                i + 1, j,
                                np.array(meta_test_acc).mean(axis=0).astype(
                                    np.float16)))
                        f.write('\n')
Ejemplo n.º 8
0
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # Load pkl dataset
    args_data = {}
    args_data['x_dim'] = "84,84,3"
    args_data['ratio'] = 1.0
    args_data['seed'] = 222
    loader_train = dataset_mini(600, 100, 'train', args_data)
    #loader_val   = dataset_mini(600, 100, 'val', args_data)
    loader_test = dataset_mini(600, 100, 'test', args_data)

    loader_train.load_data_pkl()
    #loader_val.load_data_pkl()
    loader_test.load_data_pkl()

    for epoch in range(args.epoch):
        support_x, support_y, meta_x, meta_y = get_data(loader_train)
        support_x, support_y, meta_x, meta_y = support_x.to(
            Param.device), support_y.to(Param.device), meta_x.to(
                Param.device), meta_y.to(Param.device)
        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if (epoch % 2000 == 999):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                support_x, support_y, qx, qy = get_data(loader_test)
                support_x, support_y, qx, qy = support_x.to(
                    Param.device), support_y.to(Param.device), qx.to(
                        Param.device), qy.to(Param.device)
                temp = maml_clone(support_x,
                                  support_y,
                                  qx,
                                  qy,
                                  meta_train=False)
                if (ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim=0)
            ans = ans.mean(dim=0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(
                    maml.state_dict(), Param.out_path + 'net_' + str(epoch) +
                    '_' + str(best_acc) + '.pkl')
            del maml_clone
            print(str(epoch) + ': ' + str(ans))
            with open(Param.out_path + 'test.json', 'w') as f:
                json.dump(test_result, f)
        if (epoch < 5) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()
Ejemplo n.º 9
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)

    config = [('conv2d', [32, 3, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]),
              ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]),
              ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    root = '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet'
    trainset = MiniImagenet(root,
                            mode='train',
                            n_way=args.n_way,
                            k_shot=args.k_spt,
                            k_query=args.k_qry,
                            resize=args.imgsz)
    testset = MiniImagenet(root,
                           mode='test',
                           n_way=args.n_way,
                           k_shot=args.k_spt,
                           k_query=args.k_qry,
                           resize=args.imgsz)
    trainloader = DataLoader(trainset,
                             batch_size=args.task_num,
                             shuffle=True,
                             num_workers=4,
                             worker_init_fn=worker_init_fn,
                             drop_last=True)
    testloader = DataLoader(testset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=1,
                            worker_init_fn=worker_init_fn,
                            drop_last=True)
    train_data = inf_get(trainloader)
    test_data = inf_get(testloader)

    best_acc = 0.0
    if not os.path.exists('ckpt/{}'.format(args.exp)):
        os.mkdir('ckpt/{}'.format(args.exp))
    for epoch in range(args.epoch):
        np.random.seed()
        x_spt, y_spt, x_qry, y_qry = train_data.__next__()
        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
            device), x_qry.to(device), y_qry.to(device)

        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if epoch % 100 == 0:
            print('epoch:', epoch, '\ttraining acc:', accs)

        if epoch % 2500 == 0:  # evaluation
            # save checkpoint
            torch.save(maml.state_dict(),
                       'ckpt/{}/model_{}.pkl'.format(args.exp, epoch))
            accs_all_test = []
            for _ in range(600):
                x_spt, y_spt, x_qry, y_qry = test_data.__next__()
                x_spt, y_spt, x_qry, y_qry = x_spt.squeeze().to(
                    device), y_spt.squeeze().to(device), x_qry.squeeze().to(
                        device), y_qry.squeeze().to(device)
                accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                accs_all_test.append(accs)

            # [b, update_step+1]
            accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
            with open('ckpt/' + args.exp + '/test.txt', 'a') as f:
                print('test epoch {}: acc:{:.4f}'.format(epoch, accs[-1]),
                      file=f)
Ejemplo n.º 10
0
def main():

    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()
    best_pred = 0

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    maml = Meta(args, criterion).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size)
    mini_valid = MiniImagenet(args.data_path,
                              mode='train',
                              n_way=args.n_way,
                              k_shot=args.k_spt,
                              k_query=args.k_qry,
                              batch_size=args.test_batch_size,
                              resize=args.img_size)

    train_loader = DataLoader(mini,
                              args.meta_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    valid_loader = DataLoader(mini_valid,
                              args.meta_test_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)

    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs_theta, train_accs_w = meta_train(train_loader, maml, device,
                                                    epoch, writer)
        logging.info(
            '[Epoch: {}]\t Train acc_theta: {}\t Train acc_w: {}'.format(
                epoch, train_accs_theta, train_accs_w))
        test_accs_theta, test_accs_w = meta_test(valid_loader, maml, device,
                                                 epoch, writer)
        logging.info(
            '[Epoch: {}]\t Test acc_theta: {}\t Test acc_w: {}'.format(
                epoch, test_accs_theta, test_accs_w))

        genotype = maml.model.genotype()
        logging.info('genotype = %s', genotype)

        logging.info(F.softmax(maml.model.alphas_normal, dim=-1))
        logging.info(F.softmax(maml.model.alphas_reduce, dim=-1))

        # Save the best meta model.
        new_pred = test_accs_w[-1]
        if new_pred > best_pred:
            is_best = True
            best_pred = new_pred
        else:
            is_best = False
        saver.save_checkpoint(
            {
                'epoch':
                epoch,
                'state_dict_w':
                maml.module.state_dict()
                if isinstance(maml, nn.DataParallel) else maml.state_dict(),
                'state_dict_theta':
                maml.model.arch_parameters(),
                'best_pred':
                best_pred,
            }, is_best)
Ejemplo n.º 11
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = []

    if args.arch == "Unet":
        for block in range(args.NUM_DOWN_CONV):
            out_channels = (2**block) * args.HIDDEN_DIM
            if (block == 0):
                config += [(
                    'conv2d', [out_channels, args.imgc, 3, 3, 1, 1]
                )  # out_c, in_c, k_h, k_w, stride, padding, also only conv, without bias
                           ]
            else:
                config += [
                    ('conv2d', [out_channels, out_channels // 2, 3, 3, 1,
                                1]),  # out_c, in_c, k_h, k_w, stride, padding
                ]
            config += [
                ('leakyrelu',
                 [0.2, False]),  # alpha; if true then executes relu in place
                ('bn', [out_channels])
            ]

            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('max_pool2d', [2, 2,
                                       0])]  # kernel_size, stride, padding

        for block in range(args.NUM_DOWN_CONV - 1):
            out_channels = (2**(args.NUM_DOWN_CONV - block -
                                2)) * args.HIDDEN_DIM
            in_channels = out_channels * 3
            config += [('upsample', [2])]
            config += [('conv2d', [out_channels, in_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
        config += [
            ('conv2d_b', [args.outc, args.HIDDEN_DIM, 3, 3, 1, 1])
        ]  # all the conv2d before are without bias, and this conv_b is with bias
    else:
        raise ("architectures other than Unet hasn't been added!!")
    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    # print(maml)
    for name, param in maml.named_parameters():
        print(name, param.size())
    print('Total trainable tensors:', num)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = "/scratch/users/chenkaim/pytorch-models/pytorch_" + args.model_name + "_k_shot_" + str(
        args.k_spt) + "_task_num_" + str(args.task_num) + "_meta_lr_" + str(
            args.meta_lr) + "_inner_lr_" + str(
                args.update_lr) + "_num_inner_updates_" + str(args.update_step)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/epoch_" + str(args.continue_epoch) + ".pt")
        checkpoint = torch.load(model_path + "/epoch_" +
                                str(args.continue_epoch) + ".pt")
        print(checkpoint.keys())
        print(checkpoint.items())
        maml.load_state_dict(checkpoint['state_dict'])
        maml.lr_scheduler.load_state_dict(checkpoint['scheduler'])
        maml.meta_optim.load_state_dict(checkpoint['optimizer'])
        start_epoch = args.continue_epoch

    db = RCWA_data_loader(batchsz=args.task_num,
                          n_way=args.n_way,
                          k_shot=args.k_spt,
                          k_query=args.k_qry,
                          imgsz=args.imgsz,
                          data_folder=args.data_folder)

    for step in range(start_epoch, args.epoch):
        print("epoch: ", step)
        if step % EPOCH_SAVE_INTERVAL == 0:
            torch.save(maml.state_dict(),
                       model_path + "/epoch_" + str(step) + ".pt")
        for itr in range(
                int(0.7 * db.total_data_samples /
                    ((args.k_spt + args.k_qry) * args.task_num))):
            x_spt, y_spt, x_qry, y_qry = db.next()
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

            # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
            accs, loss_q, ave_trans, min_trans = maml.transference(
                x_spt, y_spt, x_qry, y_qry)

            if itr % SUMMARY_INTERVAL == 0:
                print_str = "Iteration %d: pre-inner-loop train accuracy: %.5f, post-iner-loop test accuracy: %.5f, train_loss: %.5f, ave_trans: %.2f, min_trans: %.2f" % (
                    itr, accs[0], accs[-1], loss_q, ave_trans, min_trans)
                print(print_str)

            if itr % TEST_PRINT_INTERVAL == 0:
                accs = []
                for _ in range(10):
                    # test
                    x_spt, y_spt, x_qry, y_qry = db.next('test')
                    x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                                 torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                    # split to single task each time
                    for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                            x_spt, y_spt, x_qry, y_qry):
                        test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                    x_qry_one, y_qry_one)
                        accs.append(test_acc)

                # [b, update_step+1]
                accs = np.array(accs).mean(axis=0).astype(np.float16)
                print(
                    'Meta-validation pre-inner-loop train accuracy: %.5f, meta-validation post-inner-loop test accuracy: %.5f'
                    % (accs[0], accs[-1]))

        maml.lr_scheduler.step()
Ejemplo n.º 12
0
def main():
    torch.manual_seed(222)  # 为cpu设置种子,为了使结果是确定的
    torch.cuda.manual_seed_all(222)  # 为GPU设置种子,为了使结果是确定的
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 1, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 7040])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000)
    mini_test = MiniImagenet("./miniimagenet", mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100)
    last_accuracy = 0
    plt_train_loss = []
    plt_train_acc = []

    plt_test_loss = []
    plt_test_acc =[]
    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                d = loss_q.cpu()
                dd = d.detach().numpy()
                plt_train_loss.append(dd)
                plt_train_acc.append(accs[-1])
                print('step:', step, '\ttraining acc:', accs)

            if step % 50 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []
                loss_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, loss_test= maml.finetunning(x_spt, y_spt, x_qry, y_qry)

                    loss_all_test.append(loss_test)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                plt_test_acc.append(accs[-1])
                avg_loss = np.mean(np.array(loss_all_test))
                plt_test_loss.append(avg_loss)

                print('Test acc:', accs)
                test_accuracy = np.mean(np.array(accs))
                if test_accuracy > last_accuracy:
                    # save networks
                    torch.save(maml.state_dict(), str(
                        "./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(
                            args.k_spt) + "shot.pkl"))
                    last_accuracy = test_accuracy
    plt.figure()
    plt.title("testing info")
    plt.xlabel("episode")
    plt.ylabel("Acc/loss")
    plt.plot(plt_test_loss, label='Loss')
    plt.plot(plt_test_acc, label='Acc')
    plt.legend(loc='upper right')
    plt.savefig('./drawing/test.png')
    plt.show()

    plt.figure()
    plt.title("training info")
    plt.xlabel("episode")
    plt.ylabel("Acc/loss")
    plt.plot(plt_train_loss, label='Loss')
    plt.plot(plt_train_acc, label='Acc')
    plt.legend(loc='upper right')
    plt.savefig('./drawing/train.png')
    plt.show()
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    if len(args.gpu.split(',')) > 1:
        maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    if args.loader in [0, 1]:  # default loader
        if args.loader == 1:
            #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet
            from MiniImagenet2 import MiniImagenet
        else:
            from MiniImagenet import MiniImagenet

        trainset = MiniImagenet(Param.root,
                                mode='train',
                                n_way=args.n_way,
                                k_shot=args.k_spt,
                                k_query=args.k_qry,
                                resize=args.imgsz)
        testset = MiniImagenet(Param.root,
                               mode='test',
                               n_way=args.n_way,
                               k_shot=args.k_spt,
                               k_query=args.k_qry,
                               resize=args.imgsz)
        trainloader = DataLoader(trainset,
                                 batch_size=args.task_num,
                                 shuffle=True,
                                 num_workers=4,
                                 worker_init_fn=worker_init_fn,
                                 drop_last=True)
        testloader = DataLoader(testset,
                                batch_size=1,
                                shuffle=True,
                                num_workers=1,
                                worker_init_fn=worker_init_fn,
                                drop_last=True)
        train_data = inf_get(trainloader)
        test_data = inf_get(testloader)

    elif args.loader == 2:  # pkl loader
        args_data = {}
        args_data['x_dim'] = "84,84,3"
        args_data['ratio'] = 1.0
        args_data['seed'] = 222
        loader_train = dataset_mini(600, 100, 'train', args_data)
        #loader_val   = dataset_mini(600, 100, 'val', args_data)
        loader_test = dataset_mini(600, 100, 'test', args_data)
        loader_train.load_data_pkl()
        #loader_val.load_data_pkl()
        loader_test.load_data_pkl()

    for epoch in range(args.epoch):
        np.random.seed()
        if args.loader in [0, 1]:
            support_x, support_y, meta_x, meta_y = train_data.__next__()
            support_x, support_y, meta_x, meta_y = support_x.to(
                Param.device), support_y.to(Param.device), meta_x.to(
                    Param.device), meta_y.to(Param.device)
        elif args.loader == 2:
            support_x, support_y, meta_x, meta_y = get_data(loader_train)
            support_x, support_y, meta_x, meta_y = support_x.to(
                Param.device), support_y.to(Param.device), meta_x.to(
                    Param.device), meta_y.to(Param.device)

        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if (epoch % 2500 == 0):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                if args.loader in [0, 1]:
                    support_x, support_y, qx, qy = test_data.__next__()
                    support_x, support_y, qx, qy = support_x.to(
                        Param.device), support_y.to(Param.device), qx.to(
                            Param.device), qy.to(Param.device)
                elif args.loader == 2:
                    support_x, support_y, qx, qy = get_data(loader_test)
                    support_x, support_y, qx, qy = support_x.to(
                        Param.device), support_y.to(Param.device), qx.to(
                            Param.device), qy.to(Param.device)

                temp = maml_clone(support_x,
                                  support_y,
                                  qx,
                                  qy,
                                  meta_train=False)
                if (ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim=0)
            ans = ans.mean(dim=0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(
                    maml.state_dict(), Param.out_path + 'net_' + str(epoch) +
                    '_' + str(best_acc) + '.pkl')
            del maml_clone
            print(str(epoch) + ': ' + str(ans))
            with open(Param.out_path + 'test.json', 'w') as f:
                json.dump(test_result, f)
        if (epoch < 5) or (epoch % 100 == 0):
            plot.flush()
        plot.tick()