Example #1
0
def main(args):
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 64])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    db_train = OmniglotNShot(
        'omniglot',
        batchsz=args.task_num,  # meta-batch size, 32
        n_way=args.n_way,  # n-way, 5
        k_shot=args.k_spt,  # k-shot for support set, 1
        k_query=args.k_qry,  # k-shot for query set, 15
        imgsz=args.imgsz)  # image size, 28 (28x28)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Example #2
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way', default=5)
    argparser.add_argument('-k', help='k shot', default=1)
    argparser.add_argument('-b', help='batch size', default=32)
    argparser.add_argument('-l', help='meta learning rate', default=1e-3)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    meta_batchsz = int(args.b)
    meta_lr = float(args.l)
    train_lr = 0.4

    k_query = 15
    imgsz = 84
    mdl_file = 'ckpt/omniglot%d%d.mdl' % (n_way, k_shot)
    print('omniglot: %d-way %d-shot meta-lr:%f, train-lr:%f' %
          (n_way, k_shot, meta_lr, train_lr))

    device = torch.device('cuda:0')
    net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr,
               device)
    print(net)

    # batchsz here means total episode number
    db = OmniglotNShot('omniglot',
                       batchsz=meta_batchsz,
                       n_way=n_way,
                       k_shot=k_shot,
                       k_query=k_query,
                       imgsz=imgsz)

    for step in range(10000000):

        # train
        support_x, support_y, query_x, query_y = db.get_batch('train')
        support_x = torch.from_numpy(support_x).float().transpose(
            2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1).to(device)
        query_x = torch.from_numpy(query_x).float().transpose(2, 4).transpose(
            3, 4).repeat(1, 1, 3, 1, 1).to(device)
        support_y = torch.from_numpy(support_y).long().to(device)
        query_y = torch.from_numpy(query_y).long().to(device)

        accs = net(support_x, support_y, query_x, query_y, training=True)

        if step % 20 == 0:
            print(step, '\t', accs)

        if step % 1000 == 0:
            # test
            pass
Example #3
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    device = torch.device('cuda')
    maml = Meta(args).to(device)

    db_train = OmniglotNShot(root='E:/meta_learning',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \
                                     torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)  # task_batch=20

        if step % 50 == 0:
            print('step:', step, '\t training acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.FloatTensor(x_spt).to(device), torch.LongTensor(y_spt).to(device), \
                                             torch.FloatTensor(x_qry).to(device), torch.LongTensor(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1] -> [update_step+1,]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Example #4
0
def main():
	meta_batchsz = 32
	n_way = 20
	k_shot = 1
	k_query = k_shot
	meta_lr = 1e-3
# 	meta_lr = 1
	num_updates = 5
	dataset = 'omniglot'



	if dataset == 'omniglot':
		imgsz = 28
		db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz)

	elif dataset == 'mini-imagenet':
		imgsz = 84
		# the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use
		# get_batch(train or test) to get different batch.
		# for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader.
		mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=10000, resize=imgsz)
		db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True)
		mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=1000, resize=imgsz)
		db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True)

	else:
		raise  NotImplementedError


	meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr,
	                   num_updates=num_updates).cuda()

	tb = SummaryWriter('runs')


	# main loop
	for episode_num in range(1500):

		# 1. train
		if dataset == 'omniglot':
			support_x, support_y, query_x, query_y = db.get_batch('test')
			support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
			query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
			support_y = Variable(torch.from_numpy(support_y).long()).cuda()
			query_y = Variable(torch.from_numpy(query_y).long()).cuda()
		elif dataset == 'mini-imagenet':
			try:
				batch_test = iter(db).next()
			except StopIteration as err:
				mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
				                    batchsz=10000, resize=imgsz)
				db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True)

			support_x = Variable(batch_test[0]).cuda()
			support_y = Variable(batch_test[1]).cuda()
			query_x = Variable(batch_test[2]).cuda()
			query_y = Variable(batch_test[3]).cuda()

		# backprop has been embeded in forward func.
		if episode_num % 100 = 0:
			meta.prv_angle = 0
		accs = meta(support_x, support_y, query_x, query_y)
		train_acc = np.array(accs).mean()

		# 2. test
		if episode_num % 30 == 0:
			test_accs = []
			for i in range(min(episode_num // 5000 + 3, 10)): # get average acc.
				if dataset == 'omniglot':
					support_x, support_y, query_x, query_y = db.get_batch('test')
					support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
					query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
					support_y = Variable(torch.from_numpy(support_y).long()).cuda()
					query_y = Variable(torch.from_numpy(query_y).long()).cuda()
				elif dataset == 'mini-imagenet':
					try:
						batch_test = iter(db_test).next()
					except StopIteration as err:
						mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot,
						                         k_query=k_query,
						                         batchsz=1000, resize=imgsz)
						db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True)
					support_x = Variable(batch_test[0]).cuda()
					support_y = Variable(batch_test[1]).cuda()
					query_x = Variable(batch_test[2]).cuda()
					query_y = Variable(batch_test[3]).cuda()
 

				# get accuracy
				test_acc = meta.pred(support_x, support_y, query_x, query_y)
				test_accs.append(test_acc)

			test_acc = np.array(test_accs).mean()
			print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc)
			tb.add_scalar('test-acc', test_acc)
			tb.add_scalar('finetune-acc', train_acc)
Example #5
0
def main():
    meta_batchsz = 32 * 3
    n_way = 5
    k_shot = 5
    k_query = k_shot
    meta_lr = 1e-3
    num_updates = 5
    dataset = "mini-imagenet"

    if dataset == "omniglot":
        imgsz = 28
        db = OmniglotNShot(
            "dataset",
            batchsz=meta_batchsz,
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            imgsz=imgsz,
        )

    elif dataset == "mini-imagenet":
        imgsz = 84
        # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use
        # get_batch(train or test) to get different batch.
        # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader.
        mini = MiniImagenet(
            "../../hdd1/meta/mini-imagenet/",
            mode="train",
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            batchsz=10000,
            resize=imgsz,
        )
        db = DataLoader(mini,
                        meta_batchsz,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)
        mini_test = MiniImagenet(
            "../../hdd1/meta/mini-imagenet/",
            mode="test",
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            batchsz=1000,
            resize=imgsz,
        )
        db_test = DataLoader(mini_test,
                             meta_batchsz,
                             shuffle=True,
                             num_workers=2,
                             pin_memory=True)

    else:
        raise NotImplementedError

    # do NOT call .cuda() implicitly
    net = CSML()
    net.deploy()

    tb = SummaryWriter("runs")

    # main loop
    for episode_num in range(200000):

        # 1. train
        if dataset == "omniglot":
            support_x, support_y, query_x, query_y = db.get_batch("test")
            support_x = Variable(
                torch.from_numpy(support_x).float().transpose(2, 4).transpose(
                    3, 4).repeat(1, 1, 3, 1, 1)).cuda()
            query_x = Variable(
                torch.from_numpy(query_x).float().transpose(2, 4).transpose(
                    3, 4).repeat(1, 1, 3, 1, 1)).cuda()
            support_y = Variable(torch.from_numpy(support_y).long()).cuda()
            query_y = Variable(torch.from_numpy(query_y).long()).cuda()

        elif dataset == "mini-imagenet":
            try:
                batch_train = iter(db).next()
            except StopIteration as err:
                mini = MiniImagenet(
                    "../../hdd1/meta/mini-imagenet/",
                    mode="train",
                    n_way=n_way,
                    k_shot=k_shot,
                    k_query=k_query,
                    batchsz=10000,
                    resize=imgsz,
                )
                db = DataLoader(mini,
                                meta_batchsz,
                                shuffle=True,
                                num_workers=4,
                                pin_memory=True)
                batch_train = iter(db).next()

            support_x = Variable(batch_train[0])
            support_y = Variable(batch_train[1])
            query_x = Variable(batch_train[2])
            query_y = Variable(batch_train[3])
            print(support_x.size(), support_y.size())

        # backprop has been embeded in forward func.
        accs = net.train(support_x, support_y, query_x, query_y)
        train_acc = np.array(accs).mean()

        # 2. test
        if episode_num % 30 == 220:
            test_accs = []
            for i in range(min(episode_num // 5000 + 3,
                               10)):  # get average acc.
                if dataset == "omniglot":
                    support_x, support_y, query_x, query_y = db.get_batch(
                        "test")
                    support_x = Variable(
                        torch.from_numpy(support_x).float().transpose(
                            2, 4).transpose(3, 4).repeat(1, 1, 3, 1,
                                                         1)).cuda()
                    query_x = Variable(
                        torch.from_numpy(query_x).float().transpose(
                            2, 4).transpose(3, 4).repeat(1, 1, 3, 1,
                                                         1)).cuda()
                    support_y = Variable(
                        torch.from_numpy(support_y).long()).cuda()
                    query_y = Variable(torch.from_numpy(query_y).long()).cuda()

                elif dataset == "mini-imagenet":
                    try:
                        batch_test = iter(db_test).next()
                    except StopIteration as err:
                        mini_test = MiniImagenet(
                            "../../hdd1/meta/mini-imagenet/",
                            mode="test",
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=1000,
                            resize=imgsz,
                        )
                        db_test = DataLoader(
                            mini_test,
                            meta_batchsz,
                            shuffle=True,
                            num_workers=2,
                            pin_memory=True,
                        )
                        batch_test = iter(db).next()

                    support_x = Variable(batch_test[0])
                    support_y = Variable(batch_test[1])
                    query_x = Variable(batch_test[2])
                    query_y = Variable(batch_test[3])

                # get accuracy
                # test_acc = net.train(support_x, support_y, query_x, query_y, train=False)
                test_accs.append(test_acc)

            test_acc = np.array(test_accs).mean()
            print(
                "episode:",
                episode_num,
                "\tfinetune acc:%.6f" % train_acc,
                "\t\ttest acc:%.6f" % test_acc,
            )
            tb.add_scalar("test-acc", test_acc)
            tb.add_scalar("finetune-acc", train_acc)
Example #6
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    shared_config = [
        ('conv2d', [64, 1, 3, 3, 2, 0]),
        ('leakyrelu', [.2, True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('leakyrelu', [.2, True]),
        ('bn', [64]),
    ]

    nway_config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('leakyrelu', [.2, True]),
                   ('bn', [64]), ('conv2d', [64, 64, 3, 3, 2, 0]),
                   ('leakyrelu', [.2, True]), ('bn', [64]),
                   ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]),
                   ('bn', [64]), ('conv2d', [64, 64, 2, 2, 1, 0]),
                   ('relu', [True]), ('bn', [64]), ('flatten', []),
                   ('linear', [args.n_way, 64])]

    # reads in image
    discriminator_config = [('conv2d', [64, 1, 3, 3, 2, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('conv2d', [64, 64, 3, 3, 2, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('conv2d', [64, 64, 3, 3, 2, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('conv2d', [64, 64, 2, 2, 1, 0]),
                            ('leakyrelu', [.2, True]), ('bn', [64]),
                            ('flatten', []), ('linear', [1, 64])
                            # don't use a sigmoid at the end
                            ]

    # new gen_config
    # starts from image and convolves it into new ones
    gen_config = [
        ('convt2d', [1, 64, 3, 3, 1, 1]),
        ('leakyrelu', [.2, True]),
        ('bn', [64]),
        ('random_proj', [100, 28, 64]),
        ('convt2d', [128, 64, 3, 3, 1, 1]),
        #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
        ('relu', [.2, True]),
        ('bn', [64]),
        # ('encode', [1024, 64*28*28]),
        # ('decode', [64*28*28, 1024]),
        ('relu', [.2, True]),
        ('conv2d', [64, 64, 3, 3, 1, 1]),
        #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
        ('relu', [.2, True]),
        ('bn', [64]),
        ('conv2d', [1, 64, 3, 3, 1, 1]),
        ('sigmoid', [True])
    ]

    # old gen_config
    # gen_config = [
    #     ('random_proj', [100, 512, 64, 7]), # [latent_dim, emb_size, ch_out, h_out/w_out]
    #     # img: (64, 7, 7)
    #     ('convt2d', [64, 32, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
    #     ('bn', [32]),
    #     ('relu', [True]),
    #     # img: (32, 14, 14)
    #     ('convt2d', [32, 1, 4, 4, 2, 1]),
    #     # img: (1, 28, 28)
    #     ('sigmoid', [True])
    # ]

    # if args.condition_discrim:
    #     discriminator_config = [
    #         ('condition', [512, 1, 6]), # [emb_dim, emb_ch_out, h_out/w_out]
    #         ('conv2d', [128, 65, 2, 2, 1, 0]),
    #         ('leakyrelu', [.2, True]),
    #         ('bn', [128]),
    #         ('conv2d', [128, 128, 2, 2, 1, 0]),
    #         ('leakyrelu', [.2, True]),
    #         ('bn', [128]),
    #         ('flatten', []),
    #         ('linear', [1, 2048])
    #     ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mamlGAN = MetaGAN(args, shared_config, nway_config, discriminator_config,
                      gen_config).to(device)

    tmp = filter(lambda x: x.requires_grad, mamlGAN.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(mamlGAN)
    print('Total trainable tensors:', num)

    db_train = OmniglotNShot('omniglot',
                             batchsz=args.tasks_per_batch,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             img_sz=args.img_sz)

    save_model = not args.no_save
    if save_model:
        now = datetime.now().replace(second=0, microsecond=0)
        path = "results/" + str(now) + "_omni"
        mkdir_p(path)
        file = open(path + '/architecture.txt', 'w+')
        file.write("shared_config = " + json.dumps(shared_config) + "\n" +
                   "nway_config = " + json.dumps(nway_config) + "\n" +
                   "discriminator_config = " +
                   json.dumps(discriminator_config) + "\n" + "gen_config = " +
                   json.dumps(gen_config) + "\n" + "learn_inner_lr = " +
                   str(args.learn_inner_lr) + "\n" + "condition_discrim = " +
                   str(args.condition_discrim))
        file.close()
    for step in range(args.epoch):
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = mamlGAN(x_spt, y_spt, x_qry, y_qry)

        if step % 30 == 0:
            print("step " + str(step))
            for key in accs.keys():
                print(key + ": " + str(accs[key]))
            if save_model:
                save_train_accs(path, accs, int(step))
        if step % 500 == 0:
            print("testing")
            accs = []
            imgs = []
            for _ in range(1000 // args.tasks_per_batch):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc, ims = mamlGAN.finetunning(
                        x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)
                    imgs.append(x_spt_one.cpu().detach().numpy())
                    imgs.append(ims.cpu().detach().numpy())
                    if args.single_fast_test:
                        break
                if args.single_fast_test:
                    break

            accs = np.array(accs).mean(axis=0).astype(np.float16)
            if save_model:
                save_test_accs(path, accs, int(step))
                imgs = np.array(imgs)
                save_imgs(path, imgs, step)

                torch.save({'model_state_dict': mamlGAN.state_dict()},
                           path + "/model_step" + str(step))
                # to load, do this:
                # checkpoint = torch.load(path + "/model_step" + str(step))
                # mamlGAN.load_state_dict(checkpoint['model_state_dict'])

            # [b, update_steps+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Example #7
0
def evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold,
               mdl_file):
    """
	obey the expriment setting of MAML and Learning2Compare, we randomly sample 600 episodes and 15 query images per query
	set.
	:param net:
	:param batchsz:
	:return:
	"""
    k_query = 15
    db = OmniglotNShot('dataset',
                       batchsz=batchsz,
                       n_way=n_way,
                       k_shot=k_shot,
                       k_query=k_query,
                       imgsz=imgsz)

    accs = []
    episode_num = 0  # record tested num of episodes

    for i in range(600 // batchsz):
        support_x, support_y, query_x, query_y = db.get_batch('test')
        support_x = Variable(
            torch.from_numpy(support_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        query_x = Variable(
            torch.from_numpy(query_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        support_y = Variable(torch.from_numpy(support_y).int()).cuda()
        query_y = Variable(torch.from_numpy(query_y).int()).cuda()

        # we will split query set into 15 splits.
        # query_x : [batch, 15*way, c_, h, w]
        # query_x_b : tuple, 15 * [b, way, c_, h, w]
        query_x_b = torch.chunk(query_x, k_query, dim=1)
        # query_y : [batch, 15*way]
        # query_y_b: 15* [b, way]
        query_y_b = torch.chunk(query_y, k_query, dim=1)
        preds = []
        net.eval()
        # we don't need the total acc on 600 episodes, but we need the acc per sets of 15*nway setsz.
        total_correct = 0
        total_num = 0
        for query_x_mini, query_y_mini in zip(query_x_b, query_y_b):
            # print('query_x_mini', query_x_mini.size(), 'query_y_mini', query_y_mini.size())
            pred, correct = net(support_x, support_y,
                                query_x_mini.contiguous(), query_y_mini, False)
            correct = correct.sum()  # multi-gpu
            # pred: [b, nway]
            preds.append(pred)
            total_correct += correct.data[0]
            total_num += query_y_mini.size(0) * query_y_mini.size(1)
        # # 15 * [b, nway] => [b, 15*nway]
        # preds = torch.cat(preds, dim= 1)
        acc = total_correct / total_num
        print('%.5f,' % acc, end=' ')
        sys.stdout.flush()
        accs.append(acc)

        # update tested episode number
        episode_num += query_y.size(0)
        if episode_num > episodesz:
            # test current tested episodes acc.
            acc = np.array(accs).mean()
            if acc >= threhold:
                # if current acc is very high, we conduct all 600 episodes testing.
                continue
            else:
                # current acc is low, just conduct `episodesz` num of episodes.
                break

    # compute the distribution of 600/episodesz episodes acc.
    global best_accuracy
    accs = np.array(accs)
    accuracy, sem = mean_confidence_interval(accs)
    print('\naccuracy:', accuracy, 'sem:', sem)
    print('<<<<<<<<< accuracy:', accuracy, 'best accuracy:', best_accuracy,
          '>>>>>>>>')

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(net.state_dict(), mdl_file)
        print('Saved to checkpoint:', mdl_file)

    return accuracy, sem
Example #8
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way')
    argparser.add_argument('-k', help='k shot')
    argparser.add_argument('-b', help='batch size')
    argparser.add_argument('-l', help='learning rate', default=1e-3)
    argparser.add_argument('-t',
                           help='threshold to test all episodes',
                           default=0.97)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    k_query = 1
    batchsz = int(args.b)
    imgsz = 84
    lr = float(args.l)
    threshold = float(args.t)

    db = OmniglotNShot('dataset',
                       batchsz=batchsz,
                       n_way=n_way,
                       k_shot=k_shot,
                       k_query=k_query,
                       imgsz=imgsz)
    print('Omniglot: no rotate!  %d-way %d-shot  lr:%f' % (n_way, k_shot, lr))
    net = NaiveRN(n_way, k_shot, imgsz).cuda()
    print(net)
    mdl_file = 'ckpt/omni%d%d.mdl' % (n_way, k_shot)

    if os.path.exists(mdl_file):
        print('recover from state: ', mdl_file)
        net.load_state_dict(torch.load(mdl_file))
    else:
        print('training from scratch.')

    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('Total params:', params)

    input, input_y, query, query_y = db.get_batch(
        'train')  # (batch, n_way*k_shot, img)
    print('get batch:', input.shape, query.shape, input_y.shape, query_y.shape)

    optimizer = optim.Adam(net.parameters(), lr=lr)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'max',
                                               factor=0.5,
                                               patience=15,
                                               verbose=True)

    total_train_loss = 0
    for step in range(100000000):

        # 1. test
        if step % 400 == 0:
            accuracy, _ = 0, 0  # evaluation(net, batchsz, n_way, k_shot, imgsz, 300, threshold, mdl_file)
            scheduler.step(accuracy)

        # 2. train
        support_x, support_y, query_x, query_y = db.get_batch('train')
        support_x = Variable(
            torch.from_numpy(support_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        query_x = Variable(
            torch.from_numpy(query_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        support_y = Variable(torch.from_numpy(support_y).int()).cuda()
        query_y = Variable(torch.from_numpy(query_y).int()).cuda()

        loss = net(support_x, support_y, query_x, query_y)
        total_train_loss += loss.data[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 3. print
        if step % 20 == 0 and step != 0:
            print('%d-way %d-shot %d batch> step:%d, loss:%f' %
                  (n_way, k_shot, batchsz, step, total_train_loss))
            total_train_loss = 0
Example #9
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way + 1, 64])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    db_train = OmniglotNShot('omniglot',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    save_path = os.getcwd() + '/data/omniglot/model_batchsz' + str(
        args.k_spt) + '_stepsz' + str(args.update_lr) + '_epoch'

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs, al_accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs, '\tAL acc:', al_accs)

        if step % 500 == 0:
            al_accs, accs = [], []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    al_accs.append(maml.al_test(x_qry_one, y_qry_one))
                    accs.append(test_acc)

            # [b, update_step+1]
            pdb.set_trace()
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
            al_accs = np.array(al_accs).mean(axis=0).astype(np.float16)
            print('AL acc:', al_accs)
            torch.save(maml.state_dict(), save_path + str(step) + "_al.pt")
Example #10
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 64])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    db_train = OmniglotNShot('omniglot',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    print(args.task_num)  # 32
    print(args.n_way)  # 5
    print(args.k_spt)  # 1
    print(args.k_qry)  # 15
    print(args.imgsz)  # 28

    for step in range(args.epoch):
        #if step % 4000 == 0:
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                    torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        #print("TYPE: " + y_spt.type())
        #pdb.set_trace()
        y_spt = torch.tensor(y_spt, dtype=torch.int64,
                             device=device)  # diff syntax ?
        y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device)
        #print("TYPE: "+y_spt.type())
        #pdb.set_trace()
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        # if step % 500 == 0:
        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                            torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                y_spt = torch.tensor(y_spt, dtype=torch.int64,
                                     device=device)  # diff syntax ?
                y_qry = torch.tensor(y_qry, dtype=torch.int64, device=device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    print(x_spt_one.size())
                    print(y_spt_one.size())
                    print(x_qry_one.size())
                    print(y_qry_one.size())
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Example #11
0
def main(args):

    config = [('conv2d', [64, 1, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 3, 3, 2, 0]), ('relu', [True]), ('bn', [64]),
              ('conv2d', [64, 64, 2, 2, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('flatten', []), ('linear', [args.n_way, 64])]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    maml = Meta(args, config, device).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))

    print('Total trainable tensors:', num)

    db_train = OmniglotNShot('./',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)
        print('trainstep:', step, '\ttraining acc:', accs)

        if (step + 1) % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                        x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one,
                                                x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)

    ##############################
    for i in range(args.prune_iteration):
        # prune
        print("the {}th prune step".format(i))
        x_spt, y_spt, x_qry, y_qry = db_train.getHoleTrain()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                 torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)
        maml.prune(x_spt, y_spt, x_qry, y_qry, args.prune_number_one_epoch,
                   args.max_prune_number)

        # fine-tuning
        print("start finetuning....")
        finetune_epoch = args.finetune_epoch
        finetune_epoch = finetune_epoch * (2 if i == args.prune_iteration -
                                           1 else 1)

        for step in range(args.finetune_epoch):
            x_spt, y_spt, x_qry, y_qry = db_train.next()
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)
            accs = maml(x_spt, y_spt, x_qry, y_qry, finetune=True)

            print('finetune step:', step, '\ttraining acc:', accs)

        # print the test accuracy after pruning
        print("start testing....")
        accs = []
        for _ in range(1000 // args.task_num):
            # test
            x_spt, y_spt, x_qry, y_qry = db_train.next('test')
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).long().to(device), \
                                         torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).long().to(device)

            # split to single task each time
            for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(
                    x_spt, y_spt, x_qry, y_qry):
                test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one,
                                            y_qry_one)
                accs.append(test_acc)

        # [b, update_step+1]
        accs = np.array(accs).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
Example #12
0
def test_progress(args, net, device, viz=None, global_step=0):
    """
    to plot ani/ars/acc with respect to training epochs.
    :param args:
    :param net:
    :param device:
    :param viz:
    :return:
    """
    if args.resume is None:
        print('No ckpt file specified! make sure you are training!')

    exp = args.exp

    if viz is None:
        viz = visdom.Visdom(env='test')
    visualh = VisualH(viz)

    print('Testing now...')
    output_dir = os.path.join(args.test_dir, args.exp)
    # create test_dir
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)
    # create test_dir/exp
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # clustering, visualization and classification
    db_test = OmniglotNShot('db/omniglot',
                            batchsz=1,
                            n_way=args.n_way,
                            k_shot=args.k_spt,
                            k_query=args.k_qry,
                            imgsz=args.imgsz)

    h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = 0, 0, 0, 0
    acc0, acc1 = [], []

    for batchidx in range(args.test_episode_num):
        spt_x, spt_y, qry_x, qry_y = db_test.next('test')
        spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \
                                     torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device)
        assert spt_x.size(0) == 1
        spt_x, spt_y, qry_x, qry_y = spt_x.squeeze(0), spt_y.squeeze(
            0), qry_x.squeeze(0), qry_y.squeeze(0)

        # we can get the representation before first update, after k update
        # and test the representation on merged(test_spt, test_qry) set
        h_spt0, h_spt1, h_qry0, h_qry1, _, new_net = net.finetuning(
            spt_x, spt_y, qry_x, qry_y, args.finetuning_steps, None)

        if batchidx == 0:
            visualh.update(h_spt0, h_spt1, h_qry0, h_qry1, spt_y, qry_y,
                           global_step)

        # we will use the acquired representation to cluster.
        # h_spt: [sptsz, h_dim]
        # h_qry: [qrysz, h_dim]
        h_qry0_np = h_qry0.detach().cpu().numpy()
        h_qry1_np = h_qry1.detach().cpu().numpy()
        qry_y_np = qry_y.detach().cpu().numpy()
        h_qry0_pred = cluster.KMeans(n_clusters=args.n_way,
                                     random_state=0).fit(h_qry0_np).labels_
        h_qry1_pred = cluster.KMeans(n_clusters=args.n_way,
                                     random_state=0).fit(h_qry1_np).labels_
        h_qry0_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry0_pred)
        h_qry0_ars += metrics.adjusted_rand_score(qry_y_np, h_qry0_pred)
        h_qry1_ami += metrics.adjusted_mutual_info_score(qry_y_np, h_qry1_pred)
        h_qry1_ars += metrics.adjusted_rand_score(qry_y_np, h_qry1_pred)

        h_qry0_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y)
        h_qry1_cm = metrics.cluster.contingency_matrix(h_qry0_pred, qry_y)
        # viz.heatmap(X=h_qry0_cm, win=args.exp+' h_qry0_cm', opts=dict(title=args.exp+' h_qry0_cm:%d'%batchidx,
        #                                                               colormap='Electric'))
        # viz.heatmap(X=h_qry1_cm, win=args.exp+' h_qry1_cm', opts=dict(title=args.exp+' h_qry1_cm:%d'%batchidx,
        #                                                               colormap='Electric'))

        # return is a list of [acc_step0, acc_step1 ,...]
        acc0.append(
            net.classify_train(h_spt0,
                               spt_y,
                               h_qry0,
                               qry_y,
                               use_h=True,
                               train_step=args.classify_steps))
        acc1.append(
            net.classify_train(h_spt1,
                               spt_y,
                               h_qry1,
                               qry_y,
                               use_h=True,
                               train_step=args.classify_steps))

        if batchidx == 0:
            spt_x_hat0 = net.forward_ae(spt_x[:64])
            qry_x_hat0 = net.forward_ae(qry_x[:64])
            spt_x_hat1 = new_net.forward_ae(spt_x[:64])
            qry_x_hat1 = new_net.forward_ae(qry_x[:64])
            viz.images(qry_x[:64],
                       nrow=8,
                       win=exp + 'qry_x',
                       opts=dict(title=exp + 'qry_x'))
            # viz.images(spt_x_hat0, nrow=8, win=exp+'spt_x_hat0', opts=dict(title=exp+'spt_x_hat0'))
            viz.images(qry_x_hat0,
                       nrow=8,
                       win=exp + 'qry_x_hat0',
                       opts=dict(title=exp + 'qry_x_hat0'))
            # viz.images(spt_x_hat1, nrow=8, win=exp+'spt_x_hat1', opts=dict(title=exp+'spt_x_hat1'))
            viz.images(qry_x_hat1,
                       nrow=8,
                       win=exp + 'qry_x_hat1',
                       opts=dict(title=exp + 'qry_x_hat1'))

        if batchidx > 0:
            break


    h_qry0_ami, h_qry0_ars, h_qry1_ami, h_qry1_ars = h_qry0_ami / (batchidx + 1), h_qry0_ars / (batchidx + 1), \
                                                     h_qry1_ami / (batchidx + 1), h_qry1_ars / (batchidx + 1)
    # [[epsode1], [episode2],...] = [N, steps] => [steps]
    acc0, acc1 = np.array(acc0).mean(axis=0), np.array(acc1).mean(axis=0)

    print('ami:', h_qry0_ami, h_qry1_ami)
    print('ars:', h_qry0_ars, h_qry1_ars)
    viz.line([[h_qry0_ami, h_qry1_ami]], [global_step],
             win=exp + 'ami_on_qry01',
             update='append')
    viz.line([[h_qry0_ars, h_qry1_ars]], [global_step],
             win=exp + 'ars_on_qry01',
             update='append')
    print('acc:\n', acc0, '\n', acc1)
    viz.line([[acc0[-1], acc1[-1]]], [global_step],
             win=exp + 'acc_on_qry01',
             update='append')
Example #13
0
def main(args):

    args = update_args(args)

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    device = torch.device('cuda')
    if args.is_meta:
        # optimizer has been embedded in model.
        net = MetaAE(args)
        # model_parameters = filter(lambda p: p.requires_grad, net.learner.parameters())
        # params = sum([np.prod(p.size()) for p in model_parameters])
        # print('Total params:', params)
        tmp = filter(lambda x: x.requires_grad, net.learner.parameters())
        num = sum(map(lambda x: np.prod(x.shape), tmp))
        print('Total trainable variables:', num)
    else:
        net = AE(args, use_logits=True)
        optimizer = optim.Adam(list(net.encoder.parameters()) +
                               list(net.decoder.parameters()),
                               lr=args.meta_lr)

        tmp = filter(
            lambda x: x.requires_grad,
            list(net.encoder.parameters()) + list(net.decoder.parameters()))
        num = sum(map(lambda x: np.prod(x.shape), tmp))
        print('Total trainable variables:', num)

    net.to(device)
    print(net)

    print('=' * 15, 'Experiment:', args.exp, '=' * 15)
    print(args)

    if args.h_dim == 2:
        # borrowed from https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py
        h_range = np.rollaxis(
            np.mgrid[args.h_range:-args.h_range:args.h_nrow * 1j,
                     args.h_range:-args.h_range:args.h_nrow * 1j], 0, 3)
        # [b, q_h]
        h_manifold = torch.from_numpy(h_range.reshape([-1,
                                                       2])).to(device).float()
        print('h_manifold:', h_manifold.shape)
    else:
        h_manifold = None

    # try to resume from ckpt.mdl file
    epoch_start = 0
    if args.resume is not None:
        # ckpt/normal-fc-vae_640_2018-11-20_09:58:58.mdl
        mdl_file = args.resume
        epoch_start = int(mdl_file.split('_')[-3])
        net.load_state_dict(torch.load(mdl_file))
        print('Resume from:', args.resume, 'epoch/batches:', epoch_start)
    else:
        print('Training/test from scratch...')

    if args.test:
        assert args.resume is not None
        test.test_ft_steps(args, net, device)
        return

    vis = visdom.Visdom(env=args.exp)
    visualh = VisualH(vis)
    vis.line([[0, 0, 0]], [epoch_start],
             win=args.exp + 'train_loss',
             opts=dict(title=args.exp + 'train_qloss',
                       legend=['loss', '-lklh', 'kld'],
                       xlabel='global_step'))

    # for test_progress
    vis.line([[0, 0]], [epoch_start],
             win=args.exp + 'acc_on_qry01',
             opts=dict(title=args.exp + 'acc_on_qry01',
                       legend=['h_qry0', 'h_qry1'],
                       xlabel='global_step'))
    vis.line([[0, 0]], [epoch_start],
             win=args.exp + 'ami_on_qry01',
             opts=dict(title=args.exp + 'ami_on_qry01',
                       legend=['h_qry0', 'h_qry1'],
                       xlabel='global_step'))
    vis.line([[0, 0]], [epoch_start],
             win=args.exp + 'ars_on_qry01',
             opts=dict(title=args.exp + 'ars_on_qry01',
                       legend=['h_qry0', 'h_qry1'],
                       xlabel='global_step'))

    # 1. train
    db_train = OmniglotNShot('db/omniglot',
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    # epoch = batch number here.
    for epoch in range(epoch_start, args.train_episode_num):
        spt_x, spt_y, qry_x, qry_y = db_train.next()
        spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), torch.from_numpy(spt_y).to(device), \
                                     torch.from_numpy(qry_x).to(device), torch.from_numpy(qry_y).to(device)

        if args.is_meta:  # for meta
            loss_optim, losses_q, likelihoods_q, klds_q = net(
                spt_x, spt_y, qry_x, qry_y)

            if epoch % 300 == 0:

                if args.is_vae:
                    # print(losses_q, likelihoods_q, klds_q)
                    vis.line([[
                        losses_q[-1].item(), -likelihoods_q[-1].item(),
                        klds_q[-1].item()
                    ]], [epoch],
                             win=args.exp + 'train_loss',
                             update='append')
                    print(epoch)
                    print(
                        'loss_q:',
                        torch.stack(losses_q).detach().cpu().numpy().astype(
                            np.float16))
                    print(
                        'lkhd_q:',
                        torch.stack(
                            likelihoods_q).detach().cpu().numpy().astype(
                                np.float16))
                    print(
                        'klds_q:',
                        torch.stack(klds_q).cpu().detach().numpy().astype(
                            np.float16))
                else:
                    # print(losses_q, likelihoods_q, klds_q)
                    vis.line([[losses_q[-1].item(), 0, 0]], [epoch],
                             win=args.exp + 'train_loss',
                             update='append')
                    print(
                        epoch,
                        torch.stack(losses_q).detach().cpu().numpy().astype(
                            np.float16))

        else:  # for normal vae/ae

            loss_optim, _, likelihood, kld = net(spt_x, spt_y, qry_x, qry_y)
            optimizer.zero_grad()
            loss_optim.backward()
            torch.nn.utils.clip_grad_norm_(
                list(net.encoder.parameters()) +
                list(net.decoder.parameters()), 10)
            optimizer.step()

            if epoch % 300 == 0:

                print(epoch, loss_optim.item())
                if not args.is_vae:
                    vis.line([[loss_optim.item(), 0, 0]], [epoch],
                             win='train_loss',
                             update='append')
                else:
                    vis.line(
                        [[loss_optim.item(), -likelihood.item(),
                          kld.item()]], [epoch],
                        win='train_loss',
                        update='append')

        if epoch % 3000 == 0:
            # [qrysz, 1, 64, 64] => [qrysz, 1, 64, 64]
            x_hat = net.forward_ae(qry_x[0])
            vis.images(x_hat,
                       nrow=args.k_qry,
                       win='train_x_hat',
                       opts=dict(title='train_x_hat'))
            test.test_progress(args, net, device, vis, epoch)

        # save checkpoint.
        if epoch % 10000 == 0:
            date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            mdl_file = os.path.join(
                args.ckpt_dir,
                args.exp + '_%d' % epoch + '_' + date_str + '.mdl')
            torch.save(net.state_dict(), mdl_file)
            print('Saved into ckpt file:', mdl_file)

    # save checkpoint.
    date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    mdl_file = os.path.join(
        args.ckpt_dir, args.exp + '_%d' % args.epoch + '_' + date_str + '.mdl')
    torch.save(net.state_dict(), mdl_file)
    print('Saved Last state ckpt file:', mdl_file)
Example #14
0
def main(args):
    if not os.path.exists('./logs'):
        os.mkdir('./logs')
    logfile = os.path.sep.join(('.', 'logs', f'omniglot_way[{args.n_way}]_shot[{args.k_spt}].json'))
    if args.write_log:
        log_fp = open(logfile, 'wb')
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [64, 1, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 3, 3, 2, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('conv2d', [64, 64, 2, 2, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('flatten', []),
        ('linear', [args.n_way, 64])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    path = os.sep.join((os.path.dirname(__file__), 'dataset', 'omniglot'))
    db_train = OmniglotNShot(path,
                             batchsz=args.task_num,
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             imgsz=args.imgsz)

    for step in range(args.epoch):
        # 获取一定的 epoch 数据. 在omniglot NShot类里写的是
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next('test')
                x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                             torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
Example #15
0
def main(args):
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    torch.backends.cudnn.benchmark=True
    print(args)

    config = [
        ("conv2d", [64, 1, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 3, 3, 2, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("conv2d", [64, 64, 2, 2, 1, 0]),
        ("relu", [True]),
        ("bn", [64]),
        ("flatten", []),
        ("linear", [args.n_way, 64]),
    ]

    device = torch.device("cuda")
    maml = Meta(args, config).to(device)

    db_train = OmniglotNShot(
        "omniglot",
        batchsz=args.task_num,
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        imgsz=args.imgsz,
    )

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print("Total trainable tensors:", num)

    for step in range(args.epoch):

        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = (
            torch.from_numpy(x_spt).to(device),
            torch.from_numpy(y_spt).to(device),
            torch.from_numpy(x_qry).to(device),
            torch.from_numpy(y_qry).to(device),
        )

        # x_spt: shape: 32, 5, 1, 28, 28
        # y_spt: shape: 32, 5
        # x_qry: 32, 75, 1, 28, 28
        # y_qry: 32, 75

        # set traning=True to update running_mean, running_variance, bn_weights, bn_bias
        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 50 == 0:
            print("step:", step, "\ttraining acc:", accs)

        if step % 500 == 0:
            accs = []
            for _ in range(1000 // args.task_num):
                # test
                x_spt, y_spt, x_qry, y_qry = db_train.next("test")
                x_spt, y_spt, x_qry, y_qry = (
                    torch.from_numpy(x_spt).to(device),
                    torch.from_numpy(y_spt).to(device),
                    torch.from_numpy(x_qry).to(device),
                    torch.from_numpy(y_qry).to(device),
                )

                # split to single task each time
                for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                    test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                    accs.append(test_acc)

            # [b, update_step+1]
            accs = np.array(accs).mean(axis=0).astype(np.float16)
            print("Test acc:", accs)