def testSaveGetInitialState(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') network = q_rnn_network.QRnnNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver_nobatch = policy_saver.PolicySaver(policy, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_nobatch') saver_nobatch.save(path) reloaded_nobatch = tf.compat.v2.saved_model.load(path) self.assertIn('get_initial_state', reloaded_nobatch.signatures) reloaded_get_initial_state = ( reloaded_nobatch.signatures['get_initial_state']) self._compare_input_output_specs( reloaded_get_initial_state, expected_input_specs=(tf.TensorSpec(dtype=tf.int32, shape=(), name='batch_size'), ), expected_output_spec=policy.policy_state_spec, batch_input=False, batch_size=None) initial_state = policy.get_initial_state(batch_size=3) initial_state = self.evaluate(initial_state) reloaded_nobatch_initial_state = reloaded_nobatch.get_initial_state( batch_size=3) reloaded_nobatch_initial_state = self.evaluate( reloaded_nobatch_initial_state) tf.nest.map_structure(self.assertAllClose, initial_state, reloaded_nobatch_initial_state) saver_batch = policy_saver.PolicySaver(policy, batch_size=3) path = os.path.join(self.get_temp_dir(), 'save_model_initial_state_batch') saver_batch.save(path) reloaded_batch = tf.compat.v2.saved_model.load(path) self.assertIn('get_initial_state', reloaded_batch.signatures) reloaded_get_initial_state = reloaded_batch.signatures[ 'get_initial_state'] self._compare_input_output_specs( reloaded_get_initial_state, expected_input_specs=(), expected_output_spec=policy.policy_state_spec, batch_input=False, batch_size=3) reloaded_batch_initial_state = reloaded_batch.get_initial_state() reloaded_batch_initial_state = self.evaluate( reloaded_batch_initial_state) tf.nest.map_structure(self.assertAllClose, initial_state, reloaded_batch_initial_state)
def testVariablesAccessible(self): network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model') self.evaluate(tf.compat.v1.global_variables_initializer()) saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.evaluate( tf.compat.v1.initializers.variables(reloaded.model_variables)) model_variables = self.evaluate(policy.variables()) reloaded_model_variables = self.evaluate(reloaded.model_variables) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, model_variables, reloaded_model_variables)
def testTrainStepSaved(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) train_step = common.create_variable('train_step', initial_value=7) saver = policy_saver.PolicySaver(policy, batch_size=None, train_step=train_step) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('get_train_step', reloaded.signatures) train_step_value = self.evaluate(reloaded.train_step()) self.assertEqual(7, train_step_value) train_step = train_step.assign_add(3) saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) train_step_value = self.evaluate(reloaded.train_step()) self.assertEqual(10, train_step_value)
def testUpdateWithCompositeSavedModelAndCheckpoint(self): # Create and saved_model for a q_policy. network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) train_step = common.create_variable('train_step', initial_value=0) saver = policy_saver.PolicySaver(policy, train_step=train_step, batch_size=None) full_model_path = os.path.join(self.get_temp_dir(), 'save_model') def assert_val_equal_var(val, var): self.assertTrue(np.array_equal(np.full_like(var, val), var)) self.evaluate(tf.compat.v1.global_variables_initializer()) # Set all variables in the saved model to 1 variables = policy.variables() self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + 1), variables)) for v in self.evaluate(variables): assert_val_equal_var(1, v) with self.cached_session(): saver.save(full_model_path) # Assign 2 to all variables in the policy. Making checkpoint different than # the initial saved_model. self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + 2), variables)) for v in self.evaluate(variables): assert_val_equal_var(2, v) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') with self.cached_session(): saver.save_checkpoint(checkpoint_path) # Reload the full model and check all variables are 1 reloaded_policy = tf.compat.v2.saved_model.load(full_model_path) self.evaluate( tf.compat.v1.initializers.variables( reloaded_policy.model_variables)) for v in self.evaluate(reloaded_policy.model_variables): assert_val_equal_var(1, v) # Compose a new full saved model from the original saved model files # and variables from the checkpoint. composite_path = os.path.join(self.get_temp_dir(), 'composite_model') self.copy_tree(full_model_path, composite_path, skip_variables=True) self.copy_tree(checkpoint_path, os.path.join(composite_path)) # Reload the composite model and check all variables are 2 reloaded_policy = tf.compat.v2.saved_model.load(composite_path) self.evaluate( tf.compat.v1.initializers.variables( reloaded_policy.model_variables)) for v in self.evaluate(reloaded_policy.model_variables): assert_val_equal_var(2, v)
def __init__( self, batch_size, action_spec, time_step_spec, n_iterations, replay_buffer_max_length, learning_rate=1e-3, checkpoint_dir=None ): self.batch_size = batch_size self.time_step_spec = time_step_spec self.action_spec = action_spec observation_spec = self.time_step_spec.observation self.actor_net = HierachyActorNetwork( observation_spec, action_spec, n_iterations, n_options=4 ) value_net = value_network.ValueNetwork( observation_spec, fc_layer_params=(100,) ) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) self.global_step = tf.compat.v1.train.get_or_create_global_step() self.agent = ppo_agent.PPOAgent( time_step_spec, self.action_spec, actor_net=self.actor_net, value_net=value_net, optimizer=optimizer, normalize_rewards=True, normalize_observations=False, train_step_counter=self.global_step ) self.agent.initialize() self.agent.train = common.function(self.agent.train) self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=self.agent.collect_data_spec, batch_size=self.batch_size, max_length=replay_buffer_max_length ) self.train_checkpointer = None if (checkpoint_dir): self.train_checkpointer = common.Checkpointer( ckpt_dir=checkpoint_dir, max_to_keep=1, agent=self.agent, policy=self.agent.policy, replay_buffer=self.replay_buffer, global_step=self.global_step ) self.train_checkpointer.initialize_or_restore() self.policy_saver = policy_saver.PolicySaver(self.agent.policy)
def testPolicySaverCompatibility(self): observation_spec = { 'a': tf.TensorSpec(4, tf.float32), 'b': tf.TensorSpec(3, tf.float32) } time_step_tensor_spec = ts.time_step_spec(observation_spec) net = nest_map.NestMap({ 'a': tf.keras.layers.LSTM(8, return_state=True, return_sequences=True), 'b': tf.keras.layers.Dense(8) }) net.create_variables(observation_spec) policy = MyPolicy(time_step_tensor_spec, net) sample = tensor_spec.sample_spec_nest(time_step_tensor_spec, outer_dims=(5, )) step = policy.action(sample) self.assertEqual(step.action.shape.as_list(), [5, 8]) train_step = common.create_variable('train_step') saver = policy_saver.PolicySaver(policy, train_step=train_step) self.initialize_v1_variables() with self.cached_session(): saver.save(os.path.join(FLAGS.test_tmpdir, 'nest_map_model'))
def save_model(): optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate = 1e-3); obs_spec = TensorSpec((7,), dtype = tf.float32, name = 'observation'); action_spec = BoundedTensorSpec((1,), dtype = tf.int32, minimum = 0, maximum = 3, name = 'action'); actor_net = ActorDistributionRnnNetwork(obs_spec, action_spec, lstm_size = (100,100)); value_net = ValueRnnNetwork(obs_spec); agent = ppo_agent.PPOAgent( time_step_spec = time_step_spec(obs_spec), action_spec = action_spec, optimizer = optimizer, actor_net = actor_net, value_net = value_net, normalize_observations = True, normalize_rewards = True, use_gae = True, num_epochs = 1, ); checkpointer = Checkpointer( ckpt_dir = 'checkpoints/policy', max_to_keep = 1, agent = agent, policy = agent.policy, global_step = tf.compat.v1.train.get_or_create_global_step()); checkpointer.initialize_or_restore(); saver = policy_saver.PolicySaver(agent.policy); saver.save('final_policy');
def testMetadataSaved(self): # We need to use one default session so that self.evaluate and the # SavedModel loader share the same session. with tf.compat.v1.Session().as_default(): network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) self.evaluate(tf.compat.v1.initializers.variables(policy.variables())) train_step = common.create_variable('train_step', initial_value=1) env_step = common.create_variable('env_step', initial_value=7) metadata = {'env_step': env_step} self.evaluate(tf.compat.v1.initializers.variables([train_step, env_step])) saver = policy_saver.PolicySaver( policy, batch_size=None, train_step=train_step, metadata=metadata) if tf.executing_eagerly(): loaded_metadata = saver.get_metadata() else: loaded_metadata = self.evaluate(saver.get_metadata()) self.assertEqual(self.evaluate(metadata), loaded_metadata) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertIn('get_metadata', reloaded.signatures) env_step_value = self.evaluate(reloaded.get_metadata())['env_step'] self.assertEqual(7, env_step_value)
def testGetTrainStep(self, train_step): path = os.path.join(self.get_temp_dir(), 'saved_policy') if train_step is None: # Use the default argument, which should set the train step to be -1. saver = policy_saver.PolicySaver(self.tf_policy) expected_train_step = -1 else: saver = policy_saver.PolicySaver( self.tf_policy, train_step=tf.constant(train_step)) expected_train_step = train_step saver.save(path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) self.assertEqual(expected_train_step, eager_py_policy.get_train_step())
def testTrainStepSaved(self): # We need to use one default session so that self.evaluate and the # SavedModel loader share the same session. with tf.compat.v1.Session().as_default(): network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) train_step = common.create_variable('train_step', initial_value=7) self.evaluate(tf.compat.v1.initializers.variables([train_step])) saver = policy_saver.PolicySaver(policy, batch_size=None, train_step=train_step) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('get_train_step', reloaded.signatures) self.evaluate(tf.compat.v1.global_variables_initializer()) train_step_value = self.evaluate(reloaded.train_step()) self.assertEqual(7, train_step_value) train_step = train_step.assign_add(3) self.evaluate(train_step) saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.evaluate(tf.compat.v1.global_variables_initializer()) train_step_value = self.evaluate(reloaded.train_step()) self.assertEqual(10, train_step_value)
def testRegisterConcreteFunction(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) tf_var = tf.Variable(3) def add(b): return tf_var + b add_fn = common.function(add) # Called for side effect. add_fn.get_concrete_function(tf.TensorSpec((), dtype=tf.int32)) saver.register_concrete_function(name='add', fn=add_fn) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertAllClose(7, reloaded.add(4))
def save_model(self): # create directory to save checkpoint checkpoint_dir = os.path.join(self._savedir, 'checkpoint') train_checkpointer = common.Checkpointer( ckpt_dir=checkpoint_dir, max_to_keep=1, agent=self._agent, policy=self._agent.policy, replay_buffer=self._replay_buffer, global_step=self._train_step) # save the checkpoint train_checkpointer.save(self._train_step) # create directory to save policy policy_dir = os.path.join(self._savedir, 'policy') tf_policy_saver = policy_saver.PolicySaver(self._agent.policy) # save the policy tf_policy_saver.save(policy_dir) print("model saved.") if self._visual_flag: # save the animation for epi in range(self._num_episodes): f = self._vizdir + str(epi) + "/eval-animation.mp4" writervideo = FFMpegWriter(fps=33) self._ani[epi].save(f, writer=writervideo) plt.show() # try loading saved policy loaded_policy = tf.saved_model.load(policy_dir) eval_timestep = self._eval_env.reset() loaded_action = loaded_policy.action(eval_timestep) print("example policy: ", loaded_action)
def testUpdateFromCheckpoint(self): path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(self.tf_policy) saver.save(path) self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + -1), self.tf_policy.variables())) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') saver.save_checkpoint(checkpoint_path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) # Use evaluate to force a copy. saved_model_variables = self.evaluate(eager_py_policy.variables()) checkpoint = tf.train.Checkpoint(policy=eager_py_policy._policy) manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_path, max_to_keep=None) eager_py_policy.update_from_checkpoint(manager.latest_checkpoint) assert_np_not_equal = lambda a, b: self.assertFalse( np.equal(a, b).all()) tf.nest.map_structure(assert_np_not_equal, saved_model_variables, self.evaluate(eager_py_policy.variables())) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, self.evaluate(self.tf_policy.variables()), self.evaluate(eager_py_policy.variables()))
def testUniqueSignatures(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) action_signature_names = [ s.name for s in saver._signatures['action'].input_signature ] self.assertAllEqual( ['0/step_type', '0/reward', '0/discount', '0/observation'], action_signature_names) initial_state_signature_names = [ s.name for s in saver._signatures['get_initial_state'].input_signature ] self.assertAllEqual(['batch_size'], initial_state_signature_names)
def testRenamedSignatures(self): time_step_spec = self._time_step_spec._replace( observation=tensor_spec.BoundedTensorSpec( dtype=tf.float32, shape=(4,), minimum=-10.0, maximum=10.0)) network = q_network.QNetwork( input_tensor_spec=time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=time_step_spec, action_spec=self._action_spec, q_network=network) train_step = common.create_variable('train_step', initial_value=7) saver = policy_saver.PolicySaver( policy, train_step=train_step, batch_size=None) action_signature_names = [ s.name for s in saver._signatures['action'].input_signature ] self.assertAllEqual( ['0/step_type', '0/reward', '0/discount', '0/observation'], action_signature_names) initial_state_signature_names = [ s.name for s in saver._signatures['get_initial_state'].input_signature ] self.assertAllEqual(['batch_size'], initial_state_signature_names)
def save_agent_policy(): now = datetime.datetime.now() policy_dir = config.POLICIES_PATH + now.strftime("%m%d%Y-%H%M%S") os.mkdir(policy_dir) tf_policy_saver = policy_saver.PolicySaver(agent.policy) tf_policy_saver.save(policy_dir) print(">>>Policy saved in ", policy_dir)
def testUpdateFromCheckpoint(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x.') path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(self.tf_policy) saver.save(path) self.evaluate( tf.nest.map_structure(lambda v: v.assign(v * 0 + -1), self.tf_policy.variables())) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') saver.save_checkpoint(checkpoint_path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) # Use evaluate to force a copy. saved_model_variables = self.evaluate(eager_py_policy.variables()) eager_py_policy.update_from_checkpoint(checkpoint_path) assert_np_not_equal = lambda a, b: self.assertFalse( np.equal(a, b).all()) tf.nest.map_structure(assert_np_not_equal, saved_model_variables, self.evaluate(eager_py_policy.variables())) assert_np_all_equal = lambda a, b: self.assertTrue( np.equal(a, b).all()) tf.nest.map_structure(assert_np_all_equal, self.evaluate(self.tf_policy.variables()), self.evaluate(eager_py_policy.variables()), check_types=False)
def testRegisterFunction(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) saver.register_function('q_network', network, self._time_step_spec.observation) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) sample_input = self.evaluate( tensor_spec.sample_spec_nest( self._time_step_spec.observation, outer_dims=(3,))) expected_output, _ = network(sample_input) reloaded_output, _ = reloaded.q_network(sample_input) self.assertAllClose(expected_output, reloaded_output)
def testCheckpointSave(self): network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) train_step = common.create_variable('train_step', initial_value=0) saver = policy_saver.PolicySaver( policy, train_step=train_step, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model') self.evaluate(tf.compat.v1.global_variables_initializer()) with self.cached_session(): saver.save(path) checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint') with self.cached_session(): saver.save_checkpoint(checkpoint_path) self.assertTrue(tf.compat.v2.io.gfile.exists(checkpoint_path)) # Also test CheckpointOptions checkpoint2_path = os.path.join(self.get_temp_dir(), 'checkpoint2') options = tf.train.CheckpointOptions( experimental_io_device='/job:localhost') with self.cached_session(): saver.save_checkpoint(checkpoint2_path, options=options) self.assertTrue(tf.compat.v2.io.gfile.exists(checkpoint2_path))
def testDistributionNotImplemented(self): policy = PolicyNoDistribution() with self.assertRaisesRegex(NotImplementedError, '_distribution has not been implemented'): policy.distribution( ts.TimeStep(step_type=(), reward=(), discount=(), observation=())) train_step = common.create_variable('train_step', initial_value=0) saver = policy_saver.PolicySaver(policy, train_step=train_step, batch_size=None) path = os.path.join(self.get_temp_dir(), 'save_model') self.evaluate(tf.compat.v1.global_variables_initializer()) with self.cached_session(): saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, '_distribution has not been implemented'): self.evaluate( reloaded.distribution( ts.TimeStep(step_type=(), reward=(), discount=(), observation=())))
def testSavedModel(self): path = os.path.join(self.get_temp_dir(), 'saved_policy') saver = policy_saver.PolicySaver(self.tf_policy) saver.save(path) eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy( path, self.time_step_spec, self.action_spec) rng = np.random.RandomState() sample_time_step = array_spec.sample_spec_nest(self.time_step_spec, rng) batched_sample_time_step = nest_utils.batch_nested_array( sample_time_step) original_action = self.tf_policy.action(batched_sample_time_step) unbatched_original_action = nest_utils.unbatch_nested_tensors( original_action) original_action_np = tf.nest.map_structure(lambda t: t.numpy(), unbatched_original_action) saved_policy_action = eager_py_policy.action(sample_time_step) tf.nest.assert_same_structure(saved_policy_action.action, self.action_spec) np.testing.assert_array_almost_equal(original_action_np.action, saved_policy_action.action)
def testRenamedSignatures(self): if not tf.executing_eagerly(): self.skipTest( 'b/129079730: PolicySaver does not work in TF1.x yet') time_step_spec = self._time_step_spec._replace( observation=tensor_spec.BoundedTensorSpec( dtype=tf.float32, shape=(4, ), minimum=-10.0, maximum=10.0)) network = q_network.QNetwork( input_tensor_spec=time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy(time_step_spec=time_step_spec, action_spec=self._action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) action_signature_names = [ s.name for s in saver._signatures['action'].input_signature ] self.assertAllEqual( ['0/step_type', '0/reward', '0/discount', '0/observation'], action_signature_names) initial_state_signature_names = [ s.name for s in saver._signatures['get_initial_state'].input_signature ] self.assertAllEqual(['batch_size'], initial_state_signature_names)
def _build_saver( self, policy: tf_policy.TFPolicy ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]: saver = policy_saver.PolicySaver( policy, train_step=self._train_step, metadata=self._metadata) if self._async_saving: saver = async_policy_saver.AsyncPolicySaver(saver) return saver
def testRegisterFunction(self): if not common.has_eager_been_enabled(): self.skipTest('Only supported in TF2.x. Step is required in TF1.x') time_step_spec = ts.TimeStep( step_type=tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), name='st', minimum=0, maximum=2), reward=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(), name='reward', minimum=0.0, maximum=5.0), discount=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(), name='discount', minimum=0.0, maximum=1.0), observation=tensor_spec.BoundedTensorSpec(dtype=tf.float32, shape=(4, ), name='obs', minimum=-10.0, maximum=10.0)) action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), minimum=0, maximum=10, name='act_0') network = q_network.QNetwork( input_tensor_spec=time_step_spec.observation, action_spec=action_spec) policy = q_policy.QPolicy(time_step_spec=time_step_spec, action_spec=action_spec, q_network=network) saver = policy_saver.PolicySaver(policy, batch_size=None) async_saver = async_policy_saver.AsyncPolicySaver(saver) async_saver.register_function('q_network', network, time_step_spec.observation) path = os.path.join(self.get_temp_dir(), 'save_model') async_saver.save(path) async_saver.flush() async_saver.close() self.assertFalse(async_saver._save_thread.is_alive()) reloaded = tf.compat.v2.saved_model.load(path) sample_input = self.evaluate( tensor_spec.sample_spec_nest(time_step_spec.observation, outer_dims=(3, ))) expected_output, _ = network(sample_input) reloaded_output, _ = reloaded.q_network(sample_input) self.assertAllClose(expected_output, reloaded_output)
def setUp(self): super(PolicyLoaderTest, self).setUp() self.root_dir = self.get_temp_dir() tf_observation_spec = tensor_spec.TensorSpec((), np.float32) tf_time_step_spec = ts.time_step_spec(tf_observation_spec) tf_action_spec = tensor_spec.BoundedTensorSpec((), np.float32, 0, 3) self.net = AddNet() self.policy = greedy_policy.GreedyPolicy( q_policy.QPolicy(tf_time_step_spec, tf_action_spec, self.net)) self.train_step = common.create_variable('train_step', initial_value=0) self.saver = policy_saver.PolicySaver(self.policy, train_step=self.train_step)
def testMixturePolicyDynamicBatchSize(self): context_dim = 35 observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = ts.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec(shape=(), dtype=tf.int32, minimum=0, maximum=9, name='action') sub_policies = [ ConstantPolicy(action_spec, time_step_spec, i) for i in range(10) ] weights = [0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.5, 0] dist = tfd.Categorical(probs=weights) policy = mixture_policy.MixturePolicy(dist, sub_policies) batch_size = tf.random.uniform(shape=(), minval=10, maxval=15, dtype=tf.int32) time_step = ts.TimeStep( tf.fill(tf.expand_dims(batch_size, axis=0), ts.StepType.FIRST, name='step_type'), tf.zeros(shape=[batch_size], dtype=tf.float32, name='reward'), tf.ones(shape=[batch_size], dtype=tf.float32, name='discount'), tf.reshape(tf.range(tf.cast(batch_size * context_dim, dtype=tf.float32), dtype=tf.float32), shape=[-1, context_dim], name='observation')) action_step = policy.action(time_step) actions, bsize = self.evaluate([action_step.action, batch_size]) self.assertAllEqual(actions.shape, [bsize]) self.assertAllInSet(actions, [2, 5, 8]) saver = policy_saver.PolicySaver(policy) location = os.path.join(self.get_temp_dir(), 'saved_policy') saver.save(location) loaded_policy = tf.compat.v2.saved_model.load(location) new_batch_size = 3 new_time_step = ts.TimeStep( tf.fill(tf.expand_dims(new_batch_size, axis=0), ts.StepType.FIRST, name='step_type'), tf.zeros(shape=[new_batch_size], dtype=tf.float32, name='reward'), tf.ones(shape=[new_batch_size], dtype=tf.float32, name='discount'), tf.reshape(tf.range(tf.cast(new_batch_size * context_dim, dtype=tf.float32), dtype=tf.float32), shape=[-1, context_dim], name='observation')) new_action = self.evaluate(loaded_policy.action(new_time_step).action) self.assertAllEqual(new_action.shape, [new_batch_size]) self.assertAllInSet(new_action, [2, 5, 8])
def initiate(self): self._checkpoint_dir = os.path.join(self._save_dir, 'checkpoint') self._train_checkpointer = common.Checkpointer( ckpt_dir=self._checkpoint_dir, max_to_keep=1, agent=self._agent, policy=self._agent.policy, replay_buffer=self._replay_buffer, global_step=self._global_step) self._policy_dir = os.path.join(self._save_dir, 'policy') self._tf_policy_saver = policy_saver.PolicySaver(self._agent.policy)
def testTrainStepSaved(self): # We need to use one default session so that self.evaluate and the # SavedModel loader share the same session. with tf.compat.v1.Session().as_default(): network = q_network.QNetwork( input_tensor_spec=self._time_step_spec.observation, action_spec=self._action_spec) policy = q_policy.QPolicy( time_step_spec=self._time_step_spec, action_spec=self._action_spec, q_network=network) self.evaluate(tf.compat.v1.initializers.variables(policy.variables())) train_step = common.create_variable('train_step', initial_value=7) self.evaluate(tf.compat.v1.initializers.variables([train_step])) saver = policy_saver.PolicySaver( policy, batch_size=None, train_step=train_step) if tf.executing_eagerly(): step = saver.get_train_step() else: step = self.evaluate(saver.get_train_step()) self.assertEqual(7, step) path = os.path.join(self.get_temp_dir(), 'save_model') saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.assertIn('get_train_step', reloaded.signatures) self.evaluate(tf.compat.v1.global_variables_initializer()) train_step_value = self.evaluate(reloaded.get_train_step()) self.assertEqual(7, train_step_value) train_step = train_step.assign_add(3) self.evaluate(train_step) saver.save(path) reloaded = tf.compat.v2.saved_model.load(path) self.evaluate(tf.compat.v1.global_variables_initializer()) train_step_value = self.evaluate(reloaded.get_train_step()) self.assertEqual(10, train_step_value) # Also test passing SaveOptions. train_step = train_step.assign_add(3) self.evaluate(train_step) path2 = os.path.join(self.get_temp_dir(), 'save_model2') saver.save( path2, options=tf.saved_model.SaveOptions( experimental_io_device='/job:localhost')) reloaded = tf.compat.v2.saved_model.load(path2) self.evaluate(tf.compat.v1.global_variables_initializer()) train_step_value = self.evaluate(reloaded.get_train_step()) self.assertEqual(13, train_step_value)
def policy_saver(self, reward=0, iteration=0): """Save the policy in the directory path defined by the mehtod _save_points_dir from the current class""" save_dir = self._save_points_dir() + '_{}_{}'.format( str(reward), str(iteration)) print(save_dir) tf_policy_saver = policy_saver.PolicySaver(self.agent.policy) tf_policy_saver.save(save_dir) print('\nThe policy of the deep Q network is saved in: "{}"'.format( save_dir)) return save_dir
def testSaver(self): policy = categorical_q_policy.CategoricalQPolicy( self._time_step_spec, self._action_spec, self._q_network, self._min_q_value, self._max_q_value) saver = policy_saver.PolicySaver(policy) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.compat.v1.local_variables_initializer()) save_path = os.path.join(flags.FLAGS.test_tmpdir, 'saved_categorical_q_policy') saver.save(save_path)