Exemplo n.º 1
0
def main(args):
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 64])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    db_train = OmniglotNShot(
        'omniglot',
        batchsz=args.task_num,  # meta-batch size, 32
        n_way=args.n_way,  # n-way, 5
        k_shot=args.k_spt,  # k-shot for support set, 1
        k_query=args.k_qry,  # k-shot for query set, 15
        imgsz=args.imgsz)  # image size, 28 (28x28)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Exemplo n.º 2
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    device = torch.device('cuda')
    maml = Meta(args).to(device)

    db_train = OmniglotNShot(root='E:/meta_learning',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \
                                     torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)  # task_batch=20

        if step % 50 == 0:
            print('step:', step, '\t training acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \
                                             torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1] -> [update_step+1,]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Exemplo n.º 3
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    shared_config = [
        ('conv2d', [64, 1, 3, 3, 2, 0]),
        ('leakyrelu', [.2, True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('leakyrelu', [.2, True]),
        ('bn', [64]),
    ]

    nway_config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]),
                   ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]),
                   ('leakyrelu', [.2, True]), ('bn', [64]),
                   ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]),
                   ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]),
                   ('relu', [True]), ('bn', [64]), ('flatten', []),
                   ('linear', [args.n_way, 64])]

    # reads in image
    discriminator_config = [('conv2d', [64, 1, 3, 3, 2, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('conv2d', [64, 64, 3, 3, 2, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('conv2d', [64, 64, 3, 3, 2, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('conv2d', [64, 64, 2, 2, 1, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('flatten', []), ('linear', [1, 64])
                            # don't use a sigmoid at the end
                            ]

    # new gen_config
    # starts from image and convolves it into new ones
    gen_config = [
        ('convt2d', [1, 64, 3, 3, 1, 1]),
        ('leakyrelu', [.2, True]),
        ('bn', [64]),
        ('random_proj', [100, 28, 64]),
        ('convt2d', [128, 64, 3, 3, 1, 1]),
        #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
        ('relu', [.2, True]),
        ('bn', [64]),
        # ('encode', [1024, 64*28*28]),
        # ('decode', [64*28*28, 1024]),
        ('relu', [.2, True]),
        ('conv2d', [64, 64, 3, 3, 1, 1]),
        #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
        ('relu', [.2, True]),
        ('bn', [64]),
        ('conv2d', [1, 64, 3, 3, 1, 1]),
        ('sigmoid', [True])
    ]

    # old gen_config
    # gen_config = [
    #     ('random_proj', [100, 512, 64, 7]), # [latent_dim, emb_size, ch_out, h_out/w_out]
    #     # img: (64, 7, 7)
    #     ('convt2d', [64, 32, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
    #     ('bn', [32]),
    #     ('relu', [True]),
    #     # img: (32, 14, 14)
    #     ('convt2d', [32, 1, 4, 4, 2, 1]),
    #     # img: (1, 28, 28)
    #     ('sigmoid', [True])
    # ]

    # if args.condition_discrim:
    #     discriminator_config = [
    #         ('condition', [512, 1, 6]), # [emb_dim, emb_ch_out, h_out/w_out]
    #         ('conv2d', [128, 65, 2, 2, 1, 0]),
    #         ('leakyrelu', [.2, True]),
    #         ('bn', [128]),
    #         ('conv2d', [128, 128, 2, 2, 1, 0]),
    #         ('leakyrelu', [.2, True]),
    #         ('bn', [128]),
    #         ('flatten', []),
    #         ('linear', [1, 2048])
    #     ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mamlGAN = MetaGAN(args, shared_config, nway_config, discriminator_config,
                      gen_config).to(device)

    tmp = filter(lambda x: x.requires_grad, mamlGAN.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(mamlGAN)
    print('Total trainable tensors:', num)

    db_train = OmniglotNShot('omniglot',
                             batchsz=args.tasks_per_batch,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             img_sz=args.img_sz)

    save_model = not args.no_save
    if save_model:
        now = datetime.now().replace(second=0, microsecond=0)
        path = "results/" + str(now) + "_omni"
        mkdir_p(path)
        file = open(path + '/architecture.txt', 'w+')
        file.write("shared_config = " + json.dumps(shared_config) + "\n" +
                   "nway_config = " + json.dumps(nway_config) + "\n" +
                   "discriminator_config = " +
                   json.dumps(discriminator_config) + "\n" + "gen_config = " +
                   json.dumps(gen_config) + "\n" + "learn_inner_lr = " +
                   str(args.learn_inner_lr) + "\n" + "condition_discrim = " +
                   str(args.condition_discrim))
        file.close()
    for step in range(args.epoch):
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = mamlGAN(x_spt, y_spt, x_qry, y_qry)

        if step % 30 == 0:
            print("step " + str(step))
            for key in accs.keys():
                print(key + ": " + str(accs[key]))
            if save_model:
                save_train_accs(path, accs, int(step))
        if step % 500 == 0:
            print("testing")
            accs = []
            imgs = []
            for _ in range(1000 // args.tasks_per_batch):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc, ims = mamlGAN.finetunning(
                        x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)
                    imgs.append(x_spt_one.cpu().detach().numpy())
                    imgs.append(ims.cpu().detach().numpy())
                    if args.single_fast_test:
                        break
                if args.single_fast_test:
                    break

            accs = np.array(accs).mean(axis=0).astype(np.float16)
            if save_model:
                save_test_accs(path, accs, int(step))
                imgs = np.array(imgs)
                save_imgs(path, imgs, step)

                torch.save({'model_state_dict': mamlGAN.state_dict()},
                           path + "/model_step" + str(step))
                # to load, do this:
                # checkpoint = torch.load(path + "/model_step" + str(step))
                # mamlGAN.load_state_dict(checkpoint['model_state_dict'])

            # [b, update_steps+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Exemplo n.º 4
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way + 1, 64])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    db_train = OmniglotNShot('omniglot',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    save_path = os.getcwd() + '/data/omniglot/model_batchsz' + str(
        args.k_spt) + '_stepsz' + str(args.update_lr) + '_epoch'

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs, al_accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs, '\tAL acc:', al_accs)

        if step % 500 == 0:
            al_accs, accs = [], []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    al_accs.append(maml.al_test(x_qry_one, y_qry_one))
                    accs.append(test_acc)

            # [b, update_step+1]
            pdb.set_trace()
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
            al_accs = np.array(al_accs).mean(axis=0).astype(np.float16)
            print('AL acc:', al_accs)
            torch.save(maml.state_dict(), save_path + str(step) + "_al.pt")
Exemplo n.º 5
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 64])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    db_train = OmniglotNShot('omniglot',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    print(args.task_num)  # 32
    print(args.n_way)  # 5
    print(args.k_spt)  # 1
    print(args.k_qry)  # 15
    print(args.imgsz)  # 28

    for step in range(args.epoch):
        #if step % 4000 == 0:
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                    torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        #print("TYPE: " + y_spt.type())
        #pdb.set_trace()
        y_spt = torch.tensor(y_spt, dtype=torch.int64,
                             device=device)  # diff syntax ?
        y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device)
        #print("TYPE: "+y_spt.type())
        #pdb.set_trace()
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        # if step % 500 == 0:
        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                            torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                y_spt = torch.tensor(y_spt, dtype=torch.int64,
                                     device=device)  # diff syntax ?
                y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    print(x_spt_one.size())
                    print(y_spt_one.size())
                    print(x_qry_one.size())
                    print(y_qry_one.size())
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Exemplo n.º 6
0
def main(args):

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 64])]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    maml = Meta(args, config, device).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))

    print('Total trainable tensors:', num)

    db_train = OmniglotNShot('./',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)
        print('trainstep:', step, '\ttraining acc:', accs)

        if (step + 1) % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)

    ##############################
    for i in range(args.prune_iteration):
        # prune
        print("the {}th prune step".format(i))
        x_spt, y_spt, x_qry, y_qry = db_train.getHoleTrain()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                 torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)
        maml.prune(x_spt, y_spt, x_qry, y_qry, args.prune_number_one_epoch,
                   args.max_prune_number)

        # fine-tuning
        print("start finetuning....")
        finetune_epoch = args.finetune_epoch
        finetune_epoch = finetune_epoch * (2 if i == args.prune_iteration -
                                           1 else 1)

        for step in range(args.finetune_epoch):
            x_spt, y_spt, x_qry, y_qry = db_train.next()
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)
            accs = maml(x_spt, y_spt, x_qry, y_qry, finetune=True)

            print('finetune step:', step, '\ttraining acc:', accs)

        # print the test accuracy after pruning
        print("start testing....")
        accs = []
        for _ in range(1000 // args.task_num):
            # test
            x_spt, y_spt, x_qry, y_qry = db_train.next('test')
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

            # split to single task each time
            for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                    x_spt, y_spt, x_qry, y_qry):
                test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one,
                                            y_qry_one)
                accs.append(test_acc)

        # [b, update_step+1]
        accs = np.array(accs).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
Exemplo n.º 7
0
def test_progress(args, net, device, viz=None, global_step=0):
    """
    to plot ani/ars/acc with respect to training epochs.
    :param args:
    :param net:
    :param device:
    :param viz:
    :return:
    """
    if args.resume is None:
        print('No ckpt file specified! make sure you are training!')

    exp = args.exp

    if viz is None:
        viz = visdom.Visdom(env='test')
    visualh = VisualH(viz)

    print('Testing now...')
    output_dir = os.path.join(args.test_dir, args.exp)
    # create test_dir
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)
    # create test_dir/exp
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # clustering, visualization and classification
    db_test = OmniglotNShot('db/omniglot',
                            batchsz=1,
                            n_way=args.n_way,
                            k_shot=args.k_spt,
                            k_query=args.k_qry,
                            imgsz=args.imgsz)

    h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = 0, 0, 0, 0
    acc0, acc1 = [], []

    for batchidx in range(args.test_episode_num):
        spt_x, spt_y, qry_x, qry_y = db_test.next('test')
        spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \
                                     torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device)
        assert spt_x.size(0) == 1
        spt_x, spt_y, qry_x, qry_y = spt_x.squeeze(0), spt_y.squeeze(
            0), qry_x.squeeze(0), qry_y.squeeze(0)

        # we can get the representation before first update, after k update
        # and test the representation on merged(test_spt, test_qry) set
        h_spt0, h_spt1, h_qry0, h_qry1, _, new_net = net.finetuning(
            spt_x, spt_y, qry_x, qry_y, args.finetuning_steps, None)

        if batchidx == 0:
            visualh.update(h_spt0, h_spt1, h_qry0, h_qry1, spt_y, qry_y,
                           global_step)

        # we will use the acquired representation to cluster.
        # h_spt: [sptsz, h_dim]
        # h_qry: [qrysz, h_dim]
        h_qry0_np = h_qry0.detach().cpu().numpy()
        h_qry1_np = h_qry1.detach().cpu().numpy()
        qry_y_np = qry_y.detach().cpu().numpy()
        h_qry0_pred = cluster.KMeans(n_clusters=args.n_way,
                                     random_state=0).fit(h_qry0_np).labels_
        h_qry1_pred = cluster.KMeans(n_clusters=args.n_way,
                                     random_state=0).fit(h_qry1_np).labels_
        h_qry0_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry0_pred)
        h_qry0_ars += metrics.adjusted_rand_score(qry_y_np, h_qry0_pred)
        h_qry1_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry1_pred)
        h_qry1_ars += metrics.adjusted_rand_score(qry_y_np, h_qry1_pred)

        h_qry0_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y)
        h_qry1_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y)
        # viz.heatmap(X=h_qry0_cm, win=args.exp+' h_qry0_cm', opts=dict(title=args.exp+' h_qry0_cm:%d'%batchidx,
        #                                                               colormap='Electric'))
        # viz.heatmap(X=h_qry1_cm, win=args.exp+' h_qry1_cm', opts=dict(title=args.exp+' h_qry1_cm:%d'%batchidx,
        #                                                               colormap='Electric'))

        # return is a list of [acc_step0, acc_step1 ,...]
        acc0.append(
            net.classify_train(h_spt0,
                               spt_y,
                               h_qry0,
                               qry_y,
                               use_h=True,
                               train_step=args.classify_steps))
        acc1.append(
            net.classify_train(h_spt1,
                               spt_y,
                               h_qry1,
                               qry_y,
                               use_h=True,
                               train_step=args.classify_steps))

        if batchidx == 0:
            spt_x_hat0 = net.forward_ae(spt_x[:64])
            qry_x_hat0 = net.forward_ae(qry_x[:64])
            spt_x_hat1 = new_net.forward_ae(spt_x[:64])
            qry_x_hat1 = new_net.forward_ae(qry_x[:64])
            viz.images(qry_x[:64],
                       nrow=8,
                       win=exp + 'qry_x',
                       opts=dict(title=exp + 'qry_x'))
            # viz.images(spt_x_hat0, nrow=8, win=exp+'spt_x_hat0', opts=dict(title=exp+'spt_x_hat0'))
            viz.images(qry_x_hat0,
                       nrow=8,
                       win=exp + 'qry_x_hat0',
                       opts=dict(title=exp + 'qry_x_hat0'))
            # viz.images(spt_x_hat1, nrow=8, win=exp+'spt_x_hat1', opts=dict(title=exp+'spt_x_hat1'))
            viz.images(qry_x_hat1,
                       nrow=8,
                       win=exp + 'qry_x_hat1',
                       opts=dict(title=exp + 'qry_x_hat1'))

        if batchidx > 0:
            break


    h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = h_qry0_ami / (batchidx + 1), h_qry0_ars / (batchidx + 1), \
                                                     h_qry1_ami / (batchidx + 1), h_qry1_ars / (batchidx + 1)
    # [[epsode1], [episode2],...] = [N, steps] => [steps]
    acc0, acc1 = np.array(acc0).mean(axis=0), np.array(acc1).mean(axis=0)

    print('ami:', h_qry0_ami, h_qry1_ami)
    print('ars:', h_qry0_ars, h_qry1_ars)
    viz.line([[h_qry0_ami, h_qry1_ami]], [global_step],
             win=exp + 'ami_on_qry01',
             update='append')
    viz.line([[h_qry0_ars, h_qry1_ars]], [global_step],
             win=exp + 'ars_on_qry01',
             update='append')
    print('acc:\n', acc0, '\n', acc1)
    viz.line([[acc0[-1], acc1[-1]]], [global_step],
             win=exp + 'acc_on_qry01',
             update='append')
Exemplo n.º 8
0
def main(args):

    args = update_args(args)

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    device = torch.device('cuda')
    if args.is_meta:
        # optimizer has been embedded in model.
        net = MetaAE(args)
        # model_parameters = filter(lambda p: p.requires_grad, net.learner.parameters())
        # params = sum([np.prod(p.size()) for p in model_parameters])
        # print('Total params:', params)
        tmp = filter(lambda x: x.requires_grad, net.learner.parameters())
        num = sum(map(lambda x: np.prod(x.shape), tmp))
        print('Total trainable variables:', num)
    else:
        net = AE(args, use_logits=True)
        optimizer = optim.Adam(list(net.encoder.parameters()) +
                               list(net.decoder.parameters()),
                               lr=args.meta_lr)

        tmp = filter(
            lambda x: x.requires_grad,
            list(net.encoder.parameters()) + list(net.decoder.parameters()))
        num = sum(map(lambda x: np.prod(x.shape), tmp))
        print('Total trainable variables:', num)

    net.to(device)
    print(net)

    print('=' * 15, 'Experiment:', args.exp, '=' * 15)
    print(args)

    if args.h_dim == 2:
        # borrowed from https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py
        h_range = np.rollaxis(
            np.mgrid[args.h_range:-args.h_range:args.h_nrow * 1j,
                     args.h_range:-args.h_range:args.h_nrow * 1j], 0, 3)
        # [b, q_h]
        h_manifold = torch.from_numpy(h_range.reshape([-1,
                                                       2])).to(device).float()
        print('h_manifold:', h_manifold.shape)
    else:
        h_manifold = None

    # try to resume from ckpt.mdl file
    epoch_start = 0
    if args.resume is not None:
        # ckpt/normal-fc-vae_640_2018-11-20_09:58:58.mdl
        mdl_file = args.resume
        epoch_start = int(mdl_file.split('_')[-3])
        net.load_state_dict(torch.load(mdl_file))
        print('Resume from:', args.resume, 'epoch/batches:', epoch_start)
    else:
        print('Training/test from scratch...')

    if args.test:
        assert args.resume is not None
        test.test_ft_steps(args, net, device)
        return

    vis = visdom.Visdom(env=args.exp)
    visualh = VisualH(vis)
    vis.line([[0, 0, 0]], [epoch_start],
             win=args.exp + 'train_loss',
             opts=dict(title=args.exp + 'train_qloss',
                       legend=['loss', '-lklh', 'kld'],
                       xlabel='global_step'))

    # for test_progress
    vis.line([[0, 0]], [epoch_start],
             win=args.exp + 'acc_on_qry01',
             opts=dict(title=args.exp + 'acc_on_qry01',
                       legend=['h_qry0', 'h_qry1'],
                       xlabel='global_step'))
    vis.line([[0, 0]], [epoch_start],
             win=args.exp + 'ami_on_qry01',
             opts=dict(title=args.exp + 'ami_on_qry01',
                       legend=['h_qry0', 'h_qry1'],
                       xlabel='global_step'))
    vis.line([[0, 0]], [epoch_start],
             win=args.exp + 'ars_on_qry01',
             opts=dict(title=args.exp + 'ars_on_qry01',
                       legend=['h_qry0', 'h_qry1'],
                       xlabel='global_step'))

    # 1. train
    db_train = OmniglotNShot('db/omniglot',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    # epoch = batch number here.
    for epoch in range(epoch_start, args.train_episode_num):
        spt_x, spt_y, qry_x, qry_y = db_train.next()
        spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \
                                     torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device)

        if args.is_meta:  # for meta
            loss_optim, losses_q, likelihoods_q, klds_q = net(
                spt_x, spt_y, qry_x, qry_y)

            if epoch % 300 == 0:

                if args.is_vae:
                    # print(losses_q, likelihoods_q, klds_q)
                    vis.line([[
                        losses_q[-1].item(), -likelihoods_q[-1].item(),
                        klds_q[-1].item()
                    ]], [epoch],
                             win=args.exp + 'train_loss',
                             update='append')
                    print(epoch)
                    print(
                        'loss_q:',
                        torch.stack(losses_q).detach().cpu().numpy().astype(
                            np.float16))
                    print(
                        'lkhd_q:',
                        torch.stack(
                            likelihoods_q).detach().cpu().numpy().astype(
                                np.float16))
                    print(
                        'klds_q:',
                        torch.stack(klds_q).cpu().detach().numpy().astype(
                            np.float16))
                else:
                    # print(losses_q, likelihoods_q, klds_q)
                    vis.line([[losses_q[-1].item(), 0, 0]], [epoch],
                             win=args.exp + 'train_loss',
                             update='append')
                    print(
                        epoch,
                        torch.stack(losses_q).detach().cpu().numpy().astype(
                            np.float16))

        else:  # for normal vae/ae

            loss_optim, _, likelihood, kld = net(spt_x, spt_y, qry_x, qry_y)
            optimizer.zero_grad()
            loss_optim.backward()
            torch.nn.utils.clip_grad_norm_(
                list(net.encoder.parameters()) +
                list(net.decoder.parameters()), 10)
            optimizer.step()

            if epoch % 300 == 0:

                print(epoch, loss_optim.item())
                if not args.is_vae:
                    vis.line([[loss_optim.item(), 0, 0]], [epoch],
                             win='train_loss',
                             update='append')
                else:
                    vis.line(
                        [[loss_optim.item(), -likelihood.item(),
                          kld.item()]], [epoch],
                        win='train_loss',
                        update='append')

        if epoch % 3000 == 0:
            # [qrysz, 1, 64, 64] => [qrysz, 1, 64, 64]
            x_hat = net.forward_ae(qry_x[0])
            vis.images(x_hat,
                       nrow=args.k_qry,
                       win='train_x_hat',
                       opts=dict(title='train_x_hat'))
            test.test_progress(args, net, device, vis, epoch)

        # save checkpoint.
        if epoch % 10000 == 0:
            date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            mdl_file = os.path.join(
                args.ckpt_dir,
                args.exp + '_%d' % epoch + '_' + date_str + '.mdl')
            torch.save(net.state_dict(), mdl_file)
            print('Saved into ckpt file:', mdl_file)

    # save checkpoint.
    date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    mdl_file = os.path.join(
        args.ckpt_dir, args.exp + '_%d' % args.epoch + '_' + date_str + '.mdl')
    torch.save(net.state_dict(), mdl_file)
    print('Saved Last state ckpt file:', mdl_file)
Exemplo n.º 9
0
def main(args):
    if not os.path.exists('./logs'):
        os.mkdir('./logs')
    logfile = os.path.sep.join(('.', 'logs', f'omniglot_way[{args.n_way}]_shot[{args.k_spt}].json'))
    if args.write_log:
        log_fp = open(logfile, 'wb')
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [64, 1, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 2, 2, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('flatten', []),
        ('linear', [args.n_way, 64])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    path = os.sep.join((os.path.dirname(__file__), 'dataset', 'omniglot'))
    db_train = OmniglotNShot(path,
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):
        # 获取一定的 epoch 数据. 在omniglot NShot类里写的是
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Exemplo n.º 10
0
def main(args):
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    torch.backends.cudnn.benchmark=True
    print(args)

    config = [
        ("conv2d", [64, 1, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 2, 2, 1, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("flatten", []),
        ("linear", [args.n_way, 64]),
    ]

    device = torch.device("cuda")
    maml = Meta(args, config).to(device)

    db_train = OmniglotNShot(
        "omniglot",
        batchsz=args.task_num,
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        imgsz=args.imgsz,
    )

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print("Total trainable tensors:", num)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = (
            torch.from_numpy(x_spt).to(device),
            torch.from_numpy(y_spt).to(device),
            torch.from_numpy(x_qry).to(device),
            torch.from_numpy(y_qry).to(device),
        )

        # x_spt: shape: 32, 5, 1, 28, 28
        # y_spt: shape: 32, 5
        # x_qry: 32, 75, 1, 28, 28
        # y_qry: 32, 75

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print("step:", step, "\ttraining acc:", accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next("test")
                x_spt, y_spt, x_qry, y_qry = (
                    torch.from_numpy(x_spt).to(device),
                    torch.from_numpy(y_spt).to(device),
                    torch.from_numpy(x_qry).to(device),
                    torch.from_numpy(y_qry).to(device),
                )

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print("Test acc:", accs)