def main():
    n_way = 5
    k_shot = 1
    k_query = 15
    batchsz = 5

    mdfile1 = './ckpy/feature-%d-way-%d-shot.pkl' % (n_way, k_shot)
    mdfile2 = './ckpy/relation-%d-way-%d-shot.pkl' % (n_way, k_shot)

    mini = MiniImagenet('./mini-imagenet/',
                        mode='test',
                        n_way=n_way,
                        k_shot=k_shot,
                        k_query=k_query,
                        batchsz=2000,
                        resize=84)  # 训练是,batchsz = 200
    db = DataLoader(mini, batch_size=batchsz, num_workers=0, pin_memory=False)

    feature_embed = CNNEncoder().cuda()
    Relation_score = RelationNetWork(64, 8).cuda()  # relation_dim == 8 ??

    if os.path.exists(mdfile1):
        print("file1-feature exit...")
        feature_embed.load_state_dict(torch.load(mdfile1))
    if os.path.exists(mdfile2):
        print("f2-relation exit...")
        Relation_score.load_state_dict(torch.load(mdfile2))

    for ts in range(3):
        total_correct = 0
        total_num = 0
        accuracy = 0
        accuarcies = []

        for i, batch in enumerate(db):
            support_x = Variable(
                batch[0]).cuda()  # [batch_size, n_way*k_shot, c , h , w]
            support_y = Variable(batch[1]).cuda()
            query_x = Variable(batch[2]).cuda()
            query_y = Variable(batch[3]).cuda()  # [b, n_way * q ]

            bh, set1, c, h, w = support_x.size()
            set2 = query_x.size(1)

            support_xf = feature_embed(support_x.view(bh * set1, c, h,
                                                      w)).view(
                                                          bh, set1, 64, 19, 19)
            query_xf = feature_embed(query_x.view(bh * set2, c, h, w)).view(
                bh, set2, 64, 19, 19)

            support_xf = support_xf.unsqueeze(1).expand(
                bh, set2, set1, 64, 19, 19)
            query_xf = query_xf.unsqueeze(2).expand(bh, set2, set1, 64, 19, 19)

            comb = torch.cat((support_xf, query_xf), dim=3)

            score = Relation_score(comb.view(bh * set2 * set1, 64 * 2, 19,
                                             19)).view(bh, set2, set1)

            # score_np = score.cpu().data.numpy()

            support_y_np = support_y.cpu().data.numpy()
            rn_score_np = score.cpu().data.numpy()  # 转numpy cpu
            pred = []
            # for ii,bb in enumerate(score_np):
            #             #     for jj,bset in enumerate(bb):
            #             #         sim = []
            #             #         for way in range(n_way):
            #             #             sim.append(np.sum(bset[way*k_shot:(way+1)*k_shot]))
            #             #         idx = np.array(sim).argmax()
            #             #         pred.append(support_y_np[ii,k_shot*idx])
            #             # pred = Variable(torch.from_numpy(np.array(pred).reshape(bh,set2))).cuda()
            #             #
            #             # correct += torch.eq(pred,query_y).sum()
            #             # total += query_y.size(0)*query_y.size(1)
            #             # accuarcy = float(correct)/float(total)
            #             # print("epoch",ts,"i-batch",i,"acc:",accuarcy)
            #             # accuarcies.append(accuarcy)
            for ii, tb in enumerate(rn_score_np):
                for jj, tset in enumerate(tb):
                    sim = []
                    for way in range(n_way):
                        sim.append(
                            np.sum(tset[way * k_shot:(way + 1) * k_shot]))

                    idx = np.array(sim).argmax()
                    pred.append(support_y_np[ii, idx *
                                             k_shot])  # 同一个类标签相同 ,注意还有batch维度
                    # ×k_shot是因为,上一个步用sum将k_shot压缩了

                # 此时的pred.size = [b.set2]
                # print("pred.size=", np.array(pred).shape)
            pred = Variable(torch.from_numpy(np.array(pred).reshape(
                bh, set2))).cuda()
            correct = torch.eq(pred, query_y).sum()

            total_correct += correct.data[0]
            total_num += query_y.size(0) * query_y.size(1)

        accuracy = total_correct / total_num
        print("epoch", ts, "acc:", accuracy)
        accuarcies.append(accuracy)
        test_accuracy, h = mean_confidence_interval(accuarcies)
        print("test accuracy:", test_accuracy, "h:", h)
def main():

    n_way = 5
    k_shot = 1
    k_query = 15
    batchsz = 5
    best_acc = 0
    mdfile1 = './ckpy/feature-%d-way-%d-shot.pkl' %(n_way,k_shot)
    mdfile2 = './ckpy/relation-%d-way-%d-shot.pkl' %(n_way,k_shot)
    feature_embed = CNNEncoder().cuda()
    Relation_score = RelationNetWork(64, 8).cuda()  # relation_dim == 8 ??

    feature_embed.apply(weight_init)
    Relation_score.apply(weight_init)

    feature_optim = torch.optim.Adam(feature_embed.parameters(), lr=0.001)
    relation_opim = torch.optim.Adam(Relation_score.parameters(), lr=0.001)

    loss_fn = torch.nn.MSELoss().cuda()

    if os.path.exists(mdfile1):
         print("load mdfile1...")
         feature_embed.load_state_dict(torch.load(mdfile1))
    if os.path.exists(mdfile2):
         print("load mdfile2...")
         Relation_score.load_state_dict(torch.load(mdfile2))

    for epoch in range(1000):
        mini = MiniImagenet('./mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=1000, resize=84)  #38400
        db = DataLoader(mini,batch_size=batchsz,shuffle=True,num_workers=4,pin_memory=True)  # 64 , 5*(1+15) , c, h, w
        mini_val = MiniImagenet('./mini-imagenet/', mode='val', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=200, resize=84)   #9600
        db_val = DataLoader(mini_val,batch_size=batchsz,shuffle=True,num_workers=4,pin_memory=True)


        for step,batch in enumerate(db):
            support_x = Variable(batch[0]).cuda()   # [batch_size, n_way*(k_shot+k_query), c , h , w]
            support_y = Variable(batch[1]).cuda()
            query_x = Variable(batch[2]).cuda()
            query_y = Variable(batch[3]).cuda()

            bh,set1,c,h,w = support_x.size()
            set2 = query_x.size(1)

            feature_embed.train()
            Relation_score.train()

            support_xf = feature_embed(support_x.view(bh*set1,c,h,w)).view(bh,set1,64,19,19)                 # 在 test 的 时候 重复
            query_xf = feature_embed(query_x.view(bh*set2,c,h,w)).view(bh,set2,64,19,19)

            # print("query_f:", query_xf.size())

            support_xf = support_xf.unsqueeze(1).expand(bh,set2,set1,64,19,19)
            query_xf = query_xf.unsqueeze(2).expand(bh,set2,set1,64,19,19)

            comb = torch.cat((support_xf,query_xf),dim=3)       # bh,set2,set1,2c,h,w
            # print(comb.is_cuda)
            # print(comb.view(bh*set2*set1,2*64,19,19).is_cuda)
            score = Relation_score(comb.view(bh*set2*set1,2*64,19,19)).view(bh,set2,set1,1).squeeze(3)

            support_yf = support_y.unsqueeze(1).expand(bh,set2,set1)
            query_yf = query_y.unsqueeze(2).expand(bh,set2,set1)
            label = torch.eq(support_yf,query_yf).float()

            feature_optim.zero_grad()
            relation_opim.zero_grad()

            loss = loss_fn(score,label)
            loss.backward()

            #torch.nn.utils.clip_grad_norm(feature_embed.parameters(),0.5)  # 梯度裁剪? 降低学习率?
            #torch.nn.utils.clip_grad_norm(Relation_score.parameters(),0.5)

            feature_optim.step()
            relation_opim.step()

            # if step%100==0:
            #     print("step:",epoch+1,"train_loss: ",loss.data[0])
            logger.log_value('{}-way-{}-shot loss:'.format(n_way, k_shot),loss.data[0])

            if step%200==0:
                print("---------test--------")

                total_correct = 0
                total_num = 0
                accuracy = 0
                for j,batch_test in enumerate(db_val):
                    # if (j%100==0):
                    #     print(j,'-------------')
                    support_x = Variable(batch_test[0]).cuda()
                    support_y = Variable(batch_test[1]).cuda()
                    query_x = Variable(batch_test[2]).cuda()
                    query_y = Variable(batch_test[3]).cuda()

                    bh,set1,c,h,w = support_x.size()
                    set2 = query_x.size(1)

                    feature_embed.eval()
                    Relation_score.eval()

                    support_xf = feature_embed(support_x.view(bh*set1,c,h,w)).view(bh,set1,64,19,19)                 # 在 test 的 时候 重复
                    query_xf = feature_embed(query_x.view(bh*set2,c,h,w)).view(bh,set2,64,19,19)

                    support_xf = support_xf.unsqueeze(1).expand(bh,set2,set1,64,19,19)
                    query_xf = query_xf.unsqueeze(2).expand(bh,set2,set1,64,19,19)

                    comb = torch.cat((support_xf,query_xf),dim=3)       # bh,set2,set1,2c,h,w
                    score = Relation_score(comb.view(bh*set2*set1,2*64,19,19)).view(bh,set2,set1,1).squeeze(3)

                    rn_score_np = score.cpu().data.numpy()                                                      # 转numpy cpu
                    pred = []
                    support_y_np = support_y.cpu().data.numpy()

                    for ii,tb in enumerate(rn_score_np):
                        for jj,tset in enumerate(tb):
                            sim = []
                            for way in range(n_way):
                                sim.append(np.sum(tset[way*k_shot:(way+1)*k_shot]))

                            idx = np.array(sim).argmax()
                            pred.append(support_y_np[ii,idx*k_shot])                 # 同一个类标签相同 ,注意还有batch维度
                                                                                     # ×k_shot是因为,上一个步用sum将k_shot压缩了

                    #此时的pred.size = [b.set2]
                    #print("pred.size=", np.array(pred).shape)
                    pred = Variable(torch.from_numpy(np.array(pred).reshape(bh,set2))).cuda()
                    correct = torch.eq(pred,query_y).sum()

                    total_correct += correct.data[0]
                    total_num += query_y.size(0)*query_y.size(1)

                accuracy = total_correct/total_num
                logger.log_value('acc : ',accuracy)
                print("epoch:",epoch,"acc:",accuracy)
                if accuracy>best_acc:
                    print("-------------------epoch",epoch,"step:",step,"acc:",accuracy,"---------------------------------------")
                    best_acc = accuracy
                    torch.save(feature_embed.state_dict(),mdfile1)
                    torch.save(Relation_score.state_dict(),mdfile2)

            #if step% == 0 and step != 0:
             #   print("%d-way %d-shot %d batch | epoch:%d step:%d, loss:%f" %(n_way,k_shot,batchsz,epoch,step,loss.cpu().data[0]))
    logger.step()
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    if len(args.gpu.split(',')) > 1:
        maml = torch.nn.DataParallel(maml)
    state_dict = torch.load(Param.out_path + args.ckpt)
    print(state_dict.keys())
    pretrained_dict = OrderedDict()
    for k in state_dict.keys():
        if n_gpus == -1:
            pretrained_dict[k[7:]] = deepcopy(state_dict[k])
        else:
            pretrained_dict[k[0:]] = deepcopy(state_dict[k])
    maml.load_state_dict(pretrained_dict)
    print("Load from ckpt:", Param.out_path + args.ckpt)

    #opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    if args.loader in [0, 1]:  # default loader
        if args.loader == 1:
            #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet
            from MiniImagenet2 import MiniImagenet
        else:
            from MiniImagenet import MiniImagenet

        testset = MiniImagenet(Param.root,
                               mode='test',
                               n_way=args.n_way,
                               k_shot=args.k_spt,
                               k_query=args.k_qry,
                               resize=args.imgsz)
        testloader = DataLoader(testset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=4,
                                worker_init_fn=worker_init_fn,
                                drop_last=True)
        test_data = inf_get(testloader)

    elif args.loader == 2:  # pkl loader
        args_data = {}
        args_data['x_dim'] = "84,84,3"
        args_data['ratio'] = 1.0
        args_data['seed'] = 222
        loader_test = dataset_mini(600, 100, 'test', args_data)
        loader_test.load_data_pkl()
    """Test for 600 epochs (each has 4 tasks)"""
    ans = None
    maml_clone = deepcopy(maml)
    for itr in range(600):  # 600x4 test tasks
        if args.loader in [0, 1]:
            support_x, support_y, qx, qy = test_data.__next__()
            support_x, support_y, qx, qy = support_x.to(
                Param.device), support_y.to(Param.device), qx.to(
                    Param.device), qy.to(Param.device)
        elif args.loader == 2:
            support_x, support_y, qx, qy = get_data(loader_test)
            support_x, support_y, qx, qy = support_x.to(
                Param.device), support_y.to(Param.device), qx.to(
                    Param.device), qy.to(Param.device)

        temp = maml_clone(support_x, support_y, qx, qy, meta_train=False)
        if (ans is None):
            ans = temp
        else:
            ans = torch.cat([ans, temp], dim=0)
        if itr % 100 == 0:
            print(itr, ans.mean(dim=0).tolist())
    meanacc = np.array(ans.mean(dim=0).tolist())
    stdacc = np.array(ans.std(dim=0).tolist())
    ci95 = 1.96 * stdacc / np.sqrt(600)
    print(f'Acc: {meanacc[-1]:.4f}, ci95: {ci95[-1]:.4f}')
    with open(Param.out_path + 'test.txt', 'w') as f:
        print(f'Acc: {meanacc[-1]:.4f}, ci95: {ci95[-1]:.4f}', file=f)
        x = self.layer1(x)
        # print("layer 1",x.size())
        x = self.layer2(x)
        # print("layer 2",x.size())
        x = self.layer3(x)
        # print("layer 3",x.size())

        x = self.layer4(x)

        return x  # [bz * (way*(s+q)), 64, 19,19]


from torch.autograd import Variable

if __name__ == "__main__":
    mini = MiniImagenet(root='./mini-imagenet/',
                        mode='train',
                        batchsz=100,
                        n_way=5,
                        k_shot=5,
                        k_query=5,
                        resize=84,
                        startidx=0)
    for i, m in enumerate(mini):
        support_x, support_y, query_x, query_y = m
        print(i, support_x.size())
        support_x = Variable(support_x).cuda()
        net = CNNEncoder().cuda()
        ans = net(support_x)
        print(ans.size())
        print("--------")
Exemplo n.º 5
0
def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)

    config = [('conv2d', [32, 3, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]),
              ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [32, 32, 3, 3, 1, 1]), ('bn', [32]), ('relu', [True]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [32, 32, 3, 3, 1, 1]),
              ('bn', [32]), ('relu', [True]), ('max_pool2d', [2, 2, 0]),
              ('flatten', []), ('linear', [args.n_way, 32 * 5 * 5])]

    device = torch.device('cuda')
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    root = '/mnt/aitrics_ext/ext01/yanbin/MAML-Pytorch/data/miniImagenet'
    trainset = MiniImagenet(root,
                            mode='train',
                            n_way=args.n_way,
                            k_shot=args.k_spt,
                            k_query=args.k_qry,
                            resize=args.imgsz)
    testset = MiniImagenet(root,
                           mode='test',
                           n_way=args.n_way,
                           k_shot=args.k_spt,
                           k_query=args.k_qry,
                           resize=args.imgsz)
    trainloader = DataLoader(trainset,
                             batch_size=args.task_num,
                             shuffle=True,
                             num_workers=4,
                             worker_init_fn=worker_init_fn,
                             drop_last=True)
    testloader = DataLoader(testset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=1,
                            worker_init_fn=worker_init_fn,
                            drop_last=True)
    train_data = inf_get(trainloader)
    test_data = inf_get(testloader)

    best_acc = 0.0
    if not os.path.exists('ckpt/{}'.format(args.exp)):
        os.mkdir('ckpt/{}'.format(args.exp))
    for epoch in range(args.epoch):
        np.random.seed()
        x_spt, y_spt, x_qry, y_qry = train_data.__next__()
        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(
            device), x_qry.to(device), y_qry.to(device)

        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if epoch % 100 == 0:
            print('epoch:', epoch, '\ttraining acc:', accs)

        if epoch % 2500 == 0:  # evaluation
            # save checkpoint
            torch.save(maml.state_dict(),
                       'ckpt/{}/model_{}.pkl'.format(args.exp, epoch))
            accs_all_test = []
            for _ in range(600):
                x_spt, y_spt, x_qry, y_qry = test_data.__next__()
                x_spt, y_spt, x_qry, y_qry = x_spt.squeeze().to(
                    device), y_spt.squeeze().to(device), x_qry.squeeze().to(
                        device), y_qry.squeeze().to(device)
                accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                accs_all_test.append(accs)

            # [b, update_step+1]
            accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
            with open('ckpt/' + args.exp + '/test.txt', 'a') as f:
                print('test epoch {}: acc:{:.4f}'.format(epoch, accs[-1]),
                      file=f)
def main():
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    #np.random.seed(222)
    test_result = {}
    best_acc = 0.0

    maml = Meta(args, Param.config).to(Param.device)
    if len(args.gpu.split(',')) > 1:
        maml = torch.nn.DataParallel(maml)
    opt = optim.Adam(maml.parameters(), lr=args.meta_lr)
    #opt = optim.SGD(maml.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=args.weight_decay)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print('Total trainable tensors:', num)

    if args.loader in [0, 1]:  # default loader
        if args.loader == 1:
            #from dataloader.mini_imagenet import MiniImageNet as MiniImagenet
            from MiniImagenet2 import MiniImagenet
        else:
            from MiniImagenet import MiniImagenet

        trainset = MiniImagenet(Param.root,
                                mode='train',
                                n_way=args.n_way,
                                k_shot=args.k_spt,
                                k_query=args.k_qry,
                                resize=args.imgsz)
        testset = MiniImagenet(Param.root,
                               mode='test',
                               n_way=args.n_way,
                               k_shot=args.k_spt,
                               k_query=args.k_qry,
                               resize=args.imgsz)
        trainloader = DataLoader(trainset,
                                 batch_size=args.task_num,
                                 shuffle=True,
                                 num_workers=4,
                                 worker_init_fn=worker_init_fn,
                                 drop_last=True)
        testloader = DataLoader(testset,
                                batch_size=1,
                                shuffle=True,
                                num_workers=1,
                                worker_init_fn=worker_init_fn,
                                drop_last=True)
        train_data = inf_get(trainloader)
        test_data = inf_get(testloader)

    elif args.loader == 2:  # pkl loader
        args_data = {}
        args_data['x_dim'] = "84,84,3"
        args_data['ratio'] = 1.0
        args_data['seed'] = 222
        loader_train = dataset_mini(600, 100, 'train', args_data)
        #loader_val   = dataset_mini(600, 100, 'val', args_data)
        loader_test = dataset_mini(600, 100, 'test', args_data)
        loader_train.load_data_pkl()
        #loader_val.load_data_pkl()
        loader_test.load_data_pkl()

    for epoch in range(args.epoch):
        np.random.seed()
        if args.loader in [0, 1]:
            support_x, support_y, meta_x, meta_y = train_data.__next__()
            support_x, support_y, meta_x, meta_y = support_x.to(
                Param.device), support_y.to(Param.device), meta_x.to(
                    Param.device), meta_y.to(Param.device)
        elif args.loader == 2:
            support_x, support_y, meta_x, meta_y = get_data(loader_train)
            support_x, support_y, meta_x, meta_y = support_x.to(
                Param.device), support_y.to(Param.device), meta_x.to(
                    Param.device), meta_y.to(Param.device)

        meta_loss = maml(support_x, support_y, meta_x, meta_y).mean()
        opt.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_value_(maml.parameters(), clip_value=10.0)
        opt.step()
        plot.plot('meta_loss', meta_loss.item())

        if (epoch % 2500 == 0):
            ans = None
            maml_clone = deepcopy(maml)
            for _ in range(600):
                if args.loader in [0, 1]:
                    support_x, support_y, qx, qy = test_data.__next__()
                    support_x, support_y, qx, qy = support_x.to(
                        Param.device), support_y.to(Param.device), qx.to(
                            Param.device), qy.to(Param.device)
                elif args.loader == 2:
                    support_x, support_y, qx, qy = get_data(loader_test)
                    support_x, support_y, qx, qy = support_x.to(
                        Param.device), support_y.to(Param.device), qx.to(
                            Param.device), qy.to(Param.device)

                temp = maml_clone(support_x,
                                  support_y,
                                  qx,
                                  qy,
                                  meta_train=False)
                if (ans is None):
                    ans = temp
                else:
                    ans = torch.cat([ans, temp], dim=0)
            ans = ans.mean(dim=0).tolist()
            test_result[epoch] = ans
            if (ans[-1] > best_acc):
                best_acc = ans[-1]
                torch.save(
                    maml.state_dict(), Param.out_path + 'net_' + str(epoch) +
                    '_' + str(best_acc) + '.pkl')
            del maml_clone
            print(str(epoch) + ': ' + str(ans))
            with open(Param.out_path + 'test.json', 'w') as f:
                json.dump(test_result, f)
        if (epoch < 5) or (epoch % 100 == 0):
            plot.flush()
        plot.tick()