示例#1
0
def main():
	argparser = argparse.ArgumentParser()
	argparser.add_argument('-n', help='n way', default=5)
	argparser.add_argument('-k', help='k shot', default=1)
	argparser.add_argument('-b', help='batch size', default=4)
	argparser.add_argument('-l', help='learning rate', default=1e-3)
	args = argparser.parse_args()
	n_way = int(args.n)
	k_shot = int(args.k)
	meta_batchsz = int(args.b)
	lr = float(args.l)

	k_query = 1
	imgsz = 84
	threhold = 0.699 if k_shot==5 else 0.584 # threshold for when to test full version of episode
	mdl_file = 'ckpt/maml%d%d.mdl'%(n_way, k_shot)
	print('mini-imagnet: %d-way %d-shot lr:%f, threshold:%f' % (n_way, k_shot, lr, threhold))



	device = torch.device('cuda')
	net = MAML(n_way, k_shot, k_query, meta_batchsz=meta_batchsz, K=5, device=device)
	print(net)

	if os.path.exists(mdl_file):
		print('load from checkpoint ...', mdl_file)
		net.load_state_dict(torch.load(mdl_file))
	else:
		print('training from scratch.')

	# whole parameters number
	model_parameters = filter(lambda p: p.requires_grad, net.parameters())
	params = sum([np.prod(p.size()) for p in model_parameters])
	print('Total params:', params)


	for epoch in range(1000):
		# batchsz here means total episode number
		mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=10000, resize=imgsz)
		# fetch meta_batchsz num of episode each time
		db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=8, pin_memory=True)

		for step, batch in enumerate(db):

			# 2. train
			support_x = batch[0].to(device)
			support_y = batch[1].to(device)
			query_x = batch[2].to(device)
			query_y = batch[3].to(device)

			accs = net(support_x, support_y, query_x, query_y, training = True)

			if step % 10 == 0:
				print(accs)
示例#2
0
testset = miniimagenet("data",
                       ways=5,
                       shots=5,
                       test_shots=15,
                       meta_test=True,
                       download=True)
testloader = BatchMetaDataLoader(testset,
                                 batch_size=2,
                                 num_workers=4,
                                 shuffle=True)
evaliter = iter(testloader)

model_path = './model/model.pth'
model = MAML().to(device)
model.load_state_dict(torch.load(model_path))
loss_fn = torch.nn.CrossEntropyLoss().to(device)

test_loss_log = []
test_acc_log = []

for i in range(1000):
    evalbatch = evaliter.next()
    model.eval()
    testloss, testacc = test(model,
                             evalbatch,
                             loss_fn,
                             lr=0.01,
                             train_step=10,
                             device=device)