def main(argv): # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') config = FLAGS.config game = config.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) print(f'Playing {game} with {num_actions} actions') model = models.ActorCritic(num_outputs=num_actions) ppo_lib.train(model, config, FLAGS.workdir)
def main(argv): # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') config = FLAGS.config game = config.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) print(f'Playing {game} with {num_actions} actions') module = models.ActorCritic(num_outputs=num_actions) key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) initial_params = models.get_initial_params(subkey, module) optimizer = models.create_optimizer(initial_params, config.learning_rate) optimizer = ppo_lib.train(module, optimizer, config, FLAGS.logdir)
def test_model(self): outputs = self.choose_random_outputs() module = models.ActorCritic(num_outputs=outputs) params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size, ) + obs_shape) log_probs, values = agent.policy_action(module.apply, params, random_input) self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = np.sum(np.exp(log_probs), axis=1) self.assertEqual(sum_probs.shape, (test_batch_size, )) np_testing.assert_allclose(sum_probs, np.ones((test_batch_size, )), atol=1e-6)
def test_optimization_step(self): num_outputs = 4 trn_data = self.generate_random_data(num_actions=num_outputs) clip_param = 0.1 vf_coeff = 0.5 entropy_coeff = 0.01 lr = 2.5e-4 batch_size = 256 key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) module = models.ActorCritic(num_outputs) initial_params = models.get_initial_params(subkey, module) lr = 2.5e-4 optimizer = models.create_optimizer(initial_params, lr) optimizer, _ = ppo_lib.train_step(module, optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr, batch_size) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer))
def test_model(self): key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) outputs = self.choose_random_outputs() module = models.ActorCritic(num_outputs=outputs) initial_params = models.get_initial_params(subkey, module) lr = 2.5e-4 optimizer = models.create_optimizer(initial_params, lr) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size, ) + obs_shape) log_probs, values = agent.policy_action(optimizer.target, module, random_input) self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = np.sum(np.exp(log_probs), axis=1) self.assertEqual(sum_probs.shape, (test_batch_size, )) np_testing.assert_allclose(sum_probs, np.ones((test_batch_size, )), atol=1e-6)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.log_dir = FLAGS.workdir FLAGS.stderrthreshold = 'info' logging.get_absl_handler().start_logging_to_file() # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') config = FLAGS.config game = config.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) print(f'Playing {game} with {num_actions} actions') module = models.ActorCritic(num_outputs=num_actions) key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) initial_params = models.get_initial_params(subkey, module) optimizer = models.create_optimizer(initial_params, config.learning_rate) optimizer = ppo_lib.train(module, optimizer, config, FLAGS.workdir)
def test_optimization_step(self): num_outputs = 4 trn_data = self.generate_random_data(num_actions=num_outputs) clip_param = 0.1 vf_coeff = 0.5 entropy_coeff = 0.01 batch_size = 256 module = models.ActorCritic(num_outputs) initial_params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module) config = ml_collections.ConfigDict({ 'learning_rate': 2.5e-4, 'decaying_lr_and_clip_param': True, }) state = ppo_lib.create_train_state(initial_params, module, config, 1000) state, _ = ppo_lib.train_step(state, trn_data, batch_size, clip_param=clip_param, vf_coeff=vf_coeff, entropy_coeff=entropy_coeff) self.assertIsInstance(state, train_state.TrainState)
def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # * Step 1: init data folders print("init data folders") # * Init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders( ) # * Step 2: init neural networks print("init neural networks") feature_encoder = models.CNNEncoder() model = models.ActorCritic(FEATURE_DIM, RELATION_DIM, CLASS_NUM) #feature_encoder = torch.nn.DataParallel(feature_encoder) #actor = torch.nn.DataParallel(actor) #critic = torch.nn.DataParallel(critic) feature_encoder.train() model.train() feature_encoder.apply(models.weights_init) model.apply(models.weights_init) feature_encoder.to(device) model.to(device) cross_entropy = nn.CrossEntropyLoss() feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=10000, gamma=0.5) model_optim = torch.optim.Adam(model.parameters(), lr=2.5 * LEARNING_RATE) model_scheduler = StepLR(model_optim, step_size=10000, gamma=0.5) agent = a2cAgent.A2CAgent(GAMMA, ENTROPY_WEIGHT, CLASS_NUM, device) if os.path.exists( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): model.load_state_dict( torch.load( str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load actor network success") # * Step 3: build graph print("Training...") last_accuracy = 0.0 loss_list = [] number_of_query_image = 15 for episode in range(EPISODE): #print(f"EPISODE : {episode}") losses = [] for meta_batch in range(META_BATCH_RANGE): meta_env_states_list = [] meta_env_labels_list = [] model_fast_weight = OrderedDict(model.named_parameters()) for inner_batch in range(INNER_BATCH_RANGE): # * Generate environment env_states_list = [] env_labels_list = [] for env in range(ENV_LENGTH): task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) batch_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=number_of_query_image, split="test", shuffle=True) samples, sample_labels = next(iter(sample_dataloader)) samples, sample_labels = samples.to( device), sample_labels.to(device) batches, batch_labels = next(iter(batch_dataloader)) batches, batch_labels = batches.to( device), batch_labels.to(device) inner_sample_features = feature_encoder(samples) inner_sample_features = inner_sample_features.view( CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19) inner_sample_features = torch.sum(inner_sample_features, 1).squeeze(1) inner_batch_features = feature_encoder(batches) inner_sample_feature_ext = inner_sample_features.unsqueeze( 0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1, 1) inner_batch_features_ext = inner_batch_features.unsqueeze( 0).repeat(CLASS_NUM, 1, 1, 1, 1) inner_batch_features_ext = torch.transpose( inner_batch_features_ext, 0, 1) inner_relation_pairs = torch.cat( (inner_sample_feature_ext, inner_batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19) env_states_list.append(inner_relation_pairs) env_labels_list.append(batch_labels) inner_env = a2cAgent.env(env_states_list, env_labels_list) inner_loss = agent.train(inner_env, model) inner_gradients = torch.autograd.grad( inner_loss.mean(), model_fast_weight.values(), create_graph=True, allow_unused=True) model_fast_weight = OrderedDict( (name, param - INNER_LR * (0 if grad is None else grad)) for ((name, param), grad ) in zip(model_fast_weight.items(), inner_gradients)) model.weight = model_fast_weight for meta_env in range(META_ENV_LENGTH): task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) batch_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=number_of_query_image, split="test", shuffle=True) # * num_per_class : number of query images # * sample datas samples, sample_labels = next(iter(sample_dataloader)) samples, sample_labels = samples.to(device), sample_labels.to( device) # * Generate env for meta update batches, batch_labels = next(iter(batch_dataloader)) # * init dataset # * sample_dataloader is to obtain previous samples for compare # * batch_dataloader is to batch samples for training batches, batch_labels = batches.to(device), batch_labels.to( device) # * calculates features #feature_encoder.weight = feature_fast_weights sample_features = feature_encoder(samples) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19) sample_features = torch.sum(sample_features, 1).squeeze(1) batch_features = feature_encoder(batches) # * calculate relations # * each batch sample link to every samples to calculate relations # * to form a 100 * 128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( number_of_query_image * CLASS_NUM, 1, 1, 1, 1) batch_features_ext = batch_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) batch_features_ext = torch.transpose(batch_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19) meta_env_states_list.append(relation_pairs) meta_env_labels_list.append(batch_labels) meta_env = a2cAgent.env(meta_env_states_list, meta_env_labels_list) loss = agent.train(meta_env, model) losses.append(loss) feature_encoder_optim.zero_grad() model_optim.zero_grad() torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) meta_batch_loss = torch.stack(losses).mean() meta_batch_loss.backward() feature_encoder_optim.step() model_optim.step() feature_encoder_scheduler.step() model_scheduler.step() if (episode + 1) % 100 == 0: mean_loss = meta_batch_loss.cpu().detach().numpy() print(f"episode : {episode+1}, meta_loss : {mean_loss:.4f}") loss_list.append(mean_loss) if (episode + 1) % 500 == 0: print("Testing...") total_reward = 0 total_num_of_test_samples = 0 for i in range(TEST_EPISODE): # * Generate env env_states_list = [] env_labels_list = [] number_of_query_image = 10 task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image) sample_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False) test_dataloader = tg.get_mini_imagenet_data_loader( task, num_per_class=number_of_query_image, split="test", shuffle=True) sample_images, sample_labels = next(iter(sample_dataloader)) sample_images, sample_labels = sample_images.to( device), sample_labels.to(device) test_images, test_labels = next(iter(test_dataloader)) total_num_of_test_samples += len(test_labels) test_images, test_labels = test_images.to( device), test_labels.to(device) # * calculate features sample_features = feature_encoder(sample_images) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder(test_images) # * calculate relations # * each batch sample link to every samples to calculate relations # * to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( number_of_query_image * CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19) env_states_list.append(relation_pairs) env_labels_list.append(test_labels) test_env = a2cAgent.env(env_states_list, env_labels_list) rewards = agent.test(test_env, model) total_reward += rewards test_accuracy = total_reward / (1.0 * total_num_of_test_samples) mean_loss = np.mean(loss_list) print(f'mean loss : {mean_loss}') print("test accuracy : ", test_accuracy) writer.add_scalar('1.loss', mean_loss, episode + 1) writer.add_scalar('4.test accuracy', test_accuracy, episode + 1) loss_list = [] if test_accuracy > last_accuracy: # save networks torch.save( feature_encoder.state_dict(), str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) torch.save( model.state_dict(), str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) print("save networks for episode:", episode) last_accuracy = test_accuracy