def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # * Step 1: init data folders print("init data folders") # * Init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders() # * Step 2: init neural networks print("init neural networks") feature_encoder = models.CNNEncoder() actor = models.Actor(FEATURE_DIM, RELATION_DIM, CLASS_NUM) critic = models.Critic(FEATURE_DIM, RELATION_DIM) #feature_encoder = torch.nn.DataParallel(feature_encoder) #actor = torch.nn.DataParallel(actor) #critic = torch.nn.DataParallel(critic) feature_encoder.train() actor.train() critic.train() feature_encoder.apply(models.weights_init) actor.apply(models.weights_init) critic.apply(models.weights_init) feature_encoder.to(device) actor.to(device) critic.to(device) cross_entropy = nn.CrossEntropyLoss() feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=10000, gamma=0.5) actor_optim = torch.optim.Adam(actor.parameters(), lr=LEARNING_RATE) actor_scheduler = StepLR(actor_optim, step_size=10000, gamma=0.5) critic_optim = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE * 10) critic_scheduler = StepLR(critic_optim, step_size=10000, gamma=0.5) agent = a2cAgent.A2CAgent(actor, critic, GAMMA, ENTROPY_WEIGHT, FEATURE_DIM, RELATION_DIM, CLASS_NUM, device) if os.path.exists(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict(torch.load(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists(str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): actor.load_state_dict(torch.load(str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load actor network success") if os.path.exists(str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): critic.load_state_dict(torch.load(str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load critic network success") # * Step 3: build graph print("Training...") last_accuracy = 0.0 mbal_loss_list = [] mbcl_loss_list = [] loss_list = [] number_of_query_image = 15 for episode in range(EPISODE): #print(f"EPISODE : {episode}") policy_losses = [] value_losses = [] for meta_batch in range(META_BATCH_RANGE): meta_env_states_list = [] meta_env_labels_list = [] for inner_batch in range(INNER_BATCH_RANGE): # * Generate environment env_states_list = [] env_labels_list = [] for env_num in range(ENV_LENGTH): task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) batch_dataloader = tg.get_data_loader(task, num_per_class=5, split="test", shuffle=True) samples, sample_labels = next(iter(sample_dataloader)) samples, sample_labels = samples.to(device), sample_labels.to(device) for batches, batch_labels in batch_dataloader: batches, batch_labels = batches.to(device), batch_labels.to(device) inner_sample_features = feature_encoder(samples) inner_sample_features = inner_sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5) inner_sample_features = torch.sum(inner_sample_features, 1).squeeze(1) inner_batch_features = feature_encoder(batches) inner_sample_feature_ext = inner_sample_features.unsqueeze(0).repeat(5 * CLASS_NUM, 1, 1, 1, 1) inner_batch_features_ext = inner_batch_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1) inner_batch_features_ext = torch.transpose(inner_batch_features_ext, 0, 1) inner_relation_pairs = torch.cat((inner_sample_feature_ext, inner_batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) env_states_list.append(inner_relation_pairs) env_labels_list.append(batch_labels) inner_env = a2cAgent.env(env_states_list, env_labels_list) agent.train(inner_env, inner_update=True) task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) batch_dataloader = tg.get_data_loader(task, num_per_class=number_of_query_image, split="test", shuffle=True) samples, sample_labels = next(iter(sample_dataloader)) samples, sample_labels = samples.to(device), sample_labels.to(device) # * Generate env for meta update # * init dataset # * sample_dataloader is to obtain previous samples for compare # * batch_dataloader is to batch samples for training # * num_per_class : number of query images # * sample datas batches, batch_labels = next(iter(batch_dataloader)) batches, batch_labels = batches.to(device), batch_labels.to(device) # * calculates features #feature_encoder.weight = feature_fast_weights sample_features = feature_encoder(samples) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5) sample_features = torch.sum(sample_features, 1).squeeze(1) batch_features = feature_encoder(batches) # * calculate relations # * each batch sample link to every samples to calculate relations # * to form a 100 * 128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1, 1) batch_features_ext = batch_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1) batch_features_ext = torch.transpose(batch_features_ext, 0, 1) relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) meta_env_states_list.append(relation_pairs) meta_env_labels_list.append(batch_labels) meta_env = a2cAgent.env(meta_env_states_list, meta_env_labels_list) agent.train(meta_env, policy_loss_list=policy_losses, value_loss_list=value_losses) feature_encoder_optim.zero_grad() actor_optim.zero_grad() critic_optim.zero_grad() torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(actor.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(critic.parameters(), 0.5) meta_batch_actor_loss = torch.stack(policy_losses).mean() meta_batch_critic_loss = torch.stack(value_losses).mean() meta_batch_actor_loss.backward(retain_graph=True) meta_batch_critic_loss.backward() feature_encoder_optim.step() actor_optim.step() critic_optim.step() feature_encoder_scheduler.step(episode) actor_scheduler.step(episode) critic_scheduler.step(episode) if (episode + 1) % 100 == 0: mbal = meta_batch_actor_loss.cpu().detach().numpy() mbcl = meta_batch_critic_loss.cpu().detach().numpy() print(f"episode : {episode+1}, meta_batch_actor_loss : {mbal:.4f}, meta_batch_critic_loss : {mbcl:.4f}") mbal_loss_list.append(mbal) mbcl_loss_list.append(mbcl) loss_list.append(mbal + mbcl) if (episode + 1) % 500 == 0: print("Testing...") total_reward = 0 total_test_samples = 0 for i in range(TEST_EPISODE): # * Generate env env_states_list = [] env_labels_list = [] number_of_query_image = 10 task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) test_dataloader = tg.get_data_loader(task, num_per_class=number_of_query_image, split="test", shuffle=True) sample_images, sample_labels = next(iter(sample_dataloader)) sample_images, sample_labels = sample_images.to(device), sample_labels.to(device) test_images, test_labels = next(iter(test_dataloader)) total_test_samples += len(test_labels) test_images, test_labels = test_images.to(device), test_labels.to(device) # * calculate features sample_features = feature_encoder(sample_images) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder(test_images) # * calculate relations # * each batch sample link to every samples to calculate relations # * to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) env_states_list.append(relation_pairs) env_labels_list.append(test_labels) test_env = a2cAgent.env(env_states_list, env_labels_list) rewards = agent.test(test_env) total_reward += rewards test_accuracy = total_reward / (1.0 * total_test_samples) mean_loss = np.mean(loss_list) mean_actor_loss = np.mean(mbal_loss_list) mean_critic_loss = np.mean(mbcl_loss_list) print(f'mean loss : {mean_loss}') print("test accuracy : ", test_accuracy) writer.add_scalar('1.loss', mean_loss, episode + 1) writer.add_scalar('2.mean_actor_loss', mean_actor_loss, episode + 1) writer.add_scalar('3.mean_critic_loss', mean_critic_loss, episode + 1) writer.add_scalar('4.test accuracy', test_accuracy, episode + 1) loss_list = [] mbal_loss_list = [] mbcl_loss_list = [] if test_accuracy > last_accuracy: # save networks torch.save( feature_encoder.state_dict(), str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl") ) torch.save( actor.state_dict(), str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl") ) torch.save( critic.state_dict(), str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl") ) print("save networks for episode:", episode) last_accuracy = test_accuracy
def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # * Step 1: init data folders print("init data folders") # * Init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders( ) # * Step 2: init neural networks print("init neural networks") feature_encoder = models.CNNEncoder() actor = models.Actor(FEATURE_DIM, RELATION_DIM, CLASS_NUM) critic = models.Critic(FEATURE_DIM, RELATION_DIM) #feature_encoder = torch.nn.DataParallel(feature_encoder) #actor = torch.nn.DataParallel(actor) #critic = torch.nn.DataParallel(critic) feature_encoder.train() actor.train() critic.train() feature_encoder.apply(models.weights_init) actor.apply(models.weights_init) critic.apply(models.weights_init) feature_encoder.to(device) actor.to(device) critic.to(device) agent = a2cAgent.A2CAgent(actor, critic, GAMMA, ENTROPY_WEIGHT, FEATURE_DIM, RELATION_DIM, CLASS_NUM, device) #feature_encoder.eval() #relation_network.eval() if os.path.exists( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): actor.load_state_dict( torch.load( str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load actor network success") if os.path.exists( str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): critic.load_state_dict( torch.load( str("./models/miniimagenet_critic_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load critic network success") max_accuracy_list = [] mean_accuracy_list = [] for episode in range(1): total_accuracy = [] for i in range(TEST_EPISODE): # * Generate env env_states_list = [] env_labels_list = [] number_of_query_image = 15 task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) test_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=number_of_query_image, split="test", shuffle=True) sample_images, sample_labels = next(iter(sample_dataloader)) test_images, test_labels = next(iter(test_dataloader)) sample_images, sample_labels = sample_images.to( device), sample_labels.to(device) test_images, test_labels = test_images.to(device), test_labels.to( device) # * calculate features sample_features = feature_encoder(sample_images) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder(test_images) # * calculate relations # * each batch sample link to every samples to calculate relations # * to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( number_of_query_image * CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19) env_states_list.append(relation_pairs) env_labels_list.append(test_labels) test_env = a2cAgent.env(env_states_list, env_labels_list) rewards = agent.test(test_env) test_accuracy = rewards / len(test_labels) print(test_accuracy) total_accuracy.append(test_accuracy) mean_accuracy, conf_int = mean_confidence_interval(total_accuracy) print(f"Total accuracy : {mean_accuracy:.4f}") print(f"confidence interval : {conf_int:.4f}")