Esempio n. 1
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way', default=5)
    argparser.add_argument('-k', help='k shot', default=1)
    argparser.add_argument('-b', help='batch size', default=32)
    argparser.add_argument('-l', help='meta learning rate', default=1e-3)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    meta_batchsz = int(args.b)
    meta_lr = float(args.l)
    train_lr = 0.4

    k_query = 15
    imgsz = 84
    mdl_file = 'ckpt/omniglot%d%d.mdl' % (n_way, k_shot)
    print('omniglot: %d-way %d-shot meta-lr:%f, train-lr:%f' %
          (n_way, k_shot, meta_lr, train_lr))

    device = torch.device('cuda:0')
    net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr,
               device)
    print(net)

    # batchsz here means total episode number
    db = OmniglotNShot('omniglot',
                       batchsz=meta_batchsz,
                       n_way=n_way,
                       k_shot=k_shot,
                       k_query=k_query,
                       imgsz=imgsz)

    for step in range(10000000):

        # train
        support_x, support_y, query_x, query_y = db.get_batch('train')
        support_x = torch.from_numpy(support_x).float().transpose(
            2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1).to(device)
        query_x = torch.from_numpy(query_x).float().transpose(2, 4).transpose(
            3, 4).repeat(1, 1, 3, 1, 1).to(device)
        support_y = torch.from_numpy(support_y).long().to(device)
        query_y = torch.from_numpy(query_y).long().to(device)

        accs = net(support_x, support_y, query_x, query_y, training=True)

        if step % 20 == 0:
            print(step, '\t', accs)

        if step % 1000 == 0:
            # test
            pass
Esempio n. 2
0
def main():
	meta_batchsz = 32
	n_way = 20
	k_shot = 1
	k_query = k_shot
	meta_lr = 1e-3
# 	meta_lr = 1
	num_updates = 5
	dataset = 'omniglot'



	if dataset == 'omniglot':
		imgsz = 28
		db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz)

	elif dataset == 'mini-imagenet':
		imgsz = 84
		# the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use
		# get_batch(train or test) to get different batch.
		# for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader.
		mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=10000, resize=imgsz)
		db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True)
		mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=1000, resize=imgsz)
		db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True)

	else:
		raise  NotImplementedError


	meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr,
	                   num_updates=num_updates).cuda()

	tb = SummaryWriter('runs')


	# main loop
	for episode_num in range(1500):

		# 1. train
		if dataset == 'omniglot':
			support_x, support_y, query_x, query_y = db.get_batch('test')
			support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
			query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
			support_y = Variable(torch.from_numpy(support_y).long()).cuda()
			query_y = Variable(torch.from_numpy(query_y).long()).cuda()
		elif dataset == 'mini-imagenet':
			try:
				batch_test = iter(db).next()
			except StopIteration as err:
				mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
				                    batchsz=10000, resize=imgsz)
				db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True)

			support_x = Variable(batch_test[0]).cuda()
			support_y = Variable(batch_test[1]).cuda()
			query_x = Variable(batch_test[2]).cuda()
			query_y = Variable(batch_test[3]).cuda()

		# backprop has been embeded in forward func.
		if episode_num % 100 = 0:
			meta.prv_angle = 0
		accs = meta(support_x, support_y, query_x, query_y)
		train_acc = np.array(accs).mean()

		# 2. test
		if episode_num % 30 == 0:
			test_accs = []
			for i in range(min(episode_num // 5000 + 3, 10)): # get average acc.
				if dataset == 'omniglot':
					support_x, support_y, query_x, query_y = db.get_batch('test')
					support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
					query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
					support_y = Variable(torch.from_numpy(support_y).long()).cuda()
					query_y = Variable(torch.from_numpy(query_y).long()).cuda()
				elif dataset == 'mini-imagenet':
					try:
						batch_test = iter(db_test).next()
					except StopIteration as err:
						mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot,
						                         k_query=k_query,
						                         batchsz=1000, resize=imgsz)
						db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True)
					support_x = Variable(batch_test[0]).cuda()
					support_y = Variable(batch_test[1]).cuda()
					query_x = Variable(batch_test[2]).cuda()
					query_y = Variable(batch_test[3]).cuda()
 

				# get accuracy
				test_acc = meta.pred(support_x, support_y, query_x, query_y)
				test_accs.append(test_acc)

			test_acc = np.array(test_accs).mean()
			print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc)
			tb.add_scalar('test-acc', test_acc)
			tb.add_scalar('finetune-acc', train_acc)
Esempio n. 3
0
def evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold,
               mdl_file):
    """
	obey the expriment setting of MAML and Learning2Compare, we randomly sample 600 episodes and 15 query images per query
	set.
	:param net:
	:param batchsz:
	:return:
	"""
    k_query = 15
    db = OmniglotNShot('dataset',
                       batchsz=batchsz,
                       n_way=n_way,
                       k_shot=k_shot,
                       k_query=k_query,
                       imgsz=imgsz)

    accs = []
    episode_num = 0  # record tested num of episodes

    for i in range(600 // batchsz):
        support_x, support_y, query_x, query_y = db.get_batch('test')
        support_x = Variable(
            torch.from_numpy(support_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        query_x = Variable(
            torch.from_numpy(query_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        support_y = Variable(torch.from_numpy(support_y).int()).cuda()
        query_y = Variable(torch.from_numpy(query_y).int()).cuda()

        # we will split query set into 15 splits.
        # query_x : [batch, 15*way, c_, h, w]
        # query_x_b : tuple, 15 * [b, way, c_, h, w]
        query_x_b = torch.chunk(query_x, k_query, dim=1)
        # query_y : [batch, 15*way]
        # query_y_b: 15* [b, way]
        query_y_b = torch.chunk(query_y, k_query, dim=1)
        preds = []
        net.eval()
        # we don't need the total acc on 600 episodes, but we need the acc per sets of 15*nway setsz.
        total_correct = 0
        total_num = 0
        for query_x_mini, query_y_mini in zip(query_x_b, query_y_b):
            # print('query_x_mini', query_x_mini.size(), 'query_y_mini', query_y_mini.size())
            pred, correct = net(support_x, support_y,
                                query_x_mini.contiguous(), query_y_mini, False)
            correct = correct.sum()  # multi-gpu
            # pred: [b, nway]
            preds.append(pred)
            total_correct += correct.data[0]
            total_num += query_y_mini.size(0) * query_y_mini.size(1)
        # # 15 * [b, nway] => [b, 15*nway]
        # preds = torch.cat(preds, dim= 1)
        acc = total_correct / total_num
        print('%.5f,' % acc, end=' ')
        sys.stdout.flush()
        accs.append(acc)

        # update tested episode number
        episode_num += query_y.size(0)
        if episode_num > episodesz:
            # test current tested episodes acc.
            acc = np.array(accs).mean()
            if acc >= threhold:
                # if current acc is very high, we conduct all 600 episodes testing.
                continue
            else:
                # current acc is low, just conduct `episodesz` num of episodes.
                break

    # compute the distribution of 600/episodesz episodes acc.
    global best_accuracy
    accs = np.array(accs)
    accuracy, sem = mean_confidence_interval(accs)
    print('\naccuracy:', accuracy, 'sem:', sem)
    print('<<<<<<<<< accuracy:', accuracy, 'best accuracy:', best_accuracy,
          '>>>>>>>>')

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(net.state_dict(), mdl_file)
        print('Saved to checkpoint:', mdl_file)

    return accuracy, sem
Esempio n. 4
0
def main():
    meta_batchsz = 32 * 3
    n_way = 5
    k_shot = 5
    k_query = k_shot
    meta_lr = 1e-3
    num_updates = 5
    dataset = "mini-imagenet"

    if dataset == "omniglot":
        imgsz = 28
        db = OmniglotNShot(
            "dataset",
            batchsz=meta_batchsz,
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            imgsz=imgsz,
        )

    elif dataset == "mini-imagenet":
        imgsz = 84
        # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use
        # get_batch(train or test) to get different batch.
        # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader.
        mini = MiniImagenet(
            "../../hdd1/meta/mini-imagenet/",
            mode="train",
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            batchsz=10000,
            resize=imgsz,
        )
        db = DataLoader(mini,
                        meta_batchsz,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)
        mini_test = MiniImagenet(
            "../../hdd1/meta/mini-imagenet/",
            mode="test",
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            batchsz=1000,
            resize=imgsz,
        )
        db_test = DataLoader(mini_test,
                             meta_batchsz,
                             shuffle=True,
                             num_workers=2,
                             pin_memory=True)

    else:
        raise NotImplementedError

    # do NOT call .cuda() implicitly
    net = CSML()
    net.deploy()

    tb = SummaryWriter("runs")

    # main loop
    for episode_num in range(200000):

        # 1. train
        if dataset == "omniglot":
            support_x, support_y, query_x, query_y = db.get_batch("test")
            support_x = Variable(
                torch.from_numpy(support_x).float().transpose(2, 4).transpose(
                    3, 4).repeat(1, 1, 3, 1, 1)).cuda()
            query_x = Variable(
                torch.from_numpy(query_x).float().transpose(2, 4).transpose(
                    3, 4).repeat(1, 1, 3, 1, 1)).cuda()
            support_y = Variable(torch.from_numpy(support_y).long()).cuda()
            query_y = Variable(torch.from_numpy(query_y).long()).cuda()

        elif dataset == "mini-imagenet":
            try:
                batch_train = iter(db).next()
            except StopIteration as err:
                mini = MiniImagenet(
                    "../../hdd1/meta/mini-imagenet/",
                    mode="train",
                    n_way=n_way,
                    k_shot=k_shot,
                    k_query=k_query,
                    batchsz=10000,
                    resize=imgsz,
                )
                db = DataLoader(mini,
                                meta_batchsz,
                                shuffle=True,
                                num_workers=4,
                                pin_memory=True)
                batch_train = iter(db).next()

            support_x = Variable(batch_train[0])
            support_y = Variable(batch_train[1])
            query_x = Variable(batch_train[2])
            query_y = Variable(batch_train[3])
            print(support_x.size(), support_y.size())

        # backprop has been embeded in forward func.
        accs = net.train(support_x, support_y, query_x, query_y)
        train_acc = np.array(accs).mean()

        # 2. test
        if episode_num % 30 == 220:
            test_accs = []
            for i in range(min(episode_num // 5000 + 3,
                               10)):  # get average acc.
                if dataset == "omniglot":
                    support_x, support_y, query_x, query_y = db.get_batch(
                        "test")
                    support_x = Variable(
                        torch.from_numpy(support_x).float().transpose(
                            2, 4).transpose(3, 4).repeat(1, 1, 3, 1,
                                                         1)).cuda()
                    query_x = Variable(
                        torch.from_numpy(query_x).float().transpose(
                            2, 4).transpose(3, 4).repeat(1, 1, 3, 1,
                                                         1)).cuda()
                    support_y = Variable(
                        torch.from_numpy(support_y).long()).cuda()
                    query_y = Variable(torch.from_numpy(query_y).long()).cuda()

                elif dataset == "mini-imagenet":
                    try:
                        batch_test = iter(db_test).next()
                    except StopIteration as err:
                        mini_test = MiniImagenet(
                            "../../hdd1/meta/mini-imagenet/",
                            mode="test",
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=1000,
                            resize=imgsz,
                        )
                        db_test = DataLoader(
                            mini_test,
                            meta_batchsz,
                            shuffle=True,
                            num_workers=2,
                            pin_memory=True,
                        )
                        batch_test = iter(db).next()

                    support_x = Variable(batch_test[0])
                    support_y = Variable(batch_test[1])
                    query_x = Variable(batch_test[2])
                    query_y = Variable(batch_test[3])

                # get accuracy
                # test_acc = net.train(support_x, support_y, query_x, query_y, train=False)
                test_accs.append(test_acc)

            test_acc = np.array(test_accs).mean()
            print(
                "episode:",
                episode_num,
                "\tfinetune acc:%.6f" % train_acc,
                "\t\ttest acc:%.6f" % test_acc,
            )
            tb.add_scalar("test-acc", test_acc)
            tb.add_scalar("finetune-acc", train_acc)
Esempio n. 5
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way')
    argparser.add_argument('-k', help='k shot')
    argparser.add_argument('-b', help='batch size')
    argparser.add_argument('-l', help='learning rate', default=1e-3)
    argparser.add_argument('-t',
                           help='threshold to test all episodes',
                           default=0.97)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    k_query = 1
    batchsz = int(args.b)
    imgsz = 84
    lr = float(args.l)
    threshold = float(args.t)

    db = OmniglotNShot('dataset',
                       batchsz=batchsz,
                       n_way=n_way,
                       k_shot=k_shot,
                       k_query=k_query,
                       imgsz=imgsz)
    print('Omniglot: no rotate!  %d-way %d-shot  lr:%f' % (n_way, k_shot, lr))
    net = NaiveRN(n_way, k_shot, imgsz).cuda()
    print(net)
    mdl_file = 'ckpt/omni%d%d.mdl' % (n_way, k_shot)

    if os.path.exists(mdl_file):
        print('recover from state: ', mdl_file)
        net.load_state_dict(torch.load(mdl_file))
    else:
        print('training from scratch.')

    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('Total params:', params)

    input, input_y, query, query_y = db.get_batch(
        'train')  # (batch, n_way*k_shot, img)
    print('get batch:', input.shape, query.shape, input_y.shape, query_y.shape)

    optimizer = optim.Adam(net.parameters(), lr=lr)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'max',
                                               factor=0.5,
                                               patience=15,
                                               verbose=True)

    total_train_loss = 0
    for step in range(100000000):

        # 1. test
        if step % 400 == 0:
            accuracy, _ = 0, 0  # evaluation(net, batchsz, n_way, k_shot, imgsz, 300, threshold, mdl_file)
            scheduler.step(accuracy)

        # 2. train
        support_x, support_y, query_x, query_y = db.get_batch('train')
        support_x = Variable(
            torch.from_numpy(support_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        query_x = Variable(
            torch.from_numpy(query_x).float().transpose(2, 4).transpose(
                3, 4).repeat(1, 1, 3, 1, 1)).cuda()
        support_y = Variable(torch.from_numpy(support_y).int()).cuda()
        query_y = Variable(torch.from_numpy(query_y).int()).cuda()

        loss = net(support_x, support_y, query_x, query_y)
        total_train_loss += loss.data[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 3. print
        if step % 20 == 0 and step != 0:
            print('%d-way %d-shot %d batch> step:%d, loss:%f' %
                  (n_way, k_shot, batchsz, step, total_train_loss))
            total_train_loss = 0