コード例 #1
0
ファイル: SVHNtrain.py プロジェクト: BarcaD12/svhn-maml
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 256])]

    #device = torch.device('cuda')
    device = torch.device('cpu')

    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    db_train = SvhnNShot(batchsz=args.task_num,
                         n_way=args.n_way,
                         k_shot=args.k_spt,
                         k_query=args.k_qry,
                         imgsz=args.imgsz)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt.long(), x_qry, y_qry.long())

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one.long(),
                                                x_qry_one, y_qry_one.long())
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
コード例 #2
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    print(args)
    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    ckpt_dir = "./checkpoint_miniimage.pth"
    print("Load trained model")
    ckpt = torch.load(ckpt_dir)
    maml.load_state_dict(ckpt['model'])

    mini_test = MiniImagenet("F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\",
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=1,
                             resize=args.imgsz)

    db_test = DataLoader(mini_test,
                         1,
                         shuffle=True,
                         num_workers=1,
                         pin_memory=True)
    accs_all_test = []
    #count = 0
    #print("Test_loader",db_test)

    for x_spt, y_spt, x_qry, y_qry in db_test:

        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
        x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
        accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
コード例 #3
0
def main():
    torch.manual_seed(121)
    torch.cuda.manual_seed_all(121)
    np.random.seed(121)

    nshot = SinwaveNShot(all_numbers_class=2000,
                         batch_size=20,
                         n_way=5,
                         k_shot=5,
                         k_query=15,
                         root='data')
    maml = Meta(hid_dim=64, meta_lr=1e-3, update_lr=0.004)

    for step in range(10000):
        x_spt, y_spt, x_qry, y_qry, param_spt, param_qry = nshot.next('train')
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt), torch.from_numpy(
            y_spt), torch.from_numpy(x_qry), torch.from_numpy(y_qry)

        loss = maml(x_spt, y_spt, x_qry, y_qry)
        if step % 20 == 0:
            print('step:', step, '\ttraining loss:', loss)

        if step % 500 == 0:
            loss = []
            for _ in range(1000 // 20):
                # test
                x_spt, y_spt, x_qry, y_qry, param_spt, param_qry = nshot.next(
                    'test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt), torch.from_numpy(y_spt), \
                                             torch.from_numpy(x_qry), torch.from_numpy(y_qry)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one, param_spt_one, param_qry_onein in \
                        zip(x_spt, y_spt, x_qry, y_qry, param_spt, param_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one,
                                                param_spt_one, param_qry_onein)
                    loss.append(test_acc)

            # [b, update_step+1]
            loss = np.array(loss).mean(axis=0).astype(np.float16)
            print('Test loss:', loss)
コード例 #4
0
ファイル: main.py プロジェクト: YinjieJ/MaskMAML
def main():
    print(args)
    device = torch.device('cuda')
    maml = Meta(args).to(device)
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)
    trainset = Gen(args.task_num, args.k_spt, args.k_qry)
    testset = Gen(args.task_num, args.k_spt, args.k_qry * 10)

    for epoch in range(args.epoch):
        ind = [i for i in range(trainset.xs.shape[0])]
        np.random.shuffle(ind)
        xs, ys = torch.Tensor(trainset.xs[ind]).to(device), torch.Tensor(
            trainset.ys[ind]).to(device)
        xq, yq = torch.Tensor(trainset.xq[ind]).to(device), torch.Tensor(
            trainset.yq[ind]).to(device)
        maml.train()
        loss = maml(xs, ys, xq, yq, epoch)
        print('Epoch: {} Initial loss: {} Train loss: {}'.format(
            epoch, loss[0] / args.task_num, loss[-1] / args.task_num))
        if (epoch + 1) % 50 == 0:
            print("Evaling the model...")
            torch.save(maml.state_dict(), 'save.pt')
            # del(maml)
            # maml = Meta(args).to(device)
            # maml.load_state_dict(torch.load('save.pt'))
            maml.eval()
            i = random.randint(0, testset.xs.shape[0] - 1)
            xs, ys = torch.Tensor(testset.xs[i]).to(device), torch.Tensor(
                testset.ys[i]).to(device)
            xq, yq = torch.Tensor(testset.xq[i]).to(device), torch.Tensor(
                testset.yq[i]).to(device)
            losses, losses_q, logits_q, _ = maml.finetunning(xs, ys, xq, yq)
            print('Epoch: {} Initial loss: {} Test loss: {}'.format(
                epoch, losses_q[0], losses_q[-1]))
コード例 #5
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    mini_test = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)


    ckpt_dir = "./model/"

    for epoch in range(args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)


        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)

                # save checkpoints
                os.makedirs(ckpt_dir, exist_ok=True)
                print('Saving the model as a checkpoint...')
                torch.save({'epoch': epoch, 'Steps': step, 'model': maml.state_dict()}, os.path.join(ckpt_dir, 'checkpoint.pth'))
コード例 #6
0
ファイル: train.py プロジェクト: ziqiaomeng/G-Meta
def main():
    mem_usage = memory_usage(-1, interval=.5, timeout=1)
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    root = args.data_dir

    feat = np.load(root + 'features.npy', allow_pickle=True)

    with open(root + '/graph_dgl.pkl', 'rb') as f:
        dgl_graph = pickle.load(f)

    if args.task_setup == 'Disjoint':
        with open(root + 'label.pkl', 'rb') as f:
            info = pickle.load(f)
    elif args.task_setup == 'Shared':
        if args.task_mode == 'True':
            root = root + '/task' + str(args.task_n) + '/'
        with open(root + 'label.pkl', 'rb') as f:
            info = pickle.load(f)

    total_class = len(np.unique(np.array(list(info.values()))))
    print('There are {} classes '.format(total_class))

    if args.task_setup == 'Disjoint':
        labels_num = args.n_way
    elif args.task_setup == 'Shared':
        labels_num = total_class

    if len(feat.shape) == 2:
        # single graph, to make it compatible to multiple graph retrieval.
        feat = [feat]

    config = [('GraphConv', [feat[0].shape[1], args.hidden_dim])]

    if args.h > 1:
        config = config + [('GraphConv', [args.hidden_dim, args.hidden_dim])
                           ] * (args.h - 1)

    config = config + [('Linear', [args.hidden_dim, labels_num])]

    if args.link_pred_mode == 'True':
        config.append(('LinkPred', [True]))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    max_acc = 0
    model_max = copy.deepcopy(maml)

    db_train = Subgraphs(root,
                         'train',
                         info,
                         n_way=args.n_way,
                         k_shot=args.k_spt,
                         k_query=args.k_qry,
                         batchsz=args.batchsz,
                         args=args,
                         adjs=dgl_graph,
                         h=args.h)
    db_val = Subgraphs(root,
                       'val',
                       info,
                       n_way=args.n_way,
                       k_shot=args.k_spt,
                       k_query=args.k_qry,
                       batchsz=100,
                       args=args,
                       adjs=dgl_graph,
                       h=args.h)
    db_test = Subgraphs(root,
                        'test',
                        info,
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=100,
                        args=args,
                        adjs=dgl_graph,
                        h=args.h)
    print('------ Start Training ------')
    s_start = time.time()
    max_memory = 0
    for epoch in range(args.epoch):
        db = DataLoader(db_train,
                        args.task_num,
                        shuffle=True,
                        num_workers=args.num_workers,
                        pin_memory=True,
                        collate_fn=collate)
        s_f = time.time()
        for step, (x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry,
                   g_spt, g_qry) in enumerate(db):
            nodes_len = 0
            if step >= 1:
                data_loading_time = time.time() - s_r
            else:
                data_loading_time = time.time() - s_f
            s = time.time()
            # x_spt: a list of #task_num tasks, where each task is a mini-batch of k-shot * n_way subgraphs
            # y_spt: a list of #task_num lists of labels. Each list is of length k-shot * n_way int.
            nodes_len += sum([sum([len(j) for j in i]) for i in n_spt])
            accs = maml(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry,
                        g_spt, g_qry, feat)
            max_memory = max(max_memory,
                             float(psutil.virtual_memory().used / (1024**3)))
            if step % args.train_result_report_steps == 0:
                print('Epoch:', epoch + 1, ' Step:', step, ' training acc:',
                      str(accs[-1])[:5], ' time elapsed:',
                      str(time.time() - s)[:5], ' data loading takes:',
                      str(data_loading_time)[:5], ' Memory usage:',
                      str(float(psutil.virtual_memory().used / (1024**3)))[:5])
            s_r = time.time()

        # validation per epoch
        db_v = DataLoader(db_val,
                          1,
                          shuffle=True,
                          num_workers=args.num_workers,
                          pin_memory=True,
                          collate_fn=collate)
        accs_all_test = []

        for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_v:

            accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry,
                                    n_spt, n_qry, g_spt, g_qry, feat)
            accs_all_test.append(accs)

        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        print('Epoch:', epoch + 1, ' Val acc:', str(accs[-1])[:5])
        if accs[-1] > max_acc:
            max_acc = accs[-1]
            model_max = copy.deepcopy(maml)

    db_t = DataLoader(db_test,
                      1,
                      shuffle=True,
                      num_workers=args.num_workers,
                      pin_memory=True,
                      collate_fn=collate)
    accs_all_test = []

    for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t:
        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry,
                                n_spt, n_qry, g_spt, g_qry, feat)
        accs_all_test.append(accs)

    accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
    print('Test acc:', str(accs[1])[:5])

    for x_spt, y_spt, x_qry, y_qry, c_spt, c_qry, n_spt, n_qry, g_spt, g_qry in db_t:
        accs = model_max.finetunning(x_spt, y_spt, x_qry, y_qry, c_spt, c_qry,
                                     n_spt, n_qry, g_spt, g_qry, feat)
        accs_all_test.append(accs)

    #torch.save(model_max.state_dict(), './model.pt')

    accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
    print('Early Stopped Test acc:', str(accs[-1])[:5])
    print('Total Time:', str(time.time() - s_start)[:5])
    print('Max Momory:', str(max_memory)[:5])
コード例 #7
0
def main():
    #print(args)
    TARGET_MODEL = 3

    config = [
        ('conv2d', [16, 1, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [16]),
        ('conv2d', [32, 16, 4, 4, 2, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('conv2d', [64, 32, 4, 4, 2, 1]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 4, 4, 2, 1]),
        ('relu', [True]),
        ('bn', [64]),
        ('convt2d', [64, 32, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('convt2d', [32, 16, 4, 4, 2, 1]),
        ('relu', [True]),
        ('bn', [16]),
        ('convt2d', [16, 8, 4, 4, 2, 1]),
        ('relu', [True]),
        ('bn', [8]),
        ('convt2d', [8, 1, 3, 3, 1, 1]),
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    #print(maml)
    #print('Total trainable tensors:', num)

    # initiate different datasets 
    minis = []
    for i in range(args.task_num):
        path = osp.join("./zoo_cw_grad_mnist/train", MODELS[i] + "_mnist.npy")
        mini = mnist(path,
                    mode='train', 
                    n_way=args.n_way, 
                    k_shot=args.k_spt,
                    k_query=args.k_qry,
                    batchsz=100, 
                    resize=args.imgsz)
        db = DataLoader(mini, args.batchsize, shuffle=True, num_workers=0, pin_memory=True)
        minis.append(db)

    path_test = osp.join("./zoo_cw_grad_mnist/test", MODELS[TARGET_MODEL] + "_mnist.npy")
    mini_test = mnist(path_test, 
                    mode='test', 
                    n_way=1, 
                    k_shot=args.k_spt,
                    k_query=args.k_qry,
                    batchsz=100, 
                    resize=args.imgsz)

    mini_test = DataLoader(mini_test, 10, shuffle=True, num_workers=0, pin_memory=True)

    # start training
    step_number = len(minis[0])
    test_step_number = len(mini_test)
    BEST_ACC = 1.0
    target_model = get_target_model(TARGET_MODEL).to(device)
    def save_model(model,acc):
        model_file_path = './checkpoint/mnist'
        if not os.path.exists(model_file_path):
            os.makedirs(model_file_path)
    
        file_name = str(acc) + 'mnist_'+ MODELS[TARGET_MODEL] + '.pt'
        save_model_path = os.path.join(model_file_path, file_name)
        torch.save(model.state_dict(), save_model_path)
    def load_model(model,acc):
        model_checkpoint_path = './checkpoint/mnist/' + str(acc) + 'mnist_' + MODELS[TARGET_MODEL] + '.pt'
        assert os.path.exists(model_checkpoint_path)
        model.load_state_dict(torch.load(model_checkpoint_path))        
        return model

    for epoch in range(args.epoch//100):
        minis_iter = []
        for i in range(len(minis)):
            minis_iter.append(iter(minis[i]))
        mini_test_iter = iter(mini_test)
        if args.resume:
            maml = load_model(maml,"0.7231071")
        for step in range(step_number):
            batch_data = []
            for each in minis_iter:
                batch_data.append(each.next())
            accs = maml(batch_data, device)

            if (step + 1) % step_number == 0:
                print('Training acc:', accs)
                if accs[0] < BEST_ACC:
                    BEST_ACC = accs[0]
                    save_model(maml, BEST_ACC)

            if (epoch + 1) % 15 == 0 and step ==0:  # evaluation
                accs_all_test = []
                for i in range(3):
                    test_data = mini_test_iter.next()
                    accs = maml.finetunning(test_data, target_model, device)
                    accs_all_test.append(accs)

                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
コード例 #8
0
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    maml = Meta(args, STSTNet()).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('/home/lf/miniImagenet/',
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000,
                        resize=args.imgsz)
    mini_test = MiniImagenet('/home/lf/miniImagenet/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100,
                             resize=args.imgsz)

    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in tqdm(enumerate(db)):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(mini_test,
                                     1,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
    torch.save(maml.net, 'maml_ststnet.pth')
コード例 #9
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]
    cuda = 'cuda:' + args.gpu_index
    device = torch.device(cuda)
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
#     print(maml)
    print('Total trainable tensors:', num)
    if args.mode == 0:
        mode_val_test = 'val'
        train = 'train'
    else:
        mode_val_test = 'test'
        train = 'train_ts'

    # batchsz here means total episode number
    mini = MiniImagenet('../flower/', mode=train, n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz, cross_val_idx=args.cross_val_idx)
    mini_test = MiniImagenet('../flower/', mode=mode_val_test, n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz, cross_val_idx=args.cross_val_idx)
    accs_list_tr = []
    accs_list_ts = []
    for epoch in range(args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)
                accs_list_tr.append(accs)

            if step % 500 == 0 or (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]):  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('step:', step, '\ttest acc:', accs)
                accs_list_ts.append(accs)
                if (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]):
                    with open('acc_cv2/natural2_' + mode_val_test + '(task_num_' + str(args.task_num) +  ').txt', mode='a') as f:
                        f.write(str(accs[-1]) + ',')
コード例 #10
0
def main(args):
    if not os.path.exists('./logs'):
        os.mkdir('./logs')
    logfile = os.path.sep.join(('.', 'logs', f'omniglot_way[{args.n_way}]_shot[{args.k_spt}].json'))
    if args.write_log:
        log_fp = open(logfile, 'wb')
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [64, 1, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 2, 2, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('flatten', []),
        ('linear', [args.n_way, 64])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    path = os.sep.join((os.path.dirname(__file__), 'dataset', 'omniglot'))
    db_train = OmniglotNShot(path,
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):
        # 获取一定的 epoch 数据. 在omniglot NShot类里写的是
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
コード例 #11
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    device = torch.device('cuda')

    config = []

    if args.arch == "Unet":
        for block in range(args.NUM_DOWN_CONV):
            out_channels = (2**block) * args.HIDDEN_DIM
            if (block == 0):
                config += [(
                    'conv2d', [out_channels, args.imgc, 3, 3, 1, 1]
                )  # out_c, in_c, k_h, k_w, stride, padding, also only conv, without bias
                           ]
            else:
                config += [
                    ('conv2d', [out_channels, out_channels // 2, 3, 3, 1,
                                1]),  # out_c, in_c, k_h, k_w, stride, padding
                ]
            config += [
                ('leakyrelu',
                 [0.2, False]),  # alpha; if true then executes relu in place
                ('bn', [out_channels])
            ]

            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('max_pool2d', [2, 2,
                                       0])]  # kernel_size, stride, padding

        for block in range(args.NUM_DOWN_CONV - 1):
            out_channels = (2**(args.NUM_DOWN_CONV - block -
                                2)) * args.HIDDEN_DIM
            in_channels = out_channels * 3
            config += [('upsample', [2])]
            config += [('conv2d', [out_channels, in_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
            config += [('conv2d', [out_channels, out_channels, 3, 3, 1, 1]),
                       ('leakyrelu', [0.2, False]), ('bn', [out_channels])]
        config += [
            ('conv2d_b', [args.outc, args.HIDDEN_DIM, 3, 3, 1, 1])
        ]  # all the conv2d before are without bias, and this conv_b is with bias
    else:
        raise ("architectures other than Unet hasn't been added!!")

    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    # print(maml)
    for name, param in maml.named_parameters():
        print(name, param.size())
    print('Total trainable tensors:', num)

    SUMMARY_INTERVAL = 5
    TEST_PRINT_INTERVAL = SUMMARY_INTERVAL * 5
    ITER_SAVE_INTERVAL = 300
    EPOCH_SAVE_INTERVAL = 5

    model_path = "/scratch/users/chenkaim/pytorch-models/pytorch_" + args.model_name + "_k_shot_" + str(
        args.k_spt) + "_task_num_" + str(args.task_num) + "_meta_lr_" + str(
            args.meta_lr) + "_inner_lr_" + str(
                args.update_lr) + "_num_inner_updates_" + str(args.update_step)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    start_epoch = 0
    if (args.continue_train):
        print("Restoring weights from ",
              model_path + "/epoch_" + str(args.continue_epoch) + ".pt")
        checkpoint = torch.load(model_path + "/epoch_" +
                                str(args.continue_epoch) + ".pt")
        maml = checkpoint['model']
        maml.lr_scheduler = checkpoint['lr_scheduler']
        maml.meta_optim = checkpoint['optimizer']
        start_epoch = args.continue_epoch

    db = RCWA_data_loader(batchsz=args.task_num,
                          n_way=args.n_way,
                          k_shot=args.k_spt,
                          k_query=args.k_qry,
                          imgsz=args.imgsz,
                          data_folder=args.data_folder)

    for step in range(start_epoch, args.epoch):
        print("epoch: ", step)
        if step % EPOCH_SAVE_INTERVAL == 0:
            checkpoint = {
                'epoch': step,
                'model': maml,
                'optimizer': maml.meta_optim,
                'lr_scheduler': maml.lr_scheduler
            }
            torch.save(checkpoint, model_path + "/epoch_" + str(step) + ".pt")
        for itr in range(
                int(0.7 * db.total_data_samples /
                    ((args.k_spt + args.k_qry) * args.task_num))):
            x_spt, y_spt, x_qry, y_qry = db.next()
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

            # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
            accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry)

            if itr % SUMMARY_INTERVAL == 0:
                print_str = "Iteration %d: pre-inner-loop train accuracy: %.5f, post-iner-loop test accuracy: %.5f, train_loss: %.5f" % (
                    itr, accs[0], accs[-1], loss_q)
                print(print_str)

            if itr % TEST_PRINT_INTERVAL == 0:
                accs = []
                for _ in range(10):
                    # test
                    x_spt, y_spt, x_qry, y_qry = db.next('test')
                    x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                                 torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                    # split to single task each time
                    for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                            x_spt, y_spt, x_qry, y_qry):
                        test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                    x_qry_one, y_qry_one)
                        accs.append(test_acc)

                # [b, update_step+1]
                accs = np.array(accs).mean(axis=0).astype(np.float16)
                print(
                    'Meta-validation pre-inner-loop train accuracy: %.5f, meta-validation post-inner-loop test accuracy: %.5f'
                    % (accs[0], accs[-1]))

        maml.lr_scheduler.step()
コード例 #12
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='train',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=10000,
        resize=args.imgsz)
    mini_val = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='val',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=600,
        resize=args.imgsz)
    mini_test = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='test',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=600,
        resize=args.imgsz)

    best_acc = 0.0
    if not os.path.exists('ckpt/{}'.format(args.exp)):
        os.mkdir('ckpt/{}'.format(args.exp))
    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 500 == 0:
                print('step:', step, '\ttraining acc:', accs)
            if step % 1000 == 0:  # evaluation
                db_val = DataLoader(mini_val,
                                    1,
                                    shuffle=True,
                                    num_workers=1,
                                    pin_memory=True)
                accs_all_val = []
                for x_spt, y_spt, x_qry, y_qry in db_val:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_val.append(accs)
                mean, std, ci95 = cal_conf(np.array(accs_all_val))
                print('Val acc:{}, std:{}. ci95:{}'.format(
                    mean[-1], std[-1], ci95[-1]))
                if mean[-1] > best_acc or step % 5000 == 0:
                    best_acc = mean[-1]
                    torch.save(
                        maml.state_dict(),
                        'ckpt/{}/model_e{}s{}_{:.4f}.pkl'.format(
                            args.exp, epoch, step, best_acc))
                    with open('ckpt/' + args.exp + '/val.txt', 'a') as f:
                        print(
                            'val epoch {}, step {}: acc_val:{:.4f}, ci95:{:.4f}'
                            .format(epoch, step, best_acc, ci95[-1]),
                            file=f)

                    ## Test
                    db_test = DataLoader(mini_test,
                                         1,
                                         shuffle=True,
                                         num_workers=1,
                                         pin_memory=True)
                    accs_all_test = []
                    for x_spt, y_spt, x_qry, y_qry in db_test:
                        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                     x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                        accs_all_test.append(accs)
                    mean, std, ci95 = cal_conf(np.array(accs_all_test))
                    print('Test acc:{}, std:{}, ci95:{}'.format(
                        mean[-1], std[-1], ci95[-1]))
                    with open('ckpt/' + args.exp + '/test.txt', 'a') as f:
                        print(
                            'test epoch {}, step {}: acc_test:{:.4f}, ci95:{:.4f}'
                            .format(epoch, step, mean[-1], ci95[-1]),
                            file=f)
コード例 #13
0
ファイル: train_prune.py プロジェクト: sjtu-cs222/Group_15
def main(args):

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 64])]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    maml = Meta(args, config, device).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))

    print('Total trainable tensors:', num)

    db_train = OmniglotNShot('./',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)
        print('trainstep:', step, '\ttraining acc:', accs)

        if (step + 1) % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)

    ##############################
    for i in range(args.prune_iteration):
        # prune
        print("the {}th prune step".format(i))
        x_spt, y_spt, x_qry, y_qry = db_train.getHoleTrain()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                 torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)
        maml.prune(x_spt, y_spt, x_qry, y_qry, args.prune_number_one_epoch,
                   args.max_prune_number)

        # fine-tuning
        print("start finetuning....")
        finetune_epoch = args.finetune_epoch
        finetune_epoch = finetune_epoch * (2 if i == args.prune_iteration -
                                           1 else 1)

        for step in range(args.finetune_epoch):
            x_spt, y_spt, x_qry, y_qry = db_train.next()
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)
            accs = maml(x_spt, y_spt, x_qry, y_qry, finetune=True)

            print('finetune step:', step, '\ttraining acc:', accs)

        # print the test accuracy after pruning
        print("start testing....")
        accs = []
        for _ in range(1000 // args.task_num):
            # test
            x_spt, y_spt, x_qry, y_qry = db_train.next('test')
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

            # split to single task each time
            for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                    x_spt, y_spt, x_qry, y_qry):
                test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one,
                                            y_qry_one)
                accs.append(test_acc)

        # [b, update_step+1]
        accs = np.array(accs).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
コード例 #14
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)

    config = [('conv2d', [32, 3, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]),
              ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]),
              ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    root = '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet'
    trainset = MiniImagenet(root,
                            mode='train',
                            n_way=args.n_way,
                            k_shot=args.k_spt,
                            k_query=args.k_qry,
                            resize=args.imgsz)
    testset = MiniImagenet(root,
                           mode='test',
                           n_way=args.n_way,
                           k_shot=args.k_spt,
                           k_query=args.k_qry,
                           resize=args.imgsz)
    trainloader = DataLoader(trainset,
                             batch_size=args.task_num,
                             shuffle=True,
                             num_workers=4,
                             worker_init_fn=worker_init_fn,
                             drop_last=True)
    testloader = DataLoader(testset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=1,
                            worker_init_fn=worker_init_fn,
                            drop_last=True)
    train_data = inf_get(trainloader)
    test_data = inf_get(testloader)

    best_acc = 0.0
    if not os.path.exists('ckpt/{}'.format(args.exp)):
        os.mkdir('ckpt/{}'.format(args.exp))
    for epoch in range(args.epoch):
        np.random.seed()
        x_spt, y_spt, x_qry, y_qry = train_data.__next__()
        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
            device), x_qry.to(device), y_qry.to(device)

        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if epoch % 100 == 0:
            print('epoch:', epoch, '\ttraining acc:', accs)

        if epoch % 2500 == 0:  # evaluation
            # save checkpoint
            torch.save(maml.state_dict(),
                       'ckpt/{}/model_{}.pkl'.format(args.exp, epoch))
            accs_all_test = []
            for _ in range(600):
                x_spt, y_spt, x_qry, y_qry = test_data.__next__()
                x_spt, y_spt, x_qry, y_qry = x_spt.squeeze().to(
                    device), y_spt.squeeze().to(device), x_qry.squeeze().to(
                        device), y_qry.squeeze().to(device)
                accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                accs_all_test.append(accs)

            # [b, update_step+1]
            accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
            with open('ckpt/' + args.exp + '/test.txt', 'a') as f:
                print('test epoch {}: acc:{:.4f}'.format(epoch, accs[-1]),
                      file=f)
コード例 #15
0
def main():
    torch.manual_seed(222)  # 为cpu设置种子,为了使结果是确定的
    torch.cuda.manual_seed_all(222)  # 为GPU设置种子,为了使结果是确定的
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 1, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 7040])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000)
    mini_test = MiniImagenet("./miniimagenet", mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100)
    last_accuracy = 0
    plt_train_loss = []
    plt_train_acc = []

    plt_test_loss = []
    plt_test_acc =[]
    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                d = loss_q.cpu()
                dd = d.detach().numpy()
                plt_train_loss.append(dd)
                plt_train_acc.append(accs[-1])
                print('step:', step, '\ttraining acc:', accs)

            if step % 50 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []
                loss_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, loss_test= maml.finetunning(x_spt, y_spt, x_qry, y_qry)

                    loss_all_test.append(loss_test)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                plt_test_acc.append(accs[-1])
                avg_loss = np.mean(np.array(loss_all_test))
                plt_test_loss.append(avg_loss)

                print('Test acc:', accs)
                test_accuracy = np.mean(np.array(accs))
                if test_accuracy > last_accuracy:
                    # save networks
                    torch.save(maml.state_dict(), str(
                        "./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(
                            args.k_spt) + "shot.pkl"))
                    last_accuracy = test_accuracy
    plt.figure()
    plt.title("testing info")
    plt.xlabel("episode")
    plt.ylabel("Acc/loss")
    plt.plot(plt_test_loss, label='Loss')
    plt.plot(plt_test_acc, label='Acc')
    plt.legend(loc='upper right')
    plt.savefig('./drawing/test.png')
    plt.show()

    plt.figure()
    plt.title("training info")
    plt.xlabel("episode")
    plt.ylabel("Acc/loss")
    plt.plot(plt_train_loss, label='Loss')
    plt.plot(plt_train_acc, label='Acc')
    plt.legend(loc='upper right')
    plt.savefig('./drawing/train.png')
    plt.show()
コード例 #16
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('./data/',
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000,
                        resize=args.img_sz)
    mini_test = MiniImagenet('./data/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100,
                             resize=args.img_sz)

    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(mini_test,
                                     1,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
コード例 #17
0
def main():

    start_time = time.time()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    print(args)
    print(argv)
    os.makedirs(args.modelfile.split('/')[0], exist_ok=True)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    mini = MiniImagenet('./dataset/mini-imagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    if args.domain == 'mini':
        mini_test = MiniImagenet('./dataset/mini-imagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
        mini_val = MiniImagenet('./dataset/mini-imagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    elif args.domain == 'cub':
        print("CUB dataset")
        mini_test = MiniImagenet('./dataset/CUB_200_2011/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    elif args.domain == 'traffic':
        print("Traffic dataset")
        mini_test = MiniImagenet('./dataset/GTSRB/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    elif args.domain == 'flower':
        print("flower dataset")
        mini_test = MiniImagenet('./dataset/102flowers/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    else:
        print("Dataset Error")
        return

    if args.mode == 'test':
        count = 0
        db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=6, pin_memory=True)
        accs_all_test = []

        for x_spt, y_spt, x_qry, y_qry in db_test:
            print(count)
            count += 1
            x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                         x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
            accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'test', args.modelfile, pertub_scale=args.pertub_scale, num_ensemble=args.num_ensemble, fgsm_epsilon=args.fgsm_epsilon)
            accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        np.set_printoptions(linewidth=1000)
        print("Running Time:", time.time()-start_time)
        print(accs)
        return


    for epoch in range(args.epoch//10000):
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=4, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('epoch:', epoch, 'step:', step, '\ttraining acc:', accs)

            if step % 200 == 0:
                print("Save model", args.modelfile)
                torch.save(maml, args.modelfile)
                db_test = DataLoader(mini_val, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_val = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'train_test')
                    accs_all_val.append(accs)
                
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_test = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)
                    
                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                accs_val = np.array(accs_all_val).mean(axis=0).astype(np.float16)

                save_modelfile = "{}_{}_{}_{:0.4f}_{:0.4f}.pth".format(args.modelfile, epoch, step, accs_val[-1], accs[-1])
                print(save_modelfile)
                torch.save(maml, save_modelfile) 
                print("Val:", accs_val)
                print("Test:", accs)
コード例 #18
0
def main():
    torch.manual_seed(222)  # 为cpu设置种子,为了使结果是确定的
    torch.cuda.manual_seed_all(222)  # 为GPU设置种子,为了使结果是确定的
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 1, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 7040])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)
    if os.path.exists(
            "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")):
        path = "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")
        maml.load_state_dict(path)
        print("load model success")

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000)
    mini_test = MiniImagenet("./miniimagenet", mode='val', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100)
    test_accuracy = []
    for epoch in range(10):
        # fetch meta_batchsz num of episode each time
        db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
        accs_all_test = []

        for x_spt, y_spt, x_qry, y_qry in db_test:
            x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                         x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

            accs, loss_t = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
            accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
        test_accuracy.append(accs[-1])
    average_accuracy = sum(test_accuracy) / len(test_accuracy)
    print("average accuracy:{}".format(average_accuracy))
コード例 #19
0
def main(args):
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    torch.backends.cudnn.benchmark=True
    print(args)

    config = [
        ("conv2d", [64, 1, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 2, 2, 1, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("flatten", []),
        ("linear", [args.n_way, 64]),
    ]

    device = torch.device("cuda")
    maml = Meta(args, config).to(device)

    db_train = OmniglotNShot(
        "omniglot",
        batchsz=args.task_num,
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        imgsz=args.imgsz,
    )

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print("Total trainable tensors:", num)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = (
            torch.from_numpy(x_spt).to(device),
            torch.from_numpy(y_spt).to(device),
            torch.from_numpy(x_qry).to(device),
            torch.from_numpy(y_qry).to(device),
        )

        # x_spt: shape: 32, 5, 1, 28, 28
        # y_spt: shape: 32, 5
        # x_qry: 32, 75, 1, 28, 28
        # y_qry: 32, 75

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print("step:", step, "\ttraining acc:", accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next("test")
                x_spt, y_spt, x_qry, y_qry = (
                    torch.from_numpy(x_spt).to(device),
                    torch.from_numpy(y_spt).to(device),
                    torch.from_numpy(x_qry).to(device),
                    torch.from_numpy(y_qry).to(device),
                )

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print("Test acc:", accs)