Ejemplo n.º 1
0
def train(feature_encoder, relation_network, train_data, config):
    feature_encoder.train()
    relation_network.train()
    task = ClassifyTask(train_data, config["CLASS_NUM"],
                        config["SAMPLE_NUM_PER_CLASS"],
                        config["BATCH_NUM_PER_CLASS"])
    sample_dataloader = get_data_loader(
        task,
        config,
        num_per_class=config["SAMPLE_NUM_PER_CLASS"],
        split="train",
        shuffle=False)
    batch_dataloader = get_data_loader(
        task,
        config,
        num_per_class=config["BATCH_NUM_PER_CLASS"],
        split="test",
        shuffle=True)

    # sample datas
    samples, sample_labels = sample_dataloader.__iter__().next()
    batches, batch_labels = batch_dataloader.__iter__().next()

    # calculate features
    sample_features0 = feature_encoder(Variable(samples)).to(device)  # 25*128
    sample_features1 = sample_features0.view(config["CLASS_NUM"],
                                             config["SAMPLE_NUM_PER_CLASS"],
                                             -1)
    sample_features = torch.sum(sample_features1, 1).squeeze(1)
    batch_features = feature_encoder(Variable(batches)).to(device)  # 75*300

    # calculate relations
    sample_features_ext = sample_features.unsqueeze(0).repeat(
        config["BATCH_NUM_PER_CLASS"] * config["CLASS_NUM"], 1, 1)
    batch_features_ext = batch_features.unsqueeze(0).repeat(
        config["CLASS_NUM"], 1, 1)
    batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

    relations = relation_network(sample_features_ext, batch_features_ext).view(
        -1, config["CLASS_NUM"])

    mse = nn.MSELoss().to(device)
    batch_labels = batch_labels.long()
    one_hot_labels = torch.zeros(
        config["BATCH_NUM_PER_CLASS"] * config["CLASS_NUM"],
        config["CLASS_NUM"])
    one_hot_labels = one_hot_labels.scatter_(1, batch_labels.view(-1, 1), 1)
    one_hot_labels = Variable(one_hot_labels).to(device)

    loss = mse(relations, one_hot_labels)
    feature_encoder.zero_grad()
    relation_network.zero_grad()
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
    # torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

    return loss.item()
Ejemplo n.º 2
0
def train(episode):
    model.train()
    task = ClassifyTask(train_data, config["model"]["class"],
                        int(config["model"]["support"]),
                        int(config["model"]["query"]))
    sample_dataloader = get_data_loader(task,
                                        word2index,
                                        config,
                                        shuffle=False)
    data, target = sample_dataloader.__iter__().next()
    data = data.to(device)
    target = target.to(device)
    optimizer.zero_grad()
    model.zero_grad()
    predict = model(data)
    loss, acc = criterion(predict, target)
    loss.backward()
    optimizer.step()
    scheduler.step()

    writer.add_scalar('train_loss', loss.item(), episode)
    writer.add_scalar('train_acc', acc, episode)
    if episode % log_interval == 0:
        print('Train Episode: {} Loss: {} Acc: {}'.format(
            episode, loss.item(), acc))
Ejemplo n.º 3
0
def dev(episode):
    model.eval()
    correct = 0.
    count = 0.
    task = ClassifyTask(dev_data, config["model"]["class"],
                        int(config["model"]["support"]),
                        int(config["model"]["query"]))
    dev_loader = get_data_loader(task, word2index, config, shuffle=False)
    for data, target in dev_loader:
        with torch.no_grad():
            data = data.to(device)
            target = target.to(device)
            predict = model(data)
            _, acc = criterion(predict, target)
            amount = len(target) - support * 2
            correct += acc * amount
            count += amount
    acc = correct / count
    writer.add_scalar('dev_acc', acc, episode)
    print('Dev Episode: {} Acc: {}'.format(episode, acc))
    return acc
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders,metatest_character_folders = tg.omniglot_character_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM)


    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5)

    if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    for episode in range(EPISODE):


            # test
            print("Testing...")
            total_rewards = 0
            accuracies = []
            for i in range(TEST_EPISODE):
                degrees = random.choice([0,90,180,270])
                task = tg.OmniglotTask(metatest_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,SAMPLE_NUM_PER_CLASS,)
                sample_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
                test_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees)

                sample_images,sample_labels = sample_dataloader.__iter__().next()
                test_images,test_labels = test_dataloader.__iter__().next()

                # calculate features
                sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                test_features_ext = test_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                test_features_ext = torch.transpose(test_features_ext,0,1)

                relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,5,5)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)
                
                # start change
                
                use_cuda = torch.cuda.is_available()
                device = torch.device('cuda:0' if use_cuda else 'cpu')
                test_labels = test_labels.to(device)
                
                rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM)]

                total_rewards += np.sum(rewards)
                accuracy = np.sum(rewards)/1.0/CLASS_NUM/SAMPLE_NUM_PER_CLASS
                accuracies.append(accuracy)

            test_accuracy,h = mean_confidence_interval(accuracies)

            print("test accuracy:",test_accuracy,"h:",h)
            total_accuracy += test_accuracy

    print("aver_accuracy:",total_accuracy/EPISODE)
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    #     if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
    #         feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
    #         print("load feature encoder success")
    #     if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
    #         relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
    #         print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM,
                               SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)
        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # 5x64*5*5
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 5, 5)
        sample_features = torch.sum(sample_features, 1) / 5.0
        sample_features = sample_features.squeeze(1)
        batch_features = feature_encoder(
            Variable(batches).cuda(GPU))  # 20x64*5*5

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 5, 5)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.item())

        if (episode + 1) % 5000 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    CLASS_NUM,
                    SAMPLE_NUM_PER_CLASS,
                    SAMPLE_NUM_PER_CLASS,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                test_images, test_labels = test_dataloader.__iter__().next()

                # calculate features
                sample_features = feature_encoder(
                    Variable(sample_images).cuda(GPU))  # 5x64
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 5, 5)
                sample_features = torch.sum(sample_features, 1) / 5.0
                sample_features = sample_features.squeeze(1)
                test_features = feature_encoder(
                    Variable(test_images).cuda(GPU))  # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 5, 5)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)
                predict_labels = predict_labels.cuda(GPU)
                test_labels = test_labels.cuda(GPU)
                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
                ]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE

            print("test accuracy:", test_accuracy)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./omniglot_feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./omniglot_relation_network_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
Ejemplo n.º 6
0
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders,metatest_character_folders = tg.mstar_character_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(576,RELATION_DIM)


    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5)

    if os.path.exists(str("./models/mstar_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/mstar_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(str("./models/mstar_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str("./models/mstar_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load relation network success")


    total_accuracy = 0.0
    total_accuracy_other3 = 0
    total_accuracy_other2 = 0
    total_accuracy_other0 = 0
    total_accuracy_other1 = 0
    for episode in range(EPISODE):

            # test
            print("Testing...")
            total_rewards = 0
            total_other3 = 0
            total_other2 = 0
            total_other0 = 0
            total_other1 = 0
            accuracies = []
            accuracies_other3 = []
            accuracies_other2 = []
            accuracies_other0 = []
            accuracies_other1 = []
            for i in range(TEST_EPISODE):
                degrees = random.choice([0,90,180,270])
                task = tg.MstarTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS, phase ='testing')
                sample_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
                test_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees)

                sample_images,sample_labels = sample_dataloader.__iter__().next()
                test_images,test_labels = test_dataloader.__iter__().next()

                test_labels = Variable(test_labels).cuda(GPU)

                # calculate features
                sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,sample_features.size(3),sample_features.size(3))
                sample_features = torch.sum(sample_features,1).squeeze(1)
                test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                test_features_ext = test_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
                test_features_ext = torch.transpose(test_features_ext,0,1)

                relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,sample_features.size(3),sample_features.size(3))
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)

                rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM*SAMPLE_NUM_PER_CLASS)]

                gf3_rewards_other3 = [1 if (predict_labels[j] == test_labels[j] and predict_labels[j] == 3) else 0 for j in
                           range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)]

                gf3_rewards_other2 = [1 if (predict_labels[j] == test_labels[j] and predict_labels[j] == 2) else 0 for j in
                                    range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)]

                rewards_other0 = [1 if (predict_labels[j] == test_labels[j] and predict_labels[j] == 0) else 0 for j
                                     in
                                     range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)]
                rewards_other1 = [1 if (predict_labels[j] == test_labels[j] and predict_labels[j] == 1) else 0 for j
                                  in
                                  range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)]

                total_rewards += np.sum(rewards)
                total_other3 += np.sum(gf3_rewards_other3)
                total_other2 += np.sum(gf3_rewards_other2)
                total_other0 += np.sum(rewards_other0)
                total_other1 += np.sum(rewards_other1)
                accuracy = np.sum(rewards)/1.0/CLASS_NUM/SAMPLE_NUM_PER_CLASS
                accuracies.append(accuracy)

                accuracy = np.sum(gf3_rewards_other3)/1.0/CLASS_NUM/SAMPLE_NUM_PER_CLASS * 4
                accuracies_other3.append(accuracy)

                accuracy = np.sum(gf3_rewards_other2) / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS * 4
                accuracies_other2.append(accuracy)

                accuracy = np.sum(rewards_other0) / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS * 4
                accuracies_other0.append(accuracy)

                accuracy = np.sum(rewards_other1) / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS * 4
                accuracies_other1.append(accuracy)


            test_accuracy,h = mean_confidence_interval(accuracies)

            test_accuracy_other3, h_other3 = mean_confidence_interval(accuracies_other3)
            test_accuracy_other2, h_other2 = mean_confidence_interval(accuracies_other2)
            test_accuracy_other0, h_other0 = mean_confidence_interval(accuracies_other0)
            test_accuracy_other1, h_other1 = mean_confidence_interval(accuracies_other1)

            print("test accuracy:",test_accuracy,"h:",h)
            print("other0: ",test_accuracy_other0, "other1: ",test_accuracy_other1, "other2: ",test_accuracy_other2, "other3: ",test_accuracy_other3, )
            total_accuracy += test_accuracy
            total_accuracy_other3 += test_accuracy_other3
            total_accuracy_other2 += test_accuracy_other2
            total_accuracy_other0 += test_accuracy_other0
            total_accuracy_other1 += test_accuracy_other1

    print("aver_accuracy:",total_accuracy/EPISODE)
       
    print('ava_other0:', total_accuracy_other0 / EPISODE)
    print('ava_other1:', total_accuracy_other1 / EPISODE)
    print('ava_other2:', total_accuracy_other2 / EPISODE)
    print('ava_other3:', total_accuracy_other3 / EPISODE)
Ejemplo n.º 7
0
def main():
    writer = SummaryWriter('/home/caffe/achu/logs/pytorch_omniglot_FSL.log')
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
        data_folder=args.dataset_folder,
        no_of_training_samples=args.training_samples_per_class,
        no_of_validation_samples=args.support_set_samples_per_class)

    # Step 2: init neural networks
    print("init neural networks")

    cnn_output_dims = cnn_final_output_dims(args.image_size)
    rn_dims = rn_dims_before_FCN(cnn_output_dims)
    fcn_size = args.channel_dim * (rn_dims**2)

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(fcn_size, args.hidden_unit)

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    if torch.cuda.device_count() >= 1:
        print("Let's use", torch.cuda.device_count(), "args.gpus!")
        feature_encoder = nn.DataParallel(feature_encoder)
        relation_network = nn.DataParallel(relation_network)
    else:
        feature_encoder.cuda(args.gpu)
        relation_network.cuda(args.gpu)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=args.learning_rate)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=args.learning_rate)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(args.class_num) +
                "way_" + str(args.training_samples_per_class) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(str("./models/omniglot_feature_encoder_" +
                           str(args.class_num) + "way_" +
                           str(args.training_samples_per_class) + "shot.pkl"),
                       map_location='cuda:0'))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(args.class_num) +
                "way_" + str(args.training_samples_per_class) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(str("./models/omniglot_relation_network_" +
                           str(args.class_num) + "way_" +
                           str(args.training_samples_per_class) + "shot.pkl"),
                       map_location='cuda:0'))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(args.episode):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tg.OmniglotTask(metatrain_character_folders, args.class_num,
                               args.training_samples_per_class,
                               args.support_set_samples_per_class)
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=args.training_samples_per_class,
            split="train",
            shuffle=False,
            rotation=degrees)
        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=args.support_set_samples_per_class,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        batches, batch_labels = batch_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(Variable(samples).cuda(
            args.gpu))  # 5x64*5*5
        sample_features = sample_features.view(args.class_num,
                                               args.training_samples_per_class,
                                               args.feature_dim,
                                               cnn_output_dims,
                                               cnn_output_dims)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        batch_features = feature_encoder(Variable(batches).cuda(
            args.gpu))  # 20x64*5*5

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            args.support_set_samples_per_class * args.class_num, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            args.class_num, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, args.feature_dim * 2,
                                           cnn_output_dims, cnn_output_dims)
        relations = relation_network(relation_pairs).view(-1, args.class_num)

        mse = nn.MSELoss().cuda(args.gpu)
        one_hot_labels = Variable(
            torch.zeros(args.support_set_samples_per_class * args.class_num,
                        args.class_num).scatter_(1, batch_labels.view(-1, 1),
                                                 1)).cuda(args.gpu)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        writer.add_scalar('Training loss', loss.data, episode + 1)

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data)

        if (episode + 1) % 5000 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(args.test_episode):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    args.class_num,
                    args.training_samples_per_class,
                    args.training_samples_per_class,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=args.training_samples_per_class,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=args.support_set_samples_per_class,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                test_images, test_labels = test_dataloader.__iter__().next()

                test_labels = test_labels.cuda()

                # calculate features
                sample_features = feature_encoder(
                    Variable(sample_images).cuda(args.gpu))  # 5x64
                sample_features = sample_features.view(
                    args.class_num, args.training_samples_per_class,
                    args.feature_dim, cnn_output_dims, cnn_output_dims)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(
                    Variable(test_images).cuda(args.gpu))  # 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    args.training_samples_per_class * args.class_num, 1, 1, 1,
                    1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    args.class_num, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, args.feature_dim * 2, cnn_output_dims,
                            cnn_output_dims)
                relations = relation_network(relation_pairs).view(
                    -1, args.class_num)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(args.class_num *
                                   args.training_samples_per_class)
                ]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / args.class_num / args.training_samples_per_class / args.test_episode

            print("validation accuracy:", test_accuracy)
            writer.add_scalar('Validation accuracy', test_accuracy,
                              episode + 1)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" +
                        str(args.class_num) + "way_" +
                        str(args.training_samples_per_class) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" +
                        str(args.class_num) + "way_" +
                        str(args.training_samples_per_class) + "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
Ejemplo n.º 8
0
def test(feature_encoder, relation_network):
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )
    # test
    print("Testing...")
    total_rewards = 0

    for i in range(TEST_EPISODE):
        # degrees = random.choice([0, 90, 180, 270])
        degrees = random.choice([0])
        task = tg.OmniglotTask(
            metatest_character_folders,
            CLASS_NUM,
            SAMPLE_NUM_PER_CLASS,
            SAMPLE_NUM_PER_CLASS,
        )
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)
        test_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        sample_images, sample_labels = sample_dataloader.__iter__().next()
        test_images, test_labels = test_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(sample_images).to(device))  # 5x64
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 5, 5)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        test_features = feature_encoder(
            Variable(test_images).to(device))  # 20x64

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        test_features_ext = test_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        test_features_ext = torch.transpose(test_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, test_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 5, 5)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        _, predict_labels = torch.max(relations.data, 1)

        rewards = [
            1 if predict_labels[j] == test_labels[j] else 0
            for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
        ]

        total_rewards += np.sum(rewards)

    test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE

    print("test accuracy:", test_accuracy)
Ejemplo n.º 9
0
def valid(feature_encoder, relation_network, test_data, config):
    # test
    feature_encoder.eval()
    relation_network.eval()

    total_rewards = 0

    for i in range(config["TEST_EPISODE"]):  # 训练测试 集合数量不同
        task = ClassifyTask(test_data, config["CLASS_NUM"],
                            config["SAMPLE_NUM_PER_CLASS"],
                            config["BATCH_NUM_PER_CLASS"])
        sample_dataloader = get_data_loader(
            task,
            config,
            num_per_class=config["SAMPLE_NUM_PER_CLASS"],
            split="train",
            shuffle=False)
        test_dataloader = get_data_loader(
            task,
            config,
            num_per_class=config["SAMPLE_NUM_PER_CLASS"],
            split="test",
            shuffle=True)

        sample_images, sample_labels = sample_dataloader.__iter__().next()
        test_images, test_labels = test_dataloader.__iter__().next()

        # calculate features
        sample_features = feature_encoder(
            Variable(sample_images).to(device))  # 5x28->   #5*50
        sample_features = sample_features.view(config["CLASS_NUM"],
                                               config["SAMPLE_NUM_PER_CLASS"],
                                               -1)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        test_features = feature_encoder(
            Variable(test_images).to(device))  # 20x64

        sample_features_ext = sample_features.unsqueeze(0).repeat(
            config["SAMPLE_NUM_PER_CLASS"] * config["CLASS_NUM"], 1, 1)
        test_features_ext = test_features.unsqueeze(0).repeat(
            config["CLASS_NUM"], 1, 1)
        test_features_ext = torch.transpose(test_features_ext, 0, 1)

        relations = relation_network(sample_features_ext,
                                     test_features_ext)  # 25
        relations = relations.view(-1, config["CLASS_NUM"])  # 5*5

        _, predict_labels = torch.max(relations.data, 1)
        predict_labels = predict_labels.cpu()
        rewards = [
            1 if int(predict_labels[j]) == int(test_labels[j]) else 0
            for j in range(config["CLASS_NUM"] *
                           config["SAMPLE_NUM_PER_CLASS"])
        ]

        total_rewards += np.sum(rewards)

    test_accuracy = total_rewards / 1.0 / config["CLASS_NUM"] / config[
        "SAMPLE_NUM_PER_CLASS"] / config["TEST_EPISODE"]
    print("test accuracy:", test_accuracy)

    return test_accuracy
def main():
    # * Step 1: init data folders
    print("init data folders")

    # * init character folders for dataset construction
    metartrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = ot.CNNEncoder().to(device)
    relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device)

    feature_encoder.eval()
    relation_network.eval()

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")

    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    for episode in range(EPISODE):
        # * test
        print("Testing...")
        total_rewards = 0
        accuracies = []

        for i in range(TEST_EPISODE):
            degrees = random.choice([0, 90, 180, 270])
            task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS)

            sample_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False,
                rotation=degrees)
            test_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="test",
                shuffle=True,
                rotation=degrees)

            sample_images, sample_labels = next(iter(sample_dataloader))
            test_images, test_labels = next(iter(test_dataloader))

            sample_images, sample_labels = sample_images.to(
                device), sample_labels.to(device)
            test_images, test_labels = test_images.to(device), test_labels.to(
                device)

            # * Calculate features
            sample_features = feature_encoder(sample_images)
            sample_features = sample_features.view(CLASS_NUM,
                                                   SAMPLE_NUM_PER_CLASS,
                                                   FEATURE_DIM, 5, 5)
            sample_features = torch.sum(sample_features, 1).squeeze(1)
            test_features = feature_encoder(test_images)

            sample_features_ext = sample_features.unsqueeze(0).repeat(
                SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = test_features.unsqueeze(0).repeat(
                CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = torch.transpose(test_features_ext, 0, 1)

            relation_pairs = torch.cat(
                (sample_features_ext, test_features_ext),
                2).view(-1, FEATURE_DIM * 2, 5, 5)
            relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

            _, predict_labels = torch.max(relations.data, 1)

            rewards = [
                1 if predict_labels[j] == test_labels[j] else 0
                for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
            ]

            total_rewards += np.sum(rewards)
            accuracy = np.sum(rewards) / (1.0 * CLASS_NUM *
                                          SAMPLE_NUM_PER_CLASS)
            accuracies.append(accuracy)

        test_accuracy, h = mean_confidence_interval(accuracies)

        print(f'test accuracy : {test_accuracy}, h : {h}')
        total_accuracy += test_accuracy

    print(f"average accuracy : {total_accuracy / EPISODE}")
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU)
    relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1)

    if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")))
        print("load similarity network success")
    if os.path.exists(METHOD) == False:
        os.system('mkdir ' + METHOD)

    # Step 3: build graph
    print("Training...")

    best_accuracy = 0.0
    best_h = 0.0

    for episode in range(EPISODE):
        with torch.no_grad():   
            # query
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0,90,180,270])
                task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,)
                support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees)
                query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="query",shuffle=True,rotation=degrees)

                support_images,support_labels = support_dataloader.__iter__().next()
                query_images,query_labels = query_dataloader.__iter__().next()
                
                # calculate features
                support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64
                support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1)
                query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25)

                H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU)
                H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU)
                # HOP features
                for d in range(support_features.size(0)):
                    s = support_features[d,:,:].squeeze(0)
                    s = (1.0 / support_features.size(2)) * s.mm(s.t())
                    H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
                for d in range(query_features.size(0)):
                    s = query_features[d,:,:].squeeze(0)
                    s = (1.0 / query_features.size(2)) * s.mm(s.t())
                    H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA)
                    
                # calculate relations
                # each query support link to every supports to calculate relations
                # to form a 100x128 matrix for relation network
                support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
                query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1)
                query_features_ext = torch.transpose(query_features_ext,0,1)

                relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64)
                relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                _,predict_labels = torch.max(relations.data,1)

                rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE

            print("query accuracy:",test_accuracy)
            print("best accuracy:",best_accuracy)

            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
Ejemplo n.º 12
0
def valid(feature_encoder, relation_network, test_data, config, word2index):
    # test
    feature_encoder.eval()
    relation_network.eval()

    total_rewards = 0

    for i in range(config["TEST_EPISODE"]):  # 训练测试 集合数量不同
        task = OmniglotTask(test_data, config["CLASS_NUM"],
                            config["SAMPLE_NUM_PER_CLASS"],
                            config["BATCH_NUM_PER_CLASS"], "test")
        sample_dataloader = get_data_loader(
            task,
            config,
            num_per_class=config["SAMPLE_NUM_PER_CLASS"],
            split="train",
            shuffle=False)
        test_dataloader = get_data_loader(
            task,
            config,
            num_per_class=config["SAMPLE_NUM_PER_CLASS"],
            split="test",
            shuffle=True)

        sample_images, sample_labels, class_folders = sample_dataloader.__iter__(
        ).next()
        test_images, test_labels, class_folders = test_dataloader.__iter__(
        ).next()

        # calculate features
        sample_features = feature_encoder(
            Variable(sample_images).to(device))  # 5x28->   #5*50
        sample_features = sample_features.view(config["CLASS_NUM"],
                                               config["SAMPLE_NUM_PER_CLASS"],
                                               -1)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        test_features = feature_encoder(
            Variable(test_images).to(device))  # 20x64

        sample_features_ext = sample_features.unsqueeze(0).repeat(
            config["SAMPLE_NUM_PER_CLASS"] * config["CLASS_NUM"], 1, 1)
        test_features_ext = test_features.unsqueeze(0).repeat(
            config["CLASS_NUM"], 1, 1)
        test_features_ext = torch.transpose(test_features_ext, 0, 1)

        relations = relation_network(sample_features_ext,
                                     test_features_ext)  # 25
        relations = relations.view(-1, config["CLASS_NUM"])  # 5*5

        _, predict_labels = torch.max(relations.data, 1)
        predict_labels = predict_labels.cpu()
        rewards = [
            1 if int(predict_labels[j]) == int(test_labels[j]) else 0
            for j in range(config["CLASS_NUM"] *
                           config["SAMPLE_NUM_PER_CLASS"])
        ]
        if i % 10 == 0:
            pass
            # print("测试集,目标值:{},预测结果:{}".format(test_labels,predict_labels))
        if i % 100 == 0:
            sentence = "目的地改为哈尔滨"
            # task = OmniglotTask(test_data, config["CLASS_NUM"], config["SAMPLE_NUM_PER_CLASS"],
            #                     config["BATCH_NUM_PER_CLASS"], "test")
            keys = test_data.keys()
            support_inputs = []
            choice_num = config["CLASS_NUM"] * config[
                "SAMPLE_NUM_PER_CLASS"] // len(keys)
            for categeory in keys:
                class_folders = random.sample(test_data[categeory], choice_num)
                for sentence_i in class_folders:
                    sentence_index = sentence2indices(sentence_i, word2index,
                                                      config["max_len"],
                                                      Constants.PAD)
                    support_inputs.append(sentence_index)
            # support_inputs.extend([support_inputs[-1]*(config["CLASS_NUM"]*config["SAMPLE_NUM_PER_CLASS"]-len(support_inputs))])
            # sample_dataloader = get_data_loader(task, config, num_per_class=config["SAMPLE_NUM_PER_CLASS"],
            #                                     split="train",
            #                                     shuffle=False)
            # sample_images, sample_labels,class_folders = sample_dataloader.__iter__().next()
            support_inputs = torch.tensor(support_inputs)
            """"dfalsfd"""
            sentence_id = [
                config["word2index"].get(word, Constants.PAD)
                for word in sentence
            ]
            sentence_id += [0] * (12 - len(sentence_id))
            sentence_images = torch.tensor([sentence_id])
            sentence_images = sentence_images.repeat(
                config["CLASS_NUM"] * config["SAMPLE_NUM_PER_CLASS"], 1)
            test_images = sentence_images

            sample_features = feature_encoder(
                Variable(support_inputs).to(device))  # 5x28->   #5*50
            sample_features = sample_features.view(
                config["CLASS_NUM"], config["SAMPLE_NUM_PER_CLASS"], -1)
            sample_features = torch.sum(sample_features, 1).squeeze(1)
            test_features = feature_encoder(
                Variable(test_images).to(device))  # 20x64

            sample_features_ext = sample_features.unsqueeze(0).repeat(
                config["SAMPLE_NUM_PER_CLASS"] * config["CLASS_NUM"], 1, 1)
            test_features_ext = test_features.unsqueeze(0).repeat(
                config["CLASS_NUM"], 1, 1)
            test_features_ext = torch.transpose(test_features_ext, 0, 1)

            relations = relation_network(sample_features_ext,
                                         test_features_ext)  # 25
            relations = relations.view(-1, config["CLASS_NUM"])  # 5*5

            _, predict_labels = torch.max(relations.data, 1)
            predict_labels = predict_labels.cpu()
            print("预测概率为:", predict_labels)
            print("预测值为", predict_labels, keys)

        total_rewards += np.sum(rewards)

    test_accuracy = total_rewards / 1.0 / config["CLASS_NUM"] / config[
        "SAMPLE_NUM_PER_CLASS"] / config["TEST_EPISODE"]
    print("test accuracy:", test_accuracy)

    return test_accuracy
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # Step 2: init neural networks
    print("init neural networks")
    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(RELATION1_DIM, RELATION2_DIM,
                                       RELATION3_DIM)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    for episode in range(EPISODE):
        # test
        print("Testing...")
        total_rewards = 0
        accuracies = []
        for i in range(TEST_EPISODE):
            degrees = random.choice([0, 90, 180, 270])
            task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM,
                                   SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS)
            sample_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="train",
                shuffle=False,
                rotation=degrees)
            test_dataloader = tg.get_data_loader(
                task,
                num_per_class=SAMPLE_NUM_PER_CLASS,
                split="test",
                shuffle=True,
                rotation=degrees)

            sample_images, sample_labels = sample_dataloader.__iter__().next()
            test_images, test_labels = test_dataloader.__iter__().next()
            # 注意在这里的test_images取了5张

            # calculate features
            sample_features = feature_encoder(Variable(sample_images))
            # print('sample_features :', sample_features.size())
            test_features = feature_encoder(Variable(test_images))
            # print('test_features :', test_features.size())
            # calculate relations
            sample_features_ext = sample_features.unsqueeze(0).repeat(
                SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
            # print('sample_features_ext :', sample_features_ext.size())
            test_features_ext = test_features.unsqueeze(0).repeat(
                SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
            # print('test_features_ext:', test_features_ext.size())
            test_features_ext = torch.transpose(test_features_ext, 0, 1)

            relation_pairs = torch.abs(
                (sample_features_ext - test_features_ext)).view(-1, 1600)
            # print('relation_pairs :', relation_pairs.size())
            relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

            _, predict_labels = torch.max(relations.data, 1)
            # print(predict_labels)
            test_labels = test_labels.long()
            rewards = [
                1 if predict_labels[j] == test_labels[j] else 0
                for j in range(CLASS_NUM)
            ]

            total_rewards += np.sum(rewards)
            accuracy = np.sum(rewards) / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS
            accuracies.append(accuracy)

        test_accuracy, h = mean_confidence_interval(accuracies)

        print("test accuracy:", test_accuracy, "h:", h)
        total_accuracy += test_accuracy

    print("aver_accuracy:", total_accuracy / EPISODE)
def main():
    # Step 1: init data folders
    print("init data folders")
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # Step 2: init neural networks
    print("init neural networks")
    if USE_INCEPTION_EMBEDDING:
        feature_encoder = Inception(1, 10)
    else:
        feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(RELATION1_DIM, RELATION2_DIM,
                                       RELATION3_DIM)
    # 运用apply()函数进行权重初始化
    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    # feature_encoder.cuda(GPU)
    # relation_network.cuda(GPU)
    """要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。
     然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。然后一般还会定义学习率的变化策略,
     这里采用的是torch.optim.lr_scheduler模块的StepLR类,表示每隔step_size个epoch就将学习率降为原来的gamma倍。"""
    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)

    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")
    last_accuracy = 0.0
    for episode in range(EPISODE):
        # 训练开始的时候需要先更新下学习率,这是因为我们前面制定了学习率的变化策略,所以在每个epoch开始时都要更新下
        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        # 制作支持集和目标集
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM,
                               SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)

        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        batches, batch_labels = batch_dataloader.__iter__().next()
        # samples.size: torch.Size([5, 1, 28, 28]);sample_labels.size: torch.Size([5])
        # batches.size: torch.Size([95, 1, 28, 28]);batches_labels.size: torch.Size([95])

        # print(batch_labels.view(-1, 1))

        # 提取特征
        # sample_features = feature_encoder(Variable(samples).cuda(GPU))  # 5x64*5*5
        # batch_features = feature_encoder(Variable(batches).cuda(GPU))  # 20x64*5*5
        sample_features = feature_encoder(Variable(samples))
        # sample_features: torch.Size([5, 64, 5, 5])
        batch_features = feature_encoder(Variable(batches))
        # batch_features: torch.Size([95, 64, 5, 5])

        # 拼接向量,其中torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1)
        # sample_features_ext : torch.Size([95, 5, 64, 5, 5])
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
        # batch_features_ext: torch.Size([5, 95, 64, 5, 5])
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        # batch_features_ext after: torch.Size([95, 5, 64, 5, 5])
        # 在深度学习处理图像的时候,经常要考虑将多张不同图片输入到网络,这时需要用torch.cat([image1,image2],1),
        '''relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM*2, 5, 5)'''
        relation_pairs = torch.abs(
            (sample_features_ext - batch_features_ext)).view(-1, 1600)
        # 度量学习
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)
        # relations torch.Size([95, 5])

        # 优化目标
        # mse = nn.MSELoss().cuda(GPU)
        mse = nn.MSELoss()
        # one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM,
        #                                      CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)).cuda(GPU)
        change = batch_labels.view(-1, 1).long()
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, change, 1))

        loss = mse(relations, one_hot_labels)

        # training 然后先将网络中的所有梯度置0
        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()  # 计算得到loss后就要回传损失
        # 梯度剪裁
        torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(relation_network.parameters(), 0.5)
        # 回传损失过程中会计算梯度,然后需要根据这些梯度更新参数,XX.step()就是用来更新参数的。之后,
        # 你就可以从xx.param_groups[0][‘params’]里面看到各个层的梯度和权值信息。
        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 10 == 0:
            print("episode:", episode + 1, "loss", loss.data[0])

        if (episode + 1) % 100 == 0:

            # test
            print("Testing...")
            total_rewards = 0
            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    CLASS_NUM,
                    SAMPLE_NUM_PER_CLASS,
                    SAMPLE_NUM_PER_CLASS,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()
                test_images, test_labels = test_dataloader.__iter__().next()
                test_labels = test_labels.long()
                # print('test_labels', test_labels)

                # calculate features
                # sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                # test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64
                sample_features = feature_encoder(Variable(sample_images))
                test_features = feature_encoder(Variable(test_images))

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                '''sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM, 1, 1, 1, 1)'''
                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)
                '''relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM*2, 5, 5)'''
                relation_pairs = torch.abs(
                    (sample_features_ext - test_features_ext)).view(-1, 1600)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)
                test_labels = test_labels.long()
                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM)
                ]
                # print('rewards', rewards)
                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / TEST_EPISODE
            print("test accuracy:", test_accuracy)

            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))
                print("save networks for episode:", episode)
                last_accuracy = test_accuracy
Ejemplo n.º 15
0
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )  # 获取训练的文件夹和测试文件夹,每一个文件夹包含一种数据

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)  # 64 8

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        degrees = random.choice([0, 90, 180, 270])

        # 1200个训练种类的文件夹list , 种类个数C=5,样本集每种种类的样本数 K=1,每种种类查询集中的个数 19 每训练一轮生成一个task
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM,
                               SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)

        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)
        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next(
        )  # [5,1,28,28]
        batches, batch_labels = batch_dataloader.__iter__().next(
        )  # [95,1,28,28]

        # calculate features
        sample_features = feature_encoder(
            Variable(samples).cuda(GPU))  # [5,64,5,5]
        batch_features = feature_encoder(
            Variable(batches).cuda(GPU))  # [95,64,5,5]

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)  # [95, 5, 64, 5, 5]
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)  # [5, 95, 64, 5, 5]
        batch_features_ext = torch.transpose(batch_features_ext, 0,
                                             1)  # [95, 5, 64, 5, 5]

        relation_pairs = torch.cat(
            (sample_features_ext, batch_features_ext),
            2).view(-1, FEATURE_DIM * 2, 5,
                    5)  # 深度方向上连接 [95, 5, 128, 5, 5] -> [475, 128, 5, 5]
        relations = relation_network(relation_pairs).view(
            -1, CLASS_NUM)  # [95,5]  95个Q样例,每个输出5个置信度值

        mse = nn.MSELoss().cuda(GPU)

        # one_hot_labels 和 relations进行MSE运算
        one_hot_labels = Variable(
            torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1,
                                            batch_labels.view(-1, 1).long(),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data)

        if (episode + 1) % 5000 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(
                    metatest_character_folders,
                    CLASS_NUM,
                    SAMPLE_NUM_PER_CLASS,
                    SAMPLE_NUM_PER_CLASS,
                )
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = sample_dataloader.__iter__(
                ).next()  # [5, 1, 28, 28]
                test_images, test_labels = test_dataloader.__iter__().next(
                )  # [5, 1, 28, 28]  选取5张作为Q验证

                # calculate features
                sample_features = feature_encoder(
                    Variable(sample_images).cuda(GPU))  # [5, 64, 5, 5]
                test_features = feature_encoder(
                    Variable(test_images).cuda(GPU))  # [5, 64, 5, 5]

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 5, 5)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)  # 每行最大值

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM)
                ]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / TEST_EPISODE

            print("test accuracy:", test_accuracy)

            if test_accuracy > last_accuracy:

                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy