def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # * Step 1: init data folders
    print("init data folders")

    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders(
    )

    # * Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.to(device)
    relation_network.to(device)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)

    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    if os.path.exists(
            str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(
            torch.load(
                str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(
            str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(
                str("./models/omniglot_relation_network_" + str(CLASS_NUM) +
                    "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    # * Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        # * init dataset
        # * sample_dataloader is to obtain previous samples for compare
        # * batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM,
                               SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_data_loader(
            task,
            num_per_class=SAMPLE_NUM_PER_CLASS,
            split="train",
            shuffle=False,
            rotation=degrees)
        batch_dataloader = tg.get_data_loader(
            task,
            num_per_class=BATCH_NUM_PER_CLASS,
            split="test",
            shuffle=True,
            rotation=degrees)

        # * sample datas
        # samples, sample_labels = sample_dataloader.__iter__().next()
        # batches, batch_labels = batch_dataloader.__iter__().next()

        samples, sample_labels = next(iter(sample_dataloader))
        batches, batch_labels = next(iter(batch_dataloader))

        samples, sample_labels = samples.to(device), sample_labels.to(device)
        batches, batch_labels = batches.to(device), batch_labels.to(device)

        # * calculates features
        sample_features = feature_encoder(samples)
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS,
                                               FEATURE_DIM, 5, 5)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        batch_features = feature_encoder(batches)

        # * calculate relations
        # * each batch sample link to every samples to calculate relations
        # * to form a 100 * 128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(
            CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2).view(-1, FEATURE_DIM * 2, 5, 5)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        mse = nn.MSELoss()
        one_hot_labels = torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM,
                                     CLASS_NUM).to(device).scatter_(
                                         1, batch_labels.view(-1, 1), 1)
        loss = mse(relations, one_hot_labels)

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        if (episode + 1) % 100 == 0:
            print(
                f"episode : {episode+1}, loss : {loss.cpu().detach().numpy()}")

        if (episode + 1) % 500 == 0:
            print("Testing...")
            total_reward = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM,
                                       SAMPLE_NUM_PER_CLASS,
                                       SAMPLE_NUM_PER_CLASS)
                sample_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="train",
                    shuffle=False,
                    rotation=degrees)
                test_dataloader = tg.get_data_loader(
                    task,
                    num_per_class=SAMPLE_NUM_PER_CLASS,
                    split="test",
                    shuffle=True,
                    rotation=degrees)

                sample_images, sample_labels = next(iter(sample_dataloader))
                test_images, test_labels = next(iter(test_dataloader))

                sample_images, sample_labels = sample_images.to(
                    device), sample_labels.to(device)
                test_images, test_labels = test_images.to(
                    device), test_labels.to(device)

                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM,
                                                       SAMPLE_NUM_PER_CLASS,
                                                       FEATURE_DIM, 5, 5)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)

                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network

                sample_features_ext = sample_features.unsqueeze(0).repeat(
                    SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(
                    CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat(
                    (sample_features_ext, test_features_ext),
                    2).view(-1, FEATURE_DIM * 2, 5, 5)
                relations = relation_network(relation_pairs).view(
                    -1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
                ]
                total_reward += np.sum(rewards)

            test_accuracy = total_reward / (
                1.0 * CLASS_NUM * SAMPLE_NUM_PER_CLASS * TEST_EPISODE)

            print("test accuracy : ", test_accuracy)

            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +
                        "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" +
                        str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) +
                        "shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
def main():    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
    
    # * Step 1: init data folders
    print("init data folders")
    
    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders()
    
    # * Step 2: init neural networks
    print("init neural networks")
    
    feature_encoder = CNNEncoder()
    # RFT = RandomForestClassifier(n_estimators=100, random_state=1, n_jobs=-1, warm_start=True)
    
    # * It decrease bias and increase variance
    '''
    When RFT has high bias and low variance, then RFT dominate accuracy when it is combined with relation network
    '''
    # RFT = RandomForestClassifier(n_estimators=100, random_state=1, min_samples_leaf=5, n_jobs=-1, warm_start=True)
    
    # * It increase bias and decrease variance
    RFT = RandomForestClassifier(n_estimators=100, random_state=1, n_jobs=-1, warm_start=True)
    
    # RFT = RandomForestClassifier(n_estimators=100, random_state=1, n_jobs=-1)
    # RFT = RandomForestClassifier(n_estimators=100, random_state=1, min_samples_leaf=5, n_jobs=-1)
    relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM)
    
    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)
    
    feature_encoder.to(device)
    relation_network.to(device)
    
   # mse = nn.MSELoss()
    cross_entropy = nn.CrossEntropyLoss()
            
    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5)
    
    relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5)
    
    if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    
    if os.path.exists(str("./models/omniglot_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        RFT = pickle.load(open(str("./models/omniglot_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb'))
        print("load random forest success")
        
    if os.path.exists(str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")
        
    # * Step 3: build graph
    print("Training...")
    
    last_accuracy = 0.0
    last_RFT_accuracy = 0
    test_RFT_accuracy = 0
    # embedding_loss_list = []
    RFT_loss_list = []
    relation_loss_list = []
    loss_list = []

    for episode in range(EPISODE):
        
        # * init dataset
        # * sample_dataloader is to obtain previous samples for compare
        # * batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS)
        sample_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees)
        batch_dataloader = tg.get_data_loader(task, num_per_class=BATCH_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees)
        
        # * sample datas
        # samples, sample_labels = sample_dataloader.__iter__().next()
        # batches, batch_labels = batch_dataloader.__iter__().next()
        
        samples, sample_labels = next(iter(sample_dataloader))
        batches, batch_labels = next(iter(batch_dataloader))
        
        # RFT_samples, RFT_sample_labels = samples, sample_labels
        RFT_batches, RFT_batch_labels = batches, batch_labels
        
        samples, sample_labels = samples.to(device), sample_labels.to(device)
        batches, batch_labels = batches.to(device), batch_labels.to(device)
        
        # one_hot_sample_labels = torch.zeros(SAMPLE_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).to(device).scatter_(1, sample_labels.view(-1, 1), 1)
        
        # * calculates features
        linear, sample_features = feature_encoder(samples)
        # RFT_sample_features = sample_features.detach().cpu().reshape(RFT_samples.shape[0], -1)
        
        sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5)
        sample_features = torch.sum(sample_features, 1).squeeze(1)
        
        _, batch_features = feature_encoder(batches)
        # RFT_batch_features = batch_features.detach().cpu().reshape(RFT_batches.shape[0], -1)
        
        # embedding_loss = mse(linear, one_hot_sample_labels)
        # embedding_loss = cross_entropy(linear, sample_labels)
        
        # * calculate relations
        # * each batch sample link to every samples to calculate relations
        # * to form a 100 * 128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        
        relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)

        if episode % 10 == 0:
            RFT.fit(relations.detach().cpu(), RFT_batch_labels)
            RFT.n_estimators += 1
            
        RFT_prob = torch.tensor(RFT.predict_proba(relations.detach().cpu())).to(device)
        _, RFT_labels = torch.max(RFT_prob, 1)
    
        RFT_loss = cross_entropy(relations, RFT_labels) * 0.5   
        
        # one_hot_labels = torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).to(device).scatter_(1, batch_labels.view(-1, 1), 1)
        # loss = mse(relations, one_hot_labels)
        # soft_voting = torch.softmax(relations, dim=1) + torch.tensor(RFT.predict_proba(relations.detach().cpu())).to(device)
        
  
        # relation_loss = cross_entropy(relations, batch_labels)
        relation_loss = cross_entropy(relations, batch_labels)
        # loss = embedding_loss + relation_loss
        # embedding_loss.detach()
        loss = relation_loss + RFT_loss
        
        feature_encoder_optim.zero_grad()
        relation_network_optim.zero_grad()
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)
        
        feature_encoder_optim.step()
        relation_network_optim.step()
                
        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)
        
        if (episode + 1) % 100 == 0:
            print(f"episode : {episode+1}, loss : {loss.cpu().detach().numpy()}")
            loss_list.append(loss.cpu().detach().numpy())
            RFT_loss_list.append(RFT_loss.cpu().detach().numpy())            
            relation_loss_list.append(relation_loss.cpu().detach().numpy())
            
        if (episode + 1) % 500 == 0:
            print("Testing...")
            total_reward = 0
            
            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS)
                sample_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees)
                test_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees)
                
                sample_images, sample_labels = next(iter(sample_dataloader))
                test_images, test_labels = next(iter(test_dataloader))

                sample_images, sample_labels = sample_images.to(device), sample_labels.to(device)
                test_images, test_labels = test_images.to(device), test_labels.to(device)
                    
                # * calculate features
                _, sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                _, test_features = feature_encoder(test_images)
                
                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network
                
                sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5)
                relations = relation_network(relation_pairs).view(-1, CLASS_NUM)
                
                _, predict_labels = torch.max(relations.data, 1)
                
                rewards = [1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)]
                total_reward += np.sum(rewards)
                
                if i % 200 == 0:
                    RFT_predict = RFT.predict(relations.detach().cpu())
                    assert RFT_predict.shape == test_labels.detach().cpu().shape
                    print(accuracy_score(RFT_predict, test_labels.detach().cpu()))
                    test_RFT_accuracy += accuracy_score(RFT_predict, test_labels.detach().cpu())
            
            test_accuracy = total_reward / (1.0 * CLASS_NUM * SAMPLE_NUM_PER_CLASS * TEST_EPISODE)
            test_RFT_accuracy /= (TEST_EPISODE // 200)
            print("test accuracy : ", test_accuracy)
            print(f"{test_RFT_accuracy:.3f} %")            
            mean_loss = np.mean(loss_list)
            mean_RFT_loss = np.mean(RFT_loss_list)
            mean_relation_loss = np.mean(relation_loss_list)
            
            print(f'mean loss : {mean_loss}')   
            print(f'RFT loss : {mean_RFT_loss}')         
            # writer.add_scalar('1.embedding loss', mean_embedding_loss, episode + 1)
            writer.add_scalar('1.RFT loss', mean_RFT_loss, episode + 1)
            writer.add_scalar('RFT_accuracy', test_RFT_accuracy, episode + 1)
            writer.add_scalar('2.relation loss', mean_relation_loss, episode + 1)
            writer.add_scalar('loss', mean_loss, episode + 1)            
            writer.add_scalar('test accuracy', test_accuracy, episode + 1)
            
            loss_list = []            
            # embedding_loss_list = []
            relation_loss_list = []
            RFT_loss_list = [] 
            if test_RFT_accuracy > last_RFT_accuracy:
                pickle.dump(RFT, open(str("./models/omniglot_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'wb'))
                last_RFT_accuracy = test_RFT_accuracy
                
            test_RFT_accuracy = 0
            
            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")
                )
                torch.save(
                    relation_network.state_dict(),
                    str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")
                )

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy    
def main():
    # * Step 1: init data folders
    print("init data folders")
    
    # * init character folders for dataset construction
    metartrain_character_folders, metatest_character_folders = tg.omniglot_character_folders()
    
    # * Step 2: init neural networks
    print("init neural networks")
    
    feature_encoder = ot.CNNEncoder().to(device)
    RFT = RandomForestClassifier(n_estimators=100, random_state=1, warm_start=True)
    relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device)
    
    feature_encoder.eval()
    relation_network.eval()    
    
    if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")
    
    if os.path.exists(str("./models/omniglot_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        RFT = pickle.load(open(str("./models/omniglot_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb'))
        print("load random forest success")
        
    if os.path.exists(str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load relation network success")

    RFT_total_accuracy = 0.0
    total_accuracy = 0.0
    soft_voting_total_accuracy = 0.0
    
    for episode in range(1):
        # * test
        print("Testing...")
        total_rewards = 0
        soft_voting_total_rewards = 0
        
        accuracies = []
        RFT_accuracies = []
        soft_voting_accuracies = []
        
        for i in range(100):
            degrees = random.choice([0, 90, 180, 270])
            task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS)
            
            sample_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees)
            test_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees)
            
            sample_images, sample_labels = next(iter(sample_dataloader))
            test_images, test_labels = next(iter(test_dataloader))

            RFT_samples, RFT_sample_labels = sample_images, sample_labels
            RFT_test, RFT_test_labels = test_images, test_labels
        
            sample_images, sample_labels = sample_images.to(device), sample_labels.to(device)
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            
            # * Calculate features
            _, sample_features = feature_encoder(sample_images)            
            sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5)
            sample_features = torch.sum(sample_features, 1).squeeze(1)
            _, test_features = feature_encoder(test_images)
            
            sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = test_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)
            test_features_ext = torch.transpose(test_features_ext, 0, 1)
            
            relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5)
            relations = relation_network(relation_pairs).view(-1, CLASS_NUM)
            
            RFT_predict = RFT.predict(relations.detach().cpu())
            assert RFT_predict.shape == test_labels.detach().cpu().shape
            RFT_score = accuracy_score(RFT_predict, test_labels.detach().cpu())

            
            _, predict_labels = torch.max(relations.data, 1)
            rewards = [1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)]
            total_rewards += np.sum(rewards)
            accuracy = np.sum(rewards) / (1.0 * CLASS_NUM * SAMPLE_NUM_PER_CLASS)
            
            RFT_prob = RFT.predict_proba(relations.detach().cpu())
            relation_prob = torch.softmax(relations.data, dim=1)
          
            RFT_prob_tensor = torch.tensor(RFT_prob).to(device)
            # soft_voting = (RFT_prob_tensor + relation_prob) / 2
            soft_voting = (RFT_prob_tensor / relation_prob)
            _, soft_voting_predicted_labels = torch.max(soft_voting, 1)
            
            '''
            for i in range(len(RFT_predict)):
                if RFT_predict[i] != predict_labels[i]:
                    print(f'RFT_prob      : {RFT_prob[i]}, label : {RFT_predict[i]}, answer : {test_labels[i]}')
                    print(f'relation_prob : {relation_prob[i].detach().cpu().numpy()}, label : {predict_labels[i]}, answer : {test_labels[i]}')
                    print(f'soft_voting   : {soft_voting[i].detach().cpu().numpy()}, label : {soft_voting_predicted_labels[i]}, answer : {test_labels[i]}')
                    print("----------------------------------")
            '''
            '''
                print(relations.data)
                print(RFT_prob)
                print(relation_prob)
                
                assert 1 == 2
            '''
            
            soft_voting_rewards = [1 if soft_voting_predicted_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)]
            soft_voting_total_rewards += np.sum(soft_voting_rewards)
            soft_voting_accuracy = np.sum(soft_voting_rewards) / (1.0 * CLASS_NUM * SAMPLE_NUM_PER_CLASS)
            
            print(f"{i+1}th RFT accuracy : {RFT_score}, CNN accuracy : {accuracy}, combined_accuracy : {soft_voting_accuracy}")
            
            RFT_accuracies.append(RFT_score)
            accuracies.append(accuracy)
            soft_voting_accuracies.append(soft_voting_accuracy)
            
        RFT_test_accuracy, RFT_h = mean_confidence_interval(RFT_accuracies)
        test_accuracy, h = mean_confidence_interval(accuracies)
        soft_voting_test_accuracy, soft_voting_h = mean_confidence_interval(soft_voting_accuracies)
        
        print(f'RFT_test_accuracy accuracy : {RFT_test_accuracy:.4f}, h : {RFT_h:.4f}')
        print(f'test accuracy : {test_accuracy:.4f}, h : {h:.4f}')
        print(f'test soft_voting_test_accuracy : {soft_voting_test_accuracy:.4f}, h : {soft_voting_h:.4f}')
        
        RFT_total_accuracy += RFT_test_accuracy
        total_accuracy += test_accuracy
        soft_voting_total_accuracy += soft_voting_test_accuracy
        
    print(f"average RFT_total_accuracy : {RFT_total_accuracy :.4f}")
    print(f"average accuracy : {total_accuracy :.4f}")
    print(f"soft_voting_total_accuracy : {soft_voting_total_accuracy :.4f}")