def main(argv): # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') config = FLAGS.config game = config.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) print(f'Playing {game} with {num_actions} actions') module = models.ActorCritic(num_outputs=num_actions) key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) initial_params = models.get_initial_params(subkey, module) optimizer = models.create_optimizer(initial_params, config.learning_rate) optimizer = ppo_lib.train(module, optimizer, config, FLAGS.logdir)
def test_optimization_step(self): num_outputs = 4 trn_data = self.generate_random_data(num_actions=num_outputs) clip_param = 0.1 vf_coeff = 0.5 entropy_coeff = 0.01 lr = 2.5e-4 batch_size = 256 key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) module = models.ActorCritic(num_outputs) initial_params = models.get_initial_params(subkey, module) lr = 2.5e-4 optimizer = models.create_optimizer(initial_params, lr) optimizer, _ = ppo_lib.train_step(module, optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr, batch_size) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer))
def test_model(self): key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) outputs = self.choose_random_outputs() module = models.ActorCritic(num_outputs=outputs) initial_params = models.get_initial_params(subkey, module) lr = 2.5e-4 optimizer = models.create_optimizer(initial_params, lr) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size, ) + obs_shape) log_probs, values = agent.policy_action(optimizer.target, module, random_input) self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = np.sum(np.exp(log_probs), axis=1) self.assertEqual(sum_probs.shape, (test_batch_size, )) np_testing.assert_allclose(sum_probs, np.ones((test_batch_size, )), atol=1e-6)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.log_dir = FLAGS.workdir FLAGS.stderrthreshold = 'info' logging.get_absl_handler().start_logging_to_file() # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') config = FLAGS.config game = config.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) print(f'Playing {game} with {num_actions} actions') module = models.ActorCritic(num_outputs=num_actions) key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) initial_params = models.get_initial_params(subkey, module) optimizer = models.create_optimizer(initial_params, config.learning_rate) optimizer = ppo_lib.train(module, optimizer, config, FLAGS.workdir)