def GetAgent(self, env, params): def init_gnn(name): """ Returns a new `GNNWrapper`instance with the given `name`. We need this function to be able to prefix the variable names with the names of the parent actor or critic network, by passing in this function and initializing the instance in the parent network. """ return GNNWrapper( params=self._gnn_sac_params["GNN"], graph_dims=self._observer.graph_dimensions, name=name) # actor network actor_net = GNNActorNetwork( input_tensor_spec=env.observation_spec(), output_tensor_spec=env.action_spec(), gnn=init_gnn, fc_layer_params=self._gnn_sac_params[ "ActorFcLayerParams", "", [128, 64]] ) # critic network critic_net = GNNCriticNetwork( (env.observation_spec(), env.action_spec()), gnn=init_gnn, observation_fc_layer_params=self._gnn_sac_params[ "CriticObservationFcLayerParams", "", [128]], action_fc_layer_params=self._gnn_sac_params[ "CriticActionFcLayerParams", "", None], joint_fc_layer_params=self._gnn_sac_params[ "CriticJointFcLayerParams", "", [128, 128]] ) # agent tf_agent = sac_agent.SacAgent( env.time_step_spec(), env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._gnn_sac_params["ActorLearningRate", "", 3e-4]), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._gnn_sac_params["CriticLearningRate", "", 3e-4]), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._gnn_sac_params["AlphaLearningRate", "", 3e-4]), target_update_tau=self._gnn_sac_params["TargetUpdateTau", "", 0.05], target_update_period=self._gnn_sac_params["TargetUpdatePeriod", "", 3], td_errors_loss_fn=tf.math.squared_difference, gamma=self._gnn_sac_params["Gamma", "", 0.995], reward_scale_factor=self._gnn_sac_params["RewardScaleFactor", "", 1.], train_step_counter=self._ckpt.step, name=self._gnn_sac_params["AgentName", "", "gnn_sac_agent"], debug_summaries=self._gnn_sac_params["DebugSummaries", "", False]) tf_agent.initialize() return tf_agent
def testAgentTransitionTrain(self): actor_net = actor_distribution_network.ActorDistributionNetwork( self._obs_spec, self._action_spec, fc_layer_params=(10, ), continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=actor_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer(0.001), critic_optimizer=tf.compat.v1.train.AdamOptimizer(0.001), alpha_optimizer=tf.compat.v1.train.AdamOptimizer(0.001)) time_step_spec = self._time_step_spec._replace( reward=tensor_spec.BoundedTensorSpec( [], tf.float32, minimum=0.0, maximum=1.0, name='reward')) transition_spec = trajectory.Transition( time_step=time_step_spec, action_step=policy_step.PolicyStep(action=self._action_spec, state=(), info=()), next_time_step=time_step_spec) sample_trajectory_experience = tensor_spec.sample_spec_nest( transition_spec, outer_dims=(3, )) agent.train(sample_trajectory_experience)
def testAgentTrajectoryTrain(self): actor_net = actor_distribution_network.ActorDistributionNetwork( self._obs_spec, self._action_spec, fc_layer_params=(10, ), continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=actor_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer(0.001), critic_optimizer=tf.compat.v1.train.AdamOptimizer(0.001), alpha_optimizer=tf.compat.v1.train.AdamOptimizer(0.001)) trajectory_spec = trajectory.Trajectory( step_type=self._time_step_spec.step_type, observation=self._time_step_spec.observation, action=self._action_spec, policy_info=(), next_step_type=self._time_step_spec.step_type, reward=tensor_spec.BoundedTensorSpec([], tf.float32, minimum=0.0, maximum=1.0, name='reward'), discount=self._time_step_spec.discount) sample_trajectory_experience = tensor_spec.sample_spec_nest( trajectory_spec, outer_dims=(3, 2)) agent.train(sample_trajectory_experience)
def testCriticRegLoss(self): agent = sac_agent.SacAgent(self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(0.5), actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy) observations = tf.zeros((2, 2), dtype=tf.float32) time_steps = ts.restart(observations, batch_size=2) actions = tf.zeros((2, 1), dtype=tf.float32) rewards = tf.zeros((2, ), dtype=tf.float32) discounts = tf.zeros((2, ), dtype=tf.float32) next_observations = tf.zeros((2, 2), dtype=tf.float32) next_time_steps = ts.transition(next_observations, rewards, discounts) # Expected loss only regularization loss. expected_loss = 2.0 loss = agent.critic_loss(time_steps, actions, next_time_steps, td_errors_loss_fn=tf.math.squared_difference) self.evaluate(tf.compat.v1.global_variables_initializer()) loss_ = self.evaluate(loss) self.assertAllClose(loss_, expected_loss)
def testTrainWithRnn(self): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40,), ) critic_net = critic_rnn_network.CriticRnnNetwork( (self._obs_spec, self._action_spec), observation_fc_layer_params=(16,), action_fc_layer_params=(16,), joint_fc_layer_params=(16,), lstm_size=(16,), output_fc_layer_params=None, ) counter = common.create_variable('test_train_counter') optimizer_fn = tf.compat.v1.train.AdamOptimizer agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_net, actor_network=actor_net, actor_optimizer=optimizer_fn(1e-3), critic_optimizer=optimizer_fn(1e-3), alpha_optimizer=optimizer_fn(1e-3), train_step_counter=counter, ) batch_size = 5 observations = tf.constant( [[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep( step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=[observations]) experience = trajectory.Trajectory( time_steps.step_type, [observations], actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Force variable creation. agent.policy.variables() if tf.executing_eagerly(): loss = lambda: agent.train(experience) else: loss = agent.train(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertEqual(self.evaluate(counter), 0) self.evaluate(loss) self.assertEqual(self.evaluate(counter), 1)
def GetAgent(self, env, params): def _normal_projection_net(action_spec, init_means_output_factor=0.1): return normal_projection_network.NormalProjectionNetwork( action_spec, mean_transform=None, state_dependent_std=True, init_means_output_factor=init_means_output_factor, std_transform=sac_agent.std_clip_transform, scale_distribution=True) # actor network actor_net = actor_distribution_network.ActorDistributionNetwork( env.observation_spec(), env.action_spec(), fc_layer_params=tuple( self._params["ML"]["BehaviorSACAgent"]["ActorFcLayerParams", "", [512, 256, 256]]), continuous_projection_net=_normal_projection_net) # critic network critic_net = critic_network.CriticNetwork( (env.observation_spec(), env.action_spec()), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=tuple(self._params["ML"]["BehaviorSACAgent"][ "CriticJointFcLayerParams", "", [512, 256, 256]])) # agent tf_agent = sac_agent.SacAgent( env.time_step_spec(), env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["BehaviorSACAgent"][ "ActorLearningRate", "", 3e-4]), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["BehaviorSACAgent"][ "CriticLearningRate", "", 3e-4]), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["BehaviorSACAgent"][ "AlphaLearningRate", "", 3e-4]), target_update_tau=self._params["ML"]["BehaviorSACAgent"][ "TargetUpdateTau", "", 0.05], target_update_period=self._params["ML"]["BehaviorSACAgent"][ "TargetUpdatePeriod", "", 3], td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=self._params["ML"]["BehaviorSACAgent"]["Gamma", "", 0.995], reward_scale_factor=self._params["ML"]["BehaviorSACAgent"][ "RewardScaleFactor", "", 1.], train_step_counter=self._ckpt.step, name=self._params["ML"]["BehaviorSACAgent"]["AgentName", "", "sac_agent"], debug_summaries=self._params["ML"]["BehaviorSACAgent"][ "DebugSummaries", "", False]) tf_agent.initialize() return tf_agent
def testCreateAgent(self): sac_agent.SacAgent(self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy)
def GetAgent(self, env, params): self._params["ML"]["GraphDims"] = self._observer.graph_dimensions # actor network actor_net = GNNActorNetwork( input_tensor_spec=env.observation_spec(), output_tensor_spec=env.action_spec(), gnn=self._init_gnn, fc_layer_params=self._gnn_sac_params["ActorFcLayerParams", "", [256, 256]], params=params) # critic network critic_net = GNNCriticNetwork( (env.observation_spec(), env.action_spec()), gnn=self._init_gnn, observation_fc_layer_params=self._gnn_sac_params[ "CriticObservationFcLayerParams", "", [256]], action_fc_layer_params=self._gnn_sac_params[ "CriticActionFcLayerParams", "", None], joint_fc_layer_params=self._gnn_sac_params[ "CriticJointFcLayerParams", "", [256, 256]], params=params) # agent tf_agent = sac_agent.SacAgent( env.time_step_spec(), env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._gnn_sac_params["ActorLearningRate", "", 3e-4]), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._gnn_sac_params["CriticLearningRate", "", 3e-4]), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._gnn_sac_params["AlphaLearningRate", "", 0.]), target_update_tau=self._gnn_sac_params["TargetUpdateTau", "", 1.], target_update_period=self._gnn_sac_params["TargetUpdatePeriod", "", 1], td_errors_loss_fn=tf.math.squared_difference, gamma=self._gnn_sac_params["Gamma", "", 0.995], reward_scale_factor=self._gnn_sac_params["RewardScaleFactor", "", 1.], train_step_counter=self._ckpt.step, name=self._gnn_sac_params["AgentName", "", "gnn_sac_agent"], debug_summaries=self._gnn_sac_params["DebugSummaries", "", True]) tf_agent.initialize() return tf_agent
def testSharedLayer(self): shared_layer = tf.keras.layers.Dense( 1, kernel_initializer=tf.compat.v1.initializers.constant([0]), bias_initializer=tf.compat.v1.initializers.constant([0]), name='shared') critic_net_1 = DummyCriticNet(shared_layer=shared_layer) critic_net_2 = DummyCriticNet(shared_layer=shared_layer) target_shared_layer = tf.keras.layers.Dense( 1, kernel_initializer=tf.compat.v1.initializers.constant([0]), bias_initializer=tf.compat.v1.initializers.constant([0]), name='shared_target') target_critic_net_1 = DummyCriticNet(shared_layer=target_shared_layer) target_critic_net_2 = DummyCriticNet(shared_layer=target_shared_layer) agent = sac_agent.SacAgent(self._time_step_spec, self._action_spec, critic_network=critic_net_1, critic_network_2=critic_net_2, target_critic_network=target_critic_net_1, target_critic_network_2=target_critic_net_2, actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, target_entropy=3.0, initial_log_alpha=4.0, target_update_tau=0.5, actor_policy_ctor=DummyActorPolicy) self.evaluate([ tf.compat.v1.global_variables_initializer(), tf.compat.v1.local_variables_initializer() ]) self.evaluate(agent.initialize()) for v in shared_layer.variables: self.evaluate(v.assign(v * 0 + 1)) self.evaluate(agent._update_target()) self.assertEqual(1.0, self.evaluate(shared_layer.variables[0][0][0])) self.assertEqual(0.5, self.evaluate(target_shared_layer.variables[0][0][0]))
def testCreateAgent(self, create_critic_net_fn, skip_in_tf1): if skip_in_tf1 and not common.has_eager_been_enabled(): self.skipTest('Skipping test: sequential networks not supported in TF1') critic_network = create_critic_net_fn() sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_network, actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy)
def load_policy(agent_class, tf_env): load_dir = FLAGS.load_dir assert load_dir and osp.exists( load_dir ), 'need to provide valid load_dir to load policy, got: {}'.format( load_dir) global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=(256, 256), continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork((observation_spec, action_spec), joint_fc_layer_params=(256, 256)) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), target_update_tau=0.005, target_update_period=1, td_errors_loss_fn=tf.keras.losses.mse, gamma=0, reward_scale_factor=1., gradient_clipping=1., debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=global_step) train_checkpointer = common.Checkpointer(ckpt_dir=load_dir, agent=tf_agent, global_step=global_step) status = train_checkpointer.initialize_or_restore() status.expect_partial() logging.info('Loaded from checkpoint: %s, trained %s steps', train_checkpointer._manager.latest_checkpoint, global_step.numpy()) return tf_agent.policy
def GetAgent(self, env, params): gnn_sac_params = self._params["ML"]["BehaviorGraphSACAgent"] # actor network actor_net = GNNActorNetwork( input_tensor_spec=env.observation_spec(), output_tensor_spec=env.action_spec(), gnn=GNNWrapper(params=gnn_sac_params["GNN"], graph_dims=self._observer.graph_dimensions), fc_layer_params=gnn_sac_params["ActorFcLayerParams", "", [128, 64]]) # critic network critic_net = GNNCriticNetwork( (env.observation_spec(), env.action_spec()), gnn=GNNWrapper(params=gnn_sac_params["GNN"], graph_dims=self._observer.graph_dimensions), observation_fc_layer_params=gnn_sac_params[ "CriticObservationFcLayerParams", "", [128]], action_fc_layer_params=gnn_sac_params["CriticActionFcLayerParams", "", None], joint_fc_layer_params=gnn_sac_params["CriticJointFcLayerParams", "", [128, 128]]) # agent tf_agent = sac_agent.SacAgent( env.time_step_spec(), env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=gnn_sac_params["ActorLearningRate", "", 3e-4]), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=gnn_sac_params["CriticLearningRate", "", 3e-4]), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=gnn_sac_params["AlphaLearningRate", "", 3e-4]), target_update_tau=gnn_sac_params["TargetUpdateTau", "", 0.05], target_update_period=gnn_sac_params["TargetUpdatePeriod", "", 3], td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=gnn_sac_params["Gamma", "", 0.995], reward_scale_factor=gnn_sac_params["RewardScaleFactor", "", 1.], train_step_counter=self._ckpt.step, name=gnn_sac_params["AgentName", "", "gnn_sac_agent"], debug_summaries=gnn_sac_params["DebugSummaries", "", False]) tf_agent.initialize() return tf_agent
def testLoss(self, mock_actions_and_log_probs, mock_apply_gradients): # Mock _actions_and_log_probs so that _train() and _loss() run on the same # sampled values. actions = tf.constant([[0.2], [0.5], [-0.8]]) log_pi = tf.constant([-1.1, -0.8, -0.5]) mock_actions_and_log_probs.return_value = (actions, log_pi) # Skip applying gradients since mocking _actions_and_log_probs. del mock_apply_gradients actor_net = actor_distribution_network.ActorDistributionNetwork( self._obs_spec, self._action_spec, fc_layer_params=(10,), continuous_projection_net=tanh_normal_projection_network .TanhNormalProjectionNetwork) agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=actor_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer(0.001), critic_optimizer=tf.compat.v1.train.AdamOptimizer(0.001), alpha_optimizer=tf.compat.v1.train.AdamOptimizer(0.001)) observations = tf.constant( [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], dtype=tf.float32) actions = tf.constant([[[0], [1]], [[2], [3]], [[4], [5]]], dtype=tf.float32) time_steps = ts.TimeStep( step_type=tf.constant([[1, 1]] * 3, dtype=tf.int32), reward=tf.constant([[1, 1]] * 3, dtype=tf.float32), discount=tf.constant([[1, 1]] * 3, dtype=tf.float32), observation=observations) experience = trajectory.Trajectory( time_steps.step_type, observations, actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) test_util.test_loss_and_train_output( test=self, expect_equal_loss_values=True, agent=agent, experience=experience)
def create_sac_agent(train_env, reward_scale_factor): return sac_agent.SacAgent( train_env.time_step_spec(), train_env.action_spec(), actor_network=create_actor_network(train_env), critic_network=create_critic_network(train_env), actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), target_update_tau=0.005, target_update_period=1, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=reward_scale_factor, gradient_clipping=None, train_step_counter=tf.compat.v1.train.get_or_create_global_step(), )
def testPolicy(self): agent = sac_agent.SacAgent(self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy) observations = tf.constant([[1, 2]], dtype=tf.float32) time_steps = ts.restart(observations) action_step = agent.policy.action(time_steps) self.evaluate(tf.compat.v1.global_variables_initializer()) action_ = self.evaluate(action_step.action) self.assertLessEqual(action_, self._action_spec.maximum) self.assertGreaterEqual(action_, self._action_spec.minimum)
def testActorLoss(self): agent = sac_agent.SacAgent(self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) time_steps = ts.restart(observations, batch_size=2) expected_loss = (2 * 10 - (2 + 1) - (4 + 1)) / 2 loss = agent.actor_loss(time_steps) self.evaluate(tf.compat.v1.global_variables_initializer()) loss_ = self.evaluate(loss) self.assertAllClose(loss_, expected_loss)
def create_sac_agent(self, actor, critic, actor_alpha, critic_alpha, alpha_alpha, gamma): train_step_counter = tf.Variable(0) return sac_agent.SacAgent( spec.get_time_step_spec(), spec.get_action_spec(), actor_network=actor, critic_network=critic, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_alpha), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_alpha), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_alpha), target_update_tau=0.05, target_update_period=5, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=gamma, train_step_counter=train_step_counter)
def sac_agent(self): return sac_agent.SacAgent( self.train_env.time_step_spec(), self.action_spec, actor_network=self.actor_net, critic_network=self.critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self.actor_lr), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self.critic_lr), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self.alpha_lr), target_update_tau=self.target_update_tau, target_update_period=self.target_update_period, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=self.gamma, reward_scale_factor=self.reward_scale, gradient_clipping=self.gradient_clipping, train_step_counter=self.global_step)
def testCriticLoss(self, create_critic_net_fn, skip_in_tf1): if skip_in_tf1 and not common.has_eager_been_enabled(): self.skipTest('Skipping test: sequential networks not supported in TF1') critic_network = create_critic_net_fn() agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_network, actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, actor_policy_ctor=DummyActorPolicy) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) time_steps = ts.restart(observations, batch_size=2) actions = tf.constant([[5], [6]], dtype=tf.float32) rewards = tf.constant([10, 20], dtype=tf.float32) discounts = tf.constant([0.9, 0.9], dtype=tf.float32) next_observations = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) next_time_steps = ts.transition(next_observations, rewards, discounts) td_targets = [7.3, 19.1] pred_td_targets = [7., 10.] self.evaluate(tf.compat.v1.global_variables_initializer()) # Expected critic loss has factor of 2, for the two TD3 critics. expected_loss = self.evaluate(2 * tf.compat.v1.losses.mean_squared_error( tf.constant(td_targets), tf.constant(pred_td_targets))) loss = agent.critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn=tf.math.squared_difference) self.evaluate(tf.compat.v1.global_variables_initializer()) loss_ = self.evaluate(loss) self.assertAllClose(loss_, expected_loss)
def testAlphaLoss(self): agent = sac_agent.SacAgent(self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, squash_actions=False, target_entropy=3.0, initial_log_alpha=4.0, actor_policy_ctor=DummyActorPolicy) observations = [tf.constant([[1, 2], [3, 4]], dtype=tf.float32)] time_steps = ts.restart(observations, batch_size=2) expected_loss = 4.0 * (-10 - 3) loss = agent.alpha_loss(time_steps) self.evaluate(tf.global_variables_initializer()) loss_ = self.evaluate(loss) self.assertAllClose(loss_, expected_loss)
def _create_agent(train_step: tf.Variable, observation_tensor_spec: types.NestedTensorSpec, action_tensor_spec: types.NestedTensorSpec, time_step_tensor_spec: ts.TimeStep, learning_rate: float) -> tf_agent.TFAgent: """Creates an agent.""" critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=(256, 256), kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=(256, 256), continuous_projection_net=tanh_normal_projection_network .TanhNormalProjectionNetwork) return sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), target_update_tau=0.005, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, train_step_counter=train_step)
def testCriticLoss(self): agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=DummyCriticNet(), actor_network=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, squash_actions=False, actor_policy_ctor=DummyActorPolicy) observations = [tf.constant([[1, 2], [3, 4]], dtype=tf.float32)] time_steps = ts.restart(observations) actions = tf.constant([[5], [6]], dtype=tf.float32) rewards = tf.constant([10, 20], dtype=tf.float32) discounts = tf.constant([0.9, 0.9], dtype=tf.float32) next_observations = [tf.constant([[5, 6], [7, 8]], dtype=tf.float32)] next_time_steps = ts.transition(next_observations, rewards, discounts) td_targets = [7.3, 19.1] pred_td_targets = [7., 10.] self.evaluate(tf.compat.v1.global_variables_initializer()) # Expected critic loss has factor of 2, for the two TD3 critics. expected_loss = self.evaluate(2 * tf.compat.v1.losses.mean_squared_error( tf.constant(td_targets), tf.constant(pred_td_targets))) loss = agent.critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error) self.evaluate(tf.compat.v1.global_variables_initializer()) loss_ = self.evaluate(loss) self.assertAllClose(loss_, expected_loss)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
fc_layer_params=HyperParms.actor_fc_layer_params, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork)) with objStrategy.scope(): train_step = train_utils.create_train_step() tf_agent = sac_agent.SacAgent( specTimeStep, specAction, actor_network=nnActor, critic_network=nnCritic, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=HyperParms.actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=HyperParms.critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=HyperParms.alpha_learning_rate), target_update_tau=HyperParms.target_update_tau, target_update_period=HyperParms.target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=HyperParms.gamma, reward_scale_factor=HyperParms.reward_scale_factor, train_step_counter=train_step) tf_agent.initialize() print(f" -- REPLAY BUFFER ({now()}) -- ") rate_limiter = reverb.rate_limiters.SampleToInsertRatio(samples_per_insert=3.0, min_size_to_sample=3, error_buffer=3.0)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', agent_name='sac', # agent's name num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs mask_names=['birdeye'], # names for masks preprocessing_combiner=tf.keras.layers.Add( ), # takes a flat list of tensors and combines them actor_lstm_size=(40, ), # lstm size for actor critic_lstm_size=(40, ), # lstm size for critic actor_output_fc_layers=(100, ), # lstm output critic_output_fc_layers=(100, ), # lstm output epsilon_greedy=0.1, # exploration parameter for DQN q_learning_rate=1e-3, # q learning rate for DQN ou_stddev=0.2, # exploration paprameter for DDPG ou_damping=0.15, # exploration parameter for DDPG dqda_clipping=None, # for DDPG exploration_noise_std=0.1, # exploration paramter for td3 actor_update_period=2, # for td3 # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, initial_model_train_steps=100000, # initial model training batch_size=256, model_batch_size=32, # model training batch size sequence_length=4, # number of timesteps to train model actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, # learning rate for model training td_errors_loss_fn=tf.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1 ): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) # Get train and eval directories root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment if agent_name == 'latent_sac': py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names + mask_names, action_repeat=action_repeat) elif agent_name == 'dqn': py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat) else: py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() ## Make tf agent if agent_name == 'latent_sac': # Get model network for latent sac if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical elif model_network_ctor_type == 'non-hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical else: raise NotImplementedError model_net = model_network_ctor(input_names, input_names + mask_names) # Get the latent spec latent_size = model_net.latent_size latent_observation_spec = tensor_spec.TensorSpec((latent_size, ), dtype=tf.float32) latent_time_step_spec = ts.time_step_spec( observation_spec=latent_observation_spec) # Get actor and critic net actor_net = actor_distribution_network.ActorDistributionNetwork( latent_observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (latent_observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) # Build the inner SAC agent based on latent space inner_agent = sac_agent.SacAgent( latent_time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) inner_agent.initialize() # Build the latent sac agent tf_agent = latent_sac_agent.LatentSACAgent( time_step_spec, action_spec, inner_agent=inner_agent, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), model_batch_size=model_batch_size, num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) else: # Set up preprosessing layers for dictionary observation inputs preprocessing_layers = collections.OrderedDict() for name in input_names: preprocessing_layers[name] = Preprocessing_Layer(32, 256) if len(input_names) < 2: preprocessing_combiner = None if agent_name == 'dqn': q_rnn_net = q_rnn_network.QRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_rnn_net, epsilon_greedy=epsilon_greedy, n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=q_learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'ddpg' or agent_name == 'td3': actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) if agent_name == 'ddpg': tf_agent = ddpg_agent.DdpgAgent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) elif agent_name == 'td3': tf_agent = td3_agent.Td3Agent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'sac': actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_distribution_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math. squared_difference, # make critic loss dimension compatible gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: raise NotImplementedError tf_agent.initialize() # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_policy = tf_agent.policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) # Optimize the performance by using tf functions initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if (env_steps.result() == 0 or replay_buffer.num_frames() == 0): logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() if agent_name == 'latent_sac': compute_summaries(eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=1, num_episodes_to_render=1, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=1, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) if agent_name == 'latent_sac': def train_model_step(): experience, _ = next(iterator) return tf_agent.train_model(experience) train_model_step = common.function(train_model_step) # Training initializations time_step = None time_acc = 0 env_steps_before = env_steps.result().numpy() # Start training for iteration in range(num_iterations): start_time = time.time() if agent_name == 'latent_sac' and iteration < initial_model_train_steps: train_model_step() else: # Run collect time_step, _ = collect_driver.run(time_step=time_step) # Train an iteration for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time # Log training information if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_allowlist='position', eval_env_name=None, num_iterations=1000000, # Params for networks. actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), num_parallel_environments=1, # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=256, critic_learning_rate=3e-4, train_sequence_length=20, actor_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for RNN SAC on DM control.""" root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_allowlist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_allowlist=[observations_allowlist]) ] else: env_wrappers = [] env_load_fn = functools.partial(suite_dm_control.load, task_name=task_name, env_wrappers=env_wrappers) if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length + 1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
print('Actor Network Created.') # create SAC Agent # https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/SacAgent global_step = tf.compat.v1.train.get_or_create_global_step() if shouldContinueFromLastCheckpoint: global_step = tf.compat.v1.train.get_global_step() # with strategy.scope(): # train_step = train_utils.create_train_step() tf_agent = sac_agent.SacAgent( env.time_step_spec(), action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=actorLearningRate), critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=criticLearningRate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=alphaLearningRate), target_update_tau=target_update_tau, gamma=gamma, gradient_clipping=gradientClipping, train_step_counter=global_step, ) tf_agent.initialize() print('SAC Agent Created.') # policies evaluate_policy = greedy_policy.GreedyPolicy(tf_agent.policy) collect_policy = tf_agent.collect_policy # metrics and evaluation
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layer_params, continuous_projection_net=normal_projection_net) global_step = tf.compat.v1.train.get_or_create_global_step() tf_agent = sac_agent.SacAgent( train_env.time_step_spec(), action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, train_step_counter=global_step) tf_agent.initialize() eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) collect_policy = tf_agent.collect_policy def compute_avg_return(environment, policy, num_episodes=5):