Exemple #1
0
  def test_agent_is_checkpointable(self):
    agent = networks.ImpalaDeep(9)
    output0 = _run_actor(agent)

    checkpoint_dir = '/tmp/training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model.ckpt')
    ckpt = tf.train.Checkpoint(agent=agent)

    ckpt.save(file_prefix=checkpoint_prefix)

    for v in agent.trainable_variables:
      v.assign_add(tf.broadcast_to(1., v.shape))

    output1 = _run_actor(agent)

    ckpt_path = tf.train.latest_checkpoint(checkpoint_dir)
    ckpt.restore(ckpt_path).assert_consumed()

    output2 = _run_actor(agent)

    self.assertEqual(len(agent.trainable_variables), 39)
    self.assertAllEqual(output0[0].policy_logits, output2[0].policy_logits)
    self.assertNotAllEqual(output0[0].policy_logits, output1[0].policy_logits)
Exemple #2
0
def create_agent(action_space, unused_env_observation_space,
                 unused_parametric_action_distribution):
    return networks.ImpalaDeep(action_space.n)