def main():
    print("init data folders")
    metatrain_folders, metaquery_folders = tg.mini_imagenet_folders()

    print("init neural networks")
    foreground_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    background_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    mixture_network = models.MixtureNetwork().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(
        FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU)

    vanilla_foreground_encoder = models.FeatureEncoder().apply(
        weights_init).cuda(GPU)
    vanilla_background_encoder = models.FeatureEncoder().apply(
        weights_init).cuda(GPU)
    vanilla_mixture_network = models.MixtureNetwork().apply(weights_init).cuda(
        GPU)

    foreground_encoder_optim = torch.optim.Adam(
        foreground_encoder.parameters(), lr=LEARNING_RATE)
    foreground_encoder_scheduler = StepLR(foreground_encoder_optim,
                                          step_size=100000,
                                          gamma=0.5)
    background_encoder_optim = torch.optim.Adam(
        background_encoder.parameters(), lr=LEARNING_RATE)
    background_encoder_scheduler = StepLR(background_encoder_optim,
                                          step_size=100000,
                                          gamma=0.5)
    mixture_network_optim = torch.optim.Adam(mixture_network.parameters(),
                                             lr=LEARNING_RATE)
    mixture_network_scheduler = StepLR(mixture_network_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    # Loading models
    if os.path.exists(
            str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        foreground_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_foreground_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load foreground encoder success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        background_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_background_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load background encoder success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        mixture_network.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_mixture_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load mixture network success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

# Loading vanilla models
    if os.path.exists(
            str("./vanilla_models/miniImagenet_foreground_encoder_" +
                str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                "shot.pkl")):
        vanilla_foreground_encoder.load_state_dict(
            torch.load(
                str("./vanilla_models/miniImagenet_foreground_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load vanilla foreground encoder success")
    if os.path.exists(
            str("./vanilla_models/miniImagenet_background_encoder_" +
                str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                "shot.pkl")):
        vanilla_background_encoder.load_state_dict(
            torch.load(
                str("./vanilla_models/miniImagenet_background_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load vanilla background encoder success")
    if os.path.exists(
            str("./vanilla_models/miniImagenet_mixture_network_" +
                str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                "shot.pkl")):
        vanilla_mixture_network.load_state_dict(
            torch.load(
                str("./vanilla_models/miniImagenet_mixture_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load vanilla mixture network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    print("Training...")

    best_accuracy = 0.0
    start = time.time()
    for episode in range(EPISODE):
        mse = nn.MSELoss().cuda(GPU)

        foreground_encoder_scheduler.step(episode)
        background_encoder_scheduler.step(episode)
        mixture_network_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SUPPORT_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        support_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SUPPORT_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        query_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # support datas
        support_img, support_sal, support_labels = support_dataloader.__iter__(
        ).next()
        query_img, query_sal, query_labels = query_dataloader.__iter__().next()
        # calculate features
        support_foreground_features = foreground_encoder(
            Variable(support_img * support_sal).cuda(GPU)).view(
                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
        support_background_features = background_encoder(
            Variable(support_img * (1 - support_sal)).cuda(GPU)).view(
                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
        query_foreground_features = foreground_encoder(
            Variable(query_img * query_sal).cuda(GPU))
        query_background_features = background_encoder(
            Variable(query_img * (1 - query_sal)).cuda(GPU))

        # Real-Representation Regularization (TriR), teacher network
        support_foreground_features_ = vanilla_foreground_encoder(
            Variable(support_img * support_sal).cuda(GPU))
        support_background_features_ = vanilla_background_encoder(
            Variable(support_img * (1 - support_sal)).cuda(GPU))
        support_mix_features_ = vanilla_mixture_network(
            support_foreground_features_ + support_background_features_)
        support_mix_features__ = mixture_network(
            (support_foreground_features + support_background_features).view(
                -1, 64, 19, 19))
        TriR = args.beta * mse(
            support_mix_features__,
            Variable(support_mix_features_, requires_grad=False))

        # Inter-class Hallucination
        support_foreground_features = support_foreground_features.unsqueeze(
            2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1, 1, 1)
        support_background_features = support_background_features.view(
            1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19,
            19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1, 1)
        similarity_measure = similarity_func(
            support_background_features, CLASS_NUM,
            SUPPORT_NUM_PER_CLASS).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1,
                                        1, 1)

        support_mix_features = mixture_network(
            (support_foreground_features + support_background_features).view(
                (CLASS_NUM**2) * (SUPPORT_NUM_PER_CLASS**2), 64, 19,
                19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 64, 19**2)
        support_mix_features = (support_mix_features *
                                similarity_measure).sum(2).sum(1)
        query_mix_features = mixture_network(query_foreground_features +
                                             query_background_features).view(
                                                 -1, 64, 19**2)
        so_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64,
                                                    64)).cuda(GPU)
        so_query_features = Variable(
            torch.Tensor(BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU)

        # second-order features
        for d in range(support_mix_features.size()[0]):
            s = support_mix_features[d, :, :].squeeze(0)
            s = (1.0 / support_mix_features.size()[2]) * s.mm(s.transpose(
                0, 1))
            so_support_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)
        for d in range(query_mix_features.size()[0]):
            s = query_mix_features[d, :, :].squeeze(0)
            s = (1.0 / query_mix_features.size()[2]) * s.mm(s.transpose(0, 1))
            so_query_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)

        # calculate relations with 64x64 second-order features
        support_features_ext = so_support_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = so_query_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = torch.transpose(query_features_ext, 0, 1)
        relation_pairs = torch.cat((support_features_ext, query_features_ext),
                                   2).view(-1, 2, 64, 64)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, query_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels) + TriR

        # update network parameters
        foreground_encoder.zero_grad()
        background_encoder.zero_grad()
        mixture_network.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(foreground_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(background_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(mixture_network.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        foreground_encoder_optim.step()
        background_encoder_optim.step()
        mixture_network_optim.step()
        relation_network_optim.step()

        if np.mod(episode + 1, 100) == 0:
            print("episode:", episode + 1, "loss", loss.item())

        if np.mod(episode, 2500) == 0:
            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                counter = 0
                task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM,
                                           SUPPORT_NUM_PER_CLASS, 15)
                support_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SUPPORT_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                num_per_class = 2
                query_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=True)
                support_img, support_sal, support_labels = support_dataloader.__iter__(
                ).next()
                for query_img, query_sal, query_labels in query_dataloader:
                    query_size = query_labels.shape[0]
                    # calculate foreground and background features
                    support_foreground_features = foreground_encoder(
                        Variable(support_img * support_sal).cuda(GPU)).view(
                            CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
                    support_background_features = background_encoder(
                        Variable(
                            support_img * (1 - support_sal)).cuda(GPU)).view(
                                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
                    query_foreground_features = foreground_encoder(
                        Variable(query_img * query_sal).cuda(GPU))
                    query_background_features = background_encoder(
                        Variable(query_img * (1 - query_sal)).cuda(GPU))

                    # Inter-class Hallucination
                    support_foreground_features = support_foreground_features.unsqueeze(
                        2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1,
                                  1, 1)
                    support_background_features = support_background_features.view(
                        1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19,
                        19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1,
                                   1)
                    similarity_measure = similarity_func(
                        support_background_features, CLASS_NUM,
                        SUPPORT_NUM_PER_CLASS).view(CLASS_NUM,
                                                    SUPPORT_NUM_PER_CLASS, -1,
                                                    1, 1)
                    support_mix_features = mixture_network(
                        (support_foreground_features +
                         support_background_features).view(
                             (CLASS_NUM * SUPPORT_NUM_PER_CLASS)**2, 64, 19,
                             19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1,
                                       64, 19**2)
                    support_mix_features = (support_mix_features *
                                            similarity_measure).sum(2).sum(1)
                    query_mix_features = mixture_network(
                        query_foreground_features +
                        query_background_features).view(-1, 64, 19**2)
                    so_support_features = Variable(
                        torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                    so_query_features = Variable(
                        torch.Tensor(query_size, 1, 64, 64)).cuda(GPU)

                    # second-order features
                    for d in range(support_mix_features.size()[0]):
                        s = support_mix_features[d, :, :].squeeze(0)
                        s = (1.0 / support_mix_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        so_support_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)
                    for d in range(query_mix_features.size()[0]):
                        s = query_mix_features[d, :, :].squeeze(0)
                        s = (1.0 / query_mix_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        so_query_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)

                    # calculate relations with 64x64 second-order features
                    support_features_ext = so_support_features.unsqueeze(
                        0).repeat(query_size, 1, 1, 1, 1)
                    query_features_ext = so_query_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    query_features_ext = torch.transpose(
                        query_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (support_features_ext, query_features_ext),
                        2).view(-1, 2, 64, 64)
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)
                    _, predict_labels = torch.max(relations.data, 1)
                    rewards = [
                        1 if predict_labels[j] == query_labels[j].cuda(GPU)
                        else 0 for j in range(query_size)
                    ]
                    total_rewards += np.sum(rewards)
                    counter += query_size

                accuracy = total_rewards / 1.0 / counter
                accuracies.append(accuracy)
            test_accuracy, h = mean_confidence_interval(accuracies)
            print("test accuracy:", test_accuracy, "h:", h)

            if test_accuracy > best_accuracy:
                # save networks
                torch.save(
                    foreground_encoder.state_dict(),
                    str(METHOD + "/miniImagenet_foreground_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    background_encoder.state_dict(),
                    str(METHOD + "/miniImagenet_background_encoder_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    mixture_network.state_dict(),
                    str(METHOD + "/miniImagenet_mixture_network_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str(METHOD + "/miniImagenet_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                        "shot.pkl"))
                print("save networks for episode:", episode)
                best_accuracy = test_accuracy
            print("best accuracy:", best_accuracy)
コード例 #2
0
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders, metaquery_folders = tg.mini_imagenet_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(
        FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=50000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=50000,
                                        gamma=0.5)

    if os.path.exists(
            str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    # Step 3: build graph
    print("Training...")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # support_dataloader is to obtain previous supports for compare
        # query_dataloader is to query supports for training
        task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM,
                                   SUPPORT_NUM_PER_CLASS, QUERY_NUM_PER_CLASS)
        support_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=SUPPORT_NUM_PER_CLASS,
            split="train",
            shuffle=False)
        query_dataloader = tg.get_mini_imagenet_data_loader(
            task,
            num_per_class=QUERY_NUM_PER_CLASS,
            split="test",
            shuffle=True)

        # support datas
        supports, support_labels = support_dataloader.__iter__().next()
        queries, query_labels = query_dataloader.__iter__().next()

        # calculate features
        support_features = feature_encoder(Variable(supports).cuda(GPU)).view(
            CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19**2).sum(1)  # 5x64*19*19
        query_features = feature_encoder(Variable(queries).cuda(GPU)).view(
            QUERY_NUM_PER_CLASS * CLASS_NUM, 64, 19**2)  # 20x64*19*19
        H_support_features = Variable(
            torch.Tensor(SUPPORT_NUM_PER_CLASS * CLASS_NUM, 1, 64,
                         64)).cuda(GPU)
        H_query_features = Variable(
            torch.Tensor(QUERY_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU)
        # HOP features
        for d in range(support_features.size()[0]):
            s = support_features[d, :, :].squeeze(0)
            s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view(s.size())
            s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0, 1))
            H_support_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)
        for d in range(query_features.size()[0]):
            s = query_features[d, :, :].squeeze(0)
            s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view(s.size())
            s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0, 1))
            H_query_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA)

        # form relation pairs
        support_features_ext = H_support_features.unsqueeze(0).repeat(
            QUERY_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = H_query_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        query_features_ext = torch.transpose(query_features_ext, 0, 1)
        relation_pairs = torch.cat((support_features_ext, query_features_ext),
                                   2).view(-1, 2, 64, 64)
        # calculate relation scores
        relations = relation_network(relation_pairs).view(
            -1, CLASS_NUM * SUPPORT_NUM_PER_CLASS)

        # define loss function
        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(QUERY_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, query_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # updating network parameters with their gradients
        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data[0])

        if episode % 500 == 0:
            # query
            print("Testing...")

            accuracies = []
            for i in range(TEST_EPISODE):
                with torch.no_grad():
                    total_rewards = 0
                    counter = 0
                    task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, 1,
                                               2)
                    support_dataloader = tg.get_mini_imagenet_data_loader(
                        task, num_per_class=1, split="train", shuffle=False)
                    num_per_class = 2
                    query_dataloader = tg.get_mini_imagenet_data_loader(
                        task,
                        num_per_class=num_per_class,
                        split="query",
                        shuffle=True)
                    support_images, support_labels = support_dataloader.__iter__(
                    ).next()
                    for query_images, query_labels in query_dataloader:
                        query_size = query_labels.shape[0]
                        # calculate features
                        support_features = feature_encoder(
                            Variable(support_images).cuda(GPU)).view(
                                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64,
                                19**2).sum(1)
                        query_features = feature_encoder(
                            Variable(query_images).cuda(GPU)).view(
                                num_per_class * CLASS_NUM, 64, 19**2)

                        H_support_features = Variable(
                            torch.Tensor(SUPPORT_NUM_PER_CLASS * CLASS_NUM, 1,
                                         64, 64)).cuda(GPU)
                        H_query_features = Variable(
                            torch.Tensor(num_per_class * CLASS_NUM, 1, 64,
                                         64)).cuda(GPU)
                        # HOP features
                        for d in range(support_features.size()[0]):
                            s = support_features[d, :, :].squeeze(0)
                            s = s - LAMBDA * s.mean(1).repeat(
                                1,
                                s.size()[1]).view(s.size())
                            s = (1.0 / support_features.size()[2]) * s.mm(
                                s.transpose(0, 1))
                            H_support_features[d, :, :, :] = power_norm(
                                s / s.trace(), SIGMA)
                        for d in range(query_features.size()[0]):
                            s = query_features[d, :, :].squeeze(0)
                            s = s - LAMBDA * s.mean(1).repeat(
                                1,
                                s.size()[1]).view(s.size())
                            s = (1.0 / query_features.size()[2]) * s.mm(
                                s.transpose(0, 1))
                            H_query_features[d, :, :, :] = power_norm(
                                s / s.trace(), SIGMA)

                        # form relation pairs
                        support_features_ext = H_support_features.unsqueeze(
                            0).repeat(query_size, 1, 1, 1, 1)
                        query_features_ext = H_query_features.unsqueeze(
                            0).repeat(1 * CLASS_NUM, 1, 1, 1, 1)
                        query_features_ext = torch.transpose(
                            query_features_ext, 0, 1)
                        relation_pairs = torch.cat(
                            (support_features_ext, query_features_ext),
                            2).view(-1, 2, 64, 64)
                        # calculate relation scores
                        relations = relation_network(relation_pairs).view(
                            -1, CLASS_NUM)

                        _, predict_labels = torch.max(relations.data, 1)

                        rewards = [
                            1 if predict_labels[j] == query_labels[j].cuda(GPU)
                            else 0 for j in range(query_size)
                        ]

                        total_rewards += np.sum(rewards)
                        counter += query_size
                    accuracy = total_rewards / 1.0 / counter
                    accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            print("Test accuracy:", test_accuracy, "h:", h)
            print("Best accuracy: ", best_accuracy, "h:", best_h)

            if test_accuracy > best_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str(METHOD + "/feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str(METHOD + "/relation_network_" + str(CLASS_NUM) +
                        "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))
                print("save networks for episode:", episode)

                best_accuracy = test_accuracy
                best_h = h
コード例 #3
0
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1)

    if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
        print("load similarity network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    # Step 3: build graph
    print("Training...")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        with torch.no_grad():   
            # query
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0,90,180,270])
                task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,)
                support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
                query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="query",shuffle=True,rotation=degrees)

                support_images,support_labels = support_dataloader.__iter__().next()
                query_images,query_labels = query_dataloader.__iter__().next()
                
                # calculate features
                support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64
                support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1)
                query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25)

                H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
                # HOP features
                for d in range(support_features.size(0)):
                    s = support_features[d,:,:].squeeze(0)
                    s = (1.0 / support_features.size(2)) * s.mm(s.t())
                    H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
                for d in range(query_features.size(0)):
                    s = query_features[d,:,:].squeeze(0)
                    s = (1.0 / query_features.size(2)) * s.mm(s.t())
                    H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
                    
                # calculate relations
                # each query support link to every supports to calculate relations
                # to form a 100x128 matrix for relation network
                support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
                query_features_ext = torch.transpose(query_features_ext,0,1)

                relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)

                rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE

            print("query accuracy:",test_accuracy)
            print("best accuracy:",best_accuracy)

            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
コード例 #4
0
def main():
    print("init data folders")
    metatrain_folders, metaquery_folders = tg.mini_imagenet_folders()

    print("init neural networks")
    foreground_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    background_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    mixture_network = models.MixtureNetwork().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(
        FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU)

    # Loading models
    if os.path.exists(
            str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        foreground_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_foreground_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load foreground encoder success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        background_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_background_encoder_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load background encoder success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        mixture_network.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_mixture_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load mixture network success")
    if os.path.exists(
            str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str(METHOD + "/miniImagenet_relation_network_" +
                    str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) +
                    "shot.pkl")))
        print("load relation network success")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        with torch.no_grad():
            # test
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                counter = 0
                task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM,
                                           SUPPORT_NUM_PER_CLASS, 15)
                support_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=SUPPORT_NUM_PER_CLASS,
                    split="train",
                    shuffle=False)
                num_per_class = 2
                query_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="test",
                    shuffle=True)
                support_img, support_sal, support_labels = support_dataloader.__iter__(
                ).next()
                for query_img, query_sal, query_labels in query_dataloader:
                    query_size = query_labels.shape[0]
                    # calculate foreground and background features
                    support_foreground_features = foreground_encoder(
                        Variable(support_img * support_sal).cuda(GPU)).view(
                            CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
                    support_background_features = background_encoder(
                        Variable(
                            support_img * (1 - support_sal)).cuda(GPU)).view(
                                CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19)
                    query_foreground_features = foreground_encoder(
                        Variable(query_img * query_sal).cuda(GPU))
                    query_background_features = background_encoder(
                        Variable(query_img * (1 - query_sal)).cuda(GPU))

                    # Inter-class Hallucination
                    support_foreground_features = support_foreground_features.unsqueeze(
                        2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1,
                                  1, 1)
                    support_background_features = support_background_features.view(
                        1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19,
                        19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1,
                                   1)
                    similarity_measure = similarity_func(
                        support_background_features, CLASS_NUM,
                        SUPPORT_NUM_PER_CLASS).view(CLASS_NUM,
                                                    SUPPORT_NUM_PER_CLASS, -1,
                                                    1, 1)
                    support_mix_features = mixture_network(
                        (support_foreground_features +
                         support_background_features).view(
                             (CLASS_NUM * SUPPORT_NUM_PER_CLASS)**2, 64, 19,
                             19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1,
                                       64, 19**2)
                    support_mix_features = (support_mix_features *
                                            similarity_measure).sum(2).sum(1)
                    query_mix_features = mixture_network(
                        query_foreground_features +
                        query_background_features).view(-1, 64, 19**2)
                    so_support_features = Variable(
                        torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                    so_query_features = Variable(
                        torch.Tensor(query_size, 1, 64, 64)).cuda(GPU)

                    # second-order features
                    for d in range(support_mix_features.size()[0]):
                        s = support_mix_features[d, :, :].squeeze(0)
                        s = (1.0 / support_mix_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        so_support_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)
                    for d in range(query_mix_features.size()[0]):
                        s = query_mix_features[d, :, :].squeeze(0)
                        s = (1.0 / query_mix_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        so_query_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)

                    # calculate relations with 64x64 second-order features
                    support_features_ext = so_support_features.unsqueeze(
                        0).repeat(query_size, 1, 1, 1, 1)
                    query_features_ext = so_query_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    query_features_ext = torch.transpose(
                        query_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (support_features_ext, query_features_ext),
                        2).view(-1, 2, 64, 64)
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)
                    _, predict_labels = torch.max(relations.data, 1)
                    rewards = [
                        1 if predict_labels[j] == query_labels[j].cuda(GPU)
                        else 0 for j in range(query_size)
                    ]
                    total_rewards += np.sum(rewards)
                    counter += query_size

                accuracy = total_rewards / 1.0 / counter
                accuracies.append(accuracy)
            test_accuracy, h = mean_confidence_interval(accuracies)
            print("test accuracy:", test_accuracy, "h:", h)
            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                best_h = h
            print("best accuracy:", best_accuracy, "h:", best_h)
コード例 #5
0
def main():
    metatrain_folders, metaquery_folders = tg.mini_imagenet_folders()

    print("init neural networks")

    feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(
        FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=50000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=50000,
                                        gamma=0.5)

    if os.path.exists(
            str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" +
                    str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    # Step 3: build graph
    print("Training...")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        with torch.no_grad():
            print("Testing...")
            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                counter = 0
                task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, 1, 2)
                support_dataloader = tg.get_mini_imagenet_data_loader(
                    task, num_per_class=1, split="train", shuffle=False)
                num_per_class = 2
                query_dataloader = tg.get_mini_imagenet_data_loader(
                    task,
                    num_per_class=num_per_class,
                    split="query",
                    shuffle=True)
                support_images, support_labels = support_dataloader.__iter__(
                ).next()
                for query_images, query_labels in query_dataloader:
                    query_size = query_labels.shape[0]
                    # calculate features
                    support_features = feature_encoder(
                        Variable(support_images).cuda(GPU)).view(
                            CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19**2).sum(1)
                    query_features = feature_encoder(
                        Variable(query_images).cuda(GPU)).view(
                            num_per_class * CLASS_NUM, 64, 19**2)

                    H_support_features = Variable(
                        torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                    H_query_features = Variable(
                        torch.Tensor(num_per_class * CLASS_NUM, 1, 64,
                                     64)).cuda(GPU)
                    # HOP features
                    for d in range(support_features.size()[0]):
                        s = support_features[d, :, :].squeeze(0)
                        s = s - LAMBDA * s.mean(1).repeat(1,
                                                          s.size()[1]).view(
                                                              s.size())
                        s = (1.0 / support_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        H_support_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)
                    for d in range(query_features.size()[0]):
                        s = query_features[d, :, :].squeeze(0)
                        s = s - LAMBDA * s.mean(1).repeat(1,
                                                          s.size()[1]).view(
                                                              s.size())
                        s = (1.0 / query_features.size()[2]) * s.mm(
                            s.transpose(0, 1))
                        H_query_features[d, :, :, :] = power_norm(
                            s / s.trace(), SIGMA)

                    # form relation pairs
                    support_features_ext = H_support_features.unsqueeze(
                        0).repeat(query_size, 1, 1, 1, 1)
                    query_features_ext = H_query_features.unsqueeze(0).repeat(
                        1 * CLASS_NUM, 1, 1, 1, 1)
                    query_features_ext = torch.transpose(
                        query_features_ext, 0, 1)
                    relation_pairs = torch.cat(
                        (support_features_ext, query_features_ext),
                        2).view(-1, 2, 64, 64)
                    # calculate relation scores
                    relations = relation_network(relation_pairs).view(
                        -1, CLASS_NUM)

                    _, predict_labels = torch.max(relations.data, 1)

                    rewards = [
                        1 if predict_labels[j] == query_labels[j].cuda(GPU)
                        else 0 for j in range(query_size)
                    ]

                    total_rewards += np.sum(rewards)
                    counter += query_size
                accuracy = total_rewards / 1.0 / counter
                accuracies.append(accuracy)

            test_accuracy, h = mean_confidence_interval(accuracies)

            print("Test accuracy:", test_accuracy, "h:", h)
            print("Best accuracy: ", best_accuracy, "h:", best_h)

            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                best_h = h