def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    #     if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
    #         feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
    #         print("load feature encoder success")
    #     if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
    #         relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
    #         print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM,
                               SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)
        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # 5x64*5*5
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 5, 5)
        sample_features = torch.sum(sample_features, 1) / 5.0
        sample_features = sample_features.squeeze(1)
        batch_features = feature_encoder(
            Variable(batches).cuda(GPU))  # 20x64*5*5

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 5, 5)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.item())

        if (episode + 1) % 5000 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    CLASS_NUM,
                    SAMPLE_NUM_PER_CLASS,
                    SAMPLE_NUM_PER_CLASS,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                test_images, test_labels = test_dataloader.__iter__().next()

                # calculate features
                sample_features = feature_encoder(
                    Variable(sample_images).cuda(GPU))  # 5x64
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 5, 5)
                sample_features = torch.sum(sample_features, 1) / 5.0
                sample_features = sample_features.squeeze(1)
                test_features = feature_encoder(
                    Variable(test_images).cuda(GPU))  # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 5, 5)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)
                predict_labels = predict_labels.cuda(GPU)
                test_labels = test_labels.cuda(GPU)
                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
                ]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE

            print("test accuracy:", test_accuracy)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./omniglot_feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./omniglot_relation_network_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders,metatest_character_folders = tg.omniglot_character_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM)


    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5)

    if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    for episode in range(EPISODE):


            # test
            print("Testing...")
            total_rewards = 0
            accuracies = []
            for i in range(TEST_EPISODE):
                degrees = random.choice([0,90,180,270])
                task = tg.OmniglotTask(metatest_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,SAMPLE_NUM_PER_CLASS,)
                sample_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
                test_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees)

                sample_images,sample_labels = sample_dataloader.__iter__().next()
                test_images,test_labels = test_dataloader.__iter__().next()

                # calculate features
                sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                test_features_ext = test_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                test_features_ext = torch.transpose(test_features_ext,0,1)

                relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,5,5)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)
                
                # start change
                
                use_cuda = torch.cuda.is_available()
                device = torch.device('cuda:0' if use_cuda else 'cpu')
                test_labels = test_labels.to(device)
                
                rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM)]

                total_rewards += np.sum(rewards)
                accuracy = np.sum(rewards)/1.0/CLASS_NUM/SAMPLE_NUM_PER_CLASS
                accuracies.append(accuracy)

            test_accuracy,h = mean_confidence_interval(accuracies)

            print("test accuracy:",test_accuracy,"h:",h)
            total_accuracy += test_accuracy

    print("aver_accuracy:",total_accuracy/EPISODE)
Exemple #3
0
def test(feature_encoder, relation_network):
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )
    # test
    print("Testing...")
    total_rewards = 0

    for i in range(TEST_EPISODE):
        # degrees = random.choice([0, 90, 180, 270])
        degrees = random.choice([0])
        task = tg.OmniglotTask(
            metatest_character_folders,
            CLASS_NUM,
            SAMPLE_NUM_PER_CLASS,
            SAMPLE_NUM_PER_CLASS,
        )
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)
        test_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        sample_images, sample_labels = sample_dataloader.__iter__().next()
        test_images, test_labels = test_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(sample_images).to(device))  # 5x64
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 5, 5)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        test_features = feature_encoder(
            Variable(test_images).to(device))  # 20x64

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        test_features_ext = test_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        test_features_ext = torch.transpose(test_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, test_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 5, 5)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        _, predict_labels = torch.max(relations.data, 1)

        rewards = [
            1 if predict_labels[j] == test_labels[j] else 0
            for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
        ]

        total_rewards += np.sum(rewards)

    test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE

    print("test accuracy:", test_accuracy)
Exemple #4
0
def main():
    writer = SummaryWriter('/home/caffe/achu/logs/pytorch_omniglot_FSL.log')
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
        data_folder=args.dataset_folder,
        no_of_training_samples=args.training_samples_per_class,
        no_of_validation_samples=args.support_set_samples_per_class)

    # Step 2: init neural networks
    print("init neural networks")

    cnn_output_dims = cnn_final_output_dims(args.image_size)
    rn_dims = rn_dims_before_FCN(cnn_output_dims)
    fcn_size = args.channel_dim * (rn_dims**2)

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(fcn_size, args.hidden_unit)

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    if torch.cuda.device_count() >= 1:
        print("Let's use", torch.cuda.device_count(), "args.gpus!")
        feature_encoder = nn.DataParallel(feature_encoder)
        relation_network = nn.DataParallel(relation_network)
    else:
        feature_encoder.cuda(args.gpu)
        relation_network.cuda(args.gpu)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=args.learning_rate)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=args.learning_rate)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(args.class_num) +
                "way_" + str(args.training_samples_per_class) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(str("./models/omniglot_feature_encoder_" +
                           str(args.class_num) + "way_" +
                           str(args.training_samples_per_class) + "shot.pkl"),
                       map_location='cuda:0'))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(args.class_num) +
                "way_" + str(args.training_samples_per_class) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(str("./models/omniglot_relation_network_" +
                           str(args.class_num) + "way_" +
                           str(args.training_samples_per_class) + "shot.pkl"),
                       map_location='cuda:0'))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(args.episode):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tg.OmniglotTask(metatrain_character_folders, args.class_num,
                               args.training_samples_per_class,
                               args.support_set_samples_per_class)
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=args.training_samples_per_class,
            split="train",
            shuffle=False,
            rotation=degrees)
        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=args.support_set_samples_per_class,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(Variable(samples).cuda(
            args.gpu))  # 5x64*5*5
        sample_features = sample_features.view(args.class_num,
                                               args.training_samples_per_class,
                                               args.feature_dim,
                                               cnn_output_dims,
                                               cnn_output_dims)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        batch_features = feature_encoder(Variable(batches).cuda(
            args.gpu))  # 20x64*5*5

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            args.support_set_samples_per_class * args.class_num, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            args.class_num, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, args.feature_dim * 2,
                                           cnn_output_dims, cnn_output_dims)
        relations = relation_network(relation_pairs).view(-1, args.class_num)

        mse = nn.MSELoss().cuda(args.gpu)
        one_hot_labels = Variable(
            torch.zeros(args.support_set_samples_per_class * args.class_num,
                        args.class_num).scatter_(1, batch_labels.view(-1, 1),
                                                 1)).cuda(args.gpu)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        writer.add_scalar('Training loss', loss.data, episode + 1)

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data)

        if (episode + 1) % 5000 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(args.test_episode):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    args.class_num,
                    args.training_samples_per_class,
                    args.training_samples_per_class,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=args.training_samples_per_class,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=args.support_set_samples_per_class,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                test_images, test_labels = test_dataloader.__iter__().next()

                test_labels = test_labels.cuda()

                # calculate features
                sample_features = feature_encoder(
                    Variable(sample_images).cuda(args.gpu))  # 5x64
                sample_features = sample_features.view(
                    args.class_num, args.training_samples_per_class,
                    args.feature_dim, cnn_output_dims, cnn_output_dims)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(
                    Variable(test_images).cuda(args.gpu))  # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    args.training_samples_per_class * args.class_num, 1, 1, 1,
                    1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    args.class_num, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, args.feature_dim * 2, cnn_output_dims,
                            cnn_output_dims)
                relations = relation_network(relation_pairs).view(
                    -1, args.class_num)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(args.class_num *
                                   args.training_samples_per_class)
                ]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / args.class_num / args.training_samples_per_class / args.test_episode

            print("validation accuracy:", test_accuracy)
            writer.add_scalar('Validation accuracy', test_accuracy,
                              episode + 1)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" +
                        str(args.class_num) + "way_" +
                        str(args.training_samples_per_class) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" +
                        str(args.class_num) + "way_" +
                        str(args.training_samples_per_class) + "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
def main():
    # * Step 1: init data folders
    print("init data folders")

    # * init character folders for dataset construction
    metartrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = ot.CNNEncoder().to(device)
    relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device)

    feature_encoder.eval()
    relation_network.eval()

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    for episode in range(EPISODE):
        # * test
        print("Testing...")
        total_rewards = 0
        accuracies = []

        for i in range(TEST_EPISODE):
            degrees = random.choice([0, 90, 180, 270])
            task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS)

            sample_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False,
                rotation=degrees)
            test_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="test",
                shuffle=True,
                rotation=degrees)

            sample_images, sample_labels = next(iter(sample_dataloader))
            test_images, test_labels = next(iter(test_dataloader))

            sample_images, sample_labels = sample_images.to(
                device), sample_labels.to(device)
            test_images, test_labels = test_images.to(device), test_labels.to(
                device)

            # * Calculate features
            sample_features = feature_encoder(sample_images)
            sample_features = sample_features.view(CLASS_NUM,
                                                   SAMPLE_NUM_PER_CLASS,
                                                   FEATURE_DIM, 5, 5)
            sample_features = torch.sum(sample_features, 1).squeeze(1)
            test_features = feature_encoder(test_images)

            sample_features_ext = sample_features.unsqueeze(0).repeat(
                SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = test_features.unsqueeze(0).repeat(
                CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = torch.transpose(test_features_ext, 0, 1)

            relation_pairs = torch.cat(
                (sample_features_ext, test_features_ext),
                2).view(-1, FEATURE_DIM * 2, 5, 5)
            relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

            _, predict_labels = torch.max(relations.data, 1)

            rewards = [
                1 if predict_labels[j] == test_labels[j] else 0
                for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
            ]

            total_rewards += np.sum(rewards)
            accuracy = np.sum(rewards) / (1.0 * CLASS_NUM *
                                          SAMPLE_NUM_PER_CLASS)
            accuracies.append(accuracy)

        test_accuracy, h = mean_confidence_interval(accuracies)

        print(f'test accuracy : {test_accuracy}, h : {h}')
        total_accuracy += test_accuracy

    print(f"average accuracy : {total_accuracy / EPISODE}")
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1)

    if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
        print("load similarity network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    # Step 3: build graph
    print("Training...")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        with torch.no_grad():   
            # query
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0,90,180,270])
                task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,)
                support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
                query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="query",shuffle=True,rotation=degrees)

                support_images,support_labels = support_dataloader.__iter__().next()
                query_images,query_labels = query_dataloader.__iter__().next()
                
                # calculate features
                support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64
                support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1)
                query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25)

                H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
                # HOP features
                for d in range(support_features.size(0)):
                    s = support_features[d,:,:].squeeze(0)
                    s = (1.0 / support_features.size(2)) * s.mm(s.t())
                    H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
                for d in range(query_features.size(0)):
                    s = query_features[d,:,:].squeeze(0)
                    s = (1.0 / query_features.size(2)) * s.mm(s.t())
                    H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
                    
                # calculate relations
                # each query support link to every supports to calculate relations
                # to form a 100x128 matrix for relation network
                support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
                query_features_ext = torch.transpose(query_features_ext,0,1)

                relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)

                rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE

            print("query accuracy:",test_accuracy)
            print("best accuracy:",best_accuracy)

            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # Step 2: init neural networks
    print("init neural networks")
    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(RELATION1_DIM, RELATION2_DIM,
                                       RELATION3_DIM)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    for episode in range(EPISODE):
        # test
        print("Testing...")
        total_rewards = 0
        accuracies = []
        for i in range(TEST_EPISODE):
            degrees = random.choice([0, 90, 180, 270])
            task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS)
            sample_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False,
                rotation=degrees)
            test_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="test",
                shuffle=True,
                rotation=degrees)

            sample_images, sample_labels = sample_dataloader.__iter__().next()
            test_images, test_labels = test_dataloader.__iter__().next()
            # 注意在这里的test_images取了5张

            # calculate features
            sample_features = feature_encoder(Variable(sample_images))
            # print('sample_features :', sample_features.size())
            test_features = feature_encoder(Variable(test_images))
            # print('test_features :', test_features.size())
            # calculate relations
            sample_features_ext = sample_features.unsqueeze(0).repeat(
                SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
            # print('sample_features_ext :', sample_features_ext.size())
            test_features_ext = test_features.unsqueeze(0).repeat(
                SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
            # print('test_features_ext:', test_features_ext.size())
            test_features_ext = torch.transpose(test_features_ext, 0, 1)

            relation_pairs = torch.abs(
                (sample_features_ext - test_features_ext)).view(-1, 1600)
            # print('relation_pairs :', relation_pairs.size())
            relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

            _, predict_labels = torch.max(relations.data, 1)
            # print(predict_labels)
            test_labels = test_labels.long()
            rewards = [
                1 if predict_labels[j] == test_labels[j] else 0
                for j in range(CLASS_NUM)
            ]

            total_rewards += np.sum(rewards)
            accuracy = np.sum(rewards) / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS
            accuracies.append(accuracy)

        test_accuracy, h = mean_confidence_interval(accuracies)

        print("test accuracy:", test_accuracy, "h:", h)
        total_accuracy += test_accuracy

    print("aver_accuracy:", total_accuracy / EPISODE)
def main():
    # Step 1: init data folders
    print("init data folders")
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # Step 2: init neural networks
    print("init neural networks")
    if USE_INCEPTION_EMBEDDING:
        feature_encoder = Inception(1, 10)
    else:
        feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(RELATION1_DIM, RELATION2_DIM,
                                       RELATION3_DIM)
    # 运用apply()函数进行权重初始化
    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    # feature_encoder.cuda(GPU)
    # relation_network.cuda(GPU)
    """要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。
     然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。然后一般还会定义学习率的变化策略,
     这里采用的是torch.optim.lr_scheduler模块的StepLR类,表示每隔step_size个epoch就将学习率降为原来的gamma倍。"""
    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)

    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")
    last_accuracy = 0.0
    for episode in range(EPISODE):
        # 训练开始的时候需要先更新下学习率,这是因为我们前面制定了学习率的变化策略,所以在每个epoch开始时都要更新下
        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        # 制作支持集和目标集
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM,
                               SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)

        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        batches, batch_labels = batch_dataloader.__iter__().next()
        # samples.size: torch.Size([5, 1, 28, 28]);sample_labels.size: torch.Size([5])
        # batches.size: torch.Size([95, 1, 28, 28]);batches_labels.size: torch.Size([95])

        # print(batch_labels.view(-1, 1))

        # 提取特征
        # sample_features = feature_encoder(Variable(samples).cuda(GPU))  # 5x64*5*5
        # batch_features = feature_encoder(Variable(batches).cuda(GPU))  # 20x64*5*5
        sample_features = feature_encoder(Variable(samples))
        # sample_features: torch.Size([5, 64, 5, 5])
        batch_features = feature_encoder(Variable(batches))
        # batch_features: torch.Size([95, 64, 5, 5])

        # 拼接向量,其中torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1)
        # sample_features_ext : torch.Size([95, 5, 64, 5, 5])
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
        # batch_features_ext: torch.Size([5, 95, 64, 5, 5])
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        # batch_features_ext after: torch.Size([95, 5, 64, 5, 5])
        # 在深度学习处理图像的时候,经常要考虑将多张不同图片输入到网络,这时需要用torch.cat([image1,image2],1),
        '''relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM*2, 5, 5)'''
        relation_pairs = torch.abs(
            (sample_features_ext - batch_features_ext)).view(-1, 1600)
        # 度量学习
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)
        # relations torch.Size([95, 5])

        # 优化目标
        # mse = nn.MSELoss().cuda(GPU)
        mse = nn.MSELoss()
        # one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM,
        #                                      CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)).cuda(GPU)
        change = batch_labels.view(-1, 1).long()
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, change, 1))

        loss = mse(relations, one_hot_labels)

        # training 然后先将网络中的所有梯度置0
        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()  # 计算得到loss后就要回传损失
        # 梯度剪裁
        torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(relation_network.parameters(), 0.5)
        # 回传损失过程中会计算梯度,然后需要根据这些梯度更新参数,XX.step()就是用来更新参数的。之后,
        # 你就可以从xx.param_groups[0][‘params’]里面看到各个层的梯度和权值信息。
        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 10 == 0:
            print("episode:", episode + 1, "loss", loss.data[0])

        if (episode + 1) % 100 == 0:

            # test
            print("Testing...")
            total_rewards = 0
            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    CLASS_NUM,
                    SAMPLE_NUM_PER_CLASS,
                    SAMPLE_NUM_PER_CLASS,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                test_images, test_labels = test_dataloader.__iter__().next()
                test_labels = test_labels.long()
                # print('test_labels', test_labels)

                # calculate features
                # sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                # test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64
                sample_features = feature_encoder(Variable(sample_images))
                test_features = feature_encoder(Variable(test_images))

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                '''sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM, 1, 1, 1, 1)'''
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)
                '''relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM*2, 5, 5)'''
                relation_pairs = torch.abs(
                    (sample_features_ext - test_features_ext)).view(-1, 1600)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)
                test_labels = test_labels.long()
                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM)
                ]
                # print('rewards', rewards)
                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / TEST_EPISODE
            print("test accuracy:", test_accuracy)

            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                print("save networks for episode:", episode)
                last_accuracy = test_accuracy
Exemple #9
0
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )  # 获取训练的文件夹和测试文件夹,每一个文件夹包含一种数据

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)  # 64 8

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        degrees = random.choice([0, 90, 180, 270])

        # 1200个训练种类的文件夹list , 种类个数C=5,样本集每种种类的样本数 K=1,每种种类查询集中的个数 19 每训练一轮生成一个task
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM,
                               SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)

        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)
        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next(
        )  # [5,1,28,28]
        batches, batch_labels = batch_dataloader.__iter__().next(
        )  # [95,1,28,28]

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # [5,64,5,5]
        batch_features = feature_encoder(
            Variable(batches).cuda(GPU))  # [95,64,5,5]

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)  # [95, 5, 64, 5, 5]
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)  # [5, 95, 64, 5, 5]
        batch_features_ext = torch.transpose(batch_features_ext, 0,
                                             1)  # [95, 5, 64, 5, 5]

        relation_pairs = torch.cat(
            (sample_features_ext, batch_features_ext),
            2).view(-1, FEATURE_DIM * 2, 5,
                    5)  # 深度方向上连接 [95, 5, 128, 5, 5] -> [475, 128, 5, 5]
        relations = relation_network(relation_pairs).view(
            -1, CLASS_NUM)  # [95,5]  95个Q样例,每个输出5个置信度值

        mse = nn.MSELoss().cuda(GPU)

        # one_hot_labels 和 relations进行MSE运算
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1,
                                            batch_labels.view(-1, 1).long(),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data)

        if (episode + 1) % 5000 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    CLASS_NUM,
                    SAMPLE_NUM_PER_CLASS,
                    SAMPLE_NUM_PER_CLASS,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()  # [5, 1, 28, 28]
                test_images, test_labels = test_dataloader.__iter__().next(
                )  # [5, 1, 28, 28]  选取5张作为Q验证

                # calculate features
                sample_features = feature_encoder(
                    Variable(sample_images).cuda(GPU))  # [5, 64, 5, 5]
                test_features = feature_encoder(
                    Variable(test_images).cuda(GPU))  # [5, 64, 5, 5]

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 5, 5)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)  # 每行最大值

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM)
                ]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / TEST_EPISODE

            print("test accuracy:", test_accuracy)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy