Exemple #1
0
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # * Step 1: init data folders
    print("init data folders")

    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    RFT = RandomForestClassifier(n_estimators=100,
                                 random_state=1,
                                 n_jobs=-1,
                                 warm_start=True)
    #RFT = RandomForestClassifier(n_estimators=100, random_state=1, max_depth=10, n_jobs=-1, warm_start=True)
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.to(device)
    relation_network.to(device)

    cross_entropy = nn.CrossEntropyLoss()

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)

    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        RFT = pickle.load(
            open(
                str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb'))
        print("load random forest success")

    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    # * Step 3: build graph
    print("Training...")

    last_accuracy = 0.0
    last_RFT_accuracy = 0
    test_RFT_accuracy = 0
    # embedding_loss_list = []
    RFT_loss_list = []
    relation_loss_list = []
    loss_list = []
    RFT_fit_index = 100

    for episode in range(EPISODE):

        # * init dataset
        # * sample_dataloader is to obtain previous samples for compare
        # * batch_dataloader is to batch samples for training
        # degrees = random.choice([0, 90, 180, 270])
        number_of_query_image = 10
        task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, number_of_query_image)
        sample_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        batch_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=number_of_query_image,
            split="test",
            shuffle=True)
        # * num_per_class : number of query images

        # * sample datas
        # samples, sample_labels = sample_dataloader.__iter__().next()
        # batches, batch_labels = batch_dataloader.__iter__().next()

        samples, sample_labels = next(iter(sample_dataloader))
        batches, batch_labels = next(iter(batch_dataloader))

        RFT_batches, RFT_batch_labels = batches, batch_labels

        samples, sample_labels = samples.to(device), sample_labels.to(device)
        batches, batch_labels = batches.to(device), batch_labels.to(device)

        # * calculates features
        sample_features = feature_encoder(samples)
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 19, 19)
        # * Testing : mean vs sum
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        batch_features = feature_encoder(batches)

        # * calculate relations
        # * each batch sample link to every samples to calculate relations
        # * to form a 100 * 128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 19, 19)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        if episode > 30000:
            RFT_fit_index = 1000

        if episode % RFT_fit_index == 0:
            RFT.fit(relations.detach().cpu(), RFT_batch_labels)
            RFT.n_estimators += 1

        # one_hot_labels = torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).to(device).scatter_(1, batch_labels.view(-1, 1), 1)
        RFT_prob = torch.tensor(RFT.predict_proba(
            relations.detach().cpu())).to(device)
        _, RFT_labels = torch.max(RFT_prob, 1)
        RFT_loss = cross_entropy(relations, RFT_labels) * 0.7

        relation_loss = cross_entropy(relations, batch_labels)
        loss = relation_loss + RFT_loss

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        if (episode + 1) % 100 == 0:
            print(
                f"episode : {episode+1}, loss : {loss.cpu().detach().numpy()}")
            loss_list.append(loss.cpu().detach().numpy())
            RFT_loss_list.append(RFT_loss.cpu().detach().numpy())
            relation_loss_list.append(relation_loss.cpu().detach().numpy())

        if (episode + 1) % 1000 == 0:
            print("Testing...")
            total_reward = 0

            for i in range(TEST_EPISODE):
                number_of_query_image = 10
                task = tg.MiniImagenetTask(metatest_character_folders,
                                           CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                           number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=number_of_query_image,
                    split="test",
                    shuffle=True)

                sample_images, sample_labels = next(iter(sample_dataloader))
                test_images, test_labels = next(iter(test_dataloader))

                sample_images, sample_labels = sample_images.to(
                    device), sample_labels.to(device)
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)

                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
                ]
                total_reward += np.sum(rewards)

                if i % 200 == 0:
                    RFT_predict = RFT.predict(relations.detach().cpu())
                    assert RFT_predict.shape == test_labels.detach().cpu(
                    ).shape
                    print(
                        accuracy_score(RFT_predict,
                                       test_labels.detach().cpu()))
                    test_RFT_accuracy += accuracy_score(
                        RFT_predict,
                        test_labels.detach().cpu())

            test_accuracy = total_reward / (
                1.0 * CLASS_NUM * SAMPLE_NUM_PER_CLASS * TEST_EPISODE)
            test_RFT_accuracy /= (TEST_EPISODE // 200)
            # print("test accuracy : ", test_accuracy)
            print(f"{test_RFT_accuracy:.3f} %")
            mean_loss = np.mean(loss_list)
            mean_RFT_loss = np.mean(RFT_loss_list)
            mean_relation_loss = np.mean(relation_loss_list)

            print(f'mean loss : {mean_loss}')
            print(f'RFT loss : {mean_RFT_loss}')
            # writer.add_scalar('1.embedding loss', mean_embedding_loss, episode + 1)
            writer.add_scalar('1.RFT loss', mean_RFT_loss, episode + 1)
            writer.add_scalar('RFT_accuracy', test_RFT_accuracy, episode + 1)
            writer.add_scalar('2.relation loss', mean_relation_loss,
                              episode + 1)
            writer.add_scalar('loss', mean_loss, episode + 1)
            writer.add_scalar('test accuracy', test_accuracy, episode + 1)

            loss_list = []
            # embedding_loss_list = []
            relation_loss_list = []
            RFT_loss_list = []
            if test_RFT_accuracy > last_RFT_accuracy:
                pickle.dump(
                    RFT,
                    open(
                        str("./models/miniimagenet_random_forest_" +
                            str(CLASS_NUM) + "way_" +
                            str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'wb'))
                last_RFT_accuracy = test_RFT_accuracy
                print("save random forest for episode:", episode)

            test_RFT_accuracy = 0

            print("test accuracy : ", test_accuracy)

            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot_exp.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/miniimagenet_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot_exp.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
Exemple #2
0
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction,遍历train、val所有类别并将其顺序打乱,metatrain_folders,metatest_folders分别记录训练、验证用list
    metatrain_folders, metatest_folders = tg.mini_imagenet_folders()

    # Step 2: init neural networks,分特征编码网络和关系网络
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)

    feature_encoder.apply(
        weights_init
    )  ## pytorch网络初始化https://blog.csdn.net/dss_dssssd/article/details/83990511
    relation_network.apply(weights_init)

    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)  #Adam
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training,共100类,每类600个样本,存在有一些类别一直没有参与训练的问题,此处可以略作完善
        task = tg.MiniImagenetTask(
            metatrain_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS,
            BATCH_NUM_PER_CLASS
        )  #从train所有文件夹中随机抽5个文件夹,每个文件夹均随机取1张训练15张测试,每张图均有lable范围(0~4)
        #如下从train中取[[0],[1],[2],[3],[4]]
        sample_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False)  #内部用的均值方差比较特别,如何计算得到的
        #如下从test中取[[0~15],[15~30],[30~45],[45~60],[60~75]]
        batch_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # sample datas,https://blog.csdn.net/g11d111/article/details/81504637
        samples, sample_labels = sample_dataloader.__iter__().next(
        )  ##这一步具体做什么,取出的数是什么样子的
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # 5x64*19*19 ##注释错了,原来注释为5*64*5*5
        batch_features = feature_encoder(Variable(batches).cuda(
            GPU))  # 75x64*19*19#####注释错了,而原来注释为20*64*5*5)

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1,
            1)  ##75*5*64*19*19,repeat是扩增,
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)  ##5*75*64*19*19
        batch_features_ext = torch.transpose(batch_features_ext, 0,
                                             1)  #transpose是矩阵转制,75*5*64*19*19
        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 19,
                                           19)  #375*128*19*19
        relations = relation_network(relation_pairs).view(
            -1, CLASS_NUM * SAMPLE_NUM_PER_CLASS)  ##75*5

        mse = nn.MSELoss().cuda(
            GPU)  ##scatter_解析https://www.cnblogs.com/shiyublog/p/10924287.html
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        #clip_grad_norm梯度裁剪:在BP过程中会产生梯度消失(就是偏导无限接近0,导致长时记忆无法更新),那么最简单粗暴的方法,设定阈值,当梯度小于阈值时,更新的梯度为阈值
        #https://blog.csdn.net/qq_29340857/article/details/70574528
        torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print(
                "episode:", episode + 1, "loss",
                loss.item())  #print("episode:",episode+1,"loss",loss.data[0])

        if episode % 5000 == 0:

            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                counter = 0
                task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM, 1, 15)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task, num_per_class=1, split="train", shuffle=False)

                num_per_class = 3
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=True)
                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                for test_images, test_labels in test_dataloader:
                    test_labels = test_labels.cuda()
                    batch_size = test_labels.shape[0]
                    # calculate features
                    sample_features = feature_encoder(
                        Variable(sample_images).cuda(GPU))  # 5x64
                    test_features = feature_encoder(
                        Variable(test_images).cuda(GPU))  # 20x64

                    # calculate relations
                    # each batch sample link to every samples to calculate relations
                    # to form a 100x128 matrix for relation network
                    sample_features_ext = sample_features.unsqueeze(0).repeat(
                        batch_size, 1, 1, 1, 1)
                    test_features_ext = test_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    test_features_ext = torch.transpose(
                        test_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (sample_features_ext, test_features_ext),
                        2).view(-1, FEATURE_DIM * 2, 19, 19)
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)

                    _, predict_labels = torch.max(relations.data, 1)

                    rewards = [
                        1 if predict_labels[j] == test_labels[j] else 0
                        for j in range(batch_size)
                    ]

                    total_rewards += np.sum(rewards)
                    counter += batch_size
                accuracy = total_rewards / 1.0 / counter
                accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            print("test accuracy:", test_accuracy, "h:", h)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/miniimagenet_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
Exemple #3
0
def main():
    # Step 1: init data folders
    print("init data folders")
    metatrain_folders,metatest_folders = tg.mini_imagenet_folders()
    # init character folders for dataset construction

    # Step 2: init neural networks
    print("init neural networks")
    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork()

    feature_encoder_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(LEARNING_RATE,100000,0.5,staircase=True)
    feature_encoder_optim = tf.keras.optimizers.Adam(learning_rate=0.001,epsilon=1e-08)
    relation_network_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(LEARNING_RATE,100000,0.5,staircase=True)
    relation_network_optim = tf.keras.optimizers.Adam(learning_rate=0.001,epsilon=1e-08)

    if os.path.exists(str("models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot")):
        feature_encoder = tf.keras.models.load_model(str("models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot"))
        print("load feature encoder success")
    if os.path.exists(str("models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot")):
        relation_network = tf.keras.models.load_model(str("models/miniimagenet_relation_network_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot"))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):
        task = tg.MiniImagenetTask(metatrain_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
        sample_dataset = tg.dataset(task,SAMPLE_NUM_PER_CLASS,split='train',shuffle=False)
        batch_dataset = tg.dataset(task,BATCH_NUM_PER_CLASS,split='test',shuffle=True)

        sample_dataloader = tf.data.Dataset.from_generator(sample_dataset.generator, output_types=(tf.float32, tf.float32), output_shapes = ((84,84,3),(5,1))).batch(SAMPLE_NUM_PER_CLASS*CLASS_NUM).take(1)
        batch_dataloader = tf.data.Dataset.from_generator(batch_dataset.generator, output_types=(tf.float32, tf.float32), output_shapes = ((84,84,3),(5,1))).batch(BATCH_NUM_PER_CLASS*CLASS_NUM).take(1)

        samples,sample_labels = next(iter(sample_dataloader))
        batches,batch_labels = next(iter(batch_dataloader))

        loss = train_one_step(feature_encoder, relation_network, feature_encoder_optim, relation_network_optim, samples, sample_labels, batches, batch_labels).numpy()

        if (episode+1)%100 == 0:
            print("episode:",episode+1,"loss",loss)

        if episode%5000 == 0:
            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):      
                task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
                sample_dataset = tg.dataset(task,SAMPLE_NUM_PER_CLASS,split='train',shuffle=False)
                batch_dataset = tg.dataset(task,5,split='test',shuffle=True)

                sample_dataloader = tf.data.Dataset.from_generator(sample_dataset.generator, output_types=(tf.float32, tf.float32), output_shapes = ((84,84,3),(5,1))).batch(SAMPLE_NUM_PER_CLASS*CLASS_NUM).take(1)
                batch_dataloader = tf.data.Dataset.from_generator(batch_dataset.generator, output_types=(tf.float32, tf.float32), output_shapes = ((84,84,3),(5,1))).batch(5*CLASS_NUM).take(1)

                samples,sample_labels = next(iter(sample_dataloader))
                batches,batch_labels = next(iter(batch_dataloader)) 

                accuracies.append(test(feature_encoder, relation_network, samples, sample_labels, batches, batch_labels))
            
            test_accuracy,h = mean_confidence_interval(accuracies)
            
            print("test accuracy:",test_accuracy,"h:",h)

            if test_accuracy > last_accuracy:

                # save networks
                feature_encoder.save(str("models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot"),save_format='tf')
                relation_network.save(str("models/miniimagenet_relation_network_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot"),save_format='tf')

                print("save networks for episode:",episode)

                last_accuracy = test_accuracy
def main():
    print("init data folders")
    metatrain_folders, metaquery_folders = tg.mini_imagenet_folders()

    print("init neural networks")
    foreground_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    background_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    mixture_network = models.MixtureNetwork().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(
        FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU)

    vanilla_foreground_encoder = models.FeatureEncoder().apply(
        weights_init).cuda(GPU)
    vanilla_background_encoder = models.FeatureEncoder().apply(
        weights_init).cuda(GPU)
    vanilla_mixture_network = models.MixtureNetwork().apply(weights_init).cuda(
        GPU)

    foreground_encoder_optim = torch.optim.Adam(
        foreground_encoder.parameters(), lr=LEARNING_RATE)
    foreground_encoder_scheduler = StepLR(foreground_encoder_optim,
                                          step_size=100000,
                                          gamma=0.5)
    background_encoder_optim = torch.optim.Adam(
        background_encoder.parameters(), lr=LEARNING_RATE)
    background_encoder_scheduler = StepLR(background_encoder_optim,
                                          step_size=100000,
                                          gamma=0.5)
    mixture_network_optim = torch.optim.Adam(mixture_network.parameters(),
                                             lr=LEARNING_RATE)
    mixture_network_scheduler = StepLR(mixture_network_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    # Loading models
    if os.path.exists(
            str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        foreground_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_foreground_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load foreground encoder success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        background_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_background_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load background encoder success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        mixture_network.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_mixture_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load mixture network success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

# Loading vanilla models
    if os.path.exists(
            str("./vanilla_models/miniImagenet_foreground_encoder_" +
                str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                "shot.pkl")):
        vanilla_foreground_encoder.load_state_dict(
            torch.load(
                str("./vanilla_models/miniImagenet_foreground_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load vanilla foreground encoder success")
    if os.path.exists(
            str("./vanilla_models/miniImagenet_background_encoder_" +
                str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                "shot.pkl")):
        vanilla_background_encoder.load_state_dict(
            torch.load(
                str("./vanilla_models/miniImagenet_background_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load vanilla background encoder success")
    if os.path.exists(
            str("./vanilla_models/miniImagenet_mixture_network_" +
                str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                "shot.pkl")):
        vanilla_mixture_network.load_state_dict(
            torch.load(
                str("./vanilla_models/miniImagenet_mixture_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load vanilla mixture network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    print("Training...")

    best_accuracy = 0.0
    start = time.time()
    for episode in range(EPISODE):
        mse = nn.MSELoss().cuda(GPU)

        foreground_encoder_scheduler.step(episode)
        background_encoder_scheduler.step(episode)
        mixture_network_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SUPPORT_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        support_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SUPPORT_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        query_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # support datas
        support_img, support_sal, support_labels = support_dataloader.__iter__(
        ).next()
        query_img, query_sal, query_labels = query_dataloader.__iter__().next()
        # calculate features
        support_foreground_features = foreground_encoder(
            Variable(support_img * support_sal).cuda(GPU)).view(
                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
        support_background_features = background_encoder(
            Variable(support_img * (1 - support_sal)).cuda(GPU)).view(
                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
        query_foreground_features = foreground_encoder(
            Variable(query_img * query_sal).cuda(GPU))
        query_background_features = background_encoder(
            Variable(query_img * (1 - query_sal)).cuda(GPU))

        # Real-Representation Regularization (TriR), teacher network
        support_foreground_features_ = vanilla_foreground_encoder(
            Variable(support_img * support_sal).cuda(GPU))
        support_background_features_ = vanilla_background_encoder(
            Variable(support_img * (1 - support_sal)).cuda(GPU))
        support_mix_features_ = vanilla_mixture_network(
            support_foreground_features_ + support_background_features_)
        support_mix_features__ = mixture_network(
            (support_foreground_features + support_background_features).view(
                -1, 64, 19, 19))
        TriR = args.beta * mse(
            support_mix_features__,
            Variable(support_mix_features_, requires_grad=False))

        # Inter-class Hallucination
        support_foreground_features = support_foreground_features.unsqueeze(
            2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1, 1, 1)
        support_background_features = support_background_features.view(
            1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19,
            19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1, 1)
        similarity_measure = similarity_func(
            support_background_features, CLASS_NUM,
            SUPPORT_NUM_PER_CLASS).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1,
                                        1, 1)

        support_mix_features = mixture_network(
            (support_foreground_features + support_background_features).view(
                (CLASS_NUM**2) * (SUPPORT_NUM_PER_CLASS**2), 64, 19,
                19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 64, 19**2)
        support_mix_features = (support_mix_features *
                                similarity_measure).sum(2).sum(1)
        query_mix_features = mixture_network(query_foreground_features +
                                             query_background_features).view(
                                                 -1, 64, 19**2)
        so_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64,
                                                    64)).cuda(GPU)
        so_query_features = Variable(
            torch.Tensor(BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU)

        # second-order features
        for d in range(support_mix_features.size()[0]):
            s = support_mix_features[d, :, :].squeeze(0)
            s = (1.0 / support_mix_features.size()[2]) * s.mm(s.transpose(
                0, 1))
            so_support_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)
        for d in range(query_mix_features.size()[0]):
            s = query_mix_features[d, :, :].squeeze(0)
            s = (1.0 / query_mix_features.size()[2]) * s.mm(s.transpose(0, 1))
            so_query_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)

        # calculate relations with 64x64 second-order features
        support_features_ext = so_support_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = so_query_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = torch.transpose(query_features_ext, 0, 1)
        relation_pairs = torch.cat((support_features_ext, query_features_ext),
                                   2).view(-1, 2, 64, 64)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, query_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels) + TriR

        # update network parameters
        foreground_encoder.zero_grad()
        background_encoder.zero_grad()
        mixture_network.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(foreground_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(background_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(mixture_network.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        foreground_encoder_optim.step()
        background_encoder_optim.step()
        mixture_network_optim.step()
        relation_network_optim.step()

        if np.mod(episode + 1, 100) == 0:
            print("episode:", episode + 1, "loss", loss.item())

        if np.mod(episode, 2500) == 0:
            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                counter = 0
                task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM,
                                           SUPPORT_NUM_PER_CLASS, 15)
                support_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SUPPORT_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                num_per_class = 2
                query_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=True)
                support_img, support_sal, support_labels = support_dataloader.__iter__(
                ).next()
                for query_img, query_sal, query_labels in query_dataloader:
                    query_size = query_labels.shape[0]
                    # calculate foreground and background features
                    support_foreground_features = foreground_encoder(
                        Variable(support_img * support_sal).cuda(GPU)).view(
                            CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
                    support_background_features = background_encoder(
                        Variable(
                            support_img * (1 - support_sal)).cuda(GPU)).view(
                                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
                    query_foreground_features = foreground_encoder(
                        Variable(query_img * query_sal).cuda(GPU))
                    query_background_features = background_encoder(
                        Variable(query_img * (1 - query_sal)).cuda(GPU))

                    # Inter-class Hallucination
                    support_foreground_features = support_foreground_features.unsqueeze(
                        2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1,
                                  1, 1)
                    support_background_features = support_background_features.view(
                        1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19,
                        19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1,
                                   1)
                    similarity_measure = similarity_func(
                        support_background_features, CLASS_NUM,
                        SUPPORT_NUM_PER_CLASS).view(CLASS_NUM,
                                                    SUPPORT_NUM_PER_CLASS, -1,
                                                    1, 1)
                    support_mix_features = mixture_network(
                        (support_foreground_features +
                         support_background_features).view(
                             (CLASS_NUM * SUPPORT_NUM_PER_CLASS)**2, 64, 19,
                             19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1,
                                       64, 19**2)
                    support_mix_features = (support_mix_features *
                                            similarity_measure).sum(2).sum(1)
                    query_mix_features = mixture_network(
                        query_foreground_features +
                        query_background_features).view(-1, 64, 19**2)
                    so_support_features = Variable(
                        torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                    so_query_features = Variable(
                        torch.Tensor(query_size, 1, 64, 64)).cuda(GPU)

                    # second-order features
                    for d in range(support_mix_features.size()[0]):
                        s = support_mix_features[d, :, :].squeeze(0)
                        s = (1.0 / support_mix_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        so_support_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)
                    for d in range(query_mix_features.size()[0]):
                        s = query_mix_features[d, :, :].squeeze(0)
                        s = (1.0 / query_mix_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        so_query_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)

                    # calculate relations with 64x64 second-order features
                    support_features_ext = so_support_features.unsqueeze(
                        0).repeat(query_size, 1, 1, 1, 1)
                    query_features_ext = so_query_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    query_features_ext = torch.transpose(
                        query_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (support_features_ext, query_features_ext),
                        2).view(-1, 2, 64, 64)
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)
                    _, predict_labels = torch.max(relations.data, 1)
                    rewards = [
                        1 if predict_labels[j] == query_labels[j].cuda(GPU)
                        else 0 for j in range(query_size)
                    ]
                    total_rewards += np.sum(rewards)
                    counter += query_size

                accuracy = total_rewards / 1.0 / counter
                accuracies.append(accuracy)
            test_accuracy, h = mean_confidence_interval(accuracies)
            print("test accuracy:", test_accuracy, "h:", h)

            if test_accuracy > best_accuracy:
                # save networks
                torch.save(
                    foreground_encoder.state_dict(),
                    str(METHOD + "/miniImagenet_foreground_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    background_encoder.state_dict(),
                    str(METHOD + "/miniImagenet_background_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    mixture_network.state_dict(),
                    str(METHOD + "/miniImagenet_mixture_network_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str(METHOD + "/miniImagenet_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                print("save networks for episode:", episode)
                best_accuracy = test_accuracy
            print("best accuracy:", best_accuracy)
from __future__ import division

import torch

import task_generator as tg
import train

metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
)

FEATURE_DIM = 64
RELATION_DIM = 8
CLASS_NUM = 5
SAMPLE_NUM_PER_CLASS = 1
BATCH_NUM_PER_CLASS = 15
EPISODE = 10
TEST_EPISODE = 600

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

rn_trainer = train.RNTrainer(metatrain_character_folders,
                             metatest_character_folders, FEATURE_DIM,
                             RELATION_DIM, CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                             BATCH_NUM_PER_CLASS, DEVICE)

rn_trainer.load_models()
last_accuracy = 0.0
for r in range(EPISODE):
    test_accuracy = rn_trainer.validate(TEST_EPISODE)

print('Completed episodes')
Exemple #6
0
def main():
    metatrain_folders, metaquery_folders = tg.mini_imagenet_folders()

    print("init neural networks")

    feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(
        FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=50000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=50000,
                                        gamma=0.5)

    if os.path.exists(
            str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    # Step 3: build graph
    print("Training...")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        with torch.no_grad():
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                counter = 0
                task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, 1, 2)
                support_dataloader = tg.get_mini_imagenet_data_loader(
                    task, num_per_class=1, split="train", shuffle=False)
                num_per_class = 2
                query_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="query",
                    shuffle=True)
                support_images, support_labels = support_dataloader.__iter__(
                ).next()
                for query_images, query_labels in query_dataloader:
                    query_size = query_labels.shape[0]
                    # calculate features
                    support_features = feature_encoder(
                        Variable(support_images).cuda(GPU)).view(
                            CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19**2).sum(1)
                    query_features = feature_encoder(
                        Variable(query_images).cuda(GPU)).view(
                            num_per_class * CLASS_NUM, 64, 19**2)

                    H_support_features = Variable(
                        torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                    H_query_features = Variable(
                        torch.Tensor(num_per_class * CLASS_NUM, 1, 64,
                                     64)).cuda(GPU)
                    # HOP features
                    for d in range(support_features.size()[0]):
                        s = support_features[d, :, :].squeeze(0)
                        s = s - LAMBDA * s.mean(1).repeat(1,
                                                          s.size()[1]).view(
                                                              s.size())
                        s = (1.0 / support_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        H_support_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)
                    for d in range(query_features.size()[0]):
                        s = query_features[d, :, :].squeeze(0)
                        s = s - LAMBDA * s.mean(1).repeat(1,
                                                          s.size()[1]).view(
                                                              s.size())
                        s = (1.0 / query_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        H_query_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)

                    # form relation pairs
                    support_features_ext = H_support_features.unsqueeze(
                        0).repeat(query_size, 1, 1, 1, 1)
                    query_features_ext = H_query_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    query_features_ext = torch.transpose(
                        query_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (support_features_ext, query_features_ext),
                        2).view(-1, 2, 64, 64)
                    # calculate relation scores
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)

                    _, predict_labels = torch.max(relations.data, 1)

                    rewards = [
                        1 if predict_labels[j] == query_labels[j].cuda(GPU)
                        else 0 for j in range(query_size)
                    ]

                    total_rewards += np.sum(rewards)
                    counter += query_size
                accuracy = total_rewards / 1.0 / counter
                accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            print("Test accuracy:", test_accuracy, "h:", h)
            print("Best accuracy: ", best_accuracy, "h:", best_h)

            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                best_h = h
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders, metaquery_folders = tg.mini_imagenet_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(
        FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=50000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=50000,
                                        gamma=0.5)

    if os.path.exists(
            str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    # Step 3: build graph
    print("Training...")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # support_dataloader is to obtain previous supports for compare
        # query_dataloader is to query supports for training
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SUPPORT_NUM_PER_CLASS, QUERY_NUM_PER_CLASS)
        support_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SUPPORT_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        query_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=QUERY_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # support datas
        supports, support_labels = support_dataloader.__iter__().next()
        queries, query_labels = query_dataloader.__iter__().next()

        # calculate features
        support_features = feature_encoder(Variable(supports).cuda(GPU)).view(
            CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19**2).sum(1)  # 5x64*19*19
        query_features = feature_encoder(Variable(queries).cuda(GPU)).view(
            QUERY_NUM_PER_CLASS * CLASS_NUM, 64, 19**2)  # 20x64*19*19
        H_support_features = Variable(
            torch.Tensor(SUPPORT_NUM_PER_CLASS * CLASS_NUM, 1, 64,
                         64)).cuda(GPU)
        H_query_features = Variable(
            torch.Tensor(QUERY_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU)
        # HOP features
        for d in range(support_features.size()[0]):
            s = support_features[d, :, :].squeeze(0)
            s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view(s.size())
            s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0, 1))
            H_support_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)
        for d in range(query_features.size()[0]):
            s = query_features[d, :, :].squeeze(0)
            s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view(s.size())
            s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0, 1))
            H_query_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)

        # form relation pairs
        support_features_ext = H_support_features.unsqueeze(0).repeat(
            QUERY_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = H_query_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = torch.transpose(query_features_ext, 0, 1)
        relation_pairs = torch.cat((support_features_ext, query_features_ext),
                                   2).view(-1, 2, 64, 64)
        # calculate relation scores
        relations = relation_network(relation_pairs).view(
            -1, CLASS_NUM * SUPPORT_NUM_PER_CLASS)

        # define loss function
        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(QUERY_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, query_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # updating network parameters with their gradients
        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data[0])

        if episode % 500 == 0:
            # query
            print("Testing...")

            accuracies = []
            for i in range(TEST_EPISODE):
                with torch.no_grad():
                    total_rewards = 0
                    counter = 0
                    task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, 1,
                                               2)
                    support_dataloader = tg.get_mini_imagenet_data_loader(
                        task, num_per_class=1, split="train", shuffle=False)
                    num_per_class = 2
                    query_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=num_per_class,
                        split="query",
                        shuffle=True)
                    support_images, support_labels = support_dataloader.__iter__(
                    ).next()
                    for query_images, query_labels in query_dataloader:
                        query_size = query_labels.shape[0]
                        # calculate features
                        support_features = feature_encoder(
                            Variable(support_images).cuda(GPU)).view(
                                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64,
                                19**2).sum(1)
                        query_features = feature_encoder(
                            Variable(query_images).cuda(GPU)).view(
                                num_per_class * CLASS_NUM, 64, 19**2)

                        H_support_features = Variable(
                            torch.Tensor(SUPPORT_NUM_PER_CLASS * CLASS_NUM, 1,
                                         64, 64)).cuda(GPU)
                        H_query_features = Variable(
                            torch.Tensor(num_per_class * CLASS_NUM, 1, 64,
                                         64)).cuda(GPU)
                        # HOP features
                        for d in range(support_features.size()[0]):
                            s = support_features[d, :, :].squeeze(0)
                            s = s - LAMBDA * s.mean(1).repeat(
                                1,
                                s.size()[1]).view(s.size())
                            s = (1.0 / support_features.size()[2]) * s.mm(
                                s.transpose(0, 1))
                            H_support_features[d, :, :, :] = power_norm(
                                s / s.trace(), SIGMA)
                        for d in range(query_features.size()[0]):
                            s = query_features[d, :, :].squeeze(0)
                            s = s - LAMBDA * s.mean(1).repeat(
                                1,
                                s.size()[1]).view(s.size())
                            s = (1.0 / query_features.size()[2]) * s.mm(
                                s.transpose(0, 1))
                            H_query_features[d, :, :, :] = power_norm(
                                s / s.trace(), SIGMA)

                        # form relation pairs
                        support_features_ext = H_support_features.unsqueeze(
                            0).repeat(query_size, 1, 1, 1, 1)
                        query_features_ext = H_query_features.unsqueeze(
                            0).repeat(1 * CLASS_NUM, 1, 1, 1, 1)
                        query_features_ext = torch.transpose(
                            query_features_ext, 0, 1)
                        relation_pairs = torch.cat(
                            (support_features_ext, query_features_ext),
                            2).view(-1, 2, 64, 64)
                        # calculate relation scores
                        relations = relation_network(relation_pairs).view(
                            -1, CLASS_NUM)

                        _, predict_labels = torch.max(relations.data, 1)

                        rewards = [
                            1 if predict_labels[j] == query_labels[j].cuda(GPU)
                            else 0 for j in range(query_size)
                        ]

                        total_rewards += np.sum(rewards)
                        counter += query_size
                    accuracy = total_rewards / 1.0 / counter
                    accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            print("Test accuracy:", test_accuracy, "h:", h)
            print("Best accuracy: ", best_accuracy, "h:", best_h)

            if test_accuracy > best_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str(METHOD + "/relation_network_" + str(CLASS_NUM) +
                        "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))
                print("save networks for episode:", episode)

                best_accuracy = test_accuracy
                best_h = h
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # * Step 1: init data folders
    print("init data folders")

    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.CNNEncoder()
    actor = models.Actor(FEATURE_DIM, RELATION_DIM, CLASS_NUM)
    critic = models.Critic(FEATURE_DIM, RELATION_DIM)

    #feature_encoder = torch.nn.DataParallel(feature_encoder)
    #actor = torch.nn.DataParallel(actor)
    #critic = torch.nn.DataParallel(critic)

    feature_encoder.train()
    actor.train()
    critic.train()

    feature_encoder.apply(models.weights_init)
    actor.apply(models.weights_init)
    critic.apply(models.weights_init)

    feature_encoder.to(device)
    actor.to(device)
    critic.to(device)

    cross_entropy = nn.CrossEntropyLoss()

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=10000,
                                       gamma=0.5)

    actor_optim = torch.optim.Adam(actor.parameters(), lr=2.5 * LEARNING_RATE)
    actor_scheduler = StepLR(actor_optim, step_size=10000, gamma=0.5)

    critic_optim = torch.optim.Adam(critic.parameters(),
                                    lr=2.5 * LEARNING_RATE * 10)
    critic_scheduler = StepLR(critic_optim, step_size=10000, gamma=0.5)

    agent = a2cAgent.A2CAgent(actor, critic, GAMMA, ENTROPY_WEIGHT, CLASS_NUM,
                              device)

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        actor.load_state_dict(
            torch.load(
                str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load actor network success")

    if os.path.exists(
            str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        critic.load_state_dict(
            torch.load(
                str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load critic network success")

    # * Step 3: build graph
    print("Training...")

    last_accuracy = 0.0
    mbal_loss_list = []
    mbcl_loss_list = []
    loss_list = []
    number_of_query_image = 15
    for episode in range(EPISODE):
        #print(f"EPISODE : {episode}")
        policy_losses = []
        value_losses = []

        for meta_batch in range(META_BATCH_RANGE):
            meta_env_states_list = []
            meta_env_labels_list = []
            for inner_batch in range(INNER_BATCH_RANGE):
                # * Generate environment
                env_states_list = []
                env_labels_list = []
                for env in range(ENV_LENGTH):
                    task = tg.MiniImagenetTask(metatrain_character_folders,
                                               CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               number_of_query_image)
                    sample_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=SAMPLE_NUM_PER_CLASS,
                        split="train",
                        shuffle=False)
                    batch_dataloader = tg.get_mini_imagenet_data_loader(
                        task, num_per_class=5, split="test", shuffle=True)

                    samples, sample_labels = next(iter(sample_dataloader))
                    samples, sample_labels = samples.to(
                        device), sample_labels.to(device)
                    for batches, batch_labels in batch_dataloader:
                        batches, batch_labels = batches.to(
                            device), batch_labels.to(device)

                        inner_sample_features = feature_encoder(samples)
                        inner_sample_features = inner_sample_features.view(
                            CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19,
                            19)
                        inner_sample_features = torch.sum(
                            inner_sample_features, 1).squeeze(1)

                        inner_batch_features = feature_encoder(batches)
                        inner_sample_feature_ext = inner_sample_features.unsqueeze(
                            0).repeat(5 * CLASS_NUM, 1, 1, 1, 1)
                        inner_batch_features_ext = inner_batch_features.unsqueeze(
                            0).repeat(CLASS_NUM, 1, 1, 1, 1)
                        inner_batch_features_ext = torch.transpose(
                            inner_batch_features_ext, 0, 1)

                        inner_relation_pairs = torch.cat(
                            (inner_sample_feature_ext,
                             inner_batch_features_ext),
                            2).view(-1, FEATURE_DIM * 2, 19, 19)
                        env_states_list.append(inner_relation_pairs)
                        env_labels_list.append(batch_labels)

                inner_env = a2cAgent.env(env_states_list, env_labels_list)
                agent.train(inner_env, inner_update=True)

            for meta_env in range(META_ENV_LENGTH):
                task = tg.MiniImagenetTask(metatrain_character_folders,
                                           CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                           number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                batch_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=number_of_query_image,
                    split="test",
                    shuffle=True)
                # * num_per_class : number of query images

                # * sample datas
                samples, sample_labels = next(iter(sample_dataloader))
                samples, sample_labels = samples.to(device), sample_labels.to(
                    device)
                # * Generate env for meta update
                batches, batch_labels = next(iter(batch_dataloader))
                # * init dataset
                # * sample_dataloader is to obtain previous samples for compare
                # * batch_dataloader is to batch samples for training
                batches, batch_labels = batches.to(device), batch_labels.to(
                    device)

                # * calculates features
                #feature_encoder.weight = feature_fast_weights

                sample_features = feature_encoder(samples)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                batch_features = feature_encoder(batches)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100 * 128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = batch_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
                relation_pairs = torch.cat(
                    (sample_features_ext, batch_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)

                meta_env_states_list.append(relation_pairs)
                meta_env_labels_list.append(batch_labels)

            meta_env = a2cAgent.env(meta_env_states_list, meta_env_labels_list)
            agent.train(meta_env,
                        policy_loss_list=policy_losses,
                        value_loss_list=value_losses)

        feature_encoder_optim.zero_grad()
        actor_optim.zero_grad()
        critic_optim.zero_grad()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(actor.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(critic.parameters(), 0.5)

        meta_batch_actor_loss = torch.stack(policy_losses).mean()
        meta_batch_critic_loss = torch.stack(value_losses).mean()

        meta_batch_actor_loss.backward(retain_graph=True)
        meta_batch_critic_loss.backward()

        feature_encoder_optim.step()
        actor_optim.step()
        critic_optim.step()

        feature_encoder_scheduler.step()
        actor_scheduler.step()
        critic_scheduler.step()

        if (episode + 1) % 100 == 0:
            mbal = meta_batch_actor_loss.cpu().detach().numpy()
            mbcl = meta_batch_critic_loss.cpu().detach().numpy()
            print(
                f"episode : {episode+1}, meta_batch_actor_loss : {mbal:.4f}, meta_batch_critic_loss : {mbcl:.4f}"
            )

            mbal_loss_list.append(mbal)
            mbcl_loss_list.append(mbcl)
            loss_list.append(mbal + mbcl)

        if (episode + 1) % 500 == 0:
            print("Testing...")
            total_reward = 0

            total_num_of_test_samples = 0
            for i in range(TEST_EPISODE):
                # * Generate env
                env_states_list = []
                env_labels_list = []

                number_of_query_image = 10
                task = tg.MiniImagenetTask(metatest_character_folders,
                                           CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                           number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=number_of_query_image,
                    split="test",
                    shuffle=True)
                sample_images, sample_labels = next(iter(sample_dataloader))
                sample_images, sample_labels = sample_images.to(
                    device), sample_labels.to(device)

                test_images, test_labels = next(iter(test_dataloader))
                total_num_of_test_samples += len(test_labels)
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)

                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                env_states_list.append(relation_pairs)
                env_labels_list.append(test_labels)

                test_env = a2cAgent.env(env_states_list, env_labels_list)
                rewards = agent.test(test_env)
                total_reward += rewards

            test_accuracy = total_reward / (1.0 * total_num_of_test_samples)

            mean_loss = np.mean(loss_list)
            mean_actor_loss = np.mean(mbal_loss_list)
            mean_critic_loss = np.mean(mbcl_loss_list)

            print(f'mean loss : {mean_loss}')
            print("test accuracy : ", test_accuracy)

            writer.add_scalar('1.loss', mean_loss, episode + 1)
            writer.add_scalar('2.mean_actor_loss', mean_actor_loss,
                              episode + 1)
            writer.add_scalar('3.mean_critic_loss', mean_critic_loss,
                              episode + 1)
            writer.add_scalar('4.test accuracy', test_accuracy, episode + 1)

            loss_list = []
            mbal_loss_list = []
            mbcl_loss_list = []

            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    actor.state_dict(),
                    str("./models/miniimagenet_actor_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))

                torch.save(
                    critic.state_dict(),
                    str("./models/miniimagenet_critic_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                print("save networks for episode:", episode)
                last_accuracy = test_accuracy
Exemple #9
0
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # * Step 1: init data folders
    print("init data folders")

    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.CNNEncoder()
    actor = models.Actor(FEATURE_DIM, RELATION_DIM, CLASS_NUM)
    critic = models.Critic(FEATURE_DIM, RELATION_DIM)

    #feature_encoder = torch.nn.DataParallel(feature_encoder)
    #actor = torch.nn.DataParallel(actor)
    #critic = torch.nn.DataParallel(critic)

    feature_encoder.train()
    actor.train()
    critic.train()

    feature_encoder.apply(models.weights_init)
    actor.apply(models.weights_init)
    critic.apply(models.weights_init)

    feature_encoder.to(device)
    actor.to(device)
    critic.to(device)

    agent = a2cAgent.A2CAgent(actor, critic, GAMMA, ENTROPY_WEIGHT,
                              FEATURE_DIM, RELATION_DIM, CLASS_NUM, device)

    #feature_encoder.eval()
    #relation_network.eval()

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        actor.load_state_dict(
            torch.load(
                str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load actor network success")

    if os.path.exists(
            str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        critic.load_state_dict(
            torch.load(
                str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load critic network success")

    max_accuracy_list = []
    mean_accuracy_list = []
    for episode in range(1):
        total_accuracy = []
        for i in range(TEST_EPISODE):
            # * Generate env
            env_states_list = []
            env_labels_list = []
            number_of_query_image = 15
            task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       number_of_query_image)
            sample_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False)
            test_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=number_of_query_image,
                split="test",
                shuffle=True)

            sample_images, sample_labels = next(iter(sample_dataloader))
            test_images, test_labels = next(iter(test_dataloader))

            sample_images, sample_labels = sample_images.to(
                device), sample_labels.to(device)
            test_images, test_labels = test_images.to(device), test_labels.to(
                device)

            # * calculate features
            sample_features = feature_encoder(sample_images)
            sample_features = sample_features.view(CLASS_NUM,
                                                   SAMPLE_NUM_PER_CLASS,
                                                   FEATURE_DIM, 19, 19)
            sample_features = torch.sum(sample_features, 1).squeeze(1)
            test_features = feature_encoder(test_images)

            # * calculate relations
            # * each batch sample link to every samples to calculate relations
            # * to form a 100x128 matrix for relation network

            sample_features_ext = sample_features.unsqueeze(0).repeat(
                number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = test_features.unsqueeze(0).repeat(
                CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = torch.transpose(test_features_ext, 0, 1)

            relation_pairs = torch.cat(
                (sample_features_ext, test_features_ext),
                2).view(-1, FEATURE_DIM * 2, 19, 19)
            env_states_list.append(relation_pairs)
            env_labels_list.append(test_labels)

            test_env = a2cAgent.env(env_states_list, env_labels_list)
            rewards = agent.test(test_env)
            test_accuracy = rewards / len(test_labels)
            print(test_accuracy)
            total_accuracy.append(test_accuracy)

        mean_accuracy, conf_int = mean_confidence_interval(total_accuracy)
        print(f"Total accuracy : {mean_accuracy:.4f}")
        print(f"confidence interval : {conf_int:.4f}")
Exemple #10
0
def main():
    # * Step 1: init data folders
    print("init data folders")

    # * init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = ot.CNNEncoder().to(device)
    RFT = RandomForestClassifier(n_estimators=100,
                                 random_state=1,
                                 warm_start=True)
    relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device)

    #feature_encoder.eval()
    #relation_network.eval()

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        RFT = pickle.load(
            open(
                str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb'))
        print("load random forest success")

    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot_exp.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    max_accuracy_list = []
    mean_accuracy_list = []
    for test in range(5):
        print("Testing...")
        max_accuracy = 0
        total_accuracy = []
        number_of_query_image = 15
        print(f"Test {test}")
        for i in range(600):
            total_reward = 0
            task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       number_of_query_image)
            sample_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False)
            test_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=number_of_query_image,
                split="test",
                shuffle=False)

            sample_images, sample_labels = next(iter(sample_dataloader))
            sample_images, sample_labels = sample_images.to(
                device), sample_labels.to(device)
            # print(f"Episode {i}")
            for test_images, test_labels in test_dataloader:
                #print(test_labels.shape)
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)
                batch_size = test_labels.shape[0]

                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    batch_size, 1, 1, 1, 1)

                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                #RFT_prob = RFT.predict_proba(relations.detach().cpu())
                #relation_prob = torch.softmax(relations.data, dim=1)

                #RFT_prob_tensor = torch.tensor(RFT_prob).to(device)
                #soft_voting = (RFT_prob_tensor * 0.7) + relation_prob
                #soft_voting = (RFT_prob_tensor / relation_prob)
                #_, soft_voting_predicted_labels = torch.max(soft_voting, 1)

                _, predict_labels = torch.max(relations.data, 1)
                #print(predict_labels.item())
                #print(test_labels.item())
                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * 5)
                ]
                #rewards = [1 if soft_voting_predicted_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * number_of_query_image)]
                total_reward += np.sum(rewards)
                #print(total_reward)

            test_accuracy = total_reward / (1.0 * CLASS_NUM * 15)
            #print(test_accuracy)
            total_accuracy.append(test_accuracy)
            if test_accuracy > max_accuracy:
                max_accuracy = test_accuracy

        test_accuracy, h = mean_confidence_interval(total_accuracy)
        print(f"Final result : {test_accuracy:.4f}, h : {h:.4f} ")

        mean_accuracy = np.mean(total_accuracy)
        mean_accuracy_list.append(mean_accuracy)
        print(f"Total accuracy : {mean_accuracy:.4f}")
        print(f"max accuracy : {max_accuracy:.4f}")
        max_accuracy_list.append(max_accuracy)
    '''
        test_accuracy, h = mean_confidence_interval(accuracies)
        print(f'test accuracy : {test_accuracy:.4f}, h : {h:.4f}')
        total_accuracy += test_accuracy
                
    print(f"average accuracy : {total_accuracy/10 :.4f}")
    '''
    final_accuracy, h = mean_confidence_interval(max_accuracy_list)
    print(f"Final result : {final_accuracy:.4f}, h : {h:.4f} ")
    print(np.sort(mean_accuracy_list))
Exemple #11
0
def main():

    save_path = make_dir(way=CLASS_NUM,
                         shot=SAMPLE_NUM_PER_CLASS,
                         batch_size=BATCH_NUM_PER_CLASS)
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders, metatest_folders, metaval_folders = tg.mini_imagenet_folders(
    )

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    RP = RP_Network(GPU=GPU)
    RM = RM_Network(64, HIDDEN_UNIT, GPU=GPU)

    RP.apply(weights_init_kaiming)
    RM.apply(weights_init_kaiming)

    feature_encoder.cuda(GPU)
    RP.cuda(GPU)
    RM.cuda(GPU)

    # Implement Optimizer and Scheduler
    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = ReduceLROnPlateau(feature_encoder_optim,
                                                  mode="max",
                                                  factor=0.5,
                                                  patience=2,
                                                  verbose=True)

    RP_optim = torch.optim.Adam(RP.parameters(), lr=LEARNING_RATE)
    RP_scheduler = ReduceLROnPlateau(RP_optim,
                                     mode="max",
                                     factor=0.5,
                                     patience=2,
                                     verbose=True)

    RM_optim = torch.optim.Adam(RM.parameters(), lr=LEARNING_RATE)
    RM_optim_scheduler = ReduceLROnPlateau(RM_optim,
                                           mode='max',
                                           factor=0.5,
                                           patience=2,
                                           verbose=True)

    # somethings to save the data
    loss_list = []
    acc_list = []
    episode_list = []

    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        #feature_encoder_scheduler.step(episode)
        # RM_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        batch_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            train_query_argue=True)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next(
        )  #25*3*84*84
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # 25*64*19*19
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 19, 19)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        batch_features = feature_encoder(
            Variable(batches).cuda(GPU))  # 20x64*5*5
        """--------------- Phrase of RPN ----------------------"""
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 19, 19)
        Att = RP(relation_pairs)
        """----------------------------------------------------------------"""
        """------------------Phrase of RMN-----------------------------------------"""
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        sample_features_ext = sample_features_ext.view(-1, FEATURE_DIM, 19,
                                                       19).contiguous()
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        batch_features_ext = batch_features_ext.contiguous().view(
            -1, FEATURE_DIM, 19, 19)

        batch_features_ext_att = batch_features_ext + batch_features_ext * Att.expand_as(
            batch_features_ext)

        relations = RM(sample_features_ext,
                       batch_features_ext_att).view(-1, CLASS_NUM)
        """----------------------------------------------------------------"""

        # BP and Optimize
        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1).cuda(GPU))
        loss = mse(relations, one_hot_labels)

        # training
        feature_encoder.zero_grad()
        RP.zero_grad()
        RM.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(RP.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(RM.parameters(), 0.5)

        feature_encoder_optim.step()
        RP_optim.step()
        RM_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.item())

        # Validation
        if episode % 5000 == 0:
            print("validation...")
            accuracies_val = []
            for i in range(VAL_EPISODE):
                total_rewards = 0
                task = tg.MiniImagenetTask(metaval_folders, CLASS_NUM,
                                           SAMPLE_NUM_PER_CLASS, 15)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                num_per_class = 5
                val_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=False)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                for val_images, val_labels in val_dataloader:
                    val_images, val_labels = val_images.cuda(
                        GPU), val_labels.cuda(GPU)
                    batch_size = val_labels.shape[0]
                    # calculate features
                    sample_features = feature_encoder(
                        Variable(sample_images).cuda(GPU))  # 5x64
                    sample_features = sample_features.view(
                        CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                    sample_features = torch.sum(sample_features, 1).squeeze(1)
                    val_features = feature_encoder(
                        Variable(val_images).cuda(GPU))  # 20x64
                    """---------------RPN----------------------"""
                    # calculate relations
                    # each batch sample link to every samples to calculate relations
                    # to form a 100x128 matrix for relation network
                    sample_features_ext = sample_features.unsqueeze(0).repeat(
                        batch_size, 1, 1, 1, 1)
                    val_features_ext = val_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    val_features_ext = torch.transpose(val_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (sample_features_ext, val_features_ext),
                        2).view(-1, FEATURE_DIM * 2, 19, 19)
                    # relations =    RP(sample_features_ext,test_features_ext).view(-1,CLASS_NUM)
                    Att = RP(relation_pairs)
                    # print(Att)
                    """-----------------------------------------------"""
                    """------------------RMN-----------------------------------------"""
                    sample_features_ext = sample_features.unsqueeze(0).repeat(
                        batch_size, 1, 1, 1, 1)
                    sample_features_ext = sample_features_ext.view(
                        -1, FEATURE_DIM, 19, 19)

                    val_features_ext = val_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    val_features_ext = torch.transpose(val_features_ext, 0, 1)
                    val_features_ext = val_features_ext.contiguous().view(
                        -1, FEATURE_DIM, 19, 19)

                    val_features_ext_att = val_features_ext + val_features_ext * Att.expand_as(
                        val_features_ext)

                    relations = RM(sample_features_ext,
                                   val_features_ext_att).view(-1, CLASS_NUM)
                    """-----------------------------------------------"""

                    _, predict_labels = torch.max(relations.data, 1)

                    rewards = [
                        1 if predict_labels[j] == val_labels[j] else 0
                        for j in range(batch_size)
                    ]

                    total_rewards += np.sum(rewards)

                acc_val = total_rewards / 1.0 / CLASS_NUM / 15
                accuracies_val.append(acc_val)

            val_accuracy, h = mean_confidence_interval(accuracies_val)
            feature_encoder_scheduler.step(val_accuracy)
            RP_scheduler.step(val_accuracy)
            RM_optim_scheduler.step(val_accuracy)
            print("Acc_Val", val_accuracy, 'Episode', episode)

        if episode % 500 == 0:

            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM,
                                           SAMPLE_NUM_PER_CLASS, 15)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                num_per_class = 5
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=False)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                for test_images, test_labels in test_dataloader:
                    test_images, test_labels = test_images.cuda(
                        GPU), test_labels.cuda(GPU)
                    batch_size = test_labels.shape[0]
                    # calculate features
                    sample_features = feature_encoder(
                        Variable(sample_images).cuda(GPU))  # 5x64
                    sample_features = sample_features.view(
                        CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                    sample_features = torch.sum(sample_features, 1).squeeze(1)
                    test_features = feature_encoder(
                        Variable(test_images).cuda(GPU))  # 20x64
                    """---------------RPN----------------------"""
                    # calculate relations
                    # each batch sample link to every samples to calculate relations
                    # to form a 100x128 matrix for relation network
                    sample_features_ext = sample_features.unsqueeze(0).repeat(
                        batch_size, 1, 1, 1, 1)
                    test_features_ext = test_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    test_features_ext = torch.transpose(
                        test_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (sample_features_ext, test_features_ext),
                        2).view(-1, FEATURE_DIM * 2, 19, 19)
                    # relations =    RP(sample_features_ext,test_features_ext).view(-1,CLASS_NUM)
                    Att = RP(relation_pairs)
                    """-----------------------------------------------"""
                    """------------------RMN-----------------------------------------"""
                    sample_features_ext = sample_features.unsqueeze(0).repeat(
                        batch_size, 1, 1, 1, 1)
                    sample_features_ext = sample_features_ext.view(
                        -1, FEATURE_DIM, 19, 19)

                    test_features_ext = test_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    test_features_ext = torch.transpose(
                        test_features_ext, 0, 1)
                    test_features_ext = test_features_ext.contiguous().view(
                        -1, FEATURE_DIM, 19, 19)

                    test_features_ext_att = test_features_ext + test_features_ext * Att.expand_as(
                        test_features_ext)

                    relations = RM(sample_features_ext,
                                   test_features_ext_att).view(-1, CLASS_NUM)
                    """-----------------------------------------------"""

                    _, predict_labels = torch.max(relations.data, 1)

                    rewards = [
                        1 if predict_labels[j] == test_labels[j] else 0
                        for j in range(batch_size)
                    ]

                    total_rewards += np.sum(rewards)

                accuracy = total_rewards / 1.0 / CLASS_NUM / 15
                accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            print("test accuracy:", test_accuracy, "h:", h)

            acc_list.append(test_accuracy)
            episode_list.append(episode)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(feature_encoder.state_dict(),
                           save_path + "encoder.pkl")
                torch.save(RP.state_dict(), save_path + "RP.pkl")
                torch.save(RM.state_dict(), save_path + "RM.pkl")
                print("save networks for episode:", episode)

                last_accuracy = test_accuracy

        #save data
        np.savetxt(save_path + "acc.txt", acc_list)
        np.savetxt(save_path + "episode.txt", episode_list)
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders, metatest_folders = tg.mini_imagenet_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.cuda()
    relation_network.cuda()

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    path = os.path.join(
        'logs',
        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot")
    if not os.path.exists(path):
        os.makedirs(path)
    writer = SummaryWriter(path)

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        batch_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next(
        )  #25*3*84*84
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda())  # 25*64*19*19
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 19, 19)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        batch_features = feature_encoder(Variable(batches).cuda())  # 20x64*5*5

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 19, 19)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        mse = nn.MSELoss().cuda()
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1).cuda())
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        writer.add_scalar('training loss', loss.data.item(), episode + 1)
        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data.item())

        if episode % 5000 == 0:

            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM,
                                           SAMPLE_NUM_PER_CLASS, 15)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                num_per_class = 5
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=False)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                for test_images, test_labels in test_dataloader:
                    batch_size = test_labels.shape[0]
                    # calculate features
                    sample_features = feature_encoder(
                        Variable(sample_images).cuda())  # 5x64
                    sample_features = sample_features.view(
                        CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                    sample_features = torch.sum(sample_features, 1).squeeze(1)
                    test_features = feature_encoder(
                        Variable(test_images).cuda())  # 20x64

                    # calculate relations
                    # each batch sample link to every samples to calculate relations
                    # to form a 100x128 matrix for relation network
                    sample_features_ext = sample_features.unsqueeze(0).repeat(
                        batch_size, 1, 1, 1, 1)

                    test_features_ext = test_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    test_features_ext = torch.transpose(
                        test_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (sample_features_ext, test_features_ext),
                        2).view(-1, FEATURE_DIM * 2, 19, 19)
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)

                    _, predict_labels = torch.max(relations.data, 1)
                    predict_labels = predict_labels.cpu()

                    rewards = [
                        1 if predict_labels[j] == test_labels[j] else 0
                        for j in range(batch_size)
                    ]

                    total_rewards += np.sum(rewards)

                accuracy = total_rewards / 1.0 / CLASS_NUM / 15
                accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            writer.add_scalar('test accuracy', test_accuracy, episode + 1)
            print("test accuracy:", test_accuracy, "h:", h)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/miniimagenet_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
    writer.close()
Exemple #13
0
def main():
    # * Step 1: init data folders
    print("init data folders")

    # * init character folders for dataset construction
    metatrain_folders, metatest_folders = tg.mini_imagenet_folders()

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = ot.CNNEncoder().to(device)
    # RFT = RandomForestClassifier(n_estimators=100, random_state=1, warm_start=True)
    relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device)

    feature_encoder.eval()
    relation_network.eval()

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    '''     
    if os.path.exists(str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        RFT = pickle.load(open(str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb'))
        print("load random forest success")
    '''
    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0

    for episode in range(10):
        # * test
        print("Testing...")

        accuracies = []

        for i in range(100):
            total_rewards = 0
            # degrees = random.choice([0, 90, 180, 270])
            task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       BATCH_NUM_PER_CLASS)
            sample_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False)
            test_dataloader = tg.get_mini_imagenet_data_loader(task,
                                                               num_per_class=5,
                                                               split="test",
                                                               shuffle=False)

            sample_images, sample_labels = next(iter(sample_dataloader))
            sample_images, sample_labels = sample_images.to(
                device), sample_labels.to(device)

            for test_images, test_labels in test_dataloader:
                batch_size = test_labels.shape[0]
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)
                # * Calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    batch_size, 1, 1, 1, 1)

                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(batch_size)
                ]
                total_rewards += np.sum(rewards)

            accuracy = total_rewards / (1.0 * CLASS_NUM * 15)
            accuracies.append(accuracy)

        test_accuracy, h = mean_confidence_interval(accuracies)
        print(f'test accuracy : {test_accuracy:.4f}, h : {h:.4f}')
        total_accuracy += test_accuracy

    print(f"average accuracy : {total_accuracy/10 :.4f}")
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders, metaval_folders, metatest_folders = tg.mini_imagenet_folders(
    )

    # Step 2: init neural networks and results dir
    print("init neural networks")

    path = makedir(name)  # the path to save the results

    feature_encoder = FeatureEncoder()
    middle_layer = Middle_Moudle_v3(width=WIDTH)
    explain_network = ExplainModule_v6_meta_pair(
        n_way=CLASS_NUM,
        k_shot=SAMPLE_NUM_PER_CLASS,
        width=WIDTH,
        batch_size=BATCH_NUM_PER_CLASS)

    feature_encoder.apply(weights_init_FE)

    feature_encoder.cuda(GPU)
    explain_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = ReduceLROnPlateau(feature_encoder_optim,
                                                  mode="max",
                                                  factor=0.5,
                                                  patience=2,
                                                  verbose=True,
                                                  threshold=1e-2)
    explain_network_optim = torch.optim.Adam(explain_network.parameters(),
                                             lr=LEARNING_RATE)
    explain_network_scheduler = ReduceLROnPlateau(explain_network_optim,
                                                  mode="max",
                                                  factor=0.5,
                                                  patience=2,
                                                  verbose=True,
                                                  threshold=1e-2)

    total_accuracy = 0.0
    last_accuracy = 0.0
    # Create some lists to save data
    loss_list = []
    acc_list = []
    episode_list = []
    h_list = []
    episode_list = []
    for episode in range(EPISODE):
        explain_network.batch_size = BATCH_NUM_PER_CLASS

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        batch_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            train_query_argue=True)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next(
        )  #(n*s)*3*84*84
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # (n*s)*64*19*19
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)  # [n*b,n*s,64,19,19]

        batch_features = feature_encoder(
            Variable(batches).cuda(GPU))  # (n*b)*64*19*19
        batch_features_ext = batch_features.unsqueeze(1).repeat(
            1, CLASS_NUM * SAMPLE_NUM_PER_CLASS, 1, 1, 1)  # [n*b,n*s,64,19,19]
        distance_feature = middle_layer.forward(
            support_x=sample_features_ext,
            query_x=batch_features_ext)  # [n*b,n*s,19*19]
        pre = explain_network(distance_feature,
                              torch.cat(
                                  [sample_features_ext, batch_features_ext],
                                  dim=2))  # [n*b,n]
        # print(pre.size())

        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1).cuda(GPU))
        loss = mse(pre, one_hot_labels)
        feature_encoder_optim.zero_grad()
        explain_network_optim.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(explain_network.parameters(), 0.5)
        feature_encoder_optim.step()
        explain_network_optim.step()

        if (episode + 1) % 50 == 0:
            print("episode:", episode + 1, "loss", loss.item())

        if (episode + 1) % 5000 == 0:
            print("val")
            val_accuracies = []
            with torch.no_grad():
                for i in range(TEST_EPISODE):
                    total_rewards = 0
                    task = tg.MiniImagenetTask(metaval_folders, CLASS_NUM,
                                               SAMPLE_NUM_PER_CLASS, 15)
                    sample_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=SAMPLE_NUM_PER_CLASS,
                        split="train",
                        shuffle=False)
                    num_per_class = 5
                    explain_network.batch_size = num_per_class
                    val_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=num_per_class,
                        split="test",
                        shuffle=False)
                    sample_images, sample_labels = sample_dataloader.__iter__(
                    ).next()
                    for val_images, val_labels in val_dataloader:
                        val_images, val_labels = val_images.cuda(
                            GPU), val_labels.cuda(GPU)
                        batch_size = val_labels.shape[0]
                        sample_features = feature_encoder(
                            Variable(sample_images).cuda(
                                GPU))  # (n*s)*64*19*19
                        sample_features_ext = sample_features.unsqueeze(
                            0).repeat(num_per_class * CLASS_NUM, 1, 1, 1,
                                      1)  # [n*b,n*s,64,19,19]
                        val_features = feature_encoder(
                            Variable(val_images).cuda(GPU))  # (n*b)*64*19*19
                        val_features_ext = val_features.unsqueeze(1).repeat(
                            1, CLASS_NUM * SAMPLE_NUM_PER_CLASS, 1, 1,
                            1)  # [n*b,n*s,64,19,19]
                        distance_feature = middle_layer.forward(
                            support_x=sample_features_ext,
                            query_x=val_features_ext)  # [n*b,n*s,19*19]
                        # distance_feature = distance_feature.view([CLASS_NUM*num_per_class,CLASS_NUM*SAMPLE_NUM_PER_CLASS*9*9])
                        pre = explain_network(
                            distance_feature,
                            torch.cat([sample_features_ext, val_features_ext],
                                      dim=2))  # [n*b,n]
                        # print(pre[0])
                        # print(test_labels[0])

                        _, predict_labels = torch.max(pre.data, 1)
                        rewards = [
                            1 if predict_labels[j] == val_labels[j] else 0
                            for j in range(batch_size)
                        ]

                        total_rewards += np.sum(rewards)

                    val_accuracy = total_rewards / 1.0 / CLASS_NUM / 15
                    val_accuracies.append(val_accuracy)

                accuracy_val, h = mean_confidence_interval(val_accuracies)
                #迭代
                feature_encoder_scheduler.step(accuracy_val)
                explain_network_scheduler.step(accuracy_val)

        if episode % 500 == 0:
            # test
            print("Testing...")
            accuracies = []
            with torch.no_grad():
                for i in range(TEST_EPISODE):
                    total_rewards = 0
                    task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM,
                                               SAMPLE_NUM_PER_CLASS, 15)
                    sample_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=SAMPLE_NUM_PER_CLASS,
                        split="train",
                        shuffle=False)
                    num_per_class = 5
                    explain_network.batch_size = num_per_class
                    test_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=num_per_class,
                        split="test",
                        shuffle=False)

                    sample_images, sample_labels = sample_dataloader.__iter__(
                    ).next()
                    for test_images, test_labels in test_dataloader:
                        test_images, test_labels = test_images.cuda(
                            GPU), test_labels.cuda(GPU)
                        batch_size = test_labels.shape[0]
                        sample_features = feature_encoder(
                            Variable(sample_images).cuda(
                                GPU))  # (n*s)*64*19*19
                        sample_features_ext = sample_features.unsqueeze(
                            0).repeat(num_per_class * CLASS_NUM, 1, 1, 1,
                                      1)  # [n*b,n*s,64,19,19]
                        test_features = feature_encoder(
                            Variable(test_images).cuda(GPU))  # (n*b)*64*19*19
                        test_features_ext = test_features.unsqueeze(1).repeat(
                            1, CLASS_NUM * SAMPLE_NUM_PER_CLASS, 1, 1,
                            1)  # [n*b,n*s,64,19,19]
                        distance_feature = middle_layer.forward(
                            support_x=sample_features_ext,
                            query_x=test_features_ext)  # [n*b,n*s,19*19]
                        # distance_feature = distance_feature.view([CLASS_NUM*num_per_class,CLASS_NUM*SAMPLE_NUM_PER_CLASS*9*9])
                        pre = explain_network(
                            distance_feature,
                            torch.cat([sample_features_ext, test_features_ext],
                                      dim=2))  # [n*b,n]

                        _, predict_labels = torch.max(pre.data, 1)
                        rewards = [
                            1 if predict_labels[j] == test_labels[j] else 0
                            for j in range(batch_size)
                        ]

                        total_rewards += np.sum(rewards)

                    accuracy = total_rewards / 1.0 / CLASS_NUM / 15
                    accuracies.append(accuracy)

                test_accuracy, h = mean_confidence_interval(accuracies)
                loss_list, acc_list, h_list = save_data(
                    loss_list, acc_list, h_list, loss, test_accuracy, h)
                episode_list.append(episode)
                print("test accuracy:", test_accuracy, "h:", h)
                if test_accuracy > last_accuracy:
                    torch.save(feature_encoder.state_dict(),
                               path + "feature_encoder.pkl")
                    torch.save(explain_network.state_dict(),
                               path + "explain_network.pkl")

            np.savetxt(path + "acc.txt", acc_list)
            np.savetxt(path + "eposide.txt", episode_list)
            np.savetxt(path + "h.txt", h_list)
            np.savetxt(path + "loss.txt", loss_list)
Exemple #15
0
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # * Step 1: init data folders
    print("init data folders")

    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)

    feature_encoder.train()
    relation_network.train()

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.to(device)
    relation_network.to(device)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)

    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_args.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_args.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_args.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot_args.pkl")))
        print("load relation network success")

    # * Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        # * init dataset
        # * sample_dataloader is to obtain previous samples for compare
        # * batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        batch_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # * sample datas
        # samples, sample_labels = sample_dataloader.__iter__().next()
        # batches, batch_labels = batch_dataloader.__iter__().next()

        samples, sample_labels = next(iter(sample_dataloader))
        batches, batch_labels = next(iter(batch_dataloader))

        samples, sample_labels = samples.to(device), sample_labels.to(device)
        batches, batch_labels = batches.to(device), batch_labels.to(device)

        # * calculates features
        sample_features = feature_encoder(samples)
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 19, 19)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        batch_features = feature_encoder(batches)

        # * calculate relations
        # * each batch sample link to every samples to calculate relations
        # * to form a 100 * 128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 19, 19)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        mse = nn.MSELoss()
        one_hot_labels = torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                                     CLASS_NUM).to(device).scatter_(
                                         1, batch_labels.view(-1, 1), 1)
        loss = mse(relations, one_hot_labels)

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        if (episode + 1) % 100 == 0:
            print(
                f"episode : {episode+1}, loss : {loss.cpu().detach().numpy()}")

        if (episode + 1) % 1000 == 0:
            print("Testing...")
            total_reward = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.MiniImagenetTask(metatest_character_folders,
                                           CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                           SAMPLE_NUM_PER_CLASS)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True)

                sample_images, sample_labels = next(iter(sample_dataloader))
                test_images, test_labels = next(iter(test_dataloader))

                sample_images, sample_labels = sample_images.to(
                    device), sample_labels.to(device)
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)

                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
                ]
                total_reward += np.sum(rewards)

            test_accuracy = total_reward / (
                1.0 * CLASS_NUM * SAMPLE_NUM_PER_CLASS * TEST_EPISODE)

            print("test accuracy : ", test_accuracy)

            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot_args.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/miniimagenet_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot_args.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders, metatest_folders = tg.mini_imagenet_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)
    feature_encoder = nn.DataParallel(feature_encoder)
    relation_network = nn.DataParallel(relation_network)
    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder2" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder2" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/miniimagenet_relation_network2" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network2" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0
    ########################################################################################
    year = datetime.datetime.now().year
    month = datetime.datetime.now().month
    day = datetime.datetime.now().day
    filename = "miniimagenet_train_oneshot_" + str(year) + '_' + str(
        month) + '_' + str(day) + ".txt"
    with open("models/" + filename, "w") as f:

        for episode in range(EPISODE):

            feature_encoder_scheduler.step(episode)
            relation_network_scheduler.step(episode)

            # init dataset
            # sample_dataloader is to obtain previous samples for compare
            # batch_dataloader is to batch samples for training
            task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       BATCH_NUM_PER_CLASS)

            sample_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False)
            batch_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=BATCH_NUM_PER_CLASS,
                split="test",
                shuffle=True)

            # sample datas
            samples, sample_labels = sample_dataloader.__iter__().next()
            batches, batch_labels = batch_dataloader.__iter__().next()
            # print(samples.shape)
            #[75,3,84,84]
            # print(batches.shape)
            # calculate features
            sample_features = feature_encoder(Variable(samples).cuda(GPU))
            # print(sample_features.shape)
            # torch.Size([5, 64, 19, 19])
            batch_features = feature_encoder(Variable(batches).cuda(GPU))
            # print(batch_features.shape)
            # torch.Size([75, 64, 19, 19])
            # calculate relations
            # each batch sample link to every samples to calculate relations
            # to form a 100x128 matrix for relation network
            sample_features_ext = sample_features.unsqueeze(0).repeat(
                BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
            #torch.Size([75, 5, 64, 19, 19])
            batch_features_ext = batch_features.unsqueeze(0).repeat(
                SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
            #5 75 64 19 19
            batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
            #75 5 64 19 19
            relation_pairs = torch.cat(
                (sample_features_ext, batch_features_ext),
                2).view(-1, FEATURE_DIM * 2, 19, 19)
            # print(relation_pairs.shape)
            #375,128,19,19
            relations = relation_network(relation_pairs).view(
                -1, CLASS_NUM * SAMPLE_NUM_PER_CLASS)
            print(relations.shape)
            mse = nn.MSELoss().cuda(GPU)
            one_hot_labels = Variable(
                torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                            CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                                1)).cuda(GPU)
            loss = mse(relations, one_hot_labels)

            # training

            feature_encoder.zero_grad()
            relation_network.zero_grad()

            loss.backward()

            torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
            torch.nn.utils.clip_grad_norm(relation_network.parameters(), 0.5)

            feature_encoder_optim.step()
            relation_network_optim.step()

            if (episode + 1) % 100 == 0:
                print("episode:", episode + 1, "loss", loss.item())
                newcontext = "episode:    " + str(episode +
                                                  1) + "  loss    " + str(
                                                      loss.item()) + '\n'
                f.writelines(newcontext)

            if episode % 5000 == 0:

                # test
                print("Testing...")
                accuracies = []
                for i in range(TEST_EPISODE):
                    total_rewards = 0
                    counter = 0
                    task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM, 1,
                                               15)
                    sample_dataloader = tg.get_mini_imagenet_data_loader(
                        task, num_per_class=1, split="train", shuffle=False)

                    num_per_class = 3
                    test_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=num_per_class,
                        split="test",
                        shuffle=True)
                    sample_images, sample_labels = sample_dataloader.__iter__(
                    ).next()
                    for test_images, test_labels in test_dataloader:
                        batch_size = test_labels.shape[0]
                        # calculate features
                        sample_features = feature_encoder(
                            Variable(sample_images).cuda(GPU))  # 5x64
                        test_features = feature_encoder(
                            Variable(test_images).cuda(GPU))  # 20x64

                        # calculate relations
                        # each batch sample link to every samples to calculate relations
                        # to form a 100x128 matrix for relation network
                        sample_features_ext = sample_features.unsqueeze(
                            0).repeat(batch_size, 1, 1, 1, 1)
                        test_features_ext = test_features.unsqueeze(0).repeat(
                            1 * CLASS_NUM, 1, 1, 1, 1)
                        test_features_ext = torch.transpose(
                            test_features_ext, 0, 1)
                        relation_pairs = torch.cat(
                            (sample_features_ext, test_features_ext),
                            2).view(-1, FEATURE_DIM * 2, 19, 19)
                        relations = relation_network(relation_pairs).view(
                            -1, CLASS_NUM)

                        _, predict_labels = torch.max(relations.data, 1)

                        rewards = [
                            1 if predict_labels[j].cuda()
                            == test_labels[j].cuda() else 0
                            for j in range(batch_size)
                        ]

                        total_rewards += np.sum(rewards)
                        counter += batch_size
                    accuracy = total_rewards / 1.0 / counter
                    accuracies.append(accuracy)

                test_accuracy, h = mean_confidence_interval(accuracies)

                print("test accuracy:", test_accuracy, "h:", h)
                newcontext = "episode:    " + str(
                    episode +
                    1) + "test accuracy:    " + str(test_accuracy) + '\n'
                f.writelines(newcontext)

                if test_accuracy > last_accuracy:

                    # save networks
                    # torch.save(feature_encoder.state_dict(),str("./models/miniimagenet_feature_encoder2" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))
                    # torch.save(relation_network.state_dict(),str("./models/miniimagenet_relation_network2"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))

                    print("save networks for episode:", episode)

                    last_accuracy = test_accuracy
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders, metatest_folders = tg.mini_imagenet_folders()

    # Step 2: init neural networks
    print("init neural networks")

    ### instantiation
    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)
    ### use member function to load a module, which means to take some functions.
    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):
        ### Sets the learning rate of each parameter group to the initial lr decayed by gamma every step_size epochs.
        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        batch_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # 5x64*5*5
        batch_features = feature_encoder(
            Variable(batches).cuda(GPU))  # 20x64*5*5

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 19, 19)
        relations = relation_network(relation_pairs).view(
            -1, CLASS_NUM * SAMPLE_NUM_PER_CLASS)

        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)
        ### 变量的含义,根据论文设计的算法确定;in the hyper parameters section.
        ### 代码的由粗到细,模块化,数据流,shape;常用操作的组合,刚开始很需要耐心,弄清每个细节,这是代码部分的基础,如果想要建造一颗参天大树,就得从这里开始。
        ### 关键部分的改进和实现。

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data[0])

        if episode % 5000 == 0:

            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                counter = 0
                task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM, 1, 15)
                sample_dataloader = tg.get_mini_imagenet_data_loader(
                    task, num_per_class=1, split="train", shuffle=False)

                num_per_class = 3
                test_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=True)
                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                for test_images, test_labels in test_dataloader:
                    batch_size = test_labels.shape[0]
                    # calculate features
                    sample_features = feature_encoder(
                        Variable(sample_images).cuda(GPU))  # 5x64
                    test_features = feature_encoder(
                        Variable(test_images).cuda(GPU))  # 20x64

                    # calculate relations
                    # each batch sample link to every samples to calculate relations
                    # to form a 100x128 matrix for relation network
                    sample_features_ext = sample_features.unsqueeze(0).repeat(
                        batch_size, 1, 1, 1, 1)
                    test_features_ext = test_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    test_features_ext = torch.transpose(
                        test_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (sample_features_ext, test_features_ext),
                        2).view(-1, FEATURE_DIM * 2, 19, 19)
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)

                    _, predict_labels = torch.max(relations.data, 1)

                    rewards = [
                        1 if predict_labels[j] == test_labels[j] else 0
                        for j in range(batch_size)
                    ]

                    total_rewards += np.sum(rewards)
                    counter += batch_size
                accuracy = total_rewards / 1.0 / counter
                accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            print("test accuracy:", test_accuracy, "h:", h)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/miniimagenet_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
Exemple #18
0
def main():    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
    
    # * Step 1: init data folders
    print("init data folders")
    
    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders()
    
    # * Step 2: init neural networks
    print("init neural networks")
    
    feature_encoder = models.CNNEncoder()    
    model = models.ActorCritic(FEATURE_DIM, RELATION_DIM, CLASS_NUM)

    #feature_encoder = torch.nn.DataParallel(feature_encoder)
    #actor = torch.nn.DataParallel(actor)
    #critic = torch.nn.DataParallel(critic)
    
    feature_encoder.train()
    model.train()
    
    feature_encoder.apply(models.weights_init)
    model.apply(models.weights_init)
    
    feature_encoder.to(device)
    model.to(device)

    cross_entropy = nn.CrossEntropyLoss()
        
    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=10000, gamma=0.5)
    
    model_optim = torch.optim.Adam(model.parameters(), lr=2.5 * LEARNING_RATE)
    model_scheduler = StepLR(model_optim, step_size=10000, gamma=0.5)
    
    agent = ppoAgent.PPOAgent(GAMMA, ENTROPY_WEIGHT, CLASS_NUM, device)
    
    if os.path.exists(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")            
        
    if os.path.exists(str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        model.load_state_dict(torch.load(str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load model network success")
        
    # * Step 3: build graph
    print("Training...")
    loss_list = []
    last_accuracy = 0.0    
    number_of_query_image = 15
    clip_param = 0.1
    for episode in range(EPISODE):
        if clip_param > 0 and clip_param % CLIP_DECREASE == 0:
            clip_param *= 0.5
            
        #print(f"EPISODE : {episode}")
        losses = []        
        for meta_batch in range(META_BATCH_RANGE):
            meta_env_states_list = []
            meta_env_labels_list = []
            model_fast_weight = OrderedDict(model.named_parameters())
            for inner_batch in range(INNER_BATCH_RANGE):
                # * Generate environment
                env_states_list = []
                env_labels_list = []
                inner_loss_list = []
                for _ in range(ENV_LENGTH):
                    task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image)
                    sample_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False)                
                    batch_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=5, split="test", shuffle=True)    
                    
                    samples, sample_labels = next(iter(sample_dataloader))
                    samples, sample_labels = samples.to(device), sample_labels.to(device)
                    for batches, batch_labels in batch_dataloader:
                        batches, batch_labels = batches.to(device), batch_labels.to(device)
                        
                        inner_sample_features = feature_encoder(samples)            
                        inner_sample_features = inner_sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                        inner_sample_features = torch.sum(inner_sample_features, 1).squeeze(1)
                        
                        inner_batch_features = feature_encoder(batches)
                        inner_sample_feature_ext = inner_sample_features.unsqueeze(0).repeat(5 * CLASS_NUM, 1, 1, 1, 1)
                        inner_batch_features_ext = inner_batch_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)      
                        inner_batch_features_ext = torch.transpose(inner_batch_features_ext, 0, 1)
                        
                        inner_relation_pairs = torch.cat((inner_sample_feature_ext, inner_batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19)
                        env_states_list.append(inner_relation_pairs)
                        env_labels_list.append(batch_labels)
                
                inner_env = ppoAgent.env(env_states_list, env_labels_list)
                agent.train(inner_env, model, loss_list=inner_loss_list)
                inner_loss = torch.stack(inner_loss_list).mean()
                inner_gradients = torch.autograd.grad(inner_loss.mean(), model_fast_weight.values(), create_graph=True, allow_unused=True)
    
                model_fast_weight = OrderedDict(
                    (name, param - INNER_LR * (0 if grad is None else grad))                    
                    for ((name, param), grad) in zip(model_fast_weight.items(), inner_gradients)                    
                )
            
            model.weight = model_fast_weight
            # * Generate env for meta update
            for _ in range(META_ENV_LENGTH):
                # * init dataset
                # * sample_dataloader is to obtain previous samples for compare
                # * batch_dataloader is to batch samples for training
                task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False)               
                batch_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=number_of_query_image, split="test", shuffle=True)
                # * num_per_class : number of query images
                
                # * sample datas
                samples, sample_labels = next(iter(sample_dataloader))
                batches, batch_labels = next(iter(batch_dataloader))
                
                samples, sample_labels = samples.to(device), sample_labels.to(device)
                batches, batch_labels = batches.to(device), batch_labels.to(device)
                                
                # * calculates features
                #feature_encoder.weight = feature_fast_weights
                
                sample_features = feature_encoder(samples)
                sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                batch_features = feature_encoder(batches)
                
                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100 * 128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = batch_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
                relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19)   
                
                meta_env_states_list.append(relation_pairs)
                meta_env_labels_list.append(batch_labels)
            
            meta_env = ppoAgent.env(meta_env_states_list, meta_env_labels_list)
            agent.train(meta_env, model, loss_list=losses, clip_param=clip_param)
            
        feature_encoder_optim.zero_grad()
        model_optim.zero_grad()     
        
        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)

        meta_batch_loss = torch.stack(losses).mean()
        meta_batch_loss.backward()
                
        feature_encoder_optim.step()
        model_optim.step()

        feature_encoder_scheduler.step()
        model_scheduler.step()
        
        mean_loss = None
        if (episode + 1) % 100 == 0:
            mean_loss = meta_batch_loss.cpu().detach().numpy()
            print(f"episode : {episode+1}, meta_loss : {mean_loss:.4f}")
            loss_list.append(mean_loss)
            
        if (episode + 1) % 500 == 0:
            print("Testing...")
            total_reward = 0
            
            total_test_samples = 0            
            for i in range(TEST_EPISODE):
                # * Generate env
                env_states_list = []
                env_labels_list = []
                number_of_query_image = 10
                task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False)                
                test_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=number_of_query_image, split="test", shuffle=True)
                
                sample_images, sample_labels = next(iter(sample_dataloader))
                test_images, test_labels = next(iter(test_dataloader))

                total_test_samples += len(test_labels)

                sample_images, sample_labels = sample_images.to(device), sample_labels.to(device)
                test_images, test_labels = test_images.to(device), test_labels.to(device)
                    
                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)
                
                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network
                
                sample_features_ext = sample_features.unsqueeze(0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19)
                env_states_list.append(relation_pairs)
                env_labels_list.append(test_labels)
                    
                test_env = ppoAgent.env(env_states_list, env_labels_list)
                rewards = agent.test(test_env, model)
                total_reward += rewards 
                
            test_accuracy = total_reward / (1.0 * total_test_samples)

            print(f'mean loss : {mean_loss}')   
            print("test accuracy : ", test_accuracy)
            
            writer.add_scalar('1.loss', mean_loss, episode + 1)      
            writer.add_scalar('4.test accuracy', test_accuracy, episode + 1)
            
            loss_list = []   
            
            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")
                )
                torch.save(
                    model.state_dict(),
                    str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")
                )
                
                print("save networks for episode:", episode)
                last_accuracy = test_accuracy