コード例 #1
0
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)  

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True)
    testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True)
    train_data = inf_get(trainloader)
    test_data = inf_get(testloader)

    for epoch in range(args.epoch):
        support_x, support_y, meta_x, meta_y = train_data.__next__()
        support_x, support_y, meta_x, meta_y = support_x.to(Param.device), support_y.to(Param.device), meta_x.to(Param.device), meta_y.to(Param.device)
        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value = 10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if(epoch % 2000 == 999):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                support_x, support_y, qx, qy = test_data.__next__()
                support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device)
                temp = maml_clone(support_x, support_y, qx, qy, meta_train = False)
                if(ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim = 0)
            ans = ans.mean(dim = 0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(maml.state_dict(), Param.out_path + 'net_'+ str(epoch) + '_' + str(best_acc) + '.pkl') 
            del maml_clone
            print(str(epoch) + ': '+str(ans))
            with open(Param.out_path+'test.json','w') as f:
                json.dump(test_result,f)
        if (epoch < 5) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()
コード例 #2
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    #     df_spt = pd.read_csv('support_set.csv')
    #     df_qry = pd.read_csv('query_set.csv')

    print(args)

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]
    cuda = 'cuda:' + str(args.gpu_index)
    device = torch.device(cuda)
    with open('model.pkl', 'rb') as f:
        maml = cloudpickle.load(f).to(device)


#     maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    #     print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    #     mini = MiniImagenet('./flower/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
    #                         k_query=args.k_qry,
    #                         batchsz=10000, resize=args.imgsz)
    mini_test = MiniImagenet('./flower/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100,
                             resize=args.imgsz)
    db_test = DataLoader(mini_test,
                         1,
                         shuffle=True,
                         num_workers=1,
                         pin_memory=True)
    accs_all_test = []

    for x_spt, y_spt, x_qry, y_qry in db_test:
        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                     x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
        accs_all_test.append(accs)

    # [b, update_step+1]
    accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
    print('step:', step, '\ttest acc:', accs)
コード例 #3
0
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    if n_gpus>1:
        maml = torch.nn.DataParallel(maml)
    state_dict = torch.load(Param.out_path+args.ckpt)
    print(state_dict.keys())
    pretrained_dict = OrderedDict()
    for k in state_dict.keys():
        if n_gpus==1:
            pretrained_dict[k[7:]] = deepcopy(state_dict[k])
        else:
            pretrained_dict[k[0:]] = deepcopy(state_dict[k])
    maml.load_state_dict(pretrained_dict)
    print("Load from ckpt:", Param.out_path+args.ckpt)
    
    #opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)  

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    #trainset = MiniImagenet(Param.root, mode='train', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    #valset = MiniImagenet(Param.root, mode='val', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    testset = MiniImagenet(Param.root, mode='test', n_way=args.n_way, k_shot=args.k_spt, k_query=args.k_qry, resize=args.imgsz)
    #trainloader = DataLoader(trainset, batch_size=args.task_num, shuffle=True, num_workers=4, drop_last=True)
    #valloader = DataLoader(valset, batch_size=4, shuffle=True, num_workers=4, drop_last=True)
    testloader = DataLoader(testset, batch_size=4, shuffle=True, num_workers=4, drop_last=True)
    #train_data = inf_get(trainloader)
    #val_data = inf_get(valloader)
    test_data = inf_get(testloader)
    
    """Test for 600 epochs (each has 4 tasks)"""
    ans = None
    maml_clone = deepcopy(maml)
    for itr in range(600): # 600x4 test tasks
        support_x, support_y, qx, qy = test_data.__next__()
        support_x, support_y, qx, qy = support_x.to(Param.device), support_y.to(Param.device), qx.to(Param.device), qy.to(Param.device)
        temp = maml_clone(support_x, support_y, qx, qy, meta_train = False)
        if(ans is None):
            ans = temp
        else:
            ans = torch.cat([ans, temp], dim = 0)
        if itr%100==0:
            print(itr,ans.mean(dim = 0).tolist())
    ans = ans.mean(dim = 0).tolist()
    print('Acc: '+str(ans))
    with open(Param.out_path+'test.json','w') as f:
        json.dump(ans,f)
コード例 #4
0
ファイル: maml_train.py プロジェクト: ml-lab/MAML-Pytorch
def main():
	argparser = argparse.ArgumentParser()
	argparser.add_argument('-n', help='n way', default=5)
	argparser.add_argument('-k', help='k shot', default=1)
	argparser.add_argument('-b', help='batch size', default=4)
	argparser.add_argument('-l', help='learning rate', default=1e-3)
	args = argparser.parse_args()
	n_way = int(args.n)
	k_shot = int(args.k)
	meta_batchsz = int(args.b)
	lr = float(args.l)

	k_query = 1
	imgsz = 84
	threhold = 0.699 if k_shot==5 else 0.584 # threshold for when to test full version of episode
	mdl_file = 'ckpt/maml%d%d.mdl'%(n_way, k_shot)
	print('mini-imagnet: %d-way %d-shot lr:%f, threshold:%f' % (n_way, k_shot, lr, threhold))



	device = torch.device('cuda')
	net = MAML(n_way, k_shot, k_query, meta_batchsz=meta_batchsz, K=5, device=device)
	print(net)

	if os.path.exists(mdl_file):
		print('load from checkpoint ...', mdl_file)
		net.load_state_dict(torch.load(mdl_file))
	else:
		print('training from scratch.')

	# whole parameters number
	model_parameters = filter(lambda p: p.requires_grad, net.parameters())
	params = sum([np.prod(p.size()) for p in model_parameters])
	print('Total params:', params)


	for epoch in range(1000):
		# batchsz here means total episode number
		mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=10000, resize=imgsz)
		# fetch meta_batchsz num of episode each time
		db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=8, pin_memory=True)

		for step, batch in enumerate(db):

			# 2. train
			support_x = batch[0].to(device)
			support_y = batch[1].to(device)
			query_x = batch[2].to(device)
			query_y = batch[3].to(device)

			accs = net(support_x, support_y, query_x, query_y, training = True)

			if step % 10 == 0:
				print(accs)
コード例 #5
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
    print(args)
    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    ckpt_dir = "./checkpoint_miniimage.pth"
    print("Load trained model")
    ckpt = torch.load(ckpt_dir)
    maml.load_state_dict(ckpt['model'])

    mini_test = MiniImagenet("F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\",
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=1,
                             resize=args.imgsz)

    db_test = DataLoader(mini_test,
                         1,
                         shuffle=True,
                         num_workers=1,
                         pin_memory=True)
    accs_all_test = []
    #count = 0
    #print("Test_loader",db_test)

    for x_spt, y_spt, x_qry, y_qry in db_test:

        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
        x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
        accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
コード例 #6
0
def main():
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()
    best_pred = 0

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    maml = Meta(args, criterion).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size,
                        split=[0, args.train_portion])
    mini_valid = MiniImagenet(args.data_path,
                              mode='train',
                              n_way=args.n_way,
                              k_shot=args.k_spt,
                              k_query=args.k_qry,
                              batch_size=args.batch_size,
                              resize=args.img_size,
                              split=[args.train_portion, 1])
    mini_test = MiniImagenet(args.data_path,
                             mode='train',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batch_size=args.test_batch_size,
                             resize=args.img_size,
                             split=[args.train_portion, 1])
    train_queue = DataLoader(mini,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    valid_queue = DataLoader(mini_valid,
                             args.meta_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    test_queue = DataLoader(mini_test,
                            args.meta_test_batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True)
    architect = Architect(maml.model, args)

    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs = meta_train(train_queue, valid_queue, maml, architect,
                                device, criterion, epoch, writer)
        logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs))
        valid_accs = meta_test(test_queue, maml, device, epoch, writer)
        logging.info('[Epoch: {}]\t Test acc: {}'.format(epoch, valid_accs))

        genotype = maml.model.genotype()
        logging.info('genotype = %s', genotype)

        # logging.info(F.softmax(maml.model.alphas_normal, dim=-1))
        logging.info(F.softmax(maml.model.alphas_reduce, dim=-1))

        # Save the best meta model.
        new_pred = valid_accs[-1]
        if new_pred > best_pred:
            is_best = True
            best_pred = new_pred
        else:
            is_best = False
        saver.save_checkpoint(
            {
                'epoch':
                epoch,
                'state_dict':
                maml.module.state_dict()
                if isinstance(maml, nn.DataParallel) else maml.state_dict(),
                'best_pred':
                best_pred,
            }, is_best)
コード例 #7
0
def main():

    start_time = time.time()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    print(args)
    print(argv)
    os.makedirs(args.modelfile.split('/')[0], exist_ok=True)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 1]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    mini = MiniImagenet('./dataset/mini-imagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    if args.domain == 'mini':
        mini_test = MiniImagenet('./dataset/mini-imagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
        mini_val = MiniImagenet('./dataset/mini-imagenet/', mode='val', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    elif args.domain == 'cub':
        print("CUB dataset")
        mini_test = MiniImagenet('./dataset/CUB_200_2011/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    elif args.domain == 'traffic':
        print("Traffic dataset")
        mini_test = MiniImagenet('./dataset/GTSRB/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    elif args.domain == 'flower':
        print("flower dataset")
        mini_test = MiniImagenet('./dataset/102flowers/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=args.test_iter, resize=args.imgsz)
    else:
        print("Dataset Error")
        return

    if args.mode == 'test':
        count = 0
        db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=6, pin_memory=True)
        accs_all_test = []

        for x_spt, y_spt, x_qry, y_qry in db_test:
            print(count)
            count += 1
            x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                         x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
            accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'test', args.modelfile, pertub_scale=args.pertub_scale, num_ensemble=args.num_ensemble, fgsm_epsilon=args.fgsm_epsilon)
            accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        np.set_printoptions(linewidth=1000)
        print("Running Time:", time.time()-start_time)
        print(accs)
        return


    for epoch in range(args.epoch//10000):
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=4, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('epoch:', epoch, 'step:', step, '\ttraining acc:', accs)

            if step % 200 == 0:
                print("Save model", args.modelfile)
                torch.save(maml, args.modelfile)
                db_test = DataLoader(mini_val, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_val = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, 'train_test')
                    accs_all_val.append(accs)
                
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_test = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)
                    
                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                accs_val = np.array(accs_all_val).mean(axis=0).astype(np.float16)

                save_modelfile = "{}_{}_{}_{:0.4f}_{:0.4f}.pth".format(args.modelfile, epoch, step, accs_val[-1], accs[-1])
                print(save_modelfile)
                torch.save(maml, save_modelfile) 
                print("Val:", accs_val)
                print("Test:", accs)
コード例 #8
0
ファイル: meta.py プロジェクト: dragen1860/Meta-Relation
        print('load pretrained mdl ...', pretrain_mdl_file)
        meta.load_state_dict(torch.load(pretrain_mdl_file))

    model_parameters = filter(lambda p: p.requires_grad, meta.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('total params:', params)

    optimizer = optim.Adam(meta.parameters(), lr=1e-5, weight_decay=1e-6)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)

    best_train_acc = 0
    for epoch in range(1000):
        mini = MiniImagenet('../mini-imagenet/',
                            mode='train',
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=500,
                            resize=resize)
        db = DataLoader(mini, batchsz, shuffle=True, num_workers=6)

        for step, batch in enumerate(db):
            # batch : ([10,10,3,84,84], [10,10], [10,75,3,84,84], [10,75])
            support_x = Variable(batch[0]).cuda()
            support_y = Variable(batch[1]).cuda()
            query_x = Variable(batch[2]).cuda()
            query_y = Variable(batch[3]).cuda()

            meta.train()
            cls_loss, train_acc = meta.pretrain(support_x, support_y, query_x,
                                                query_y)
コード例 #9
0
ファイル: mainv0.py プロジェクト: jizongFox/MAML-Pytorch
def main():
    meta_batchsz = 32 * 3
    n_way = 5
    k_shot = 5
    k_query = k_shot
    meta_lr = 1e-3
    num_updates = 5
    dataset = "mini-imagenet"

    if dataset == "omniglot":
        imgsz = 28
        db = OmniglotNShot(
            "dataset",
            batchsz=meta_batchsz,
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            imgsz=imgsz,
        )

    elif dataset == "mini-imagenet":
        imgsz = 84
        # the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use
        # get_batch(train or test) to get different batch.
        # for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader.
        mini = MiniImagenet(
            "../../hdd1/meta/mini-imagenet/",
            mode="train",
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            batchsz=10000,
            resize=imgsz,
        )
        db = DataLoader(mini,
                        meta_batchsz,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)
        mini_test = MiniImagenet(
            "../../hdd1/meta/mini-imagenet/",
            mode="test",
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            batchsz=1000,
            resize=imgsz,
        )
        db_test = DataLoader(mini_test,
                             meta_batchsz,
                             shuffle=True,
                             num_workers=2,
                             pin_memory=True)

    else:
        raise NotImplementedError

    # do NOT call .cuda() implicitly
    net = CSML()
    net.deploy()

    tb = SummaryWriter("runs")

    # main loop
    for episode_num in range(200000):

        # 1. train
        if dataset == "omniglot":
            support_x, support_y, query_x, query_y = db.get_batch("test")
            support_x = Variable(
                torch.from_numpy(support_x).float().transpose(2, 4).transpose(
                    3, 4).repeat(1, 1, 3, 1, 1)).cuda()
            query_x = Variable(
                torch.from_numpy(query_x).float().transpose(2, 4).transpose(
                    3, 4).repeat(1, 1, 3, 1, 1)).cuda()
            support_y = Variable(torch.from_numpy(support_y).long()).cuda()
            query_y = Variable(torch.from_numpy(query_y).long()).cuda()

        elif dataset == "mini-imagenet":
            try:
                batch_train = iter(db).next()
            except StopIteration as err:
                mini = MiniImagenet(
                    "../../hdd1/meta/mini-imagenet/",
                    mode="train",
                    n_way=n_way,
                    k_shot=k_shot,
                    k_query=k_query,
                    batchsz=10000,
                    resize=imgsz,
                )
                db = DataLoader(mini,
                                meta_batchsz,
                                shuffle=True,
                                num_workers=4,
                                pin_memory=True)
                batch_train = iter(db).next()

            support_x = Variable(batch_train[0])
            support_y = Variable(batch_train[1])
            query_x = Variable(batch_train[2])
            query_y = Variable(batch_train[3])
            print(support_x.size(), support_y.size())

        # backprop has been embeded in forward func.
        accs = net.train(support_x, support_y, query_x, query_y)
        train_acc = np.array(accs).mean()

        # 2. test
        if episode_num % 30 == 220:
            test_accs = []
            for i in range(min(episode_num // 5000 + 3,
                               10)):  # get average acc.
                if dataset == "omniglot":
                    support_x, support_y, query_x, query_y = db.get_batch(
                        "test")
                    support_x = Variable(
                        torch.from_numpy(support_x).float().transpose(
                            2, 4).transpose(3, 4).repeat(1, 1, 3, 1,
                                                         1)).cuda()
                    query_x = Variable(
                        torch.from_numpy(query_x).float().transpose(
                            2, 4).transpose(3, 4).repeat(1, 1, 3, 1,
                                                         1)).cuda()
                    support_y = Variable(
                        torch.from_numpy(support_y).long()).cuda()
                    query_y = Variable(torch.from_numpy(query_y).long()).cuda()

                elif dataset == "mini-imagenet":
                    try:
                        batch_test = iter(db_test).next()
                    except StopIteration as err:
                        mini_test = MiniImagenet(
                            "../../hdd1/meta/mini-imagenet/",
                            mode="test",
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=1000,
                            resize=imgsz,
                        )
                        db_test = DataLoader(
                            mini_test,
                            meta_batchsz,
                            shuffle=True,
                            num_workers=2,
                            pin_memory=True,
                        )
                        batch_test = iter(db).next()

                    support_x = Variable(batch_test[0])
                    support_y = Variable(batch_test[1])
                    query_x = Variable(batch_test[2])
                    query_y = Variable(batch_test[3])

                # get accuracy
                # test_acc = net.train(support_x, support_y, query_x, query_y, train=False)
                test_accs.append(test_acc)

            test_acc = np.array(test_accs).mean()
            print(
                "episode:",
                episode_num,
                "\tfinetune acc:%.6f" % train_acc,
                "\t\ttest acc:%.6f" % test_acc,
            )
            tb.add_scalar("test-acc", test_acc)
            tb.add_scalar("finetune-acc", train_acc)
コード例 #10
0
def evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file):
    """
    obey the expriment setting of MAML and Learning2Compare, we randomly sample 600 episodes and 15 query images per query
    set.
    :param net:
    :param batchsz:
    :return:
    """
    k_query = 15
    mini_val = MiniImagenet(
        "../mini-imagenet/",
        mode="test",
        n_way=n_way,
        k_shot=k_shot,
        k_query=k_query,
        batchsz=600,
        resize=imgsz,
    )
    db_val = DataLoader(mini_val, batchsz, shuffle=True, num_workers=2, pin_memory=True)

    accs = []
    episode_num = 0  # record tested num of episodes

    for batch_test in db_val:
        # [60, setsz, c_, h, w]
        # setsz = (5 + 15) * 5
        support_x = Variable(batch_test[0]).cuda()
        support_y = Variable(batch_test[1]).cuda()
        query_x = Variable(batch_test[2]).cuda()
        query_y = Variable(batch_test[3]).cuda()

        # we will split query set into 15 splits.
        # query_x : [batch, 15*way, c_, h, w]
        # query_x_b : tuple, 15 * [b, way, c_, h, w]
        query_x_b = torch.chunk(query_x, k_query, dim=1)
        # query_y : [batch, 15*way]
        # query_y_b: 15* [b, way]
        query_y_b = torch.chunk(query_y, k_query, dim=1)
        preds = []
        net.eval()
        # we don't need the total acc on 600 episodes, but we need the acc per sets of 15*nway setsz.
        total_correct = 0
        total_num = 0
        total_loss = 0
        for query_x_mini, query_y_mini in zip(query_x_b, query_y_b):
            # print('query_x_mini', query_x_mini.size(), 'query_y_mini', query_y_mini.size())
            loss, pred, correct = net(
                support_x, support_y, query_x_mini.contiguous(), query_y_mini, False
            )
            correct = correct.sum()  # multi-gpu
            # pred: [b, nway]
            preds.append(pred)
            total_correct += correct.data[0]
            total_num += query_y_mini.size(0) * query_y_mini.size(1)

            total_loss += loss.data[0]

        # # 15 * [b, nway] => [b, 15*nway]
        # preds = torch.cat(preds, dim= 1)
        acc = total_correct / total_num
        print("%.5f," % acc, end=" ")
        sys.stdout.flush()
        accs.append(acc)

        # update tested episode number
        episode_num += query_y.size(0)
        if episode_num > episodesz:
            # test current tested episodes acc.
            acc = np.array(accs).mean()
            if acc >= threhold:
                # if current acc is very high, we conduct all 600 episodes testing.
                continue
            else:
                # current acc is low, just conduct `episodesz` num of episodes.
                break

    # compute the distribution of 600/episodesz episodes acc.
    global best_accuracy
    accs = np.array(accs)
    accuracy, sem = mean_confidence_interval(accs)
    print("\naccuracy:", accuracy, "sem:", sem)
    print("<<<<<<<<< accuracy:", accuracy, "best accuracy:", best_accuracy, ">>>>>>>>")

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(net.state_dict(), mdl_file)
        print("Saved to checkpoint:", mdl_file)

    # we only take the last one batch as avg_loss
    total_loss = total_loss / n_way / k_query

    global global_test_loss_buff, global_test_acc_buff
    global_test_loss_buff = total_loss
    global_test_acc_buff = accuracy
    write2file(n_way, k_shot)

    return accuracy, sem
コード例 #11
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    ## Task Learner Setup
    task_config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 2, 0]),
                   ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 2, 0]),
                   ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 2, 0]),
                   ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]),
                   ('bn', [32]), ('max_pool2d', [2, 1, 0]), ('flatten', []),
                   ('linear', [args.n_way, 32 * 5 * 5])]
    last_epoch = 0
    suffix = "_v0"
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(
            args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt'
    while os.path.isfile(save_path):
        valid_epoch = last_epoch
        last_epoch += 500
        save_path = os.getcwd() + '/data/model_batchsz' + str(
            args.k_model) + '_stepsz' + str(
                args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt'
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(
            args.update_lr) + '_epoch' + str(valid_epoch) + suffix + '.pt'

    device = torch.device('cuda')
    task_mod = Meta(args, task_config).to(device)
    task_mod.load_state_dict(torch.load(save_path))
    task_mod.eval()

    ## AL Learner Setup
    print(args)

    al_config = [('linear', [1, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = AL_Learner(args, al_config, task_mod).to(device)
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('/home/tesca/data/miniimagenet/',
                        mode='train',
                        n_way=args.n_way,
                        k_shot=1,
                        k_query=args.k_qry,
                        batchsz=10000,
                        resize=args.imgsz)
    mini_test = MiniImagenet('/home/tesca/data/miniimagenet/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=1,
                             k_query=args.k_qry,
                             batchsz=100,
                             resize=args.imgsz)
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(args.update_lr) + '_epoch'

    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            al_accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\tAL acc:', al_accs)

            if step % 500 == 0:  # evaluation
                torch.save(maml.state_dict(),
                           save_path + str(step) + "_al_net.pt")
                '''db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
コード例 #12
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    mini_test = MiniImagenet('F:\\ACV_project\\MAML-Pytorch\\miniimagenet\\', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)


    ckpt_dir = "./model/"

    for epoch in range(args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)


        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)

                # save checkpoints
                os.makedirs(ckpt_dir, exist_ok=True)
                print('Saving the model as a checkpoint...')
                torch.save({'epoch': epoch, 'Steps': step, 'model': maml.state_dict()}, os.path.join(ckpt_dir, 'checkpoint.pth'))
コード例 #13
0
def main():
    torch.manual_seed(222)  # 为cpu设置种子,为了使结果是确定的
    torch.cuda.manual_seed_all(222)  # 为GPU设置种子,为了使结果是确定的
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 1, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 7040])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)
    if os.path.exists(
            "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")):
        path = "./models/" + str("./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(args.k_spt) + "shot.pkl")
        maml.load_state_dict(path)
        print("load model success")

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000)
    mini_test = MiniImagenet("./miniimagenet", mode='val', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100)
    test_accuracy = []
    for epoch in range(10):
        # fetch meta_batchsz num of episode each time
        db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
        accs_all_test = []

        for x_spt, y_spt, x_qry, y_qry in db_test:
            x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                         x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

            accs, loss_t = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
            accs_all_test.append(accs)

        # [b, update_step+1]
        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
        print('Test acc:', accs)
        test_accuracy.append(accs[-1])
    average_accuracy = sum(test_accuracy) / len(test_accuracy)
    print("average accuracy:{}".format(average_accuracy))
コード例 #14
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way', default=5)
    argparser.add_argument('-k', help='k shot', default=1)
    argparser.add_argument('-b', help='meta batch size', default=4)
    argparser.add_argument('-l', help='meta learning rate', default=1e-3)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    meta_batchsz = int(args.b)
    meta_lr = float(args.l)

    train_lr = 1e-2
    k_query = 15
    imgsz = 84
    K = 5  # update steps
    mdl_file = 'ckpt/miniimagenet%d%d.mdl' % (n_way, k_shot)
    print('mini-imagenet: %d-way %d-shot meta-lr:%f, train-lr:%f K-steps:%d' %
          (n_way, k_shot, meta_lr, train_lr, K))

    device = torch.device('cuda:2')
    net = MAML(n_way, k_shot, k_query, meta_batchsz, K, meta_lr,
               train_lr).to(device)
    print(net)

    for epoch in range(1000):
        # batchsz here means total episode number
        mini = MiniImagenet('/data/miniimagenet/',
                            mode='train',
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=10000,
                            resize=imgsz)
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        meta_batchsz,
                        shuffle=True,
                        num_workers=meta_batchsz,
                        pin_memory=True)

        for step, batch in enumerate(db):

            support_x = batch[0].to(device)
            support_y = batch[1].to(device)
            query_x = batch[2].to(device)
            query_y = batch[3].to(device)

            accs = net(support_x, support_y, query_x, query_y, training=True)

            if step % 50 == 0:
                print("epoch: {}, step: {}, {}accuracy: {}".format(
                    epoch, step, '\t', accs))

            if step % 500 == 0 and step != 0:  # evaluation
                # test for 600 episodes
                mini_test = MiniImagenet('/data/miniimagenet/',
                                         mode='test',
                                         n_way=n_way,
                                         k_shot=k_shot,
                                         k_query=k_query,
                                         batchsz=600,
                                         resize=imgsz)
                db_test = DataLoader(mini_test,
                                     meta_batchsz,
                                     shuffle=True,
                                     num_workers=meta_batchsz,
                                     pin_memory=True)
                accs_all_test = []
                for batch in db_test:
                    support_x = batch[0].to(device)
                    support_y = batch[1].to(device)
                    query_x = batch[2].to(device)
                    query_y = batch[3].to(device)

                    accs = net(support_x,
                               support_y,
                               query_x,
                               query_y,
                               training=False)
                    accs_all_test.append(accs)
                # [600, K+1]
                accs_all_test = np.array(accs_all_test)
                # [600, K+1] => [K+1]
                means = accs_all_test.mean(axis=0)
                # compute variance for last step K
                m, h = mean_confidence_interval(accs_all_test[:, K])
                print('>>Test:\t', means, 'variance[K]: %.4f' % h, '<<')
コード例 #15
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)
#     df_spt = pd.read_csv('support_set.csv')
#     df_qry = pd.read_csv('query_set.csv')

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
    ]
    cuda = 'cuda:' + str(args.gpu_index)
    device = torch.device(cuda)
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('./flower/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
    mini_test = MiniImagenet('./flower/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100, resize=args.imgsz)
    accs_list_tr = []
    accs_list_ts = []
    for epoch in range(args.epoch//10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)
                accs_list_tr.append(accs)

            if step % 500 == 0 or (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]):  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('step:', step, '\ttest acc:', accs)
                accs_list_ts.append(accs)
                if (step == 10000//args.task_num - 1) & (epoch == range(args.epoch//10000)[-1]):
                    with open('data/result_natural(acc).txt', mode='a') as f:
                        f.write(str(accs) + str(args.task_num) + '\n')
                    with open('model.pkl', 'wb') as f:
                        cloudpickle.dump(maml, f)
コード例 #16
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='train',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=10000,
        resize=args.imgsz)
    mini_val = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='val',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=600,
        resize=args.imgsz)
    mini_test = MiniImagenet(
        '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet/',
        mode='test',
        n_way=args.n_way,
        k_shot=args.k_spt,
        k_query=args.k_qry,
        batchsz=600,
        resize=args.imgsz)

    best_acc = 0.0
    if not os.path.exists('ckpt/{}'.format(args.exp)):
        os.mkdir('ckpt/{}'.format(args.exp))
    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 500 == 0:
                print('step:', step, '\ttraining acc:', accs)
            if step % 1000 == 0:  # evaluation
                db_val = DataLoader(mini_val,
                                    1,
                                    shuffle=True,
                                    num_workers=1,
                                    pin_memory=True)
                accs_all_val = []
                for x_spt, y_spt, x_qry, y_qry in db_val:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_val.append(accs)
                mean, std, ci95 = cal_conf(np.array(accs_all_val))
                print('Val acc:{}, std:{}. ci95:{}'.format(
                    mean[-1], std[-1], ci95[-1]))
                if mean[-1] > best_acc or step % 5000 == 0:
                    best_acc = mean[-1]
                    torch.save(
                        maml.state_dict(),
                        'ckpt/{}/model_e{}s{}_{:.4f}.pkl'.format(
                            args.exp, epoch, step, best_acc))
                    with open('ckpt/' + args.exp + '/val.txt', 'a') as f:
                        print(
                            'val epoch {}, step {}: acc_val:{:.4f}, ci95:{:.4f}'
                            .format(epoch, step, best_acc, ci95[-1]),
                            file=f)

                    ## Test
                    db_test = DataLoader(mini_test,
                                         1,
                                         shuffle=True,
                                         num_workers=1,
                                         pin_memory=True)
                    accs_all_test = []
                    for x_spt, y_spt, x_qry, y_qry in db_test:
                        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                     x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                        accs_all_test.append(accs)
                    mean, std, ci95 = cal_conf(np.array(accs_all_test))
                    print('Test acc:{}, std:{}, ci95:{}'.format(
                        mean[-1], std[-1], ci95[-1]))
                    with open('ckpt/' + args.exp + '/test.txt', 'a') as f:
                        print(
                            'test epoch {}, step {}: acc_test:{:.4f}, ci95:{:.4f}'
                            .format(epoch, step, mean[-1], ci95[-1]),
                            file=f)
コード例 #17
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    #print(args)

    if args.train_al == 1:
        last_layer = ('linear', [args.n_way + 1, 32 * 5 * 5])
    else:
        last_layer = ('linear', [args.n_way, 32 * 5 * 5])

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), last_layer]

    last_epoch = 0
    if args.train_al == 1:
        suffix = "_al"
        #suffix = "-trained_al"
    else:
        suffix = "_og"
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(
            args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt'
    while os.path.isfile(save_path):
        valid_epoch = last_epoch
        last_epoch += 500
        save_path = os.getcwd() + '/data/model_batchsz' + str(
            args.k_model) + '_stepsz' + str(
                args.update_lr) + '_epoch' + str(last_epoch) + suffix + '.pt'
    save_path = os.getcwd() + '/data/model_batchsz' + str(
        args.k_model) + '_stepsz' + str(
            args.update_lr) + '_epoch' + str(valid_epoch) + suffix + '.pt'

    device = torch.device('cuda')
    mod = Meta(args, config).to(device)
    mod.load_state_dict(torch.load(save_path))
    mod.eval()

    #tmp = filter(lambda x: x.requires_grad, maml.parameters())
    #num = sum(map(lambda x: np.prod(x.shape), tmp))
    #print(maml)
    #print('Total trainable tensors:', num)

    # batchsz here means total episode number
    if args.merge_spt_qry == 1:
        mini_test = MiniImagenet('/home/tesca/data/miniimagenet/',
                                 mode='test',
                                 n_way=args.n_way,
                                 k_shot=1,
                                 k_query=args.k_qry,
                                 batchsz=10,
                                 resize=args.imgsz)
    else:
        mini_test = MiniImagenet('/home/tesca/data/miniimagenet/',
                                 mode='test',
                                 n_way=args.n_way,
                                 k_shot=args.k_spt,
                                 k_query=args.k_qry,
                                 batchsz=10,
                                 resize=args.imgsz)
    db_test = DataLoader(mini_test,
                         1,
                         shuffle=True,
                         num_workers=1,
                         pin_memory=True)
    accs_all_test = []
    total_accs = []
    best_accs = []

    it = 0
    for x_spt, y_spt, x_qry, y_qry in db_test:
        if args.merge_spt_qry == 1:
            x_spt = x_qry
            y_spt = y_qry
        sys.stdout.write("\rTest %i" % it)
        sys.stdout.flush()
        it += 1
        x_spt, y_spt = x_spt.squeeze(0), y_spt.squeeze(0)
        x_qry_pt, y_qry_pt = x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(
            device)
        x_spt_pt, y_spt_pt = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(
            device)
        #x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
        #                             x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
        qs = QuerySelector()
        finished = False
        all_queries, all_orderings = [], []
        while not finished:
            maml = deepcopy(mod)
            queries = None  #[[], []]
            avail_cands = list(range(x_spt.shape[0]))
            test_accs = []
            q_idx = []
            w = None
            for i in range(x_spt.shape[0]):
                q = qs.query(args,
                             maml.net,
                             w=w,
                             cands=x_spt_pt,
                             avail_cands=avail_cands,
                             targets=x_qry_pt,
                             k=1,
                             method=args.al_method,
                             classification=True)
                q_idx.append(q)
                avail_cands.remove(q)
                if queries is None:
                    queries = [
                        np.array(x_spt[q].unsqueeze(0)),
                        np.array(y_spt[q].unsqueeze(0))
                    ]
                else:
                    queries = [
                        np.concatenate(
                            [queries[0],
                             np.array(x_spt[q].unsqueeze(0))]),
                        np.concatenate(
                            [queries[1],
                             np.array(y_spt[q].unsqueeze(0))])
                    ]
                xs, ys = torch.from_numpy(
                    queries[0]).to(device), torch.from_numpy(
                        queries[1]).to(device)
                #accs,w = maml.finetunning(x_spt_pt, y_spt_pt, x_qry_pt, y_qry_pt)
                accs, w = maml.finetunning(xs, ys, x_qry_pt, y_qry_pt)
                accs_all_test.append(accs)
                if len(test_accs) == 0:
                    test_accs.append(accs[0])
                test_accs.append(accs[-1])
            all_orderings.append(test_accs)
            all_queries.append(q_idx)
            del maml
            finished = not qs.next_order()
        total_accs.append(np.mean(np.array(all_orderings), axis=0))
        best_accs.append(all_queries[np.argmax(
            np.sum(np.array(all_orderings), axis=1))])
        oq = [
            all_queries[i]
            for i in np.argsort(np.sum(np.array(all_orderings), axis=1))
        ]
        #pdb.set_trace()
        oq = all_orderings[np.argsort(np.sum(np.array(all_orderings),
                                             axis=1))[-1]]
        #total_accs.append(oq)

    # [b, update_step+1]
    accs = np.array(total_accs).mean(axis=0).astype(np.float16)
    #accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
    print('Test acc:', accs)
コード例 #18
0
ファイル: main.py プロジェクト: purbayankar/Reptile-Pytorch-1
def main():
	meta_batchsz = 32
	n_way = 20
	k_shot = 1
	k_query = k_shot
	meta_lr = 1e-3
# 	meta_lr = 1
	num_updates = 5
	dataset = 'omniglot'



	if dataset == 'omniglot':
		imgsz = 28
		db = OmniglotNShot('dataset', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz)

	elif dataset == 'mini-imagenet':
		imgsz = 84
		# the dataset loaders are different from omniglot to mini-imagenet. for omniglot, it just has one loader to use
		# get_batch(train or test) to get different batch.
		# for mini-imagenet, it should have two dataloader, one is train_loader and another is test_loader.
		mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=10000, resize=imgsz)
		db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True)
		mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=1000, resize=imgsz)
		db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True)

	else:
		raise  NotImplementedError


	meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr,
	                   num_updates=num_updates).cuda()

	tb = SummaryWriter('runs')


	# main loop
	for episode_num in range(1500):

		# 1. train
		if dataset == 'omniglot':
			support_x, support_y, query_x, query_y = db.get_batch('test')
			support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
			query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
			support_y = Variable(torch.from_numpy(support_y).long()).cuda()
			query_y = Variable(torch.from_numpy(query_y).long()).cuda()
		elif dataset == 'mini-imagenet':
			try:
				batch_test = iter(db).next()
			except StopIteration as err:
				mini = MiniImagenet('../mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
				                    batchsz=10000, resize=imgsz)
				db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True)

			support_x = Variable(batch_test[0]).cuda()
			support_y = Variable(batch_test[1]).cuda()
			query_x = Variable(batch_test[2]).cuda()
			query_y = Variable(batch_test[3]).cuda()

		# backprop has been embeded in forward func.
		if episode_num % 100 = 0:
			meta.prv_angle = 0
		accs = meta(support_x, support_y, query_x, query_y)
		train_acc = np.array(accs).mean()

		# 2. test
		if episode_num % 30 == 0:
			test_accs = []
			for i in range(min(episode_num // 5000 + 3, 10)): # get average acc.
				if dataset == 'omniglot':
					support_x, support_y, query_x, query_y = db.get_batch('test')
					support_x = Variable( torch.from_numpy(support_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
					query_x = Variable( torch.from_numpy(query_x).float().transpose(2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1)).cuda()
					support_y = Variable(torch.from_numpy(support_y).long()).cuda()
					query_y = Variable(torch.from_numpy(query_y).long()).cuda()
				elif dataset == 'mini-imagenet':
					try:
						batch_test = iter(db_test).next()
					except StopIteration as err:
						mini_test = MiniImagenet('../mini-imagenet/', mode='test', n_way=n_way, k_shot=k_shot,
						                         k_query=k_query,
						                         batchsz=1000, resize=imgsz)
						db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=2, pin_memory=True)
					support_x = Variable(batch_test[0]).cuda()
					support_y = Variable(batch_test[1]).cuda()
					query_x = Variable(batch_test[2]).cuda()
					query_y = Variable(batch_test[3]).cuda()
 

				# get accuracy
				test_acc = meta.pred(support_x, support_y, query_x, query_y)
				test_accs.append(test_acc)

			test_acc = np.array(test_accs).mean()
			print('episode:', episode_num, '\tfinetune acc:%.6f' % train_acc, '\t\ttest acc:%.6f' % test_acc)
			tb.add_scalar('test-acc', test_acc)
			tb.add_scalar('finetune-acc', train_acc)
コード例 #19
0
def main():
    torch.manual_seed(222)  # 为cpu设置种子,为了使结果是确定的
    torch.cuda.manual_seed_all(222)  # 为GPU设置种子,为了使结果是确定的
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 1, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 7040])
    ]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet("./miniimagenet", mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000)
    mini_test = MiniImagenet("./miniimagenet", mode='test', n_way=args.n_way, k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100)
    last_accuracy = 0
    plt_train_loss = []
    plt_train_acc = []

    plt_test_loss = []
    plt_test_acc =[]
    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, loss_q = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                d = loss_q.cpu()
                dd = d.detach().numpy()
                plt_train_loss.append(dd)
                plt_train_acc.append(accs[-1])
                print('step:', step, '\ttraining acc:', accs)

            if step % 50 == 0:  # evaluation
                db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                accs_all_test = []
                loss_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, loss_test= maml.finetunning(x_spt, y_spt, x_qry, y_qry)

                    loss_all_test.append(loss_test)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                plt_test_acc.append(accs[-1])
                avg_loss = np.mean(np.array(loss_all_test))
                plt_test_loss.append(avg_loss)

                print('Test acc:', accs)
                test_accuracy = np.mean(np.array(accs))
                if test_accuracy > last_accuracy:
                    # save networks
                    torch.save(maml.state_dict(), str(
                        "./models/miniimagenet_maml" + str(args.n_way) + "way_" + str(
                            args.k_spt) + "shot.pkl"))
                    last_accuracy = test_accuracy
    plt.figure()
    plt.title("testing info")
    plt.xlabel("episode")
    plt.ylabel("Acc/loss")
    plt.plot(plt_test_loss, label='Loss')
    plt.plot(plt_test_acc, label='Acc')
    plt.legend(loc='upper right')
    plt.savefig('./drawing/test.png')
    plt.show()

    plt.figure()
    plt.title("training info")
    plt.xlabel("episode")
    plt.ylabel("Acc/loss")
    plt.plot(plt_train_loss, label='Loss')
    plt.plot(plt_train_acc, label='Acc')
    plt.legend(loc='upper right')
    plt.savefig('./drawing/train.png')
    plt.show()
コード例 #20
0
        net.load_state_dict(torch.load(mdl_file))

    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('total params:', params)

    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    tb = SummaryWriter('runs', str(datetime.now()))

    best_accuracy = 0
    for epoch in range(1000):

        mini = MiniImagenet('../mini-imagenet/',
                            mode='train',
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=10000,
                            resize=224)
        db = DataLoader(mini,
                        batchsz,
                        shuffle=True,
                        num_workers=8,
                        pin_memory=True)
        mini_val = MiniImagenet('../mini-imagenet/',
                                mode='val',
                                n_way=n_way,
                                k_shot=k_shot,
                                k_query=k_query,
                                batchsz=200,
                                resize=224)
コード例 #21
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-n", help="n way")
    argparser.add_argument("-k", help="k shot")
    argparser.add_argument("-b", help="batch size")
    argparser.add_argument("-l", help="learning rate", default=1e-3)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    batchsz = int(args.b)
    lr = float(args.l)

    k_query = 1
    imgsz = 224
    threhold = (
        0.699 if k_shot == 5 else 0.584
    )  # threshold for when to test full version of episode
    mdl_file = "ckpt/naive5_3x3%d%d.mdl" % (n_way, k_shot)
    print(
        "mini-imagnet: %d-way %d-shot lr:%f, threshold:%f"
        % (n_way, k_shot, lr, threhold)
    )

    global global_buff
    if os.path.exists("mini%d%d.pkl" % (n_way, k_shot)):
        global_buff = pickle.load(open("mini%d%d.pkl" % (n_way, k_shot), "rb"))
        print("load pkl buff:", len(global_buff))

    net = nn.DataParallel(Naive5(n_way, k_shot, imgsz), device_ids=[0, 1, 2]).cuda()
    print(net)

    if os.path.exists(mdl_file):
        print("load from checkpoint ...", mdl_file)
        net.load_state_dict(torch.load(mdl_file))
    else:
        print("training from scratch.")

    # whole parameters number
    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("Total params:", params)

    # build optimizer and lr scheduler
    optimizer = optim.Adam(net.parameters(), lr=lr)
    # optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, nesterov=True)
    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer, "max", factor=0.5, patience=25, verbose=True
    )

    for epoch in range(1000):
        mini = MiniImagenet(
            "../mini-imagenet/",
            mode="train",
            n_way=n_way,
            k_shot=k_shot,
            k_query=k_query,
            batchsz=10000,
            resize=imgsz,
        )
        db = DataLoader(mini, batchsz, shuffle=True, num_workers=8, pin_memory=True)
        total_train_loss = 0
        total_train_correct = 0
        total_train_num = 0

        for step, batch in enumerate(db):
            # 1. test
            if step % 300 == 0:
                # evaluation(net, batchsz, n_way, k_shot, imgsz, episodesz, threhold, mdl_file):
                accuracy, sem = evaluation(
                    net, batchsz, n_way, k_shot, imgsz, 600, threhold, mdl_file
                )
                scheduler.step(accuracy)

            # 2. train
            support_x = Variable(batch[0]).cuda()
            support_y = Variable(batch[1]).cuda()
            query_x = Variable(batch[2]).cuda()
            query_y = Variable(batch[3]).cuda()

            net.train()
            loss, pred, correct = net(support_x, support_y, query_x, query_y)
            loss = loss.sum() / support_x.size(0)  # multi-gpu, divide by total batchsz
            total_train_loss += loss.data[0]
            total_train_correct += correct.data[0]
            total_train_num += support_y.size(0) * n_way  # k_query = 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 3. print
            if step % 20 == 0 and step != 0:
                acc = total_train_correct / total_train_num
                total_train_correct = 0
                total_train_num = 0

                print(
                    "%d-way %d-shot %d batch> epoch:%d step:%d, loss:%.4f, train acc:%.4f"
                    % (n_way, k_shot, batchsz, epoch, step, total_train_loss, acc)
                )
                total_train_loss = 0

                global global_train_loss_buff, global_train_acc_buff
                global_train_loss_buff = loss.data[0] / (n_way * k_shot)
                global_train_acc_buff = acc
                write2file(n_way, k_shot)
コード例 #22
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [('conv2d', [32, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 0]), ('relu', [True]), ('bn', [32]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [32]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    # batchsz here means total episode number
    mini = MiniImagenet('./data/',
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000,
                        resize=args.img_sz)
    mini_test = MiniImagenet('./data/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=100,
                             resize=args.img_sz)

    for epoch in range(args.epoch // 10000):
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print('step:', step, '\ttraining acc:', accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(mini_test,
                                     1,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
コード例 #23
0
def main():
    saver = Saver(args)
    # set log
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p',
                        filename=os.path.join(saver.experiment_dir, 'log.txt'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # set seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True

    # set saver
    saver.create_exp_dir(scripts_to_save=glob.glob('*.py') +
                         glob.glob('*.sh') + glob.glob('*.yml'))
    saver.save_experiment_config()
    summary = TensorboardSummary(saver.experiment_dir)
    writer = summary.create_summary()

    logging.info(args)

    device = torch.device('cuda')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device)
    ''' Compute FLOPs and Params '''
    maml = Meta(args, criterion)
    flops, params = get_model_complexity_info(maml.model, (3, 84, 84),
                                              as_strings=False,
                                              print_per_layer_stat=True,
                                              verbose=True)
    logging.info('FLOPs: {} MMac Params: {}'.format(flops / 10**6, params))

    maml = Meta(args, criterion).to(device)
    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    #logging.info(maml)
    logging.info('Total trainable tensors: {}'.format(num))

    # batch_size here means total episode number
    mini = MiniImagenet(args.data_path,
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batch_size=args.batch_size,
                        resize=args.img_size)
    mini_test = MiniImagenet(args.data_path,
                             mode='val',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batch_size=args.test_batch_size,
                             resize=args.img_size)
    train_loader = DataLoader(mini,
                              args.meta_batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    test_loader = DataLoader(mini_test,
                             args.meta_test_batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)

    # load pretrained model and inference
    if args.pretrained_model:
        checkpoint = torch.load(args.pretrained_model)
        if isinstance(maml.model, torch.nn.DataParallel):
            maml.module.load_state_dict(checkpoint['state_dict'])
        else:
            maml.load_state_dict(checkpoint['state_dict'])

        if args.evaluate:
            test_accs = meta_test(test_loader, maml, device,
                                  checkpoint['epoch'])
            logging.info('[Epoch: {}]\t Test acc: {}'.format(
                checkpoint['epoch'], test_accs))
            return

    # Start training
    for epoch in range(args.epoch):
        # fetch batch_size num of episode each time
        logging.info('--------- Epoch: {} ----------'.format(epoch))

        train_accs = meta_train(train_loader, maml, device, epoch, writer,
                                test_loader, saver)
        logging.info('[Epoch: {}]\t Train acc: {}'.format(epoch, train_accs))
コード例 #24
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way', default=5)
    argparser.add_argument('-k', help='k shot', default=1)
    argparser.add_argument('-b', help='batch size', default=4)
    argparser.add_argument('-l', help='meta learning rate', default=1e-3)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    meta_batchsz = int(args.b)
    meta_lr = float(args.l)
    train_lr = 1e-2

    k_query = 15
    imgsz = 84
    mdl_file = 'ckpt/miniimagenet%d%d.mdl' % (n_way, k_shot)
    print('mini-imagnet: %d-way %d-shot meta-lr:%f, train-lr:%f' %
          (n_way, k_shot, meta_lr, train_lr))

    device = torch.device('cuda:0')
    net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr,
               device)
    print(net)

    for epoch in range(1000):
        # batchsz here means total episode number
        mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/',
                            mode='train',
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=10000,
                            resize=imgsz)
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        meta_batchsz,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)

        for step, batch in enumerate(db):

            # 2. train
            support_x = batch[0].to(device)
            support_y = batch[1].to(device)
            query_x = batch[2].to(device)
            query_y = batch[3].to(device)

            accs = net(support_x, support_y, query_x, query_y, training=True)

            if step % 50 == 0:
                print(epoch, step, '\t', accs)

            if step % 1000 == 0 and step != 0:  # batchsz here means total episode number
                mini_test = MiniImagenet(
                    '/hdd1/liangqu/datasets/miniimagenet/',
                    mode='test',
                    n_way=n_way,
                    k_shot=k_shot,
                    k_query=k_query,
                    batchsz=600,
                    resize=imgsz)
                # fetch meta_batchsz num of episode each time
                db_test = DataLoader(mini_test,
                                     meta_batchsz,
                                     shuffle=True,
                                     num_workers=4,
                                     pin_memory=True)
                accs_all_test = []
                for batch in db_test:
                    support_x = batch[0].to(device)
                    support_y = batch[1].to(device)
                    query_x = batch[2].to(device)
                    query_y = batch[3].to(device)

                    accs = net(support_x,
                               support_y,
                               query_x,
                               query_y,
                               training=True)
                    accs_all_test.append(accs)
                # [600, K+1]
                accs_all_test = np.array(accs_all_test)
                # [600, K+1] => [K+1]
                accs_all_test = accs_all_test.mean(axis=0)
                print('>>Test:\t', accs_all_test, '<<')
コード例 #25
0
def main():
    # Manually seed torch and numpy for reproducible results
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # Open csv file to write for metric logging
    try:
        f = open("results.csv", "w")
    except FileNotFoundError:
        f = open("results.csv", "x")
    f.write("Steps,tr_loss,tr_acc,val_loss,val_acc,te_loss,te_acc\n")

    # Choose PyTorch device and create the model
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = Meta(args, defineModel(args)).to(device)

    # Setup Weights and Biases logger, config hyperparams and watch model
    wandb.init(project="Meta-SGD")
    name = f"N{args.n_way}K{args.k_spt}"
    if args.update_step > 0:
        name += "Meta"
    name += f"Ret{args.ret_channels}VVS{args.vvs_depth}KS{args.kernel_size}"
    wandb.run.name = name
    wandb.config.update(args)
    wandb.watch(model)
    print(f"RUN NAME: {name}")

    # Print additional information on the model
    if args.verbose:
        tmp = filter(lambda x: x.requires_grad, model.parameters())
        num = sum(map(lambda x: np.prod(x.shape), tmp))
        print(args)
        print(model)
        print('Total trainable tensors:', num)

    # Create datasets
    # batchsz here means total episode number
    print("\nGathering Datasets:")
    mini = MiniImagenet('miniimagenet/',
                        mode='train',
                        n_way=args.n_way,
                        k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000,
                        resize=args.imgsz)
    mini_train_eval = MiniImagenet('miniimagenet/',
                                   mode='train',
                                   n_way=args.n_way,
                                   k_shot=args.k_spt,
                                   k_query=args.k_qry,
                                   batchsz=args.eval_steps,
                                   resize=args.imgsz)
    mini_val = MiniImagenet('miniimagenet/',
                            mode='val',
                            n_way=args.n_way,
                            k_shot=args.k_spt,
                            k_query=args.k_qry,
                            batchsz=args.eval_steps,
                            resize=args.imgsz)
    mini_test = MiniImagenet('miniimagenet/',
                             mode='test',
                             n_way=args.n_way,
                             k_shot=args.k_spt,
                             k_query=args.k_qry,
                             batchsz=args.eval_steps,
                             resize=args.imgsz)

    print("\nMeta-Training:")
    pruning_factor = args.pruning
    best_tr_acc, best_val_acc, best_te_acc = 0, 0, 0
    epoch_bar = tqdm(range(args.epoch // 10000),
                     desc="Training",
                     total=len(range(args.epoch // 10000)))
    for epoch in epoch_bar:
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        args.task_num,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

        task_bar = tqdm(enumerate(db),
                        desc=f"Epoch {epoch}",
                        total=len(db),
                        leave=False)
        for step, (x_spt, y_spt, x_qry, y_qry) in task_bar:
            total_steps = len(db) * epoch + step + 1

            # Perform training for each task
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
                device), x_qry.to(device), y_qry.to(device)
            model(x_spt, y_spt, x_qry, y_qry)

            if total_steps % args.save_summary_steps == 0:  # evaluation
                # Get evaluation metrics
                tr_acc, tr_loss = evaluate(model, mini_train_eval, device,
                                           "Eval Train")
                val_acc, val_loss = evaluate(model, mini_val, device,
                                             "Eval Val")
                te_acc, te_loss = evaluate(model, mini_test, device)

                # Update Task tqdm bar
                metrics = {
                    'tr acc': tr_acc,
                    'val_acc': val_acc,
                    'te_acc': te_acc
                }
                task_bar.set_postfix(metrics)
                metrics['tr_loss'] = tr_loss
                metrics['val_loss'] = val_loss
                metrics['te_loss'] = te_loss
                wandb.log(metrics)
                f.write(
                    f"{total_steps},{tr_loss},{tr_acc},{val_loss},{val_acc},{te_loss},{te_acc}\n"
                )

                # Update best metrics
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    pruning_factor = args.pruning
                else:
                    pruning_factor -= 1
                best_te_acc = max(te_acc, best_te_acc)
                best_tr_acc = max(tr_acc, best_tr_acc)

                if pruning_factor == 0: break

                # Update tqdm
                epoch_bar.set_postfix({
                    'b_tr_acc': best_tr_acc,
                    'b_val_acc': best_val_acc,
                    'b_te_acc': best_te_acc,
                    'prune': pruning_factor
                })

        if pruning_factor == 0: break

    f.close()