def main(): # * Step 1: init data folders print("init data folders") # * init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders( ) # * Step 2: init neural networks print("init neural networks") feature_encoder = ot.CNNEncoder().to(device) RFT = RandomForestClassifier(n_estimators=100, random_state=1, warm_start=True) relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device) #feature_encoder.eval() #relation_network.eval() if os.path.exists( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl"))) print("load feature encoder success") if os.path.exists( str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): RFT = pickle.load( open( str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb')) print("load random forest success") if os.path.exists( str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl")): relation_network.load_state_dict( torch.load( str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot_exp.pkl"))) print("load relation network success") total_accuracy = 0.0 max_accuracy_list = [] mean_accuracy_list = [] for test in range(5): print("Testing...") max_accuracy = 0 total_accuracy = [] number_of_query_image = 15 print(f"Test {test}") for i in range(600): total_reward = 0 task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) test_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=number_of_query_image, split="test", shuffle=False) sample_images, sample_labels = next(iter(sample_dataloader)) sample_images, sample_labels = sample_images.to( device), sample_labels.to(device) # print(f"Episode {i}") for test_images, test_labels in test_dataloader: #print(test_labels.shape) test_images, test_labels = test_images.to( device), test_labels.to(device) batch_size = test_labels.shape[0] # * calculate features sample_features = feature_encoder(sample_images) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder(test_images) # * calculate relations # * each batch sample link to every samples to calculate relations # * to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( batch_size, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19) relations = relation_network(relation_pairs).view( -1, CLASS_NUM) #RFT_prob = RFT.predict_proba(relations.detach().cpu()) #relation_prob = torch.softmax(relations.data, dim=1) #RFT_prob_tensor = torch.tensor(RFT_prob).to(device) #soft_voting = (RFT_prob_tensor * 0.7) + relation_prob #soft_voting = (RFT_prob_tensor / relation_prob) #_, soft_voting_predicted_labels = torch.max(soft_voting, 1) _, predict_labels = torch.max(relations.data, 1) #print(predict_labels.item()) #print(test_labels.item()) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * 5) ] #rewards = [1 if soft_voting_predicted_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * number_of_query_image)] total_reward += np.sum(rewards) #print(total_reward) test_accuracy = total_reward / (1.0 * CLASS_NUM * 15) #print(test_accuracy) total_accuracy.append(test_accuracy) if test_accuracy > max_accuracy: max_accuracy = test_accuracy test_accuracy, h = mean_confidence_interval(total_accuracy) print(f"Final result : {test_accuracy:.4f}, h : {h:.4f} ") mean_accuracy = np.mean(total_accuracy) mean_accuracy_list.append(mean_accuracy) print(f"Total accuracy : {mean_accuracy:.4f}") print(f"max accuracy : {max_accuracy:.4f}") max_accuracy_list.append(max_accuracy) ''' test_accuracy, h = mean_confidence_interval(accuracies) print(f'test accuracy : {test_accuracy:.4f}, h : {h:.4f}') total_accuracy += test_accuracy print(f"average accuracy : {total_accuracy/10 :.4f}") ''' final_accuracy, h = mean_confidence_interval(max_accuracy_list) print(f"Final result : {final_accuracy:.4f}, h : {h:.4f} ") print(np.sort(mean_accuracy_list))
def main(): # * Step 1: init data folders print("init data folders") # * init character folders for dataset construction metatrain_folders, metatest_folders = tg.mini_imagenet_folders() # * Step 2: init neural networks print("init neural networks") feature_encoder = ot.CNNEncoder().to(device) # RFT = RandomForestClassifier(n_estimators=100, random_state=1, warm_start=True) relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device) feature_encoder.eval() relation_network.eval() if os.path.exists( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") ''' if os.path.exists(str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): RFT = pickle.load(open(str("./models/miniimagenet_random_forest_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"), 'rb')) print("load random forest success") ''' if os.path.exists( str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str("./models/miniimagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") total_accuracy = 0.0 for episode in range(10): # * test print("Testing...") accuracies = [] for i in range(100): total_rewards = 0 # degrees = random.choice([0, 90, 180, 270]) task = tg.MiniImagenetTask(metatest_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS) sample_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) test_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=5, split="test", shuffle=False) sample_images, sample_labels = next(iter(sample_dataloader)) sample_images, sample_labels = sample_images.to( device), sample_labels.to(device) for test_images, test_labels in test_dataloader: batch_size = test_labels.shape[0] test_images, test_labels = test_images.to( device), test_labels.to(device) # * Calculate features sample_features = feature_encoder(sample_images) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder(test_images) sample_features_ext = sample_features.unsqueeze(0).repeat( batch_size, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19) relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(batch_size) ] total_rewards += np.sum(rewards) accuracy = total_rewards / (1.0 * CLASS_NUM * 15) accuracies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuracies) print(f'test accuracy : {test_accuracy:.4f}, h : {h:.4f}') total_accuracy += test_accuracy print(f"average accuracy : {total_accuracy/10 :.4f}")