def main(_): logging.set_verbosity(logging.INFO) if common.has_eager_been_enabled(): return 0 tf.enable_resource_variables() TrainEval(FLAGS.root_dir, suite_atari.game(name=FLAGS.game_name), **get_run_args()).run()
def testUpdateFromCheckpoint(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x.') path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(self.tf_policy) saver.save(path) self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + -1), self.tf_policy.variables())) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') saver.save_checkpoint(checkpoint_path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) # Use evaluate to force a copy. saved_model_variables = self.evaluate(eager_py_policy.variables()) eager_py_policy.update_from_checkpoint(checkpoint_path) assert_np_not_equal = lambda a, b: self.assertFalse( np.equal(a, b).all()) tf.nest.map_structure(assert_np_not_equal, saved_model_variables, self.evaluate(eager_py_policy.variables())) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, self.evaluate(self.tf_policy.variables()), self.evaluate(eager_py_policy.variables()), check_types=False)
def main(_): if common.has_eager_been_enabled(): return 0 tf.compat.v1.enable_resource_variables() logging.set_verbosity(logging.INFO) tf.enable_resource_variables() train_eval(FLAGS.root_dir, num_iterations=FLAGS.num_iterations)
def testRegisterFunction(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) saver.register_function('q_network', network, self._time_step_spec.observation) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) sample_input = self.evaluate( tensor_spec.sample_spec_nest( self._time_step_spec.observation, outer_dims=(3,))) expected_output, _ = network(sample_input) reloaded_output, _ = reloaded.q_network(sample_input) self.assertAllClose(expected_output, reloaded_output)
def testBatchedPyEnvCompatible(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') actor_net = actor_network.ActorNetwork( self._observation_tensor_spec, self._action_tensor_spec, fc_layer_params=(10, ), ) tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec, self._action_tensor_spec, actor_network=actor_net) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy, batch_time_steps=False) env_ctr = lambda: random_py_environment.RandomPyEnvironment( # pylint: disable=g-long-lambda self._observation_spec, self._action_spec) env = batched_py_environment.BatchedPyEnvironment( [env_ctr() for _ in range(3)]) time_step = env.reset() for _ in range(20): action_step = py_policy.action(time_step) time_step = env.step(action_step.action)
def testRandomTFPolicyCompatibility(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') tf_policy = random_tf_policy.RandomTFPolicy( self._time_step_tensor_spec, self._action_tensor_spec, info_spec=self._info_tensor_spec) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy) time_step = self._env.reset() def _check_action_step(action_step): self.assertIsInstance(action_step.action, np.ndarray) self.assertEqual(action_step.action.shape, (1, )) self.assertBetween(action_step.action[0], 2.0, 3.0) self.assertIsInstance(action_step.info['a'], np.ndarray) self.assertEqual(action_step.info['a'].shape, (1, )) self.assertBetween(action_step.info['a'][0], 0.0, 1.0) self.assertIsInstance(action_step.info['b'], np.ndarray) self.assertEqual(action_step.info['b'].shape, (1, )) self.assertBetween(action_step.info['b'][0], 100.0, 101.0) for _ in range(100): action_step = py_policy.action(time_step) _check_action_step(action_step) time_step = self._env.step(action_step.action)
def testRegisterConcreteFunction(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) tf_var = tf.Variable(3) def add(b): return tf_var + b add_fn = common.function(add) # Called for side effect. add_fn.get_concrete_function(tf.TensorSpec((), dtype=tf.int32)) saver.register_concrete_function(name='add', fn=add_fn) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertAllClose(7, reloaded.add(4))
def testPyEnvCompatible(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') observation_spec = array_spec.ArraySpec([2], np.float32) action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3) observation_tensor_spec = tensor_spec.from_spec(observation_spec) action_tensor_spec = tensor_spec.from_spec(action_spec) time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec) actor_net = actor_network.ActorNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=(10, ), ) tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec, action_tensor_spec, actor_network=actor_net) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy) # Env will validate action types automaticall since we provided the # action_spec. env = random_py_environment.RandomPyEnvironment( observation_spec, action_spec) time_step = env.reset() for _ in range(100): action_step = py_policy.action(time_step) time_step = env.step(action_step.action)
def testRegisterFunction(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') time_step_spec = ts.TimeStep( step_type=tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), name='st', minimum=0, maximum=2), reward=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(), name='reward', minimum=0.0, maximum=5.0), discount=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(), name='discount', minimum=0.0, maximum=1.0), observation=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(4, ), name='obs', minimum=-10.0, maximum=10.0)) action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), minimum=0, maximum=10, name='act_0') network = q_network.QNetwork( input_tensor_spec=time_step_spec.observation, action_spec=action_spec) policy = q_policy.QPolicy(time_step_spec=time_step_spec, action_spec=action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) async_saver = async_policy_saver.AsyncPolicySaver(saver) async_saver.register_function('q_network', network, time_step_spec.observation) path = os.path.join(self.get_temp_dir(), 'save_model') async_saver.save(path) async_saver.flush() async_saver.close() self.assertFalse(async_saver._save_thread.is_alive()) reloaded = tf.compat.v2.saved_model.load(path) sample_input = self.evaluate( tensor_spec.sample_spec_nest(time_step_spec.observation, outer_dims=(3, ))) expected_output, _ = network(sample_input) reloaded_output, _ = reloaded.q_network(sample_input) self.assertAllClose(expected_output, reloaded_output)
def main(_): logging.set_verbosity(logging.INFO) if common.has_eager_been_enabled(): return 0 tf.enable_resource_variables() agent_class = dqn_agent.DdqnAgent if FLAGS.use_ddqn else dqn_agent.DqnAgent train_eval(FLAGS.root_dir, agent_class=agent_class, num_iterations=FLAGS.num_iterations)
def add(x, y): if common.has_eager_been_enabled(): # In TF2, this should be executed in eager mode. self.assertTrue(tf.executing_eagerly()) else: # In TF1, this should be inside a temporary graph because it's being # created inside a tf.function. inner_graph = tf.compat.v1.get_default_graph() self.assertNotEqual(outer_graph, inner_graph) return x + y
def testInferenceFromCheckpoint(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x.') path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(self.tf_policy) saver.save(path) rng = np.random.RandomState() sample_time_step = array_spec.sample_spec_nest(self.time_step_spec, rng) batched_sample_time_step = nest_utils.batch_nested_array( sample_time_step) self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + -1), self.tf_policy.variables())) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') saver.save_checkpoint(checkpoint_path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) # Use evaluate to force a copy. saved_model_variables = self.evaluate(eager_py_policy.variables()) eager_py_policy.update_from_checkpoint(checkpoint_path) assert_np_not_equal = lambda a, b: self.assertFalse( np.equal(a, b).all()) tf.nest.map_structure(assert_np_not_equal, saved_model_variables, self.evaluate(eager_py_policy.variables())) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, self.evaluate(self.tf_policy.variables()), self.evaluate(eager_py_policy.variables()), check_types=False) # Can't check if the action is different as in some cases depending on # variable initialization it will be the same. Checking that they are at # least always the same. checkpoint_action = eager_py_policy.action(sample_time_step) current_policy_action = self.tf_policy.action(batched_sample_time_step) current_policy_action = self.evaluate( nest_utils.unbatch_nested_tensors(current_policy_action)) tf.nest.map_structure(assert_np_all_equal, current_policy_action, checkpoint_action)
def testCreateAgent(self, create_critic_net_fn, skip_in_tf1): if skip_in_tf1 and not common.has_eager_been_enabled(): self.skipTest('Skipping test: sequential networks not supported in TF1') critic_network = create_critic_net_fn() sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_network, actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy)
def testSavedModel(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') observation_spec = array_spec.ArraySpec([2], np.float32) action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3) time_step_spec = ts.time_step_spec(observation_spec) observation_tensor_spec = tensor_spec.from_spec(observation_spec) action_tensor_spec = tensor_spec.from_spec(action_spec) time_step_tensor_spec = tensor_spec.from_spec(time_step_spec) actor_net = actor_network.ActorNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=(10, ), ) tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec, action_tensor_spec, actor_network=actor_net) path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(tf_policy) saver.save(path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, time_step_spec, action_spec) rng = np.random.RandomState() sample_time_step = array_spec.sample_spec_nest(time_step_spec, rng) batched_sample_time_step = nest_utils.batch_nested_array( sample_time_step) original_action = tf_policy.action(batched_sample_time_step) unbatched_original_action = nest_utils.unbatch_nested_tensors( original_action) original_action_np = tf.nest.map_structure(lambda t: t.numpy(), unbatched_original_action) saved_policy_action = eager_py_policy.action(sample_time_step) tf.nest.assert_same_structure(saved_policy_action.action, action_spec) np.testing.assert_array_almost_equal(original_action_np.action, saved_policy_action.action)
def test_compress_image(self): if not common.has_eager_been_enabled(): self.skipTest("Image compression only supported in TF2.x") gin.parse_config_files_and_bindings([], """ _get_feature_encoder.compress_image=True _get_feature_parser.compress_image=True """) spec = {"image": array_spec.ArraySpec((128, 128, 3), np.uint8)} serializer = example_encoding.get_example_serializer(spec) decoder = example_encoding.get_example_decoder(spec) sample = {"image": 128 * np.ones([128, 128, 3], dtype=np.uint8)} example_proto = serializer(sample) recovered = self.evaluate(decoder(example_proto)) tf.nest.map_structure(np.testing.assert_almost_equal, sample, recovered)
def testRandomTFPolicyCompatibility(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') observation_spec = array_spec.ArraySpec([2], np.float32) action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3) info_spec = { 'a': array_spec.BoundedArraySpec([1], np.float32, 0, 1), 'b': array_spec.BoundedArraySpec([1], np.float32, 100, 101) } observation_tensor_spec = tensor_spec.from_spec(observation_spec) action_tensor_spec = tensor_spec.from_spec(action_spec) info_tensor_spec = tensor_spec.from_spec(info_spec) time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec) tf_policy = random_tf_policy.RandomTFPolicy(time_step_tensor_spec, action_tensor_spec, info_spec=info_tensor_spec) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy) env = random_py_environment.RandomPyEnvironment( observation_spec, action_spec) time_step = env.reset() def _check_action_step(action_step): self.assertIsInstance(action_step.action, np.ndarray) self.assertEqual(action_step.action.shape, (1, )) self.assertBetween(action_step.action[0], 2.0, 3.0) self.assertIsInstance(action_step.info['a'], np.ndarray) self.assertEqual(action_step.info['a'].shape, (1, )) self.assertBetween(action_step.info['a'][0], 0.0, 1.0) self.assertIsInstance(action_step.info['b'], np.ndarray) self.assertEqual(action_step.info['b'].shape, (1, )) self.assertBetween(action_step.info['b'][0], 100.0, 101.0) for _ in range(100): action_step = py_policy.action(time_step) _check_action_step(action_step) time_step = env.step(action_step.action)
def testCriticLoss(self, create_critic_net_fn, skip_in_tf1): if skip_in_tf1 and not common.has_eager_been_enabled(): self.skipTest('Skipping test: sequential networks not supported in TF1') critic_network = create_critic_net_fn() agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_network, actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) time_steps = ts.restart(observations, batch_size=2) actions = tf.constant([[5], [6]], dtype=tf.float32) rewards = tf.constant([10, 20], dtype=tf.float32) discounts = tf.constant([0.9, 0.9], dtype=tf.float32) next_observations = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) next_time_steps = ts.transition(next_observations, rewards, discounts) td_targets = [7.3, 19.1] pred_td_targets = [7., 10.] self.evaluate(tf.compat.v1.global_variables_initializer()) # Expected critic loss has factor of 2, for the two TD3 critics. expected_loss = self.evaluate(2 * tf.compat.v1.losses.mean_squared_error( tf.constant(td_targets), tf.constant(pred_td_targets))) loss = agent.critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn=tf.math.squared_difference) self.evaluate(tf.compat.v1.global_variables_initializer()) loss_ = self.evaluate(loss) self.assertAllClose(loss_, expected_loss)
def testPyEnvCompatible(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') actor_net = actor_network.ActorNetwork( self._observation_tensor_spec, self._action_tensor_spec, fc_layer_params=(10, ), ) tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec, self._action_tensor_spec, actor_network=actor_net) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy) time_step = self._env.reset() for _ in range(100): action_step = py_policy.action(time_step) time_step = self._env.step(action_step.action)
def testActionWithSeed(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') tf_policy = random_tf_policy.RandomTFPolicy( self._time_step_tensor_spec, self._action_tensor_spec, info_spec=self._info_tensor_spec) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy) time_step = self._env.reset() tf.random.set_seed(100) action_step_1 = py_policy.action(time_step, seed=100) time_step = self._env.reset() tf.random.set_seed(100) action_step_2 = py_policy.action(time_step, seed=100) time_step = self._env.reset() tf.random.set_seed(200) action_step_3 = py_policy.action(time_step, seed=200) self.assertEqual(action_step_1.action[0], action_step_2.action[0]) self.assertNotEqual(action_step_1.action[0], action_step_3.action[0])
def setUp(self): super(SavedModelPYTFEagerPolicyTest, self).setUp() if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') observation_spec = array_spec.ArraySpec([2], np.float32) self.action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3) self.time_step_spec = ts.time_step_spec(observation_spec) observation_tensor_spec = tensor_spec.from_spec(observation_spec) action_tensor_spec = tensor_spec.from_spec(self.action_spec) time_step_tensor_spec = tensor_spec.from_spec(self.time_step_spec) actor_net = actor_network.ActorNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=(10,), ) self.tf_policy = actor_policy.ActorPolicy( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net)
def testTrainStepNotSaved(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('get_train_step', reloaded.signatures) train_step_value = self.evaluate(reloaded.get_train_step()) self.assertEqual(-1, train_step_value)
def testRandomTFPolicyCompatibility(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in eager.') observation_spec = array_spec.ArraySpec([2], np.float32) action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3) observation_tensor_spec = tensor_spec.from_spec(observation_spec) action_tensor_spec = tensor_spec.from_spec(action_spec) time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec) tf_policy = random_tf_policy.RandomTFPolicy(time_step_tensor_spec, action_tensor_spec) py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy) env = random_py_environment.RandomPyEnvironment(observation_spec, action_spec) time_step = env.reset() for _ in range(100): action_step = py_policy.action(time_step) time_step = env.step(action_step.action)
def __init__(self, policy, batch_size=None, use_nest_path_signatures=True, seed=None, train_step=None, input_fn_and_spec=None): """Initialize PolicySaver for TF policy `policy`. Args: policy: A TF Policy. batch_size: The number of batch entries the policy will process at a time. This must be either `None` (unknown batch size) or a python integer. use_nest_path_signatures: SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names. seed: Random seed for the `policy.action` call, if any (this should usually be `None`, except for testing). train_step: Variable holding the train step for the policy. The value saved will be set at the time `saver.save` is called. If not provided, train_step defaults to -1. Note since the train step must be a variable it is not safe to create it directly in TF1 so in that case this is a required parameter. input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the `(time_step, policy_state)` tuple that is used as the input to the action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input for the action signature. When `input_fn_and_spec is None`, the action signature takes as input `(time_step, policy_state)`. Raises: TypeError: If `policy` is not an instance of TFPolicy. ValueError: If use_nest_path_signatures is not used and any of the following `policy` specs are missing names, or the names collide: `policy.time_step_spec`, `policy.action_spec`, `policy.policy_state_spec`, `policy.info_spec`. ValueError: If `batch_size` is not either `None` or a python integer > 0. """ if not isinstance(policy, tf_policy.Base): raise TypeError('policy is not a TFPolicy. Saw: %s' % type(policy)) if (batch_size is not None and (not isinstance(batch_size, int) or batch_size < 1)): raise ValueError( 'Expected batch_size == None or python int > 0, saw: %s' % (batch_size, )) action_fn_input_spec = (policy.time_step_spec, policy.policy_state_spec) if use_nest_path_signatures: action_fn_input_spec = _rename_spec_with_nest_paths( action_fn_input_spec) else: _check_spec(action_fn_input_spec) # Make a shallow copy as we'll be making some changes in-place. policy = copy.copy(policy) if train_step is None: if not common.has_eager_been_enabled(): raise ValueError('train_step is required in TF1 and must be a ' '`tf.Variable`: %s' % train_step) train_step = tf.Variable( -1, trainable=False, dtype=tf.int64, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, shape=()) elif not isinstance(train_step, tf.Variable): raise ValueError('train_step must be a TensorFlow variable: %s' % train_step) policy.train_step = train_step # We will need the train step for the Checkpoint object. self._train_step = train_step if batch_size is None: get_initial_state_fn = policy.get_initial_state get_initial_state_input_specs = (tf.TensorSpec( dtype=tf.int32, shape=(), name='batch_size'), ) else: get_initial_state_fn = functools.partial(policy.get_initial_state, batch_size=batch_size) get_initial_state_input_specs = () get_initial_state_fn = common.function()(get_initial_state_fn) original_action_fn = policy.action if seed is not None: def action_fn(time_step, policy_state): return original_action_fn(time_step, policy_state, seed=seed) else: action_fn = original_action_fn # We call get_concrete_function() for its side effect. get_initial_state_fn.get_concrete_function( *get_initial_state_input_specs) train_step_fn = common.function( lambda: policy.train_step).get_concrete_function() action_fn = common.function()(action_fn) def add_batch_dim(spec): return tf.TensorSpec(shape=tf.TensorShape( [batch_size]).concatenate(spec.shape), name=spec.name, dtype=spec.dtype) batched_time_step_spec = tf.nest.map_structure(add_batch_dim, policy.time_step_spec) batched_policy_state_spec = tf.nest.map_structure( add_batch_dim, policy.policy_state_spec) policy_step_spec = policy.policy_step_spec policy_state_spec = policy.policy_state_spec if use_nest_path_signatures: batched_time_step_spec = _rename_spec_with_nest_paths( batched_time_step_spec) batched_policy_state_spec = _rename_spec_with_nest_paths( batched_policy_state_spec) policy_step_spec = _rename_spec_with_nest_paths(policy_step_spec) policy_state_spec = _rename_spec_with_nest_paths(policy_state_spec) else: _check_spec(batched_time_step_spec) _check_spec(batched_policy_state_spec) _check_spec(policy_step_spec) _check_spec(policy_state_spec) if input_fn_and_spec is not None: # Store a signature based on input_fn_and_spec @common.function() def polymorphic_action_fn(example): action_inputs = input_fn_and_spec[0](example) tf.nest.map_structure( lambda spec, t: tf.Assert(spec.is_compatible_with(t[ 0]), [t]), action_fn_input_spec, action_inputs) return action_fn(*action_inputs) batched_input_spec = tf.nest.map_structure(add_batch_dim, input_fn_and_spec[1]) # We call get_concrete_function() for its side effect. polymorphic_action_fn.get_concrete_function( example=batched_input_spec) action_input_spec = (input_fn_and_spec[1], ) else: action_input_spec = action_fn_input_spec if batched_policy_state_spec: # Store the signature with a required policy state spec polymorphic_action_fn = action_fn polymorphic_action_fn.get_concrete_function( time_step=batched_time_step_spec, policy_state=batched_policy_state_spec) else: # Create a polymorphic action_fn which you can call as # restored.action(time_step) # or # restored.action(time_step, ()) # (without retracing the inner action twice) @common.function() def polymorphic_action_fn( time_step, policy_state=batched_policy_state_spec): return action_fn(time_step, policy_state) polymorphic_action_fn.get_concrete_function( time_step=batched_time_step_spec, policy_state=batched_policy_state_spec) polymorphic_action_fn.get_concrete_function( time_step=batched_time_step_spec) signatures = { 'action': _function_with_flat_signature(polymorphic_action_fn, input_specs=action_input_spec, output_spec=policy_step_spec, include_batch_dimension=True, batch_size=batch_size), 'get_initial_state': _function_with_flat_signature( get_initial_state_fn, input_specs=get_initial_state_input_specs, output_spec=policy_state_spec, include_batch_dimension=False), 'get_train_step': _function_with_flat_signature(train_step_fn, input_specs=(), output_spec=train_step.dtype, include_batch_dimension=False), } policy.action = polymorphic_action_fn policy.get_initial_state = get_initial_state_fn policy.get_train_step = train_step_fn # Adding variables as an attribute to facilitate updating them. policy.model_variables = policy.variables() self._policy = policy self._signatures = signatures
def __init__(self, policy: tf_policy.TFPolicy, batch_size: Optional[int] = None, use_nest_path_signatures: bool = True, seed: Optional[types.Seed] = None, train_step: Optional[tf.Variable] = None, input_fn_and_spec: Optional[InputFnAndSpecType] = None, metadata: Optional[Dict[Text, tf.Variable]] = None): """Initialize PolicySaver for TF policy `policy`. Args: policy: A TF Policy. batch_size: The number of batch entries the policy will process at a time. This must be either `None` (unknown batch size) or a python integer. use_nest_path_signatures: SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names. seed: Random seed for the `policy.action` call, if any (this should usually be `None`, except for testing). train_step: Variable holding the train step for the policy. The value saved will be set at the time `saver.save` is called. If not provided, train_step defaults to -1. Note since the train step must be a variable it is not safe to create it directly in TF1 so in that case this is a required parameter. input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the `(time_step, policy_state)` tuple that is used as the input to the action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input for the action signature. When `input_fn_and_spec is None`, the action signature takes as input `(time_step, policy_state)`. metadata: A dictionary of `tf.Variables` to be saved along with the policy. Raises: TypeError: If `policy` is not an instance of TFPolicy. TypeError: If `metadata` is not a dictionary of tf.Variables. ValueError: If use_nest_path_signatures is not used and any of the following `policy` specs are missing names, or the names collide: `policy.time_step_spec`, `policy.action_spec`, `policy.policy_state_spec`, `policy.info_spec`. ValueError: If `batch_size` is not either `None` or a python integer > 0. """ if not isinstance(policy, tf_policy.TFPolicy): raise TypeError('policy is not a TFPolicy. Saw: %s' % type(policy)) if (batch_size is not None and (not isinstance(batch_size, int) or batch_size < 1)): raise ValueError( 'Expected batch_size == None or python int > 0, saw: %s' % (batch_size, )) action_fn_input_spec = (policy.time_step_spec, policy.policy_state_spec) if use_nest_path_signatures: action_fn_input_spec = _rename_spec_with_nest_paths( action_fn_input_spec) else: _check_spec(action_fn_input_spec) # Make a shallow copy as we'll be making some changes in-place. saved_policy = tf.Module() saved_policy.collect_data_spec = copy.copy(policy.collect_data_spec) saved_policy.policy_state_spec = copy.copy(policy.policy_state_spec) if train_step is None: if not common.has_eager_been_enabled(): raise ValueError('train_step is required in TF1 and must be a ' '`tf.Variable`: %s' % train_step) train_step = tf.Variable( -1, trainable=False, dtype=tf.int64, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, shape=()) elif not isinstance(train_step, tf.Variable): raise ValueError('train_step must be a TensorFlow variable: %s' % train_step) # We will need the train step for the Checkpoint object. self._train_step = train_step saved_policy.train_step = self._train_step self._metadata = metadata or {} for key, value in self._metadata.items(): if not isinstance(key, str): raise TypeError('Keys of metadata must be strings: %s' % key) if not isinstance(value, tf.Variable): raise TypeError('Values of metadata must be tf.Variable: %s' % value) saved_policy.metadata = self._metadata if batch_size is None: get_initial_state_fn = policy.get_initial_state get_initial_state_input_specs = (tf.TensorSpec( dtype=tf.int32, shape=(), name='batch_size'), ) else: get_initial_state_fn = functools.partial(policy.get_initial_state, batch_size=batch_size) get_initial_state_input_specs = () get_initial_state_fn = common.function()(get_initial_state_fn) original_action_fn = policy.action if seed is not None: def action_fn(time_step, policy_state): time_step = cast(ts.TimeStep, time_step) return original_action_fn(time_step, policy_state, seed=seed) else: action_fn = original_action_fn def distribution_fn(time_step, policy_state): """Wrapper for policy.distribution() in the SavedModel.""" try: time_step = cast(ts.TimeStep, time_step) outs = policy.distribution(time_step=time_step, policy_state=policy_state) return tf.nest.map_structure(_composite_distribution, outs) except (TypeError, NotImplementedError) as e: # TODO(b/156526399): Move this to just the policy.distribution() call # once tfp.experimental.as_composite() properly handles LinearOperator* # components as well as TransformedDistributions. logging.warning( 'WARNING: Could not serialize policy.distribution() for policy ' '"%s". Calling saved_model.distribution() will raise the following ' 'assertion error: %s', policy, e) @common.function() def _raise(): tf.Assert(False, [str(e)]) return () outs = _raise() # We call get_concrete_function() for its side effect: to ensure the proper # ConcreteFunction is stored in the SavedModel. get_initial_state_fn.get_concrete_function( *get_initial_state_input_specs) train_step_fn = common.function( lambda: saved_policy.train_step).get_concrete_function() get_metadata_fn = common.function( lambda: saved_policy.metadata).get_concrete_function() batched_time_step_spec = tf.nest.map_structure( lambda spec: add_batch_dim(spec, [batch_size]), policy.time_step_spec) batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec) batched_policy_state_spec = tf.nest.map_structure( lambda spec: add_batch_dim(spec, [batch_size]), policy.policy_state_spec) policy_step_spec = policy.policy_step_spec policy_state_spec = policy.policy_state_spec if use_nest_path_signatures: batched_time_step_spec = _rename_spec_with_nest_paths( batched_time_step_spec) batched_policy_state_spec = _rename_spec_with_nest_paths( batched_policy_state_spec) policy_step_spec = _rename_spec_with_nest_paths(policy_step_spec) policy_state_spec = _rename_spec_with_nest_paths(policy_state_spec) else: _check_spec(batched_time_step_spec) _check_spec(batched_policy_state_spec) _check_spec(policy_step_spec) _check_spec(policy_state_spec) if input_fn_and_spec is not None: # Store a signature based on input_fn_and_spec @common.function() def polymorphic_action_fn(example): action_inputs = input_fn_and_spec[0](example) tf.nest.map_structure( lambda spec, t: tf.Assert(spec.is_compatible_with(t[ 0]), [t]), action_fn_input_spec, action_inputs) return action_fn(*action_inputs) @common.function() def polymorphic_distribution_fn(example): action_inputs = input_fn_and_spec[0](example) tf.nest.map_structure( lambda spec, t: tf.Assert(spec.is_compatible_with(t[ 0]), [t]), action_fn_input_spec, action_inputs) return distribution_fn(*action_inputs) batched_input_spec = tf.nest.map_structure( lambda spec: add_batch_dim(spec, [batch_size]), input_fn_and_spec[1]) # We call get_concrete_function() for its side effect: to ensure the # proper ConcreteFunction is stored in the SavedModel. polymorphic_action_fn.get_concrete_function( example=batched_input_spec) polymorphic_distribution_fn.get_concrete_function( example=batched_input_spec) action_input_spec = (input_fn_and_spec[1], ) else: action_input_spec = action_fn_input_spec if batched_policy_state_spec: # Store the signature with a required policy state spec polymorphic_action_fn = common.function()(action_fn) polymorphic_action_fn.get_concrete_function( time_step=batched_time_step_spec, policy_state=batched_policy_state_spec) polymorphic_distribution_fn = common.function()( distribution_fn) polymorphic_distribution_fn.get_concrete_function( time_step=batched_time_step_spec, policy_state=batched_policy_state_spec) else: # Create a polymorphic action_fn which you can call as # restored.action(time_step) # or # restored.action(time_step, ()) # (without retracing the inner action twice) @common.function() def polymorphic_action_fn( time_step, policy_state=batched_policy_state_spec): return action_fn(time_step, policy_state) polymorphic_action_fn.get_concrete_function( time_step=batched_time_step_spec, policy_state=batched_policy_state_spec) polymorphic_action_fn.get_concrete_function( time_step=batched_time_step_spec) @common.function() def polymorphic_distribution_fn( time_step, policy_state=batched_policy_state_spec): return distribution_fn(time_step, policy_state) polymorphic_distribution_fn.get_concrete_function( time_step=batched_time_step_spec, policy_state=batched_policy_state_spec) polymorphic_distribution_fn.get_concrete_function( time_step=batched_time_step_spec) signatures = { # CompositeTensors aren't well supported by old-style signature # mechanisms, so we do not have a signature for policy.distribution. 'action': _function_with_flat_signature(polymorphic_action_fn, input_specs=action_input_spec, output_spec=policy_step_spec, include_batch_dimension=True, batch_size=batch_size), 'get_initial_state': _function_with_flat_signature( get_initial_state_fn, input_specs=get_initial_state_input_specs, output_spec=policy_state_spec, include_batch_dimension=False), 'get_train_step': _function_with_flat_signature(train_step_fn, input_specs=(), output_spec=train_step.dtype, include_batch_dimension=False), 'get_metadata': _function_with_flat_signature(get_metadata_fn, input_specs=(), output_spec=tf.nest.map_structure( lambda v: v.dtype, self._metadata), include_batch_dimension=False), } saved_policy.action = polymorphic_action_fn saved_policy.distribution = polymorphic_distribution_fn saved_policy.get_initial_state = get_initial_state_fn saved_policy.get_train_step = train_step_fn saved_policy.get_metadata = get_metadata_fn # Adding variables as an attribute to facilitate updating them. saved_policy.model_variables = policy.variables() # TODO(b/156779400): Move to a public API for accessing all trackable leaf # objects (once it's available). For now, we have no other way of tracking # objects like Tables, Vocabulary files, etc. try: saved_policy._all_assets = policy._unconditional_checkpoint_dependencies # pylint: disable=protected-access except AttributeError as e: if '_self_unconditional' in str(e): logging.warning( 'Unable to capture all trackable objects in policy "%s". This ' 'may be okay. Error: %s', policy, e) else: raise e self._policy = saved_policy self._signatures = signatures self._action_input_spec = action_input_spec self._policy_step_spec = policy_step_spec self._policy_state_spec = policy_state_spec
def setUp(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x.') super(SequentialTest, self).setUp()
def testInferenceWithCheckpoint(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x.') # Create and saved_model for a q_policy. network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) sample_input = self.evaluate( tensor_spec.sample_spec_nest(self._time_step_spec, outer_dims=(3, ))) saver = policy_saver.PolicySaver(policy, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model') self.evaluate(tf.compat.v1.global_variables_initializer()) original_eval = self.evaluate(policy.action(sample_input)) saver.save(path) # Asign -1 to all variables in the policy. Making checkpoint different than # the initial saved_model. self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + -1), policy.variables())) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') saver.save_checkpoint(checkpoint_path) # Get an instance of the saved_model. reloaded_policy = tf.compat.v2.saved_model.load(path) self.evaluate( tf.compat.v1.initializers.variables( reloaded_policy.model_variables)) # Verify loaded saved_model variables are different than the current policy. model_variables = self.evaluate(policy.variables()) reloaded_model_variables = self.evaluate( reloaded_policy.model_variables) assert_np_not_equal = lambda a, b: self.assertFalse( np.equal(a, b).any()) tf.nest.map_structure(assert_np_not_equal, model_variables, reloaded_model_variables) # Update from checkpoint. checkpoint = tf.train.Checkpoint(policy=reloaded_policy) checkpoint_file_prefix = os.path.join(checkpoint_path, 'variables', 'variables') checkpoint.read( checkpoint_file_prefix).assert_existing_objects_matched() self.evaluate( tf.compat.v1.initializers.variables( reloaded_policy.model_variables)) # Verify variables are now equal. model_variables = self.evaluate(policy.variables()) reloaded_model_variables = self.evaluate( reloaded_policy.model_variables) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, model_variables, reloaded_model_variables) # Verify variable update affects inference. reloaded_eval = self.evaluate(reloaded_policy.action(sample_input)) tf.nest.map_structure(assert_np_not_equal, original_eval, reloaded_eval) current_eval = self.evaluate(policy.action(sample_input)) tf.nest.map_structure(assert_np_not_equal, current_eval, reloaded_eval)
def _get_feature_encoder(shape, dtype, compress_image=False, image_quality=95): """Get feature encoder function for shape and dtype. Args: shape: An array shape dtype: A list of dtypes. compress_image: Whether to compress image. It is assumed that any uint8 tensor of rank 3 with shape (w,h,3) is an image. image_quality: An optional int. Defaults to 95. Quality of the compression from 0 to 100 (higher is better and slower). Returns: A tf.train.Feature encoder. """ shape = _validate_shape(shape) dtype = _validate_dtype(dtype) if compress_image and len( shape) == 3 and shape[2] == 3 and dtype == tf.uint8: if not common.has_eager_been_enabled(): raise ValueError('Only supported in TF2.x.') def _encode_to_jpeg_bytes_list(value): value = tf.io.encode_jpeg(value, quality=image_quality) return tf.train.Feature(bytes_list=tf.train.BytesList( value=[value.numpy()])) return _encode_to_jpeg_bytes_list if dtype == tf.float32: # Serialize float32 to FloatList. def _encode_to_float_list(value): value = np.asarray(value) _check_shape_and_dtype(value, shape, dtype) return tf.train.Feature(float_list=tf.train.FloatList( value=value.flatten(order='C').tolist())) return _encode_to_float_list elif dtype == tf.int64: # Serialize int64 to Int64List. def _encode_to_int64_list(value): value = np.asarray(value) _check_shape_and_dtype(value, shape, dtype) return tf.train.Feature(int64_list=tf.train.Int64List( value=value.flatten(order='C').tolist())) return _encode_to_int64_list else: # Serialize anything else to BytesList in little endian order. le_dtype = dtype.as_numpy_dtype(0).newbyteorder('L') def _encode_to_bytes_list(value): value = np.asarray(value) _check_shape_and_dtype(value, shape, dtype) bytes_list_value = np.require(value, dtype=le_dtype, requirements='C').tostring() return tf.train.Feature(bytes_list=tf.train.BytesList( value=[bytes_list_value])) return _encode_to_bytes_list
def testUpdateWithCheckpoint(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x.') # Create and saved_model for a q_policy. network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model') self.evaluate(tf.compat.v1.global_variables_initializer()) saver.save(path) # Assign -1 to all variables in the policy. Making checkpoint different than # the initial saved_model. self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + -1), policy.variables())) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') saver.save_checkpoint(checkpoint_path) # Get an instance of the saved_model. reloaded_policy = tf.compat.v2.saved_model.load(path) self.evaluate( tf.compat.v1.initializers.variables( reloaded_policy.model_variables)) # Verify loaded saved_model variables are different than the current policy. model_variables = self.evaluate(policy.variables()) reloaded_model_variables = self.evaluate( reloaded_policy.model_variables) assert_np_not_equal = lambda a, b: self.assertFalse( np.equal(a, b).any()) tf.nest.map_structure(assert_np_not_equal, model_variables, reloaded_model_variables) # Update from checkpoint. checkpoint = tf.train.Checkpoint(policy=reloaded_policy) manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_path, max_to_keep=None) checkpoint.restore(manager.latest_checkpoint).expect_partial() self.evaluate( tf.compat.v1.initializers.variables( reloaded_policy.model_variables)) # Verify variables are now equal. model_variables = self.evaluate(policy.variables()) reloaded_model_variables = self.evaluate( reloaded_policy.model_variables) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, model_variables, reloaded_model_variables)