Ejemplo n.º 1
0
def main(argv):
    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')
    config = FLAGS.config
    game = config.game + 'NoFrameskip-v4'
    num_actions = env_utils.get_num_actions(game)
    print(f'Playing {game} with {num_actions} actions')
    module = models.ActorCritic(num_outputs=num_actions)
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    initial_params = models.get_initial_params(subkey, module)
    optimizer = models.create_optimizer(initial_params, config.learning_rate)
    optimizer = ppo_lib.train(module, optimizer, config, FLAGS.logdir)
Ejemplo n.º 2
0
 def test_optimization_step(self):
     num_outputs = 4
     trn_data = self.generate_random_data(num_actions=num_outputs)
     clip_param = 0.1
     vf_coeff = 0.5
     entropy_coeff = 0.01
     lr = 2.5e-4
     batch_size = 256
     key = jax.random.PRNGKey(0)
     key, subkey = jax.random.split(key)
     module = models.ActorCritic(num_outputs)
     initial_params = models.get_initial_params(subkey, module)
     lr = 2.5e-4
     optimizer = models.create_optimizer(initial_params, lr)
     optimizer, _ = ppo_lib.train_step(module, optimizer, trn_data,
                                       clip_param, vf_coeff, entropy_coeff,
                                       lr, batch_size)
     self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer))
Ejemplo n.º 3
0
 def test_model(self):
     key = jax.random.PRNGKey(0)
     key, subkey = jax.random.split(key)
     outputs = self.choose_random_outputs()
     module = models.ActorCritic(num_outputs=outputs)
     initial_params = models.get_initial_params(subkey, module)
     lr = 2.5e-4
     optimizer = models.create_optimizer(initial_params, lr)
     self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer))
     test_batch_size, obs_shape = 10, (84, 84, 4)
     random_input = np.random.random(size=(test_batch_size, ) + obs_shape)
     log_probs, values = agent.policy_action(optimizer.target, module,
                                             random_input)
     self.assertEqual(values.shape, (test_batch_size, 1))
     sum_probs = np.sum(np.exp(log_probs), axis=1)
     self.assertEqual(sum_probs.shape, (test_batch_size, ))
     np_testing.assert_allclose(sum_probs,
                                np.ones((test_batch_size, )),
                                atol=1e-6)
Ejemplo n.º 4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    FLAGS.log_dir = FLAGS.workdir
    FLAGS.stderrthreshold = 'info'
    logging.get_absl_handler().start_logging_to_file()

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')
    config = FLAGS.config
    game = config.game + 'NoFrameskip-v4'
    num_actions = env_utils.get_num_actions(game)
    print(f'Playing {game} with {num_actions} actions')
    module = models.ActorCritic(num_outputs=num_actions)
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    initial_params = models.get_initial_params(subkey, module)
    optimizer = models.create_optimizer(initial_params, config.learning_rate)
    optimizer = ppo_lib.train(module, optimizer, config, FLAGS.workdir)