Beispiel #1
0
def main(_):
  logging.set_verbosity(logging.INFO)
  if common.has_eager_been_enabled():
    return 0
  tf.enable_resource_variables()
  TrainEval(FLAGS.root_dir, suite_atari.game(name=FLAGS.game_name),
            **get_run_args()).run()
    def testUpdateFromCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  self.tf_policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)

        # Use evaluate to force a copy.
        saved_model_variables = self.evaluate(eager_py_policy.variables())

        eager_py_policy.update_from_checkpoint(checkpoint_path)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
                              self.evaluate(eager_py_policy.variables()))

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal,
                              self.evaluate(self.tf_policy.variables()),
                              self.evaluate(eager_py_policy.variables()),
                              check_types=False)
Beispiel #3
0
def main(_):
  if common.has_eager_been_enabled():
    return 0
  tf.compat.v1.enable_resource_variables()
  logging.set_verbosity(logging.INFO)
  tf.enable_resource_variables()
  train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
Beispiel #4
0
  def testRegisterFunction(self):
    if not common.has_eager_been_enabled():
      self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

    network = q_network.QNetwork(
        input_tensor_spec=self._time_step_spec.observation,
        action_spec=self._action_spec)

    policy = q_policy.QPolicy(
        time_step_spec=self._time_step_spec,
        action_spec=self._action_spec,
        q_network=network)

    saver = policy_saver.PolicySaver(policy, batch_size=None)
    saver.register_function('q_network', network,
                            self._time_step_spec.observation)

    path = os.path.join(self.get_temp_dir(), 'save_model')
    saver.save(path)
    reloaded = tf.compat.v2.saved_model.load(path)

    sample_input = self.evaluate(
        tensor_spec.sample_spec_nest(
            self._time_step_spec.observation, outer_dims=(3,)))
    expected_output, _ = network(sample_input)
    reloaded_output, _ = reloaded.q_network(sample_input)

    self.assertAllClose(expected_output, reloaded_output)
    def testBatchedPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        actor_net = actor_network.ActorNetwork(
            self._observation_tensor_spec,
            self._action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec,
                                             self._action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy,
                                                       batch_time_steps=False)

        env_ctr = lambda: random_py_environment.RandomPyEnvironment(  # pylint: disable=g-long-lambda
            self._observation_spec, self._action_spec)

        env = batched_py_environment.BatchedPyEnvironment(
            [env_ctr() for _ in range(3)])
        time_step = env.reset()

        for _ in range(20):
            action_step = py_policy.action(time_step)
            time_step = env.step(action_step.action)
    def testRandomTFPolicyCompatibility(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        tf_policy = random_tf_policy.RandomTFPolicy(
            self._time_step_tensor_spec,
            self._action_tensor_spec,
            info_spec=self._info_tensor_spec)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        time_step = self._env.reset()

        def _check_action_step(action_step):
            self.assertIsInstance(action_step.action, np.ndarray)
            self.assertEqual(action_step.action.shape, (1, ))
            self.assertBetween(action_step.action[0], 2.0, 3.0)

            self.assertIsInstance(action_step.info['a'], np.ndarray)
            self.assertEqual(action_step.info['a'].shape, (1, ))
            self.assertBetween(action_step.info['a'][0], 0.0, 1.0)

            self.assertIsInstance(action_step.info['b'], np.ndarray)
            self.assertEqual(action_step.info['b'].shape, (1, ))
            self.assertBetween(action_step.info['b'][0], 100.0, 101.0)

        for _ in range(100):
            action_step = py_policy.action(time_step)
            _check_action_step(action_step)
            time_step = self._env.step(action_step.action)
Beispiel #7
0
    def testRegisterConcreteFunction(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)

        tf_var = tf.Variable(3)

        def add(b):
            return tf_var + b

        add_fn = common.function(add)
        # Called for side effect.
        add_fn.get_concrete_function(tf.TensorSpec((), dtype=tf.int32))

        saver.register_concrete_function(name='add', fn=add_fn)

        path = os.path.join(self.get_temp_dir(), 'save_model')
        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertAllClose(7, reloaded.add(4))
    def testPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec)

        actor_net = actor_network.ActorNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec,
                                             action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        # Env will validate action types automaticall since we provided the
        # action_spec.
        env = random_py_environment.RandomPyEnvironment(
            observation_spec, action_spec)

        time_step = env.reset()

        for _ in range(100):
            action_step = py_policy.action(time_step)
            time_step = env.step(action_step.action)
Beispiel #9
0
    def testRegisterFunction(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

        time_step_spec = ts.TimeStep(
            step_type=tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    name='st',
                                                    minimum=0,
                                                    maximum=2),
            reward=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                 shape=(),
                                                 name='reward',
                                                 minimum=0.0,
                                                 maximum=5.0),
            discount=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                   shape=(),
                                                   name='discount',
                                                   minimum=0.0,
                                                   maximum=1.0),
            observation=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                      shape=(4, ),
                                                      name='obs',
                                                      minimum=-10.0,
                                                      maximum=10.0))
        action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    minimum=0,
                                                    maximum=10,
                                                    name='act_0')

        network = q_network.QNetwork(
            input_tensor_spec=time_step_spec.observation,
            action_spec=action_spec)

        policy = q_policy.QPolicy(time_step_spec=time_step_spec,
                                  action_spec=action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        async_saver = async_policy_saver.AsyncPolicySaver(saver)
        async_saver.register_function('q_network', network,
                                      time_step_spec.observation)

        path = os.path.join(self.get_temp_dir(), 'save_model')
        async_saver.save(path)
        async_saver.flush()
        async_saver.close()
        self.assertFalse(async_saver._save_thread.is_alive())
        reloaded = tf.compat.v2.saved_model.load(path)

        sample_input = self.evaluate(
            tensor_spec.sample_spec_nest(time_step_spec.observation,
                                         outer_dims=(3, )))
        expected_output, _ = network(sample_input)
        reloaded_output, _ = reloaded.q_network(sample_input)

        self.assertAllClose(expected_output, reloaded_output)
Beispiel #10
0
def main(_):
    logging.set_verbosity(logging.INFO)
    if common.has_eager_been_enabled():
        return 0
    tf.enable_resource_variables()
    agent_class = dqn_agent.DdqnAgent if FLAGS.use_ddqn else dqn_agent.DqnAgent
    train_eval(FLAGS.root_dir,
               agent_class=agent_class,
               num_iterations=FLAGS.num_iterations)
Beispiel #11
0
 def add(x, y):
     if common.has_eager_been_enabled():
         # In TF2, this should be executed in eager mode.
         self.assertTrue(tf.executing_eagerly())
     else:
         # In TF1, this should be inside a temporary graph because it's being
         # created inside a tf.function.
         inner_graph = tf.compat.v1.get_default_graph()
         self.assertNotEqual(outer_graph, inner_graph)
     return x + y
    def testInferenceFromCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)

        rng = np.random.RandomState()
        sample_time_step = array_spec.sample_spec_nest(self.time_step_spec,
                                                       rng)
        batched_sample_time_step = nest_utils.batch_nested_array(
            sample_time_step)

        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  self.tf_policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)

        # Use evaluate to force a copy.
        saved_model_variables = self.evaluate(eager_py_policy.variables())

        eager_py_policy.update_from_checkpoint(checkpoint_path)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
                              self.evaluate(eager_py_policy.variables()))

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal,
                              self.evaluate(self.tf_policy.variables()),
                              self.evaluate(eager_py_policy.variables()),
                              check_types=False)

        # Can't check if the action is different as in some cases depending on
        # variable initialization it will be the same. Checking that they are at
        # least always the same.
        checkpoint_action = eager_py_policy.action(sample_time_step)

        current_policy_action = self.tf_policy.action(batched_sample_time_step)
        current_policy_action = self.evaluate(
            nest_utils.unbatch_nested_tensors(current_policy_action))
        tf.nest.map_structure(assert_np_all_equal, current_policy_action,
                              checkpoint_action)
Beispiel #13
0
  def testCreateAgent(self, create_critic_net_fn, skip_in_tf1):
    if skip_in_tf1 and not common.has_eager_been_enabled():
      self.skipTest('Skipping test: sequential networks not supported in TF1')

    critic_network = create_critic_net_fn()

    sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=critic_network,
        actor_network=None,
        actor_optimizer=None,
        critic_optimizer=None,
        alpha_optimizer=None,
        actor_policy_ctor=DummyActorPolicy)
    def testSavedModel(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)
        time_step_spec = ts.time_step_spec(observation_spec)

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        time_step_tensor_spec = tensor_spec.from_spec(time_step_spec)

        actor_net = actor_network.ActorNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec,
                                             action_tensor_spec,
                                             actor_network=actor_net)

        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(tf_policy)
        saver.save(path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, time_step_spec, action_spec)

        rng = np.random.RandomState()
        sample_time_step = array_spec.sample_spec_nest(time_step_spec, rng)
        batched_sample_time_step = nest_utils.batch_nested_array(
            sample_time_step)

        original_action = tf_policy.action(batched_sample_time_step)
        unbatched_original_action = nest_utils.unbatch_nested_tensors(
            original_action)
        original_action_np = tf.nest.map_structure(lambda t: t.numpy(),
                                                   unbatched_original_action)
        saved_policy_action = eager_py_policy.action(sample_time_step)

        tf.nest.assert_same_structure(saved_policy_action.action, action_spec)

        np.testing.assert_array_almost_equal(original_action_np.action,
                                             saved_policy_action.action)
Beispiel #15
0
    def test_compress_image(self):
        if not common.has_eager_been_enabled():
            self.skipTest("Image compression only supported in TF2.x")

        gin.parse_config_files_and_bindings([], """
    _get_feature_encoder.compress_image=True
    _get_feature_parser.compress_image=True
    """)
        spec = {"image": array_spec.ArraySpec((128, 128, 3), np.uint8)}
        serializer = example_encoding.get_example_serializer(spec)
        decoder = example_encoding.get_example_decoder(spec)

        sample = {"image": 128 * np.ones([128, 128, 3], dtype=np.uint8)}
        example_proto = serializer(sample)

        recovered = self.evaluate(decoder(example_proto))
        tf.nest.map_structure(np.testing.assert_almost_equal, sample,
                              recovered)
    def testRandomTFPolicyCompatibility(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)
        info_spec = {
            'a': array_spec.BoundedArraySpec([1], np.float32, 0, 1),
            'b': array_spec.BoundedArraySpec([1], np.float32, 100, 101)
        }

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        info_tensor_spec = tensor_spec.from_spec(info_spec)
        time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec)

        tf_policy = random_tf_policy.RandomTFPolicy(time_step_tensor_spec,
                                                    action_tensor_spec,
                                                    info_spec=info_tensor_spec)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        env = random_py_environment.RandomPyEnvironment(
            observation_spec, action_spec)
        time_step = env.reset()

        def _check_action_step(action_step):
            self.assertIsInstance(action_step.action, np.ndarray)
            self.assertEqual(action_step.action.shape, (1, ))
            self.assertBetween(action_step.action[0], 2.0, 3.0)

            self.assertIsInstance(action_step.info['a'], np.ndarray)
            self.assertEqual(action_step.info['a'].shape, (1, ))
            self.assertBetween(action_step.info['a'][0], 0.0, 1.0)

            self.assertIsInstance(action_step.info['b'], np.ndarray)
            self.assertEqual(action_step.info['b'].shape, (1, ))
            self.assertBetween(action_step.info['b'][0], 100.0, 101.0)

        for _ in range(100):
            action_step = py_policy.action(time_step)
            _check_action_step(action_step)
            time_step = env.step(action_step.action)
Beispiel #17
0
  def testCriticLoss(self, create_critic_net_fn, skip_in_tf1):
    if skip_in_tf1 and not common.has_eager_been_enabled():
      self.skipTest('Skipping test: sequential networks not supported in TF1')

    critic_network = create_critic_net_fn()
    agent = sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=critic_network,
        actor_network=None,
        actor_optimizer=None,
        critic_optimizer=None,
        alpha_optimizer=None,
        actor_policy_ctor=DummyActorPolicy)

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    time_steps = ts.restart(observations, batch_size=2)
    actions = tf.constant([[5], [6]], dtype=tf.float32)

    rewards = tf.constant([10, 20], dtype=tf.float32)
    discounts = tf.constant([0.9, 0.9], dtype=tf.float32)
    next_observations = tf.constant([[5, 6], [7, 8]], dtype=tf.float32)
    next_time_steps = ts.transition(next_observations, rewards, discounts)

    td_targets = [7.3, 19.1]
    pred_td_targets = [7., 10.]

    self.evaluate(tf.compat.v1.global_variables_initializer())

    # Expected critic loss has factor of 2, for the two TD3 critics.
    expected_loss = self.evaluate(2 * tf.compat.v1.losses.mean_squared_error(
        tf.constant(td_targets), tf.constant(pred_td_targets)))

    loss = agent.critic_loss(
        time_steps,
        actions,
        next_time_steps,
        td_errors_loss_fn=tf.math.squared_difference)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    loss_ = self.evaluate(loss)
    self.assertAllClose(loss_, expected_loss)
    def testPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        actor_net = actor_network.ActorNetwork(
            self._observation_tensor_spec,
            self._action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec,
                                             self._action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        time_step = self._env.reset()

        for _ in range(100):
            action_step = py_policy.action(time_step)
            time_step = self._env.step(action_step.action)
    def testActionWithSeed(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        tf_policy = random_tf_policy.RandomTFPolicy(
            self._time_step_tensor_spec,
            self._action_tensor_spec,
            info_spec=self._info_tensor_spec)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        time_step = self._env.reset()
        tf.random.set_seed(100)
        action_step_1 = py_policy.action(time_step, seed=100)
        time_step = self._env.reset()
        tf.random.set_seed(100)
        action_step_2 = py_policy.action(time_step, seed=100)
        time_step = self._env.reset()
        tf.random.set_seed(200)
        action_step_3 = py_policy.action(time_step, seed=200)
        self.assertEqual(action_step_1.action[0], action_step_2.action[0])
        self.assertNotEqual(action_step_1.action[0], action_step_3.action[0])
Beispiel #20
0
  def setUp(self):
    super(SavedModelPYTFEagerPolicyTest, self).setUp()
    if not common.has_eager_been_enabled():
      self.skipTest('Only supported in eager.')

    observation_spec = array_spec.ArraySpec([2], np.float32)
    self.action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)
    self.time_step_spec = ts.time_step_spec(observation_spec)

    observation_tensor_spec = tensor_spec.from_spec(observation_spec)
    action_tensor_spec = tensor_spec.from_spec(self.action_spec)
    time_step_tensor_spec = tensor_spec.from_spec(self.time_step_spec)

    actor_net = actor_network.ActorNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=(10,),
    )

    self.tf_policy = actor_policy.ActorPolicy(
        time_step_tensor_spec, action_tensor_spec, actor_network=actor_net)
Beispiel #21
0
    def testTrainStepNotSaved(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('get_train_step', reloaded.signatures)
        train_step_value = self.evaluate(reloaded.get_train_step())
        self.assertEqual(-1, train_step_value)
Beispiel #22
0
  def testRandomTFPolicyCompatibility(self):
    if not common.has_eager_been_enabled():
      self.skipTest('Only supported in eager.')

    observation_spec = array_spec.ArraySpec([2], np.float32)
    action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)

    observation_tensor_spec = tensor_spec.from_spec(observation_spec)
    action_tensor_spec = tensor_spec.from_spec(action_spec)
    time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec)

    tf_policy = random_tf_policy.RandomTFPolicy(time_step_tensor_spec,
                                                action_tensor_spec)

    py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
    env = random_py_environment.RandomPyEnvironment(observation_spec,
                                                    action_spec)
    time_step = env.reset()

    for _ in range(100):
      action_step = py_policy.action(time_step)
      time_step = env.step(action_step.action)
    def __init__(self,
                 policy,
                 batch_size=None,
                 use_nest_path_signatures=True,
                 seed=None,
                 train_step=None,
                 input_fn_and_spec=None):
        """Initialize PolicySaver for  TF policy `policy`.

    Args:
      policy: A TF Policy.
      batch_size: The number of batch entries the policy will process at a time.
        This must be either `None` (unknown batch size) or a python integer.
      use_nest_path_signatures: SavedModel spec signatures will be created based
        on the sructure of the specs. Otherwise all specs must have unique
        names.
      seed: Random seed for the `policy.action` call, if any (this should
        usually be `None`, except for testing).
      train_step: Variable holding the train step for the policy. The value
        saved will be set at the time `saver.save` is called. If not provided,
        train_step defaults to -1. Note since the train step must be a variable
        it is not safe to create it directly in TF1 so in that case this is a
        required parameter.
      input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
        function that takes inputs according to tensor_spec and converts them to
        the `(time_step, policy_state)` tuple that is used as the input to the
        action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
        for the action signature. When `input_fn_and_spec is None`, the action
        signature takes as input `(time_step, policy_state)`.

    Raises:
      TypeError: If `policy` is not an instance of TFPolicy.
      ValueError: If use_nest_path_signatures is not used and any of the
        following `policy` specs are missing names, or the names collide:
        `policy.time_step_spec`, `policy.action_spec`,
        `policy.policy_state_spec`, `policy.info_spec`.
      ValueError: If `batch_size` is not either `None` or a python integer > 0.
    """
        if not isinstance(policy, tf_policy.Base):
            raise TypeError('policy is not a TFPolicy.  Saw: %s' %
                            type(policy))
        if (batch_size is not None
                and (not isinstance(batch_size, int) or batch_size < 1)):
            raise ValueError(
                'Expected batch_size == None or python int > 0, saw: %s' %
                (batch_size, ))

        action_fn_input_spec = (policy.time_step_spec,
                                policy.policy_state_spec)
        if use_nest_path_signatures:
            action_fn_input_spec = _rename_spec_with_nest_paths(
                action_fn_input_spec)
        else:
            _check_spec(action_fn_input_spec)

        # Make a shallow copy as we'll be making some changes in-place.
        policy = copy.copy(policy)
        if train_step is None:
            if not common.has_eager_been_enabled():
                raise ValueError('train_step is required in TF1 and must be a '
                                 '`tf.Variable`: %s' % train_step)
            train_step = tf.Variable(
                -1,
                trainable=False,
                dtype=tf.int64,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                shape=())
        elif not isinstance(train_step, tf.Variable):
            raise ValueError('train_step must be a TensorFlow variable: %s' %
                             train_step)
        policy.train_step = train_step

        # We will need the train step for the Checkpoint object.
        self._train_step = train_step

        if batch_size is None:
            get_initial_state_fn = policy.get_initial_state
            get_initial_state_input_specs = (tf.TensorSpec(
                dtype=tf.int32, shape=(), name='batch_size'), )
        else:
            get_initial_state_fn = functools.partial(policy.get_initial_state,
                                                     batch_size=batch_size)
            get_initial_state_input_specs = ()

        get_initial_state_fn = common.function()(get_initial_state_fn)

        original_action_fn = policy.action
        if seed is not None:

            def action_fn(time_step, policy_state):
                return original_action_fn(time_step, policy_state, seed=seed)
        else:
            action_fn = original_action_fn

        # We call get_concrete_function() for its side effect.
        get_initial_state_fn.get_concrete_function(
            *get_initial_state_input_specs)

        train_step_fn = common.function(
            lambda: policy.train_step).get_concrete_function()

        action_fn = common.function()(action_fn)

        def add_batch_dim(spec):
            return tf.TensorSpec(shape=tf.TensorShape(
                [batch_size]).concatenate(spec.shape),
                                 name=spec.name,
                                 dtype=spec.dtype)

        batched_time_step_spec = tf.nest.map_structure(add_batch_dim,
                                                       policy.time_step_spec)
        batched_policy_state_spec = tf.nest.map_structure(
            add_batch_dim, policy.policy_state_spec)

        policy_step_spec = policy.policy_step_spec
        policy_state_spec = policy.policy_state_spec

        if use_nest_path_signatures:
            batched_time_step_spec = _rename_spec_with_nest_paths(
                batched_time_step_spec)
            batched_policy_state_spec = _rename_spec_with_nest_paths(
                batched_policy_state_spec)
            policy_step_spec = _rename_spec_with_nest_paths(policy_step_spec)
            policy_state_spec = _rename_spec_with_nest_paths(policy_state_spec)
        else:
            _check_spec(batched_time_step_spec)
            _check_spec(batched_policy_state_spec)
            _check_spec(policy_step_spec)
            _check_spec(policy_state_spec)

        if input_fn_and_spec is not None:
            # Store a signature based on input_fn_and_spec
            @common.function()
            def polymorphic_action_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return action_fn(*action_inputs)

            batched_input_spec = tf.nest.map_structure(add_batch_dim,
                                                       input_fn_and_spec[1])
            # We call get_concrete_function() for its side effect.
            polymorphic_action_fn.get_concrete_function(
                example=batched_input_spec)

            action_input_spec = (input_fn_and_spec[1], )

        else:
            action_input_spec = action_fn_input_spec
            if batched_policy_state_spec:
                # Store the signature with a required policy state spec
                polymorphic_action_fn = action_fn
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
            else:
                # Create a polymorphic action_fn which you can call as
                #  restored.action(time_step)
                # or
                #  restored.action(time_step, ())
                # (without retracing the inner action twice)
                @common.function()
                def polymorphic_action_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return action_fn(time_step, policy_state)

                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

        signatures = {
            'action':
            _function_with_flat_signature(polymorphic_action_fn,
                                          input_specs=action_input_spec,
                                          output_spec=policy_step_spec,
                                          include_batch_dimension=True,
                                          batch_size=batch_size),
            'get_initial_state':
            _function_with_flat_signature(
                get_initial_state_fn,
                input_specs=get_initial_state_input_specs,
                output_spec=policy_state_spec,
                include_batch_dimension=False),
            'get_train_step':
            _function_with_flat_signature(train_step_fn,
                                          input_specs=(),
                                          output_spec=train_step.dtype,
                                          include_batch_dimension=False),
        }

        policy.action = polymorphic_action_fn
        policy.get_initial_state = get_initial_state_fn
        policy.get_train_step = train_step_fn
        # Adding variables as an attribute to facilitate updating them.
        policy.model_variables = policy.variables()

        self._policy = policy
        self._signatures = signatures
Beispiel #24
0
    def __init__(self,
                 policy: tf_policy.TFPolicy,
                 batch_size: Optional[int] = None,
                 use_nest_path_signatures: bool = True,
                 seed: Optional[types.Seed] = None,
                 train_step: Optional[tf.Variable] = None,
                 input_fn_and_spec: Optional[InputFnAndSpecType] = None,
                 metadata: Optional[Dict[Text, tf.Variable]] = None):
        """Initialize PolicySaver for  TF policy `policy`.

    Args:
      policy: A TF Policy.
      batch_size: The number of batch entries the policy will process at a time.
        This must be either `None` (unknown batch size) or a python integer.
      use_nest_path_signatures: SavedModel spec signatures will be created based
        on the sructure of the specs. Otherwise all specs must have unique
        names.
      seed: Random seed for the `policy.action` call, if any (this should
        usually be `None`, except for testing).
      train_step: Variable holding the train step for the policy. The value
        saved will be set at the time `saver.save` is called. If not provided,
        train_step defaults to -1. Note since the train step must be a variable
        it is not safe to create it directly in TF1 so in that case this is a
        required parameter.
      input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
        function that takes inputs according to tensor_spec and converts them to
        the `(time_step, policy_state)` tuple that is used as the input to the
        action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
        for the action signature. When `input_fn_and_spec is None`, the action
        signature takes as input `(time_step, policy_state)`.
      metadata: A dictionary of `tf.Variables` to be saved along with the
        policy.

    Raises:
      TypeError: If `policy` is not an instance of TFPolicy.
      TypeError: If `metadata` is not a dictionary of tf.Variables.
      ValueError: If use_nest_path_signatures is not used and any of the
        following `policy` specs are missing names, or the names collide:
        `policy.time_step_spec`, `policy.action_spec`,
        `policy.policy_state_spec`, `policy.info_spec`.
      ValueError: If `batch_size` is not either `None` or a python integer > 0.
    """
        if not isinstance(policy, tf_policy.TFPolicy):
            raise TypeError('policy is not a TFPolicy.  Saw: %s' %
                            type(policy))
        if (batch_size is not None
                and (not isinstance(batch_size, int) or batch_size < 1)):
            raise ValueError(
                'Expected batch_size == None or python int > 0, saw: %s' %
                (batch_size, ))

        action_fn_input_spec = (policy.time_step_spec,
                                policy.policy_state_spec)
        if use_nest_path_signatures:
            action_fn_input_spec = _rename_spec_with_nest_paths(
                action_fn_input_spec)
        else:
            _check_spec(action_fn_input_spec)

        # Make a shallow copy as we'll be making some changes in-place.
        saved_policy = tf.Module()
        saved_policy.collect_data_spec = copy.copy(policy.collect_data_spec)
        saved_policy.policy_state_spec = copy.copy(policy.policy_state_spec)

        if train_step is None:
            if not common.has_eager_been_enabled():
                raise ValueError('train_step is required in TF1 and must be a '
                                 '`tf.Variable`: %s' % train_step)
            train_step = tf.Variable(
                -1,
                trainable=False,
                dtype=tf.int64,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                shape=())
        elif not isinstance(train_step, tf.Variable):
            raise ValueError('train_step must be a TensorFlow variable: %s' %
                             train_step)

        # We will need the train step for the Checkpoint object.
        self._train_step = train_step
        saved_policy.train_step = self._train_step

        self._metadata = metadata or {}
        for key, value in self._metadata.items():
            if not isinstance(key, str):
                raise TypeError('Keys of metadata must be strings: %s' % key)
            if not isinstance(value, tf.Variable):
                raise TypeError('Values of metadata must be tf.Variable: %s' %
                                value)
        saved_policy.metadata = self._metadata

        if batch_size is None:
            get_initial_state_fn = policy.get_initial_state
            get_initial_state_input_specs = (tf.TensorSpec(
                dtype=tf.int32, shape=(), name='batch_size'), )
        else:
            get_initial_state_fn = functools.partial(policy.get_initial_state,
                                                     batch_size=batch_size)
            get_initial_state_input_specs = ()

        get_initial_state_fn = common.function()(get_initial_state_fn)

        original_action_fn = policy.action

        if seed is not None:

            def action_fn(time_step, policy_state):
                time_step = cast(ts.TimeStep, time_step)
                return original_action_fn(time_step, policy_state, seed=seed)
        else:
            action_fn = original_action_fn

        def distribution_fn(time_step, policy_state):
            """Wrapper for policy.distribution() in the SavedModel."""
            try:
                time_step = cast(ts.TimeStep, time_step)
                outs = policy.distribution(time_step=time_step,
                                           policy_state=policy_state)
                return tf.nest.map_structure(_composite_distribution, outs)
            except (TypeError, NotImplementedError) as e:
                # TODO(b/156526399): Move this to just the policy.distribution() call
                # once tfp.experimental.as_composite() properly handles LinearOperator*
                # components as well as TransformedDistributions.
                logging.warning(
                    'WARNING: Could not serialize policy.distribution() for policy '
                    '"%s". Calling saved_model.distribution() will raise the following '
                    'assertion error: %s', policy, e)

                @common.function()
                def _raise():
                    tf.Assert(False, [str(e)])
                    return ()

                outs = _raise()

        # We call get_concrete_function() for its side effect: to ensure the proper
        # ConcreteFunction is stored in the SavedModel.
        get_initial_state_fn.get_concrete_function(
            *get_initial_state_input_specs)

        train_step_fn = common.function(
            lambda: saved_policy.train_step).get_concrete_function()
        get_metadata_fn = common.function(
            lambda: saved_policy.metadata).get_concrete_function()

        batched_time_step_spec = tf.nest.map_structure(
            lambda spec: add_batch_dim(spec, [batch_size]),
            policy.time_step_spec)
        batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec)
        batched_policy_state_spec = tf.nest.map_structure(
            lambda spec: add_batch_dim(spec, [batch_size]),
            policy.policy_state_spec)

        policy_step_spec = policy.policy_step_spec
        policy_state_spec = policy.policy_state_spec

        if use_nest_path_signatures:
            batched_time_step_spec = _rename_spec_with_nest_paths(
                batched_time_step_spec)
            batched_policy_state_spec = _rename_spec_with_nest_paths(
                batched_policy_state_spec)
            policy_step_spec = _rename_spec_with_nest_paths(policy_step_spec)
            policy_state_spec = _rename_spec_with_nest_paths(policy_state_spec)
        else:
            _check_spec(batched_time_step_spec)
            _check_spec(batched_policy_state_spec)
            _check_spec(policy_step_spec)
            _check_spec(policy_state_spec)

        if input_fn_and_spec is not None:
            # Store a signature based on input_fn_and_spec
            @common.function()
            def polymorphic_action_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return action_fn(*action_inputs)

            @common.function()
            def polymorphic_distribution_fn(example):
                action_inputs = input_fn_and_spec[0](example)
                tf.nest.map_structure(
                    lambda spec, t: tf.Assert(spec.is_compatible_with(t[
                        0]), [t]), action_fn_input_spec, action_inputs)
                return distribution_fn(*action_inputs)

            batched_input_spec = tf.nest.map_structure(
                lambda spec: add_batch_dim(spec, [batch_size]),
                input_fn_and_spec[1])
            # We call get_concrete_function() for its side effect: to ensure the
            # proper ConcreteFunction is stored in the SavedModel.
            polymorphic_action_fn.get_concrete_function(
                example=batched_input_spec)
            polymorphic_distribution_fn.get_concrete_function(
                example=batched_input_spec)

            action_input_spec = (input_fn_and_spec[1], )

        else:
            action_input_spec = action_fn_input_spec
            if batched_policy_state_spec:
                # Store the signature with a required policy state spec
                polymorphic_action_fn = common.function()(action_fn)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)

                polymorphic_distribution_fn = common.function()(
                    distribution_fn)
                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
            else:
                # Create a polymorphic action_fn which you can call as
                #  restored.action(time_step)
                # or
                #  restored.action(time_step, ())
                # (without retracing the inner action twice)
                @common.function()
                def polymorphic_action_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return action_fn(time_step, policy_state)

                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_action_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

                @common.function()
                def polymorphic_distribution_fn(
                        time_step, policy_state=batched_policy_state_spec):
                    return distribution_fn(time_step, policy_state)

                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec,
                    policy_state=batched_policy_state_spec)
                polymorphic_distribution_fn.get_concrete_function(
                    time_step=batched_time_step_spec)

        signatures = {
            # CompositeTensors aren't well supported by old-style signature
            # mechanisms, so we do not have a signature for policy.distribution.
            'action':
            _function_with_flat_signature(polymorphic_action_fn,
                                          input_specs=action_input_spec,
                                          output_spec=policy_step_spec,
                                          include_batch_dimension=True,
                                          batch_size=batch_size),
            'get_initial_state':
            _function_with_flat_signature(
                get_initial_state_fn,
                input_specs=get_initial_state_input_specs,
                output_spec=policy_state_spec,
                include_batch_dimension=False),
            'get_train_step':
            _function_with_flat_signature(train_step_fn,
                                          input_specs=(),
                                          output_spec=train_step.dtype,
                                          include_batch_dimension=False),
            'get_metadata':
            _function_with_flat_signature(get_metadata_fn,
                                          input_specs=(),
                                          output_spec=tf.nest.map_structure(
                                              lambda v: v.dtype,
                                              self._metadata),
                                          include_batch_dimension=False),
        }

        saved_policy.action = polymorphic_action_fn
        saved_policy.distribution = polymorphic_distribution_fn
        saved_policy.get_initial_state = get_initial_state_fn
        saved_policy.get_train_step = train_step_fn
        saved_policy.get_metadata = get_metadata_fn
        # Adding variables as an attribute to facilitate updating them.
        saved_policy.model_variables = policy.variables()

        # TODO(b/156779400): Move to a public API for accessing all trackable leaf
        # objects (once it's available).  For now, we have no other way of tracking
        # objects like Tables, Vocabulary files, etc.
        try:
            saved_policy._all_assets = policy._unconditional_checkpoint_dependencies  # pylint: disable=protected-access
        except AttributeError as e:
            if '_self_unconditional' in str(e):
                logging.warning(
                    'Unable to capture all trackable objects in policy "%s".  This '
                    'may be okay.  Error: %s', policy, e)
            else:
                raise e

        self._policy = saved_policy
        self._signatures = signatures
        self._action_input_spec = action_input_spec
        self._policy_step_spec = policy_step_spec
        self._policy_state_spec = policy_state_spec
 def setUp(self):
     if not common.has_eager_been_enabled():
         self.skipTest('Only supported in TF2.x.')
     super(SequentialTest, self).setUp()
Beispiel #26
0
    def testInferenceWithCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        # Create and saved_model for a q_policy.
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)
        sample_input = self.evaluate(
            tensor_spec.sample_spec_nest(self._time_step_spec,
                                         outer_dims=(3, )))

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        original_eval = self.evaluate(policy.action(sample_input))
        saver.save(path)
        # Asign -1 to all variables in the policy. Making checkpoint different than
        # the initial saved_model.
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        # Get an instance of the saved_model.
        reloaded_policy = tf.compat.v2.saved_model.load(path)
        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify loaded saved_model variables are different than the current policy.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).any())
        tf.nest.map_structure(assert_np_not_equal, model_variables,
                              reloaded_model_variables)

        # Update from checkpoint.
        checkpoint = tf.train.Checkpoint(policy=reloaded_policy)
        checkpoint_file_prefix = os.path.join(checkpoint_path, 'variables',
                                              'variables')
        checkpoint.read(
            checkpoint_file_prefix).assert_existing_objects_matched()

        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify variables are now equal.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal, model_variables,
                              reloaded_model_variables)

        # Verify variable update affects inference.
        reloaded_eval = self.evaluate(reloaded_policy.action(sample_input))
        tf.nest.map_structure(assert_np_not_equal, original_eval,
                              reloaded_eval)
        current_eval = self.evaluate(policy.action(sample_input))
        tf.nest.map_structure(assert_np_not_equal, current_eval, reloaded_eval)
Beispiel #27
0
def _get_feature_encoder(shape, dtype, compress_image=False, image_quality=95):
    """Get feature encoder function for shape and dtype.

  Args:
    shape: An array shape
    dtype: A list of dtypes.
    compress_image: Whether to compress image. It is assumed that any uint8
      tensor of rank 3 with shape (w,h,3) is an image.
    image_quality: An optional int. Defaults to 95. Quality of the compression
      from 0 to 100 (higher is better and slower).

  Returns:
    A tf.train.Feature encoder.
  """
    shape = _validate_shape(shape)
    dtype = _validate_dtype(dtype)

    if compress_image and len(
            shape) == 3 and shape[2] == 3 and dtype == tf.uint8:
        if not common.has_eager_been_enabled():
            raise ValueError('Only supported in TF2.x.')

        def _encode_to_jpeg_bytes_list(value):
            value = tf.io.encode_jpeg(value, quality=image_quality)

            return tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[value.numpy()]))

        return _encode_to_jpeg_bytes_list

    if dtype == tf.float32:  # Serialize float32 to FloatList.

        def _encode_to_float_list(value):
            value = np.asarray(value)
            _check_shape_and_dtype(value, shape, dtype)
            return tf.train.Feature(float_list=tf.train.FloatList(
                value=value.flatten(order='C').tolist()))

        return _encode_to_float_list
    elif dtype == tf.int64:  # Serialize int64 to Int64List.

        def _encode_to_int64_list(value):
            value = np.asarray(value)
            _check_shape_and_dtype(value, shape, dtype)
            return tf.train.Feature(int64_list=tf.train.Int64List(
                value=value.flatten(order='C').tolist()))

        return _encode_to_int64_list
    else:  # Serialize anything else to BytesList in little endian order.
        le_dtype = dtype.as_numpy_dtype(0).newbyteorder('L')

        def _encode_to_bytes_list(value):
            value = np.asarray(value)
            _check_shape_and_dtype(value, shape, dtype)
            bytes_list_value = np.require(value,
                                          dtype=le_dtype,
                                          requirements='C').tostring()
            return tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[bytes_list_value]))

        return _encode_to_bytes_list
Beispiel #28
0
    def testUpdateWithCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        # Create and saved_model for a q_policy.
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        saver.save(path)

        # Assign -1 to all variables in the policy. Making checkpoint different than
        # the initial saved_model.
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        # Get an instance of the saved_model.
        reloaded_policy = tf.compat.v2.saved_model.load(path)
        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify loaded saved_model variables are different than the current policy.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).any())
        tf.nest.map_structure(assert_np_not_equal, model_variables,
                              reloaded_model_variables)

        # Update from checkpoint.
        checkpoint = tf.train.Checkpoint(policy=reloaded_policy)
        manager = tf.train.CheckpointManager(checkpoint,
                                             directory=checkpoint_path,
                                             max_to_keep=None)
        checkpoint.restore(manager.latest_checkpoint).expect_partial()

        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))

        # Verify variables are now equal.
        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(
            reloaded_policy.model_variables)

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal, model_variables,
                              reloaded_model_variables)