Exemplo n.º 1
0
def main():
    # * Step 1: init data folders
    print("init data folders")

    # * init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = ot.CNNEncoder().to(device)
    RFT = RandomForestClassifier(n_estimators=100,
                                 random_state=1,
                                 warm_start=True)
    relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device)

    #feature_encoder.eval()
    #relation_network.eval()

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        RFT = pickle.load(
            open(
                str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb'))
        print("load random forest success")

    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot_exp.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    max_accuracy_list = []
    mean_accuracy_list = []
    for test in range(5):
        print("Testing...")
        max_accuracy = 0
        total_accuracy = []
        number_of_query_image = 15
        print(f"Test {test}")
        for i in range(600):
            total_reward = 0
            task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       number_of_query_image)
            sample_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False)
            test_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=number_of_query_image,
                split="test",
                shuffle=False)

            sample_images, sample_labels = next(iter(sample_dataloader))
            sample_images, sample_labels = sample_images.to(
                device), sample_labels.to(device)
            # print(f"Episode {i}")
            for test_images, test_labels in test_dataloader:
                #print(test_labels.shape)
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)
                batch_size = test_labels.shape[0]

                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    batch_size, 1, 1, 1, 1)

                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                #RFT_prob = RFT.predict_proba(relations.detach().cpu())
                #relation_prob = torch.softmax(relations.data, dim=1)

                #RFT_prob_tensor = torch.tensor(RFT_prob).to(device)
                #soft_voting = (RFT_prob_tensor * 0.7) + relation_prob
                #soft_voting = (RFT_prob_tensor / relation_prob)
                #_, soft_voting_predicted_labels = torch.max(soft_voting, 1)

                _, predict_labels = torch.max(relations.data, 1)
                #print(predict_labels.item())
                #print(test_labels.item())
                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * 5)
                ]
                #rewards = [1 if soft_voting_predicted_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * number_of_query_image)]
                total_reward += np.sum(rewards)
                #print(total_reward)

            test_accuracy = total_reward / (1.0 * CLASS_NUM * 15)
            #print(test_accuracy)
            total_accuracy.append(test_accuracy)
            if test_accuracy > max_accuracy:
                max_accuracy = test_accuracy

        test_accuracy, h = mean_confidence_interval(total_accuracy)
        print(f"Final result : {test_accuracy:.4f}, h : {h:.4f} ")

        mean_accuracy = np.mean(total_accuracy)
        mean_accuracy_list.append(mean_accuracy)
        print(f"Total accuracy : {mean_accuracy:.4f}")
        print(f"max accuracy : {max_accuracy:.4f}")
        max_accuracy_list.append(max_accuracy)
    '''
        test_accuracy, h = mean_confidence_interval(accuracies)
        print(f'test accuracy : {test_accuracy:.4f}, h : {h:.4f}')
        total_accuracy += test_accuracy
                
    print(f"average accuracy : {total_accuracy/10 :.4f}")
    '''
    final_accuracy, h = mean_confidence_interval(max_accuracy_list)
    print(f"Final result : {final_accuracy:.4f}, h : {h:.4f} ")
    print(np.sort(mean_accuracy_list))
Exemplo n.º 2
0
def main():
    # * Step 1: init data folders
    print("init data folders")

    # * init character folders for dataset construction
    metatrain_folders, metatest_folders = tg.mini_imagenet_folders()

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = ot.CNNEncoder().to(device)
    # RFT = RandomForestClassifier(n_estimators=100, random_state=1, warm_start=True)
    relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device)

    feature_encoder.eval()
    relation_network.eval()

    if os.path.exists(
            str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    '''     
    if os.path.exists(str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        RFT = pickle.load(open(str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb'))
        print("load random forest success")
    '''
    if os.path.exists(
            str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/miniimagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0

    for episode in range(10):
        # * test
        print("Testing...")

        accuracies = []

        for i in range(100):
            total_rewards = 0
            # degrees = random.choice([0, 90, 180, 270])
            task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       BATCH_NUM_PER_CLASS)
            sample_dataloader = tg.get_mini_imagenet_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False)
            test_dataloader = tg.get_mini_imagenet_data_loader(task,
                                                               num_per_class=5,
                                                               split="test",
                                                               shuffle=False)

            sample_images, sample_labels = next(iter(sample_dataloader))
            sample_images, sample_labels = sample_images.to(
                device), sample_labels.to(device)

            for test_images, test_labels in test_dataloader:
                batch_size = test_labels.shape[0]
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)
                # * Calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    batch_size, 1, 1, 1, 1)

                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 19, 19)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(batch_size)
                ]
                total_rewards += np.sum(rewards)

            accuracy = total_rewards / (1.0 * CLASS_NUM * 15)
            accuracies.append(accuracy)

        test_accuracy, h = mean_confidence_interval(accuracies)
        print(f'test accuracy : {test_accuracy:.4f}, h : {h:.4f}')
        total_accuracy += test_accuracy

    print(f"average accuracy : {total_accuracy/10 :.4f}")