def GetAgent(self, env, params): def _normal_projection_net(action_spec, init_means_output_factor=0.1): return normal_projection_network.NormalProjectionNetwork( action_spec, mean_transform=None, state_dependent_std=True, init_means_output_factor=init_means_output_factor, std_transform=sac_agent.std_clip_transform, scale_distribution=True) # actor network actor_net = actor_distribution_network.ActorDistributionNetwork( env.observation_spec(), env.action_spec(), fc_layer_params=tuple( self._params["ML"]["BehaviorSACAgent"]["ActorFcLayerParams", "", [512, 256, 256]]), continuous_projection_net=_normal_projection_net) # critic network critic_net = critic_network.CriticNetwork( (env.observation_spec(), env.action_spec()), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=tuple(self._params["ML"]["BehaviorSACAgent"][ "CriticJointFcLayerParams", "", [512, 256, 256]])) # agent tf_agent = sac_agent.SacAgent( env.time_step_spec(), env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["BehaviorSACAgent"][ "ActorLearningRate", "", 3e-4]), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["BehaviorSACAgent"][ "CriticLearningRate", "", 3e-4]), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["BehaviorSACAgent"][ "AlphaLearningRate", "", 3e-4]), target_update_tau=self._params["ML"]["BehaviorSACAgent"][ "TargetUpdateTau", "", 0.05], target_update_period=self._params["ML"]["BehaviorSACAgent"][ "TargetUpdatePeriod", "", 3], td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=self._params["ML"]["BehaviorSACAgent"]["Gamma", "", 0.995], reward_scale_factor=self._params["ML"]["BehaviorSACAgent"][ "RewardScaleFactor", "", 1.], train_step_counter=self._ckpt.step, name=self._params["ML"]["BehaviorSACAgent"]["AgentName", "", "sac_agent"], debug_summaries=self._params["ML"]["BehaviorSACAgent"][ "DebugSummaries", "", False]) tf_agent.initialize() return tf_agent
def verifyTrainAndRestore(self, loss_fn=None): """Helper function for testing correct variable updating and restoring.""" batch_size = 2 seq_len = 2 observations = tensor_spec.sample_spec_nest(self._observation_spec, outer_dims=(batch_size, seq_len)) actions = tensor_spec.sample_spec_nest(self._action_spec, outer_dims=(batch_size, seq_len)) rewards = tf.constant([[10, 10], [20, 20]], dtype=tf.float32) discounts = tf.constant([[0.9, 0.9], [0.9, 0.9]], dtype=tf.float32) experience = trajectory.first(observation=observations, action=actions, policy_info=(), reward=rewards, discount=discounts) strategy = tf.distribute.get_strategy() with strategy.scope(): q_net = critic_network.CriticNetwork( (self._observation_spec, self._action_spec)) agent = qtopt_agent.QtOptAgent( self._time_step_spec, self._action_spec, q_network=q_net, optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), init_mean_cem=self._mean, init_var_cem=self._var, num_samples_cem=self._num_samples, actions_sampler=self._sampler, in_graph_bellman_update=True) loss_before_train = agent.loss(experience).loss # Check loss is stable. self.assertEqual(loss_before_train, agent.loss(experience).loss) # Train 1 step, verify that loss is decreased for the same input. agent.train(experience) loss_after_train = agent.loss(experience).loss self.assertLessEqual(loss_after_train, loss_before_train) # Assert loss evaluation is still stable, e.g. deterministic. self.assertLessEqual(loss_after_train, agent.loss(experience).loss) # Save checkpoint ckpt_dir = self.create_tempdir() checkpointer = common.Checkpointer(ckpt_dir=ckpt_dir, agent=agent) global_step = tf.constant(1) checkpointer.save(global_step) # Assign all vars to 0. for var in tf.nest.flatten(agent.variables): var.assign(tf.zeros_like(var)) loss_after_zero = agent.loss(experience).loss self.assertEqual(loss_after_zero, agent.loss(experience).loss) self.assertNotEqual(loss_after_zero, loss_after_train) # Restore checkpointer._checkpoint.restore( checkpointer._manager.latest_checkpoint) loss_after_restore = agent.loss(experience).loss self.assertNotEqual(loss_after_restore, loss_after_zero) self.assertEqual(loss_after_restore, loss_after_train)
def create_critic_network(self, observation_fc_layer_params, action_fc_layer_params, joint_fc_layer_params): critic_net_input_specs = (spec.get_observation_spec(), spec.get_action_spec()) return critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=observation_fc_layer_params, action_fc_layer_params=action_fc_layer_params, joint_fc_layer_params=joint_fc_layer_params, name='critic_' + self.name)
def testInitializeAgent(self): q_net = critic_network.CriticNetwork( (self._observation_spec, self._action_spec)) agent = qtopt_agent.QtOptAgent(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None, init_mean_cem=self._mean, init_var_cem=self._var, num_samples_cem=self._num_samples, actions_sampler=self._sampler) agent.initialize()
def testCreateAgent(self): q_net = critic_network.CriticNetwork( (self._observation_spec, self._action_spec)) agent = qtopt_agent.QtOptAgent(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None, init_mean_cem=self._mean, init_var_cem=self._var, num_samples_cem=self._num_samples, actions_sampler=self._sampler) self.assertIsNotNone(agent.policy)
def testBuild(self): batch_size = 3 num_obs_dims = 5 num_actions_dims = 2 obs_spec = tensor_spec.TensorSpec([num_obs_dims], tf.float32) action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32) obs = tf.random.uniform([batch_size, num_obs_dims]) actions = tf.random.uniform([batch_size, num_actions_dims]) critic_net = critic_network.CriticNetwork((obs_spec, action_spec)) q_values, _ = critic_net((obs, actions)) self.assertAllEqual(q_values.shape.as_list(), [batch_size]) self.assertLen(critic_net.trainable_variables, 2)
def testAddJointFCLayers(self): batch_size = 3 num_obs_dims = 5 num_actions_dims = 2 obs_spec = tensor_spec.TensorSpec([num_obs_dims], tf.float32) action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32) critic_net = critic_network.CriticNetwork((obs_spec, action_spec), joint_fc_layer_params=[20]) obs = tf.random.uniform([batch_size, num_obs_dims]) actions = tf.random.uniform([batch_size, num_actions_dims]) q_values, _ = critic_net((obs, actions)) self.assertAllEqual(q_values.shape.as_list(), [batch_size]) self.assertLen(critic_net.trainable_variables, 4)
def get_agent(self, env, params): """Returns a TensorFlow SAC-Agent Arguments: env {TFAPyEnvironment} -- Tensorflow-Agents PyEnvironment params {ParameterServer} -- ParameterServer from BARK Returns: agent -- tf-agent """ # actor network actor_net = actor_network.ActorNetwork( env.observation_spec(), env.action_spec(), fc_layer_params=tuple( self._params["ML"]["Agent"]["actor_fc_layer_params"]), ) # critic network critic_net = critic_network.CriticNetwork( (env.observation_spec(), env.action_spec()), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=tuple( self._params["ML"]["Agent"]["critic_joint_fc_layer_params"])) # agent # TODO(@hart): put all parameters in config file tf_agent = td3_agent.Td3Agent( env.time_step_spec(), env.action_spec(), critic_network=critic_net, actor_network=actor_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["Agent"] ["actor_learning_rate"]), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=self._params["ML"]["Agent"] ["critic_learning_rate"]), debug_summaries=self._params["ML"]["Agent"]["debug_summaries"], train_step_counter=self._ckpt.step, gamma=0.99, target_update_tau=0.5, target_policy_noise_clip=0.5) tf_agent.initialize() return tf_agent
def create_critic_network(train_env): return critic_network.CriticNetwork( (train_env.observation_spec(), train_env.action_spec()), observation_conv_layer_params=[ (4, (5, 1), 1), (4, (1, 5), 2), (8, (5, 1), 1), (8, (1, 5), 2), (16, (5, 1), 1), (16, (1, 5), 2), (32, (5, 1), 1), (32, (1, 5), 2), ], action_fc_layer_params=[128], joint_fc_layer_params=[128, 128], )
def testAddObsConvLayers(self): batch_size = 3 num_obs_dims = 5 num_actions_dims = 2 obs_spec = tensor_spec.TensorSpec([3, 3, num_obs_dims], tf.float32) action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32) critic_net = critic_network.CriticNetwork( (obs_spec, action_spec), observation_conv_layer_params=[(16, 3, 2)]) obs = tf.random.uniform([batch_size, 3, 3, num_obs_dims]) actions = tf.random.uniform([batch_size, num_actions_dims]) q_values, _ = critic_net((obs, actions)) self.assertAllEqual(q_values.shape.as_list(), [batch_size]) self.assertEqual(len(critic_net.trainable_variables), 4)
def init_agent(): """ a DDPG agent is set by default in the application""" # get the global step global_step = tf.compat.v1.train.get_or_create_global_step() # TODO: update this to get the optimizer from tensorflow 2.0 if possible optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) time_step_spec = time_step.time_step_spec( self._rl_app.observation_spec) actor_net = actor_network.ActorNetwork( self._rl_app.observation_spec, self._rl_app.action_spec, fc_layer_params=(400, 300)) value_net = critic_network.CriticNetwork( (time_step_spec.observation, self._rl_app.action_spec), observation_fc_layer_params=(400, ), action_fc_layer_params=None, joint_fc_layer_params=(300, )) tf_agent = ddpg_agent.DdpgAgent( time_step_spec, self._rl_app.action_spec, actor_network=actor_net, critic_network=value_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=1e-4), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=1e-3), ou_stddev=0.2, ou_damping=0.15, target_update_tau=0.05, target_update_period=5, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=discount, reward_scale_factor=1.0, gradient_clipping=gradient_clipping, debug_summaries=True, summarize_grads_and_vars=True, train_step_counter=global_step) tf_agent.initialize() logger.info("tf_agent initialization is complete") # Optimize by wrapping some of the code in a graph using TF function. tf_agent.train = common.function(tf_agent.train) return tf_agent
def load_policy(agent_class, tf_env): load_dir = FLAGS.load_dir assert load_dir and osp.exists( load_dir ), 'need to provide valid load_dir to load policy, got: {}'.format( load_dir) global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=(256, 256), continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork((observation_spec, action_spec), joint_fc_layer_params=(256, 256)) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4), target_update_tau=0.005, target_update_period=1, td_errors_loss_fn=tf.keras.losses.mse, gamma=0, reward_scale_factor=1., gradient_clipping=1., debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=global_step) train_checkpointer = common.Checkpointer(ckpt_dir=load_dir, agent=tf_agent, global_step=global_step) status = train_checkpointer.initialize_or_restore() status.expect_partial() logging.info('Loaded from checkpoint: %s, trained %s steps', train_checkpointer._manager.latest_checkpoint, global_step.numpy()) return tf_agent.policy
def testDropoutJointFCLayers(self, training): batch_size = 3 num_obs_dims = 5 num_actions_dims = 2 obs_spec = tensor_spec.TensorSpec([num_obs_dims], tf.float32) action_spec = tensor_spec.TensorSpec([num_actions_dims], tf.float32) critic_net = critic_network.CriticNetwork( (obs_spec, action_spec), joint_fc_layer_params=[20], joint_dropout_layer_params=[0.5]) obs = tf.random.uniform([batch_size, num_obs_dims]) actions = tf.random.uniform([batch_size, num_actions_dims]) q_values1, _ = critic_net((obs, actions), training=training) q_values2, _ = critic_net((obs, actions), training=training) self.evaluate(tf.compat.v1.global_variables_initializer()) q_values1, q_values2 = self.evaluate([q_values1, q_values2]) if training: self.assertGreater(np.linalg.norm(q_values1 - q_values2), 0) else: self.assertAllEqual(q_values1, q_values2)
def _create_agent(train_step: tf.Variable, observation_tensor_spec: types.NestedTensorSpec, action_tensor_spec: types.NestedTensorSpec, time_step_tensor_spec: ts.TimeStep, learning_rate: float) -> tf_agent.TFAgent: """Creates an agent.""" critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=(256, 256), kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=(256, 256), continuous_projection_net=tanh_normal_projection_network .TanhNormalProjectionNetwork) return sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate), target_update_tau=0.005, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, train_step_counter=train_step)
def ACnetworks(environment, hyperparams) -> (actor_network, critic_network): observation_spec = environment.observation_spec() action_spec = environment.action_spec() actor_net = actor_network.ActorNetwork( input_tensor_spec=observation_spec, output_tensor_spec=action_spec, fc_layer_params=hyperparams['actor_fc_layer_params'], dropout_layer_params=hyperparams['actor_dropout'], activation_fn=tf.nn.relu ) critic_net = critic_network.CriticNetwork( input_tensor_spec=(observation_spec, action_spec), observation_fc_layer_params=hyperparams['critic_obs_fc_layer_params'], action_fc_layer_params=hyperparams['critic_action_fc_layer_params'], joint_fc_layer_params=hyperparams['critic_joint_fc_layer_params'], joint_dropout_layer_params=hyperparams['critic_joint_dropout'], activation_fn=tf.nn.relu ) return (actor_net, critic_net)
def verifyVariableAssignAndRestore(self, loss_fn=None): strategy = tf.distribute.get_strategy() with strategy.scope(): # Use BehaviorCloningAgent instead of AWRAgent to test the network. q_net = critic_network.CriticNetwork( (self._observation_spec, self._action_spec)) agent = qtopt_agent.QtOptAgent(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None, init_mean_cem=self._mean, init_var_cem=self._var, num_samples_cem=self._num_samples, actions_sampler=self._sampler) # Assign all vars to 0. for var in tf.nest.flatten(agent.variables): var.assign(tf.zeros_like(var)) # Save checkpoint ckpt_dir = self.create_tempdir() checkpointer = common.Checkpointer(ckpt_dir=ckpt_dir, agent=agent) global_step = tf.constant(0) checkpointer.save(global_step) # Assign all vars to 1. for var in tf.nest.flatten(agent.variables): var.assign(tf.ones_like(var)) # Restore to 0. checkpointer._checkpoint.restore( checkpointer._manager.latest_checkpoint) for var in tf.nest.flatten(agent.variables): value = var.numpy() if isinstance(value, np.int64): self.assertEqual(value, 0) else: self.assertAllEqual( value, np.zeros_like(value), msg='{} has var mean {}, expected 0.'.format( var.name, value))
def testPolicy(self): q_net = critic_network.CriticNetwork( (self._observation_spec, self._action_spec)) agent = qtopt_agent.QtOptAgent(self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None, init_mean_cem=self._mean, init_var_cem=self._var, num_samples_cem=self._num_samples, actions_sampler=self._sampler) observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) time_steps = ts.restart(observations, batch_size=2) policy = agent.policy action_step = policy.action(time_steps) # Batch size 2. self.assertAllEqual( [2] + self._action_spec.shape.as_list(), action_step.action.shape, ) self.evaluate(tf.compat.v1.initialize_all_variables()) actions_ = self.evaluate(action_step.action) self.assertTrue(all(actions_ <= self._action_spec.maximum)) self.assertTrue(all(actions_ >= self._action_spec.minimum))
def _create_agent(agent: DuelAgent, train_step) -> SacAgent: observation_spec, action_spec, time_step_spec = spec_utils.get_tensor_specs( agent._collect_env) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=agent._critic_joint_fc_layer_params, ) actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=agent._actor_fc_layer_params, continuous_projection_net=TanhNormalProjectionNetwork) tf_agent = SacAgent(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=agent._actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=agent._critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=agent._alpha_learning_rate), target_update_tau=agent._target_update_tau, target_update_period=agent._target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=agent._gamma, reward_scale_factor=agent._reward_scale_factor, train_step_counter=train_step) tf_agent.initialize() return tf_agent
print(env.time_step_spec().observation) print('Action Spec:') print(env.action_spec()) collect_env = get_tf_wrapped_robo_rugby_env() eval_env = get_tf_wrapped_robo_rugby_env() objStrategy = strategy_utils.get_strategy(tpu=False, use_gpu=True) specObservation, specAction, specTimeStep = ( spec_utils.get_tensor_specs(collect_env)) with objStrategy.scope(): # Critic network trains the Actor network nnCritic = critic_network.CriticNetwork( (specObservation, specAction), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=HyperParms.critic_joint_fc_layer_params, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') with objStrategy.scope(): nnActor = actor_distribution_network.ActorDistributionNetwork( specObservation, specAction, fc_layer_params=HyperParms.actor_fc_layer_params, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork)) with objStrategy.scope(): train_step = train_utils.create_train_step()
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400, ), critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=100000, exploration_noise_std=0.1, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_update_period=2, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for TD3.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_mujoco.load(env_name)) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) return train_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] eval_summary_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create the environment. tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) # Get the data specs from the environment time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] eval_py_policy = py_tf_policy.PyTFPolicy( greedy_policy.GreedyPolicy(tf_agent.policy)) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] collect_policy = tf_agent.collect_policy initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration).run() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = replay_buffer.as_dataset( sample_batch_size=5 * batch_size, num_steps=2).apply(tf.data.experimental.unbatch()).filter( _filter_invalid_transition).batch(batch_size).prefetch( batch_size * 5) dataset_iterator = tf.compat.v1.data.make_initializable_iterator( dataset) trajectories, unused_info = dataset_iterator.get_next() train_op = tf_agent.train(trajectories) summary_ops = [] for train_metric in train_metrics: summary_ops.append( train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) with tf.compat.v1.Session() as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) # Initialize training. sess.run(dataset_iterator.initializer) common.initialize_uninitialized_variables(sess) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) global_step_val = sess.run(global_step) if global_step_val == 0: # Initial eval of randomly initialized policy metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) # Run initial collect. logging.info('Global step %d: Running initial collect op.', global_step_val) sess.run(initial_collect_op) # Checkpoint the initial replay buffer contents. rb_checkpointer.save(global_step=global_step_val) logging.info('Finished initial collect.') else: logging.info('Global step %d: Skipping initial collect op.', global_step_val) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops]) global_step_call = sess.make_callable(global_step) timed_at_step = global_step_call() time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): total_loss, _ = train_step_call() time_acc += time.time() - start_time global_step_val = global_step_call() if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_flush_op) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', agent_name='sac', # agent's name num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs mask_names=['birdeye'], # names for masks preprocessing_combiner=tf.keras.layers.Add( ), # takes a flat list of tensors and combines them actor_lstm_size=(40, ), # lstm size for actor critic_lstm_size=(40, ), # lstm size for critic actor_output_fc_layers=(100, ), # lstm output critic_output_fc_layers=(100, ), # lstm output epsilon_greedy=0.1, # exploration parameter for DQN q_learning_rate=1e-3, # q learning rate for DQN ou_stddev=0.2, # exploration paprameter for DDPG ou_damping=0.15, # exploration parameter for DDPG dqda_clipping=None, # for DDPG exploration_noise_std=0.1, # exploration paramter for td3 actor_update_period=2, # for td3 # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, initial_model_train_steps=100000, # initial model training batch_size=256, model_batch_size=32, # model training batch size sequence_length=4, # number of timesteps to train model actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, # learning rate for model training td_errors_loss_fn=tf.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1 ): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) # Get train and eval directories root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment if agent_name == 'latent_sac': py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names + mask_names, action_repeat=action_repeat) elif agent_name == 'dqn': py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat) else: py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() ## Make tf agent if agent_name == 'latent_sac': # Get model network for latent sac if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical elif model_network_ctor_type == 'non-hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical else: raise NotImplementedError model_net = model_network_ctor(input_names, input_names + mask_names) # Get the latent spec latent_size = model_net.latent_size latent_observation_spec = tensor_spec.TensorSpec((latent_size, ), dtype=tf.float32) latent_time_step_spec = ts.time_step_spec( observation_spec=latent_observation_spec) # Get actor and critic net actor_net = actor_distribution_network.ActorDistributionNetwork( latent_observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (latent_observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) # Build the inner SAC agent based on latent space inner_agent = sac_agent.SacAgent( latent_time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) inner_agent.initialize() # Build the latent sac agent tf_agent = latent_sac_agent.LatentSACAgent( time_step_spec, action_spec, inner_agent=inner_agent, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), model_batch_size=model_batch_size, num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) else: # Set up preprosessing layers for dictionary observation inputs preprocessing_layers = collections.OrderedDict() for name in input_names: preprocessing_layers[name] = Preprocessing_Layer(32, 256) if len(input_names) < 2: preprocessing_combiner = None if agent_name == 'dqn': q_rnn_net = q_rnn_network.QRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_rnn_net, epsilon_greedy=epsilon_greedy, n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=q_learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'ddpg' or agent_name == 'td3': actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) if agent_name == 'ddpg': tf_agent = ddpg_agent.DdpgAgent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) elif agent_name == 'td3': tf_agent = td3_agent.Td3Agent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'sac': actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_distribution_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math. squared_difference, # make critic loss dimension compatible gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: raise NotImplementedError tf_agent.initialize() # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_policy = tf_agent.policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) # Optimize the performance by using tf functions initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if (env_steps.result() == 0 or replay_buffer.num_frames() == 0): logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() if agent_name == 'latent_sac': compute_summaries(eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=1, num_episodes_to_render=1, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=1, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) if agent_name == 'latent_sac': def train_model_step(): experience, _ = next(iterator) return tf_agent.train_model(experience) train_model_step = common.function(train_model_step) # Training initializations time_step = None time_acc = 0 env_steps_before = env_steps.result().numpy() # Start training for iteration in range(num_iterations): start_time = time.time() if agent_name == 'latent_sac' and iteration < initial_model_train_steps: train_model_step() else: # Run collect time_step, _ = collect_driver.run(time_step=time_step) # Train an iteration for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time # Log training information if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400,), critic_action_fc_layers=None, critic_joint_fc_layers=(300,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=100000, exploration_noise_std=0.1, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_update_period=2, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for TD3.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_py_env = suite_mujoco.load(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = td3_agent.Td3Agent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] collect_policy = tf_agent.collect_policy initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = tf.compat.v1.data.make_initializable_iterator(dataset) trajectories, unused_info = iterator.get_next() train_fn = common.function(tf_agent.train) train_op = train_fn(experience=trajectories) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2])) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(train_step=global_step) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(b/126239733): Remove once Periodically can be saved. common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_ops, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, )
eval_py_env = StockEnv(eval_states, discrete=False, delay=delay, eval=True) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) from tf_agents.agents.ddpg import ddpg_agent, actor_network, critic_network from TrainAndEvaluate import train_and_evaluate_ACagent actor_net = CustomActorNetwork(train_env.observation_spec(), train_env.action_spec(), preprocessing_layers=preprocessing_layers, fc_layer_params=actor_fc_layer_params) critic_net = critic_network.CriticNetwork( (train_env.observation_spec(), train_env.action_spec()), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=critic_joint_fc_layer_params) tf_agent = td3_agent.Td3Agent( train_env.time_step_spec(), train_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=loss_fn,
def train_eval( root_dir, env_name='HalfCheetah-v1', env_load_fn=suite_mujoco.load, num_iterations=2000000, actor_fc_layers=(400, 300), critic_obs_fc_layers=(400,), critic_action_fc_layers=None, critic_joint_fc_layers=(300,), # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=100000, ou_stddev=0.2, ou_damping=0.15, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=64, actor_learning_rate=1e-4, critic_learning_rate=1e-3, dqda_clipping=None, td_errors_loss_fn=tf.losses.huber_loss, gamma=0.995, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for checkpoints, summaries, and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=20000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for DDPG.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.contrib.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes), ] # TODO(kbanoop): Figure out if it is possible to avoid the with block. with tf.contrib.summary.record_summaries_every_n_global_steps( summary_interval): if num_parallel_environments > 1: tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) else: tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) eval_py_env = env_load_fn(env_name) actor_net = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=actor_fc_layers, ) critic_net_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) critic_net = critic_network.CriticNetwork( critic_net_input_specs, observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, ) tf_agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec(), batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy()) train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] global_step = tf.train.get_or_create_global_step() collect_policy = tf_agent.collect_policy() initial_collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch], num_steps=initial_collect_steps).run() collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration).run() # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = dataset.make_initializable_iterator() trajectories, unused_info = iterator.get_next() train_op = tf_agent.train( experience=trajectories, train_step_counter=global_step) train_checkpointer = common_utils.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=tf.contrib.checkpoint.List(train_metrics)) policy_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy(), global_step=global_step) rb_checkpointer = common_utils.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(step_metrics=train_metrics[:2]) summary_op = tf.contrib.summary.all_summary_ops() with eval_summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): for eval_metric in eval_metrics: eval_metric.tf_summaries() init_agent_op = tf_agent.initialize() with tf.Session() as sess: # Initialize the graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) sess.run(iterator.initializer) # TODO(sguada) Remove once Periodically can be saved. common_utils.initialize_uninitialized_variables(sess) sess.run(init_agent_op) tf.contrib.summary.initialize(session=sess) sess.run(initial_collect_op) global_step_val = sess.run(global_step) metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, ) collect_call = sess.make_callable(collect_op) train_step_call = sess.make_callable([train_op, summary_op, global_step]) timed_at_step = sess.run(global_step) time_acc = 0 steps_per_second_ph = tf.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) for _ in range(num_iterations): start_time = time.time() collect_call() for _ in range(train_steps_per_iteration): loss_info_value, _, global_step_val = train_step_call() time_acc += time.time() - start_time if global_step_val % log_interval == 0: tf.logging.info('step = %d, loss = %f', global_step_val, loss_info_value.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc tf.logging.info('%.3f steps/sec' % steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) timed_at_step = global_step_val time_acc = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, )
def DDPG_Bipedal(root_dir): # Setting up directories for results root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train' + '/' + str(run_id)) eval_dir = os.path.join(root_dir, 'eval' + '/' + str(run_id)) vid_dir = os.path.join(root_dir, 'vid' + '/' + str(run_id)) # Set up Summary writer for training and evaluation train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ # Metric to record average return tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), # Metric to record average episode length tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] #Create global step global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Load Environment with different wrappers tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name)) eval_py_env = suite_gym.load(env_name) # Define Actor Network actorNN = actor_network.ActorNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=(400, 300), ) # Define Critic Network NN_input_specs = (tf_env.time_step_spec().observation, tf_env.action_spec()) criticNN = critic_network.CriticNetwork( NN_input_specs, observation_fc_layer_params=(400, ), action_fc_layer_params=None, joint_fc_layer_params=(300, ), ) # Define & initialize DDPG Agent agent = ddpg_agent.DdpgAgent( tf_env.time_step_spec(), tf_env.action_spec(), actor_network=actorNN, critic_network=criticNN, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=gamma, train_step_counter=global_step) agent.initialize() # Determine which train metrics to display with summary writer train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Set policies for evaluation, initial collection eval_policy = agent.policy # Actor policy collect_policy = agent.collect_policy # Actor policy with OUNoise # Set up replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) # Define driver for initial replay buffer filling initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, # Initializes with random Parameters observers=[replay_buffer.add_batch], num_steps=initial_collect_steps) # Define collect driver for collect steps per iteration collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) agent.train = common.function(agent.train) # Make 1000 random steps in tf_env and save in Replay Buffer logging.info( 'Initializing replay buffer by collecting experience for 1000 steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() # Computes Evaluation Metrics results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset outputs steps in batches of 64 dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=64, num_steps=2).prefetch(3) iterator = iter(dataset) def train_step(): experience, _ = next( iterator) #Get experience from dataset (replay buffer) return agent.train(experience) #Train agent on that experience if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() # Get start time # Collect data for replay buffer time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) # Train on experience for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='iterations_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.log_metrics(eval_metrics) if results['AverageReturn'].numpy() >= 230.0: video_score = create_video(video_dir=vid_dir, env_name="BipedalWalker-v2", vid_policy=eval_policy, video_id=global_step.numpy()) return train_loss
def train_eval( root_dir, environment_name="broken_reacher", num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), initial_collect_steps=10000, real_initial_collect_steps=10000, collect_steps_per_iteration=1, real_collect_interval=10, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, classifier_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=True, summarize_grads_and_vars=False, train_on_real=False, delta_r_warmup=0, random_seed=0, checkpoint_dir=None, ): """A simple train and eval for SAC.""" np.random.seed(random_seed) tf.random.set_seed(random_seed) root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, "train") eval_dir = os.path.join(root_dir, "eval") train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) if environment_name == "broken_reacher": get_env_fn = darc_envs.get_broken_reacher_env elif environment_name == "half_cheetah_obstacle": get_env_fn = darc_envs.get_half_cheetah_direction_env elif environment_name == "inverted_pendulum": get_env_fn = darc_envs.get_inverted_pendulum_env elif environment_name.startswith("broken_joint"): base_name = environment_name.split("broken_joint_")[1] get_env_fn = functools.partial(darc_envs.get_broken_joint_env, env_name=base_name) elif environment_name.startswith("falling"): base_name = environment_name.split("falling_")[1] get_env_fn = functools.partial(darc_envs.get_falling_env, env_name=base_name) else: raise NotImplementedError("Unknown environment: %s" % environment_name) eval_name_list = ["sim", "real"] eval_env_list = [get_env_fn(mode) for mode in eval_name_list] eval_metrics_list = [] for name in eval_name_list: eval_metrics_list.append([ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, name="AverageReturn_%s" % name), ]) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env_real = get_env_fn("real") if train_on_real: tf_env = get_env_fn("real") else: tf_env = get_env_fn("sim") time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=( tanh_normal_projection_network.TanhNormalProjectionNetwork), ) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer="glorot_uniform", last_kernel_initializer="glorot_uniform", ) classifier = classifiers.build_classifier(observation_spec, action_spec) tf_agent = darc_agent.DarcAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, classifier=classifier, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), classifier_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=classifier_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, ) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) replay_observer = [replay_buffer.add_batch] real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity, ) real_replay_observer = [real_replay_buffer.add_batch] sim_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnSim", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthSim", ), ] real_train_metrics = [ tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"), tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"), tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageReturnReal", ), tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size, name="AverageEpisodeLengthReal", ), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup( sim_train_metrics + real_train_metrics, "train_metrics"), ) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "policy"), policy=eval_policy, global_step=global_step, ) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, "replay_buffer"), max_to_keep=1, replay_buffer=(replay_buffer, real_replay_buffer), ) if checkpoint_dir is not None: checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) assert checkpoint_path is not None train_checkpointer._load_status = train_checkpointer._checkpoint.restore( # pylint: disable=protected-access checkpoint_path) train_checkpointer._load_status.initialize_or_restore() # pylint: disable=protected-access else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if replay_buffer.num_frames() == 0: initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + sim_train_metrics, num_steps=initial_collect_steps, ) real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, initial_collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=real_initial_collect_steps, ) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + sim_train_metrics, num_steps=collect_steps_per_iteration, ) real_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env_real, collect_policy, observers=real_replay_observer + real_train_metrics, num_steps=collect_steps_per_iteration, ) config_str = gin.operative_config_str() logging.info(config_str) with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"), "w") as f: f.write(config_str) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) real_initial_collect_driver.run = common.function( real_initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) real_collect_driver.run = common.function(real_collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if replay_buffer.num_frames() == 0: logging.info( "Initializing replay buffer by collecting experience for %d steps with " "a random policy.", initial_collect_steps, ) initial_collect_driver.run() real_initial_collect_driver.run() for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) time_step = None real_time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): return ~trajectories.is_boundary()[0] dataset = (replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) real_dataset = (real_replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=2).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5)) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) real_iterator = iter(real_dataset) def train_step(): experience, _ = next(iterator) real_experience, _ = next(real_iterator) return tf_agent.train(experience, real_experience=real_experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) assert not policy_state # We expect policy_state == (). if (global_step.numpy() % real_collect_interval == 0 and global_step.numpy() >= delta_r_warmup): real_time_step, policy_state = real_collect_driver.run( time_step=real_time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): train_loss = train_step() time_acc += time.time() - start_time global_step_val = global_step.numpy() if global_step_val % log_interval == 0: logging.info("step = %d, loss = %f", global_step_val, train_loss.loss) steps_per_sec = (global_step_val - timed_at_step) / time_acc logging.info("%.3f steps/sec", steps_per_sec) tf.compat.v2.summary.scalar(name="global_steps_per_sec", data=steps_per_sec, step=global_step) timed_at_step = global_step_val time_acc = 0 for train_metric in sim_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=sim_train_metrics[:2]) for train_metric in real_train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=real_train_metrics[:2]) if global_step_val % eval_interval == 0: for eval_name, eval_env, eval_metrics in zip( eval_name_list, eval_env_list, eval_metrics_list): metric_utils.eager_compute( eval_metrics, eval_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix="Metrics-%s" % eval_name, ) metric_utils.log_metrics(eval_metrics) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=10000, collect_steps_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries( train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=3200000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others # Defaults to not checkpointing saved policy. If you wish to enable this, # please note the caveat explained in README.md. policy_save_interval=-1, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" logging.info('Training SAC on: %s', env_name) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_tensor_spec, action_tensor_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_network.CriticNetwork( (observation_tensor_spec, action_tensor_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') agent = sac_agent.SacAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table(table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).prefetch(50) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner(root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join( root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() rb_observer.close() reverb_server.stop()
log_interval = 5000 # @param {type:"integer"} num_eval_episodes = 30 # @param {type:"integer"} eval_interval = 10000 # @param {type:"integer"} train_py_env = suite_gym.load(env_name) eval_py_env = suite_gym.load(env_name) train_env = tf_py_environment.TFPyEnvironment(train_py_env) eval_env = tf_py_environment.TFPyEnvironment(eval_py_env) observation_spec = train_env.observation_spec() action_spec = train_env.action_spec() critic_net = critic_network.CriticNetwork( (observation_spec, action_spec), observation_fc_layer_params=None, action_fc_layer_params=None, joint_fc_layer_params=critic_joint_fc_layer_params) def normal_projection_net(action_spec, init_means_output_factor=0.1): return normal_projection_network.NormalProjectionNetwork( action_spec, mean_transform=None, state_dependent_std=True, init_means_output_factor=init_means_output_factor, std_transform=sac_agent.std_clip_transform, scale_distribution=True) actor_net = actor_distribution_network.ActorDistributionNetwork(