def setUp(self): self.observation_spec = tensor_spec.TensorSpec((2, 3), tf.float32) self.reward_spec = tensor_spec.TensorSpec((2, ), tf.float32) self.time_step_spec = ts.time_step_spec(self.observation_spec, reward_spec=self.reward_spec) self.action_spec = tensor_spec.TensorSpec((2, ), tf.float32) self.random_env = random_tf_environment.RandomTFEnvironment( self.time_step_spec, self.action_spec)
def _build_test_env(obs_spec=None, action_spec=None, batch_size=2): if obs_spec is None: obs_spec = tensor_spec.BoundedTensorSpec((2, 3), tf.int32, -10, 10) if action_spec is None: action_spec = tensor_spec.BoundedTensorSpec((1,), tf.int32, 0, 4) time_step_spec = ts.time_step_spec(obs_spec) return random_tf_environment.RandomTFEnvironment( time_step_spec, action_spec, batch_size=batch_size)
def testRNNTrain(self, compute_value_and_advantage_in_train): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._time_step_spec.observation, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, lstm_size=(20, )) value_net = value_rnn_network.ValueRnnNetwork( self._time_step_spec.observation, input_fc_layer_params=None, output_fc_layer_params=None, lstm_size=(10, )) global_step = tf.compat.v1.train.get_or_create_global_step() agent = ppo_agent.PPOAgent( self._time_step_spec, self._action_spec, optimizer=tf.compat.v1.train.AdamOptimizer(), actor_net=actor_net, value_net=value_net, num_epochs=1, train_step_counter=global_step, compute_value_and_advantage_in_train= compute_value_and_advantage_in_train) # Use a random env, policy, and replay buffer to collect training data. random_env = random_tf_environment.RandomTFEnvironment( self._time_step_spec, self._action_spec, batch_size=1) collection_policy = random_tf_policy.RandomTFPolicy( self._time_step_spec, self._action_spec, info_spec=agent.collect_policy.info_spec) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collection_policy.trajectory_spec, batch_size=1, max_length=7) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( random_env, collection_policy, observers=[replay_buffer.add_batch], num_episodes=1) # In graph mode: finish building the graph so the optimizer # variables are created. if not tf.executing_eagerly(): _, _ = agent.train(experience=replay_buffer.gather_all()) # Initialize. self.evaluate(agent.initialize()) self.evaluate(tf.compat.v1.global_variables_initializer()) # Train one step. self.assertEqual(0, self.evaluate(global_step)) self.evaluate(collect_driver.run()) self.evaluate(agent.train(experience=replay_buffer.gather_all())) self.assertEqual(1, self.evaluate(global_step))