def main(): print("init data folders") metatrain_folders, metaquery_folders = tg.mini_imagenet_folders() print("init neural networks") foreground_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) background_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) mixture_network = models.MixtureNetwork().apply(weights_init).cuda(GPU) relation_network = models.SimilarityNetwork( FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU) vanilla_foreground_encoder = models.FeatureEncoder().apply( weights_init).cuda(GPU) vanilla_background_encoder = models.FeatureEncoder().apply( weights_init).cuda(GPU) vanilla_mixture_network = models.MixtureNetwork().apply(weights_init).cuda( GPU) foreground_encoder_optim = torch.optim.Adam( foreground_encoder.parameters(), lr=LEARNING_RATE) foreground_encoder_scheduler = StepLR(foreground_encoder_optim, step_size=100000, gamma=0.5) background_encoder_optim = torch.optim.Adam( background_encoder.parameters(), lr=LEARNING_RATE) background_encoder_scheduler = StepLR(background_encoder_optim, step_size=100000, gamma=0.5) mixture_network_optim = torch.optim.Adam(mixture_network.parameters(), lr=LEARNING_RATE) mixture_network_scheduler = StepLR(mixture_network_optim, step_size=100000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) # Loading models if os.path.exists( str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): foreground_encoder.load_state_dict( torch.load( str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load foreground encoder success") if os.path.exists( str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): background_encoder.load_state_dict( torch.load( str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load background encoder success") if os.path.exists( str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): mixture_network.load_state_dict( torch.load( str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load mixture network success") if os.path.exists( str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") # Loading vanilla models if os.path.exists( str("./vanilla_models/miniImagenet_foreground_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): vanilla_foreground_encoder.load_state_dict( torch.load( str("./vanilla_models/miniImagenet_foreground_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load vanilla foreground encoder success") if os.path.exists( str("./vanilla_models/miniImagenet_background_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): vanilla_background_encoder.load_state_dict( torch.load( str("./vanilla_models/miniImagenet_background_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load vanilla background encoder success") if os.path.exists( str("./vanilla_models/miniImagenet_mixture_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): vanilla_mixture_network.load_state_dict( torch.load( str("./vanilla_models/miniImagenet_mixture_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load vanilla mixture network success") if os.path.exists(METHOD) == False: os.system('mkdir ' + METHOD) print("Training...") best_accuracy = 0.0 start = time.time() for episode in range(EPISODE): mse = nn.MSELoss().cuda(GPU) foreground_encoder_scheduler.step(episode) background_encoder_scheduler.step(episode) mixture_network_scheduler.step(episode) relation_network_scheduler.step(episode) # init dataset task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM, SUPPORT_NUM_PER_CLASS, BATCH_NUM_PER_CLASS) support_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SUPPORT_NUM_PER_CLASS, split="train", shuffle=False) query_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=BATCH_NUM_PER_CLASS, split="test", shuffle=True) # support datas support_img, support_sal, support_labels = support_dataloader.__iter__( ).next() query_img, query_sal, query_labels = query_dataloader.__iter__().next() # calculate features support_foreground_features = foreground_encoder( Variable(support_img * support_sal).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19) support_background_features = background_encoder( Variable(support_img * (1 - support_sal)).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19) query_foreground_features = foreground_encoder( Variable(query_img * query_sal).cuda(GPU)) query_background_features = background_encoder( Variable(query_img * (1 - query_sal)).cuda(GPU)) # Real-Representation Regularization (TriR), teacher network support_foreground_features_ = vanilla_foreground_encoder( Variable(support_img * support_sal).cuda(GPU)) support_background_features_ = vanilla_background_encoder( Variable(support_img * (1 - support_sal)).cuda(GPU)) support_mix_features_ = vanilla_mixture_network( support_foreground_features_ + support_background_features_) support_mix_features__ = mixture_network( (support_foreground_features + support_background_features).view( -1, 64, 19, 19)) TriR = args.beta * mse( support_mix_features__, Variable(support_mix_features_, requires_grad=False)) # Inter-class Hallucination support_foreground_features = support_foreground_features.unsqueeze( 2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1, 1, 1) support_background_features = support_background_features.view( 1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19, 19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1, 1) similarity_measure = similarity_func( support_background_features, CLASS_NUM, SUPPORT_NUM_PER_CLASS).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 1, 1) support_mix_features = mixture_network( (support_foreground_features + support_background_features).view( (CLASS_NUM**2) * (SUPPORT_NUM_PER_CLASS**2), 64, 19, 19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 64, 19**2) support_mix_features = (support_mix_features * similarity_measure).sum(2).sum(1) query_mix_features = mixture_network(query_foreground_features + query_background_features).view( -1, 64, 19**2) so_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) so_query_features = Variable( torch.Tensor(BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU) # second-order features for d in range(support_mix_features.size()[0]): s = support_mix_features[d, :, :].squeeze(0) s = (1.0 / support_mix_features.size()[2]) * s.mm(s.transpose( 0, 1)) so_support_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA) for d in range(query_mix_features.size()[0]): s = query_mix_features[d, :, :].squeeze(0) s = (1.0 / query_mix_features.size()[2]) * s.mm(s.transpose(0, 1)) so_query_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA) # calculate relations with 64x64 second-order features support_features_ext = so_support_features.unsqueeze(0).repeat( BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) query_features_ext = so_query_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) query_features_ext = torch.transpose(query_features_ext, 0, 1) relation_pairs = torch.cat((support_features_ext, query_features_ext), 2).view(-1, 2, 64, 64) relations = relation_network(relation_pairs).view(-1, CLASS_NUM) one_hot_labels = Variable( torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1, 1), 1)).cuda(GPU) loss = mse(relations, one_hot_labels) + TriR # update network parameters foreground_encoder.zero_grad() background_encoder.zero_grad() mixture_network.zero_grad() relation_network.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(foreground_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(background_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(mixture_network.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5) foreground_encoder_optim.step() background_encoder_optim.step() mixture_network_optim.step() relation_network_optim.step() if np.mod(episode + 1, 100) == 0: print("episode:", episode + 1, "loss", loss.item()) if np.mod(episode, 2500) == 0: # test print("Testing...") accuracies = [] for i in range(TEST_EPISODE): total_rewards = 0 counter = 0 task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, SUPPORT_NUM_PER_CLASS, 15) support_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SUPPORT_NUM_PER_CLASS, split="train", shuffle=False) num_per_class = 2 query_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=num_per_class, split="test", shuffle=True) support_img, support_sal, support_labels = support_dataloader.__iter__( ).next() for query_img, query_sal, query_labels in query_dataloader: query_size = query_labels.shape[0] # calculate foreground and background features support_foreground_features = foreground_encoder( Variable(support_img * support_sal).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19) support_background_features = background_encoder( Variable( support_img * (1 - support_sal)).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19) query_foreground_features = foreground_encoder( Variable(query_img * query_sal).cuda(GPU)) query_background_features = background_encoder( Variable(query_img * (1 - query_sal)).cuda(GPU)) # Inter-class Hallucination support_foreground_features = support_foreground_features.unsqueeze( 2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1, 1, 1) support_background_features = support_background_features.view( 1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19, 19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1, 1) similarity_measure = similarity_func( support_background_features, CLASS_NUM, SUPPORT_NUM_PER_CLASS).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 1, 1) support_mix_features = mixture_network( (support_foreground_features + support_background_features).view( (CLASS_NUM * SUPPORT_NUM_PER_CLASS)**2, 64, 19, 19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 64, 19**2) support_mix_features = (support_mix_features * similarity_measure).sum(2).sum(1) query_mix_features = mixture_network( query_foreground_features + query_background_features).view(-1, 64, 19**2) so_support_features = Variable( torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) so_query_features = Variable( torch.Tensor(query_size, 1, 64, 64)).cuda(GPU) # second-order features for d in range(support_mix_features.size()[0]): s = support_mix_features[d, :, :].squeeze(0) s = (1.0 / support_mix_features.size()[2]) * s.mm( s.transpose(0, 1)) so_support_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) for d in range(query_mix_features.size()[0]): s = query_mix_features[d, :, :].squeeze(0) s = (1.0 / query_mix_features.size()[2]) * s.mm( s.transpose(0, 1)) so_query_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) # calculate relations with 64x64 second-order features support_features_ext = so_support_features.unsqueeze( 0).repeat(query_size, 1, 1, 1, 1) query_features_ext = so_query_features.unsqueeze(0).repeat( 1 * CLASS_NUM, 1, 1, 1, 1) query_features_ext = torch.transpose( query_features_ext, 0, 1) relation_pairs = torch.cat( (support_features_ext, query_features_ext), 2).view(-1, 2, 64, 64) relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == query_labels[j].cuda(GPU) else 0 for j in range(query_size) ] total_rewards += np.sum(rewards) counter += query_size accuracy = total_rewards / 1.0 / counter accuracies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuracies) print("test accuracy:", test_accuracy, "h:", h) if test_accuracy > best_accuracy: # save networks torch.save( foreground_encoder.state_dict(), str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")) torch.save( background_encoder.state_dict(), str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")) torch.save( mixture_network.state_dict(), str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")) torch.save( relation_network.state_dict(), str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")) print("save networks for episode:", episode) best_accuracy = test_accuracy print("best accuracy:", best_accuracy)
def main(): # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_folders, metaquery_folders = tg.mini_imagenet_folders() # Step 2: init neural networks print("init neural networks") feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) relation_network = models.SimilarityNetwork( FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=50000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=50000, gamma=0.5) if os.path.exists( str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") if os.path.exists(METHOD) == False: os.system('mkdir ' + METHOD) # Step 3: build graph print("Training...") best_accuracy = 0.0 best_h = 0.0 for episode in range(EPISODE): feature_encoder_scheduler.step(episode) relation_network_scheduler.step(episode) # init dataset # support_dataloader is to obtain previous supports for compare # query_dataloader is to query supports for training task = tg.MiniImagenetTask(metatrain_folders, CLASS_NUM, SUPPORT_NUM_PER_CLASS, QUERY_NUM_PER_CLASS) support_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SUPPORT_NUM_PER_CLASS, split="train", shuffle=False) query_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=QUERY_NUM_PER_CLASS, split="test", shuffle=True) # support datas supports, support_labels = support_dataloader.__iter__().next() queries, query_labels = query_dataloader.__iter__().next() # calculate features support_features = feature_encoder(Variable(supports).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19**2).sum(1) # 5x64*19*19 query_features = feature_encoder(Variable(queries).cuda(GPU)).view( QUERY_NUM_PER_CLASS * CLASS_NUM, 64, 19**2) # 20x64*19*19 H_support_features = Variable( torch.Tensor(SUPPORT_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU) H_query_features = Variable( torch.Tensor(QUERY_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU) # HOP features for d in range(support_features.size()[0]): s = support_features[d, :, :].squeeze(0) s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view(s.size()) s = (1.0 / support_features.size()[2]) * s.mm(s.transpose(0, 1)) H_support_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA) for d in range(query_features.size()[0]): s = query_features[d, :, :].squeeze(0) s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view(s.size()) s = (1.0 / query_features.size()[2]) * s.mm(s.transpose(0, 1)) H_query_features[d, :, :, :] = power_norm(s / s.trace(), SIGMA) # form relation pairs support_features_ext = H_support_features.unsqueeze(0).repeat( QUERY_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) query_features_ext = H_query_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) query_features_ext = torch.transpose(query_features_ext, 0, 1) relation_pairs = torch.cat((support_features_ext, query_features_ext), 2).view(-1, 2, 64, 64) # calculate relation scores relations = relation_network(relation_pairs).view( -1, CLASS_NUM * SUPPORT_NUM_PER_CLASS) # define loss function mse = nn.MSELoss().cuda(GPU) one_hot_labels = Variable( torch.zeros(QUERY_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, query_labels.view(-1, 1), 1)).cuda(GPU) loss = mse(relations, one_hot_labels) # updating network parameters with their gradients feature_encoder.zero_grad() relation_network.zero_grad() loss.backward() feature_encoder_optim.step() relation_network_optim.step() if (episode + 1) % 100 == 0: print("episode:", episode + 1, "loss", loss.data[0]) if episode % 500 == 0: # query print("Testing...") accuracies = [] for i in range(TEST_EPISODE): with torch.no_grad(): total_rewards = 0 counter = 0 task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, 1, 2) support_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=1, split="train", shuffle=False) num_per_class = 2 query_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=num_per_class, split="query", shuffle=True) support_images, support_labels = support_dataloader.__iter__( ).next() for query_images, query_labels in query_dataloader: query_size = query_labels.shape[0] # calculate features support_features = feature_encoder( Variable(support_images).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19**2).sum(1) query_features = feature_encoder( Variable(query_images).cuda(GPU)).view( num_per_class * CLASS_NUM, 64, 19**2) H_support_features = Variable( torch.Tensor(SUPPORT_NUM_PER_CLASS * CLASS_NUM, 1, 64, 64)).cuda(GPU) H_query_features = Variable( torch.Tensor(num_per_class * CLASS_NUM, 1, 64, 64)).cuda(GPU) # HOP features for d in range(support_features.size()[0]): s = support_features[d, :, :].squeeze(0) s = s - LAMBDA * s.mean(1).repeat( 1, s.size()[1]).view(s.size()) s = (1.0 / support_features.size()[2]) * s.mm( s.transpose(0, 1)) H_support_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) for d in range(query_features.size()[0]): s = query_features[d, :, :].squeeze(0) s = s - LAMBDA * s.mean(1).repeat( 1, s.size()[1]).view(s.size()) s = (1.0 / query_features.size()[2]) * s.mm( s.transpose(0, 1)) H_query_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) # form relation pairs support_features_ext = H_support_features.unsqueeze( 0).repeat(query_size, 1, 1, 1, 1) query_features_ext = H_query_features.unsqueeze( 0).repeat(1 * CLASS_NUM, 1, 1, 1, 1) query_features_ext = torch.transpose( query_features_ext, 0, 1) relation_pairs = torch.cat( (support_features_ext, query_features_ext), 2).view(-1, 2, 64, 64) # calculate relation scores relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == query_labels[j].cuda(GPU) else 0 for j in range(query_size) ] total_rewards += np.sum(rewards) counter += query_size accuracy = total_rewards / 1.0 / counter accuracies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuracies) print("Test accuracy:", test_accuracy, "h:", h) print("Best accuracy: ", best_accuracy, "h:", best_h) if test_accuracy > best_accuracy: # save networks torch.save( feature_encoder.state_dict(), str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")) torch.save( relation_network.state_dict(), str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")) print("save networks for episode:", episode) best_accuracy = test_accuracy best_h = h
def main(): # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders() # Step 2: init neural networks print("init neural networks") feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1) relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1) if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) print("load feature encoder success") if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) print("load similarity network success") if os.path.exists(METHOD) == False: os.system('mkdir ' + METHOD) # Step 3: build graph print("Training...") best_accuracy = 0.0 best_h = 0.0 for episode in range(EPISODE): with torch.no_grad(): # query print("Testing...") total_rewards = 0 for i in range(TEST_EPISODE): degrees = random.choice([0,90,180,270]) task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,) support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees) query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="query",shuffle=True,rotation=degrees) support_images,support_labels = support_dataloader.__iter__().next() query_images,query_labels = query_dataloader.__iter__().next() # calculate features support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64 support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1) query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25) H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) # HOP features for d in range(support_features.size(0)): s = support_features[d,:,:].squeeze(0) s = (1.0 / support_features.size(2)) * s.mm(s.t()) H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) for d in range(query_features.size(0)): s = query_features[d,:,:].squeeze(0) s = (1.0 / query_features.size(2)) * s.mm(s.t()) H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) # calculate relations # each query support link to every supports to calculate relations # to form a 100x128 matrix for relation network support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) query_features_ext = torch.transpose(query_features_ext,0,1) relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64) relations = relation_network(relation_pairs).view(-1,CLASS_NUM) _,predict_labels = torch.max(relations.data,1) rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)] total_rewards += np.sum(rewards) test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE print("query accuracy:",test_accuracy) print("best accuracy:",best_accuracy) if test_accuracy > best_accuracy: best_accuracy = test_accuracy
def main(): print("init data folders") metatrain_folders, metaquery_folders = tg.mini_imagenet_folders() print("init neural networks") foreground_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) background_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) mixture_network = models.MixtureNetwork().apply(weights_init).cuda(GPU) relation_network = models.SimilarityNetwork( FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU) # Loading models if os.path.exists( str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): foreground_encoder.load_state_dict( torch.load( str(METHOD + "/miniImagenet_foreground_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load foreground encoder success") if os.path.exists( str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): background_encoder.load_state_dict( torch.load( str(METHOD + "/miniImagenet_background_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load background encoder success") if os.path.exists( str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): mixture_network.load_state_dict( torch.load( str(METHOD + "/miniImagenet_mixture_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load mixture network success") if os.path.exists( str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str(METHOD + "/miniImagenet_relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") best_accuracy = 0.0 best_h = 0.0 for episode in range(EPISODE): with torch.no_grad(): # test print("Testing...") accuracies = [] for i in range(TEST_EPISODE): total_rewards = 0 counter = 0 task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, SUPPORT_NUM_PER_CLASS, 15) support_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SUPPORT_NUM_PER_CLASS, split="train", shuffle=False) num_per_class = 2 query_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=num_per_class, split="test", shuffle=True) support_img, support_sal, support_labels = support_dataloader.__iter__( ).next() for query_img, query_sal, query_labels in query_dataloader: query_size = query_labels.shape[0] # calculate foreground and background features support_foreground_features = foreground_encoder( Variable(support_img * support_sal).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19) support_background_features = background_encoder( Variable( support_img * (1 - support_sal)).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19, 19) query_foreground_features = foreground_encoder( Variable(query_img * query_sal).cuda(GPU)) query_background_features = background_encoder( Variable(query_img * (1 - query_sal)).cuda(GPU)) # Inter-class Hallucination support_foreground_features = support_foreground_features.unsqueeze( 2).repeat(1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 1, 1, 1) support_background_features = support_background_features.view( 1, 1, CLASS_NUM * SUPPORT_NUM_PER_CLASS, 64, 19, 19).repeat(CLASS_NUM, SUPPORT_NUM_PER_CLASS, 1, 1, 1, 1) similarity_measure = similarity_func( support_background_features, CLASS_NUM, SUPPORT_NUM_PER_CLASS).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 1, 1) support_mix_features = mixture_network( (support_foreground_features + support_background_features).view( (CLASS_NUM * SUPPORT_NUM_PER_CLASS)**2, 64, 19, 19)).view(CLASS_NUM, SUPPORT_NUM_PER_CLASS, -1, 64, 19**2) support_mix_features = (support_mix_features * similarity_measure).sum(2).sum(1) query_mix_features = mixture_network( query_foreground_features + query_background_features).view(-1, 64, 19**2) so_support_features = Variable( torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) so_query_features = Variable( torch.Tensor(query_size, 1, 64, 64)).cuda(GPU) # second-order features for d in range(support_mix_features.size()[0]): s = support_mix_features[d, :, :].squeeze(0) s = (1.0 / support_mix_features.size()[2]) * s.mm( s.transpose(0, 1)) so_support_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) for d in range(query_mix_features.size()[0]): s = query_mix_features[d, :, :].squeeze(0) s = (1.0 / query_mix_features.size()[2]) * s.mm( s.transpose(0, 1)) so_query_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) # calculate relations with 64x64 second-order features support_features_ext = so_support_features.unsqueeze( 0).repeat(query_size, 1, 1, 1, 1) query_features_ext = so_query_features.unsqueeze(0).repeat( 1 * CLASS_NUM, 1, 1, 1, 1) query_features_ext = torch.transpose( query_features_ext, 0, 1) relation_pairs = torch.cat( (support_features_ext, query_features_ext), 2).view(-1, 2, 64, 64) relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == query_labels[j].cuda(GPU) else 0 for j in range(query_size) ] total_rewards += np.sum(rewards) counter += query_size accuracy = total_rewards / 1.0 / counter accuracies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuracies) print("test accuracy:", test_accuracy, "h:", h) if test_accuracy > best_accuracy: best_accuracy = test_accuracy best_h = h print("best accuracy:", best_accuracy, "h:", best_h)
def main(): metatrain_folders, metaquery_folders = tg.mini_imagenet_folders() print("init neural networks") feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) relation_network = models.SimilarityNetwork( FEATURE_DIM, RELATION_DIM).apply(weights_init).cuda(GPU) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=50000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=50000, gamma=0.5) if os.path.exists( str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str(METHOD + "/feature_encoder_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str(METHOD + "/relation_network_" + str(CLASS_NUM) + "way_" + str(SUPPORT_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") if os.path.exists(METHOD) == False: os.system('mkdir ' + METHOD) # Step 3: build graph print("Training...") best_accuracy = 0.0 best_h = 0.0 for episode in range(EPISODE): with torch.no_grad(): print("Testing...") accuracies = [] for i in range(TEST_EPISODE): total_rewards = 0 counter = 0 task = tg.MiniImagenetTask(metaquery_folders, CLASS_NUM, 1, 2) support_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=1, split="train", shuffle=False) num_per_class = 2 query_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=num_per_class, split="query", shuffle=True) support_images, support_labels = support_dataloader.__iter__( ).next() for query_images, query_labels in query_dataloader: query_size = query_labels.shape[0] # calculate features support_features = feature_encoder( Variable(support_images).cuda(GPU)).view( CLASS_NUM, SUPPORT_NUM_PER_CLASS, 64, 19**2).sum(1) query_features = feature_encoder( Variable(query_images).cuda(GPU)).view( num_per_class * CLASS_NUM, 64, 19**2) H_support_features = Variable( torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) H_query_features = Variable( torch.Tensor(num_per_class * CLASS_NUM, 1, 64, 64)).cuda(GPU) # HOP features for d in range(support_features.size()[0]): s = support_features[d, :, :].squeeze(0) s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view( s.size()) s = (1.0 / support_features.size()[2]) * s.mm( s.transpose(0, 1)) H_support_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) for d in range(query_features.size()[0]): s = query_features[d, :, :].squeeze(0) s = s - LAMBDA * s.mean(1).repeat(1, s.size()[1]).view( s.size()) s = (1.0 / query_features.size()[2]) * s.mm( s.transpose(0, 1)) H_query_features[d, :, :, :] = power_norm( s / s.trace(), SIGMA) # form relation pairs support_features_ext = H_support_features.unsqueeze( 0).repeat(query_size, 1, 1, 1, 1) query_features_ext = H_query_features.unsqueeze(0).repeat( 1 * CLASS_NUM, 1, 1, 1, 1) query_features_ext = torch.transpose( query_features_ext, 0, 1) relation_pairs = torch.cat( (support_features_ext, query_features_ext), 2).view(-1, 2, 64, 64) # calculate relation scores relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == query_labels[j].cuda(GPU) else 0 for j in range(query_size) ] total_rewards += np.sum(rewards) counter += query_size accuracy = total_rewards / 1.0 / counter accuracies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuracies) print("Test accuracy:", test_accuracy, "h:", h) print("Best accuracy: ", best_accuracy, "h:", best_h) if test_accuracy > best_accuracy: best_accuracy = test_accuracy best_h = h