def testBuilds(self): observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0, 1) time_step_spec = ts.time_step_spec(observation_spec) time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(1,)) action_spec = [ tensor_spec.BoundedTensorSpec((2,), tf.float32, 2, 3), tensor_spec.BoundedTensorSpec((3,), tf.int32, 0, 3) ] net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, conv_layer_params=[(4, 2, 2)], input_fc_layer_params=(5,), output_fc_layer_params=(5,), lstm_size=(3,)) action_distributions, network_state = net( time_step.observation, time_step.step_type, net.get_initial_state(batch_size=1)) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual([1, 2], action_distributions[0].mode().shape.as_list()) self.assertEqual([1, 3], action_distributions[1].mode().shape.as_list()) self.assertEqual(14, len(net.variables)) # Conv Net Kernel self.assertEqual((2, 2, 3, 4), net.variables[0].shape) # Conv Net bias self.assertEqual((4,), net.variables[1].shape) # Fc Kernel self.assertEqual((64, 5), net.variables[2].shape) # Fc Bias self.assertEqual((5,), net.variables[3].shape) # LSTM Cell Kernel self.assertEqual((5, 12), net.variables[4].shape) # LSTM Cell Recurrent Kernel self.assertEqual((3, 12), net.variables[5].shape) # LSTM Cell Bias self.assertEqual((12,), net.variables[6].shape) # Fc Kernel self.assertEqual((3, 5), net.variables[7].shape) # Fc Bias self.assertEqual((5,), net.variables[8].shape) # Normal Projection Kernel self.assertEqual((5, 2), net.variables[9].shape) # Normal Projection Bias self.assertEqual((2,), net.variables[10].shape) # Normal Projection STD Bias layer self.assertEqual((2,), net.variables[11].shape) # Categorical Projection Kernel self.assertEqual((5, 12), net.variables[12].shape) # Categorical Projection Bias self.assertEqual((12,), net.variables[13].shape) # Assert LSTM cell is created. self.assertEqual((1, 3), network_state[0].shape) self.assertEqual((1, 3), network_state[1].shape)
def testRunsWithLstmStack(self): observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0, 1) time_step_spec = ts.time_step_spec(observation_spec) time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(1, 5)) action_spec = [ tensor_spec.BoundedTensorSpec((2,), tf.float32, 2, 3), tensor_spec.BoundedTensorSpec((3,), tf.int32, 0, 3) ] net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, conv_layer_params=[(4, 2, 2)], input_fc_layer_params=(5,), output_fc_layer_params=(5,), lstm_size=(3, 3)) initial_state = actor_policy.ActorPolicy(time_step_spec, action_spec, net).get_initial_state(1) net_call = net(time_step.observation, time_step.step_type, initial_state) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(tf.nest.map_structure(lambda d: d.sample(), net_call[0]))
def testHandlePreprocessingLayers(self): observation_spec = (tensor_spec.TensorSpec([1], tf.float32), tensor_spec.TensorSpec([], tf.float32)) time_step_spec = ts.time_step_spec(observation_spec) time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(3, 4)) action_spec = [ tensor_spec.BoundedTensorSpec((2,), tf.float32, 2, 3), tensor_spec.BoundedTensorSpec((3,), tf.int32, 0, 3) ] preprocessing_layers = (tf.keras.layers.Dense(4), sequential_layer.SequentialLayer([ tf.keras.layers.Reshape((1,)), tf.keras.layers.Dense(4) ])) net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=tf.keras.layers.Add()) initial_state = actor_policy.ActorPolicy(time_step_spec, action_spec, net).get_initial_state(3) action_distributions, _ = net(time_step.observation, time_step.step_type, initial_state) self.evaluate(tf.compat.v1.global_variables_initializer()) self.assertEqual([3, 4, 2], action_distributions[0].mode().shape.as_list()) self.assertEqual([3, 4, 3], action_distributions[1].mode().shape.as_list()) self.assertGreater(len(net.trainable_variables), 4)
def testTrainWithRnn(self): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40,), ) critic_net = critic_rnn_network.CriticRnnNetwork( (self._obs_spec, self._action_spec), observation_fc_layer_params=(16,), action_fc_layer_params=(16,), joint_fc_layer_params=(16,), lstm_size=(16,), output_fc_layer_params=None, ) counter = common.create_variable('test_train_counter') optimizer_fn = tf.compat.v1.train.AdamOptimizer agent = sac_agent.SacAgent( self._time_step_spec, self._action_spec, critic_network=critic_net, actor_network=actor_net, actor_optimizer=optimizer_fn(1e-3), critic_optimizer=optimizer_fn(1e-3), alpha_optimizer=optimizer_fn(1e-3), train_step_counter=counter, ) batch_size = 5 observations = tf.constant( [[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep( step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=[observations]) experience = trajectory.Trajectory( time_steps.step_type, [observations], actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Force variable creation. agent.policy.variables() if tf.executing_eagerly(): loss = lambda: agent.train(experience) else: loss = agent.train(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertEqual(self.evaluate(counter), 0) self.evaluate(loss) self.assertEqual(self.evaluate(counter), 1)
def testRNNTrain(self, compute_value_and_advantage_in_train): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._time_step_spec.observation, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, lstm_size=(20, )) value_net = value_rnn_network.ValueRnnNetwork( self._time_step_spec.observation, input_fc_layer_params=None, output_fc_layer_params=None, lstm_size=(10, )) global_step = tf.compat.v1.train.get_or_create_global_step() agent = ppo_agent.PPOAgent( self._time_step_spec, self._action_spec, optimizer=tf.compat.v1.train.AdamOptimizer(), actor_net=actor_net, value_net=value_net, num_epochs=1, train_step_counter=global_step, compute_value_and_advantage_in_train= compute_value_and_advantage_in_train) # Use a random env, policy, and replay buffer to collect training data. random_env = random_tf_environment.RandomTFEnvironment( self._time_step_spec, self._action_spec, batch_size=1) collection_policy = random_tf_policy.RandomTFPolicy( self._time_step_spec, self._action_spec, info_spec=agent.collect_policy.info_spec) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collection_policy.trajectory_spec, batch_size=1, max_length=7) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( random_env, collection_policy, observers=[replay_buffer.add_batch], num_episodes=1) # In graph mode: finish building the graph so the optimizer # variables are created. if not tf.executing_eagerly(): _, _ = agent.train(experience=replay_buffer.gather_all()) # Initialize. self.evaluate(agent.initialize()) self.evaluate(tf.compat.v1.global_variables_initializer()) # Train one step. self.assertEqual(0, self.evaluate(global_step)) self.evaluate(collect_driver.run()) self.evaluate(agent.train(experience=replay_buffer.gather_all())) self.assertEqual(1, self.evaluate(global_step))
def testTrainWithRnn(self): with tf.compat.v2.summary.record_if(False): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40, )) counter = common.create_variable('test_train_counter') agent = reinforce_agent.ReinforceAgent( self._time_step_spec, self._action_spec, actor_network=actor_net, optimizer=tf.compat.v1.train.AdamOptimizer(0.001), train_step_counter=counter) batch_size = 5 observations = tf.constant([[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep( step_type=tf.constant([[1, 1, 2]] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=observations) actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32) experience = trajectory.Trajectory(time_steps.step_type, observations, actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Force variable creation. agent.policy.variables() if tf.executing_eagerly(): loss = lambda: agent.train(experience) else: loss = agent.train(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) self.assertEqual(self.evaluate(counter), 0) self.evaluate(loss) self.assertEqual(self.evaluate(counter), 1)
def testTrainWithRnnTransitions(self): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40,)) counter = common.create_variable('test_train_counter') agent = reinforce_agent.ReinforceAgent( self._time_step_spec, self._action_spec, actor_network=actor_net, optimizer=tf.compat.v1.train.AdamOptimizer(0.001), train_step_counter=counter ) batch_size = 5 observations = tf.constant( [[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep( step_type=tf.constant([[1, 1, 1]] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=observations) actions = policy_step.PolicyStep( tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32)) next_time_steps = ts.TimeStep( step_type=tf.constant([[1, 1, 2]] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=observations) experience = trajectory.Transition(time_steps, actions, next_time_steps) agent.initialize() agent.train(experience)
def load_agents_and_create_videos(root_dir, env_name='CartPole-v0', env_load_fn=suite_gym.load, random_seed=None, max_ep_steps=1000, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=5000000, collect_episodes_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=10000, # Per-environment # Params for train num_epochs=25, learning_rate=1e-3, # Params for eval num_eval_episodes=10, num_random_episodes=1, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, rb_checkpoint_interval=20000, log_interval=50, summary_interval=50, summaries_flush_secs=10, use_tf_functions=True, debug_summaries=False, eval_metrics_callback=None, random_metrics_callback=None, summarize_grads_and_vars=False): root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') random_dir = os.path.join(root_dir, 'random') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)] random_summary_writer = tf.compat.v2.summary.create_file_writer( random_dir, flush_millis=summaries_flush_secs * 1000) random_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)] global_step = tf.compat.v1.train.get_or_create_global_step() if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) eval_py_env = env_load_fn(env_name, max_episode_steps=max_ep_steps) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) # tf_env = tf_py_environment.TFPyEnvironment( # parallel_py_environment.ParallelPyEnvironment( # [lambda: env_load_fn(env_name, max_episode_steps=max_ep_steps)] * num_parallel_environments)) tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name, max_episode_steps=max_ep_steps)) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, kl_cutoff_factor=0.0, initial_adaptive_kl_beta=0.0, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [tf_metrics.NumberOfEpisodes(), environment_steps_metric,] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments),] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) saved_model = policy_saver.PolicySaver( eval_policy, train_step=global_step) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) # if use_tf_functions: # # To speed up collect use common.function. # collect_driver.run = common.function(collect_driver.run) # tf_agent.train = common.function(tf_agent.train) # initial_collect_policy = random_tf_policy.RandomTFPolicy( # tf_env.time_step_spec(), tf_env.action_spec()) random_policy = random_tf_policy.RandomTFPolicy(eval_tf_env.time_step_spec(), eval_tf_env.action_spec()) # Make movies of the trained agent and a random agent date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S') trained_filename = "trainedPPO_" + date_string create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename) random_filename = 'random_' + date_string create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
def testTrainWithRnn(self, cql_alpha, num_cql_samples, include_critic_entropy_term, use_lagrange_cql_alpha, expected_loss): actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._obs_spec, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, conv_layer_params=None, lstm_size=(40, ), ) critic_net = critic_rnn_network.CriticRnnNetwork( (self._obs_spec, self._action_spec), observation_fc_layer_params=(16, ), action_fc_layer_params=(16, ), joint_fc_layer_params=(16, ), lstm_size=(16, ), output_fc_layer_params=None, ) counter = common.create_variable('test_train_counter') optimizer_fn = tf.compat.v1.train.AdamOptimizer agent = cql_sac_agent.CqlSacAgent( self._time_step_spec, self._action_spec, critic_network=critic_net, actor_network=actor_net, actor_optimizer=optimizer_fn(1e-3), critic_optimizer=optimizer_fn(1e-3), alpha_optimizer=optimizer_fn(1e-3), cql_alpha=cql_alpha, num_cql_samples=num_cql_samples, include_critic_entropy_term=include_critic_entropy_term, use_lagrange_cql_alpha=use_lagrange_cql_alpha, random_seed=self._random_seed, train_step_counter=counter, ) batch_size = 5 observations = tf.constant([[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32) actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32) time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32), observation=observations) experience = trajectory.Trajectory(time_steps.step_type, observations, actions, (), time_steps.step_type, time_steps.reward, time_steps.discount) # Force variable creation. agent.policy.variables() if not tf.executing_eagerly(): # Get experience first to make sure optimizer variables are created and # can be initialized. experience = agent.train(experience) with self.cached_session() as sess: common.initialize_uninitialized_variables(sess) self.assertEqual(self.evaluate(counter), 0) self.evaluate(experience) self.assertEqual(self.evaluate(counter), 1) else: self.assertEqual(self.evaluate(counter), 0) loss = self.evaluate(agent.train(experience)) self.assertAllClose(loss.loss, expected_loss) self.assertEqual(self.evaluate(counter), 1)
def construct_multigrid_networks( observation_spec, action_spec, use_rnns=True, actor_fc_layers=(200, 100), value_fc_layers=(200, 100), lstm_size=(128, ), conv_filters=8, conv_kernel=3, scalar_fc=5, scalar_name="direction", scalar_dim=4, use_stacks=False, ): """Creates an actor and critic network designed for use with MultiGrid. A convolution layer processes the image and a dense layer processes the direction the agent is facing. These are fed into some fully connected layers and an LSTM. Args: observation_spec: A tf-agents observation spec. action_spec: A tf-agents action spec. use_rnns: If True, will construct RNN networks. actor_fc_layers: Dimension and number of fully connected layers in actor. value_fc_layers: Dimension and number of fully connected layers in critic. lstm_size: Number of cells in each LSTM layers. conv_filters: Number of convolution filters. conv_kernel: Size of the convolution kernel. scalar_fc: Number of neurons in the fully connected layer processing the scalar input. scalar_name: Name of the scalar input. scalar_dim: Highest possible value for the scalar input. Used to convert to one-hot representation. use_stacks: Use ResNet stacks (compresses the image). Returns: A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork for the critic. """ preprocessing_layers = { "policy_state": tf.keras.layers.Lambda(lambda x: x) } if use_stacks: preprocessing_layers["image"] = tf.keras.models.Sequential([ multigrid_networks.cast_and_scale(), _Stack(conv_filters // 2, 2), _Stack(conv_filters, 2), tf.keras.layers.ReLU(), tf.keras.layers.Flatten() ]) else: preprocessing_layers["image"] = tf.keras.models.Sequential([ multigrid_networks.cast_and_scale(), tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding="same"), tf.keras.layers.ReLU(), tf.keras.layers.Flatten() ]) if scalar_name in observation_spec: preprocessing_layers[scalar_name] = tf.keras.models.Sequential([ multigrid_networks.one_hot_layer(scalar_dim), tf.keras.layers.Dense(scalar_fc) ]) if "position" in observation_spec: preprocessing_layers["position"] = tf.keras.models.Sequential([ multigrid_networks.cast_and_scale(), tf.keras.layers.Dense(scalar_fc) ]) preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) custom_objects = {"_Stack": _Stack} with tf.keras.utils.custom_object_scope(custom_objects): if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = value_rnn_network.ValueRnnNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) return actor_net, value_net
def train_eval( root_dir, env_name=None, env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), inference_fc_layers=(200, 100), use_rnns=None, dim_z=4, categorical=True, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, entropy_regularization=None, kl_posteriors_penalty=None, mock_inference=None, mock_reward=None, l2_distance=None, rl_steps=None, inference_steps=None, # Params for eval num_eval_episodes=30, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=10000, log_interval=1000, summary_interval=1000, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) def _env_load_fn(env_name): diayn_wrapper = ( lambda x: diayn_gym_env.DiaynGymEnv(x, dim_z, categorical)) return env_load_fn( env_name, gym_env_wrappers=[diayn_wrapper], ) eval_tf_env = tf_py_environment.TFPyEnvironment(_env_load_fn(env_name)) if num_parallel_environments == 1: py_env = _env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: _env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) augmented_time_step_spec = tf_env.time_step_spec() augmented_observation_spec = augmented_time_step_spec.observation observation_spec = augmented_observation_spec['observation'] z_spec = augmented_observation_spec['z'] reward_spec = augmented_time_step_spec.reward action_spec = tf_env.action_spec() time_step_spec = ts.time_step_spec(observation_spec) infer_from_com = False if env_name == "AntRandGoalEval-v1": infer_from_com = True if infer_from_com: input_inference_spec = tspec.BoundedTensorSpec( shape=[2], dtype=tf.float64, minimum=-1.79769313e+308, maximum=1.79769313e+308, name='body_com') else: input_inference_spec = observation_spec if tensor_spec.is_discrete(z_spec): _preprocessing_combiner = OneHotConcatenateLayer(dim_z) else: _preprocessing_combiner = DictConcatenateLayer() optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( augmented_observation_spec, action_spec, preprocessing_combiner=_preprocessing_combiner, input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( augmented_observation_spec, preprocessing_combiner=_preprocessing_combiner, input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( augmented_observation_spec, action_spec, preprocessing_combiner=_preprocessing_combiner, fc_layer_params=actor_fc_layers, name="actor_net") value_net = value_network.ValueNetwork( augmented_observation_spec, preprocessing_combiner=_preprocessing_combiner, fc_layer_params=value_fc_layers, name="critic_net") inference_net = actor_distribution_network.ActorDistributionNetwork( input_tensor_spec=input_inference_spec, output_tensor_spec=z_spec, fc_layer_params=inference_fc_layers, continuous_projection_net=normal_projection_net, name="inference_net") tf_agent = ppo_diayn_agent.PPODiaynAgent( augmented_time_step_spec, action_spec, z_spec, optimizer, actor_net=actor_net, value_net=value_net, inference_net=inference_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, entropy_regularization=entropy_regularization, kl_posteriors_penalty=kl_posteriors_penalty, mock_inference=mock_inference, mock_reward=mock_reward, infer_from_com=infer_from_com, l2_distance=l2_distance, rl_steps=rl_steps, inference_steps=inference_steps) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) actor_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'diayn_actor'), actor_net=actor_net, global_step=global_step) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'diayn_policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'diayn_replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) inference_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'diayn_inference'), inference_net=inference_net, global_step=global_step) actor_checkpointer.initialize_or_restore() train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() inference_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) # option_length = 200 # if env_name == "Plane-v1": # option_length = 10 # dataset = replay_buffer.as_dataset( # num_parallel_calls=3, sample_batch_size=num_parallel_environments, # num_steps=option_length) # iterator_dataset = iter(dataset) def train_step(): trajectories = replay_buffer.gather_all() # trajectories, _ = next(iterator_dataset) return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = {}, train_time = {}'.format( collect_time, train_time)) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) inference_checkpointer.save(global_step=global_step_val) actor_checkpointer.save(global_step=global_step_val) rb_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def train_eval( root_dir, # env_name='HalfCheetah-v2', # env_load_fn=suite_mujoco.load, env_load_fn=None, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=int(1e7), collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, # use_tf_functions=False, debug_summaries=False, summarize_grads_and_vars=False): if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) # eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) # tf_env = tf_py_environment.TFPyEnvironment( # parallel_py_environment.ParallelPyEnvironment( # [lambda: env_load_fn(env_name)] * num_parallel_environments)) eval_tf_env = tf_py_environment.TFPyEnvironment( suite_gym.wrap_env(RectEnv())) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: suite_gym.wrap_env(RectEnv())] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) preprocessing_layers = { 'target': tf.keras.models.Sequential([ # tf.keras.applications.MobileNetV2( # input_shape=(64, 64, 1), include_top=False, weights=None), # tf.keras.layers.Conv2D(1, 6), easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)), tf.keras.layers.Flatten() ]), 'canvas': tf.keras.models.Sequential([ # tf.keras.applications.MobileNetV2( # input_shape=(64, 64, 1), include_top=False, weights=None), # tf.keras.layers.Conv2D(1, 6), easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)), tf.keras.layers.Flatten() ]), 'coord': tf.keras.models.Sequential([ tf.keras.layers.Dense(64), tf.keras.layers.Dense(64), tf.keras.layers.Flatten() ]) } preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, max_to_keep=5, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), max_to_keep=5, policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) train_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = {}, train_time = {}'.format( collect_time, train_time)) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def train_eval( root_dir, env_name='HalfCheetah-v2', num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect initial_collect_steps=1, collect_steps_per_iteration=1, replay_buffer_capacity=1000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=0, eval_interval=1000000, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=200, rb_checkpoint_interval=300, log_interval=50, summary_interval=50, summaries_flush_secs=1000, debug_summaries=True, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for SAC.""" global interrupted dir_key = root_dir root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() # eval_summary_writer = tf.compat.v2.summary.create_file_writer( # eval_dir, flush_millis=summaries_flush_secs * 1000) # eval_metrics = [ # tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), # tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) # ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): train_env = T4TFEnv(metrics_key=dir_key) eval_env = T4TFEnv(fake=True) tf_env = tf_py_environment.TFPyEnvironment(train_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, # fc_layer_params=actor_fc_layers, conv_layer_params=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], lstm_size=(8, ), normal_projection_net=normal_projection_net) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_conv_layer_params=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], lstm_size=(8, ), # observation_fc_layer_params=critic_obs_fc_layers, # action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) print(actor_net) print(critic_net) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] train_metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.EnvironmentSteps(), tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()), tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. logging.info( 'Initializing replay buffer by collecting experience for %d steps with ' 'a random policy.', initial_collect_steps) initial_collect_driver.run() # results = metric_utils.eager_compute( # eval_metrics, # eval_tf_env, # eval_policy, # num_episodes=num_eval_episodes, # train_step=global_step, # summary_writer=eval_summary_writer, # summary_prefix='Metrics', # ) # if eval_metrics_callback is not None: # eval_metrics_callback(results, global_step.numpy()) # metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 # Dataset generates trajectories with shape [Bx2x...] dataset = replay_buffer.as_dataset(num_parallel_calls=1, sample_batch_size=batch_size, num_steps=2).prefetch(3) iterator = iter(dataset) for _ in range(num_iterations): if interrupted: train_env.interrupted = True start_time = time.time() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) for _ in range(train_steps_per_iteration): if interrupted: train_env.interrupted = True experience, _ = next(iterator) train_loss = tf_agent.train(experience) time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: # actor_net.save(os.path.join(root_dir, 'actor'), save_format='tf') # critic_net.save(os.path.join(root_dir, 'critic'), save_format='tf') logging.info('step = %d, loss = %f', global_step.numpy(), train_loss.loss) steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc logging.info('%.3f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) obs_shape = time_step.observation.shape tf.compat.v2.summary.image( name='input_image', data=np.reshape(time_step.observation, (1, obs_shape[1], obs_shape[2], 3)), step=global_step) timed_at_step = global_step.numpy() time_acc = 0 for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=train_metrics[:2]) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) if eval_metrics_callback is not None: eval_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() print('global step: %d' % global_step_val) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) return train_loss
def train_eval( root_dir, experiment_name, # experiment name env_name='carla-v0', agent_name='sac', # agent's name num_iterations=int(1e7), actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), model_network_ctor_type='non-hierarchical', # model net input_names=['camera', 'lidar'], # names for inputs mask_names=['birdeye'], # names for masks preprocessing_combiner=tf.keras.layers.Add( ), # takes a flat list of tensors and combines them actor_lstm_size=(40, ), # lstm size for actor critic_lstm_size=(40, ), # lstm size for critic actor_output_fc_layers=(100, ), # lstm output critic_output_fc_layers=(100, ), # lstm output epsilon_greedy=0.1, # exploration parameter for DQN q_learning_rate=1e-3, # q learning rate for DQN ou_stddev=0.2, # exploration paprameter for DDPG ou_damping=0.15, # exploration parameter for DDPG dqda_clipping=None, # for DDPG exploration_noise_std=0.1, # exploration paramter for td3 actor_update_period=2, # for td3 # Params for collect initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=int(1e5), # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train train_steps_per_iteration=1, initial_model_train_steps=100000, # initial model training batch_size=256, model_batch_size=32, # model training batch size sequence_length=4, # number of timesteps to train model actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, model_learning_rate=1e-4, # learning rate for model training td_errors_loss_fn=tf.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, # Params for eval num_eval_episodes=10, eval_interval=10000, # Params for summaries and logging num_images_per_summary=1, # images for each summary train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, gpu_allow_growth=True, # GPU memory growth gpu_memory_limit=None, # GPU memory limit action_repeat=1 ): # Name of single observation channel, ['camera', 'lidar', 'birdeye'] # Setup GPU gpus = tf.config.experimental.list_physical_devices('GPU') if gpu_allow_growth: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpu_memory_limit: for gpu in gpus: tf.config.experimental.set_virtual_device_configuration( gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=gpu_memory_limit) ]) # Get train and eval directories root_dir = os.path.expanduser(root_dir) root_dir = os.path.join(root_dir, env_name, experiment_name) # Get summary writers summary_writer = tf.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() # Eval metrics eval_metrics = [ tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric( name='AverageEpisodeLengthEvalPolicy', buffer_size=num_eval_episodes), ] global_step = tf.compat.v1.train.get_or_create_global_step() # Whether to record for summary with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create Carla environment if agent_name == 'latent_sac': py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names + mask_names, action_repeat=action_repeat) elif agent_name == 'dqn': py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat) else: py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env) fps = int(np.round(1.0 / (py_env.dt * action_repeat))) # Specs time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() ## Make tf agent if agent_name == 'latent_sac': # Get model network for latent sac if model_network_ctor_type == 'hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical elif model_network_ctor_type == 'non-hierarchical': model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical else: raise NotImplementedError model_net = model_network_ctor(input_names, input_names + mask_names) # Get the latent spec latent_size = model_net.latent_size latent_observation_spec = tensor_spec.TensorSpec((latent_size, ), dtype=tf.float32) latent_time_step_spec = ts.time_step_spec( observation_spec=latent_observation_spec) # Get actor and critic net actor_net = actor_distribution_network.ActorDistributionNetwork( latent_observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) critic_net = critic_network.CriticNetwork( (latent_observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers) # Build the inner SAC agent based on latent space inner_agent = sac_agent.SacAgent( latent_time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) inner_agent.initialize() # Build the latent sac agent tf_agent = latent_sac_agent.LatentSACAgent( time_step_spec, action_spec, inner_agent=inner_agent, model_network=model_net, model_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=model_learning_rate), model_batch_size=model_batch_size, num_images_per_summary=num_images_per_summary, sequence_length=sequence_length, gradient_clipping=gradient_clipping, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, fps=fps) else: # Set up preprosessing layers for dictionary observation inputs preprocessing_layers = collections.OrderedDict() for name in input_names: preprocessing_layers[name] = Preprocessing_Layer(32, 256) if len(input_names) < 2: preprocessing_combiner = None if agent_name == 'dqn': q_rnn_net = q_rnn_network.QRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = dqn_agent.DqnAgent( time_step_spec, action_spec, q_network=q_rnn_net, epsilon_greedy=epsilon_greedy, n_step_update=1, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=q_learning_rate), td_errors_loss_fn=common.element_wise_squared_loss, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'ddpg' or agent_name == 'td3': actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) if agent_name == 'ddpg': tf_agent = ddpg_agent.DdpgAgent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), ou_stddev=ou_stddev, ou_damping=ou_damping, target_update_tau=target_update_tau, target_update_period=target_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars) elif agent_name == 'td3': tf_agent = td3_agent.Td3Agent( time_step_spec, action_spec, actor_network=actor_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), exploration_noise_std=exploration_noise_std, target_update_tau=target_update_tau, target_update_period=target_update_period, actor_update_period=actor_update_period, dqda_clipping=dqda_clipping, td_errors_loss_fn=None, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) elif agent_name == 'sac': actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=normal_projection_net) critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork( (observation_spec, action_spec), preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers) tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_distribution_rnn_net, critic_network=critic_rnn_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math. squared_difference, # make critic loss dimension compatible gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) else: raise NotImplementedError tf_agent.initialize() # Get replay buffer replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=1, # No parallel environments max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Train metrics env_steps = tf_metrics.EnvironmentSteps() average_return = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] # Get policies # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) eval_policy = tf_agent.policy initial_collect_policy = random_tf_policy.RandomTFPolicy( time_step_spec, action_spec) collect_policy = tf_agent.collect_policy # Checkpointers train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'), max_to_keep=2) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step, max_to_keep=2) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() # Collect driver initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) # Optimize the performance by using tf functions initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if (env_steps.result() == 0 or replay_buffer.num_frames() == 0): logging.info( 'Initializing replay buffer by collecting experience for %d steps' 'with a random policy.', initial_collect_steps) initial_collect_driver.run() if agent_name == 'latent_sac': compute_summaries(eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=1, num_episodes_to_render=1, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=1, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Dataset generates trajectories with shape [Bxslx...] dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size, num_steps=sequence_length + 1).prefetch(3) iterator = iter(dataset) # Get train step def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) train_step = common.function(train_step) if agent_name == 'latent_sac': def train_model_step(): experience, _ = next(iterator) return tf_agent.train_model(experience) train_model_step = common.function(train_model_step) # Training initializations time_step = None time_acc = 0 env_steps_before = env_steps.result().numpy() # Start training for iteration in range(num_iterations): start_time = time.time() if agent_name == 'latent_sac' and iteration < initial_model_train_steps: train_model_step() else: # Run collect time_step, _ = collect_driver.run(time_step=time_step) # Train an iteration for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time # Log training information if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() # Get training metrics for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) # Evaluation if global_step.numpy() % eval_interval == 0: # Log evaluation metrics if agent_name == 'latent_sac': compute_summaries( eval_metrics, eval_tf_env, eval_policy, train_step=global_step, summary_writer=summary_writer, num_episodes=num_eval_episodes, num_episodes_to_render=num_images_per_summary, model_net=model_net, fps=10, image_keys=input_names + mask_names) else: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) metric_utils.log_metrics(eval_metrics) # Save checkpoints global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, env_name='cartpole', task_name='balance', observations_allowlist='position', eval_env_name=None, num_iterations=1000000, # Params for networks. actor_fc_layers=(400, 300), actor_output_fc_layers=(100, ), actor_lstm_size=(40, ), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(300, ), critic_output_fc_layers=(100, ), critic_lstm_size=(40, ), num_parallel_environments=1, # Params for collect initial_collect_episodes=1, collect_episodes_per_iteration=1, replay_buffer_capacity=1000000, # Params for target update target_update_tau=0.05, target_update_period=5, # Params for train train_steps_per_iteration=1, batch_size=256, critic_learning_rate=3e-4, train_sequence_length=20, actor_learning_rate=3e-4, alpha_learning_rate=3e-4, td_errors_loss_fn=tf.math.squared_difference, gamma=0.99, reward_scale_factor=0.1, gradient_clipping=None, use_tf_functions=True, # Params for eval num_eval_episodes=30, eval_interval=10000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, log_interval=1000, summary_interval=1000, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for RNN SAC on DM control.""" root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if observations_allowlist is not None: env_wrappers = [ functools.partial( wrappers.FlattenObservationsWrapper, observations_allowlist=[observations_allowlist]) ] else: env_wrappers = [] env_load_fn = functools.partial(suite_dm_control.load, task_name=task_name, env_wrappers=env_wrappers) if num_parallel_environments == 1: py_env = env_load_fn(env_name) else: py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment(py_env) eval_env_name = eval_env_name or env_name eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(eval_env_name)) time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network. TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length + 1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_time = time.time() start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step() time_acc += time.time() - start_time if global_step.numpy() % log_interval == 0: logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) if global_step.numpy() % eval_interval == 0: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.numpy()) metric_utils.log_metrics(eval_metrics) global_step_val = global_step.numpy() if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val)
def train_eval( root_dir, gpu='1', env_load_fn=None, model_ids=None, eval_env_mode='headless', conv_layer_params=None, encoder_fc_layers=[256], actor_fc_layers=[256, 256], value_fc_layers=[256, 256], use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, eval_only=False, eval_deterministic=False, num_parallel_environments_eval=1, model_ids_eval=None, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, rb_checkpoint_interval=500, log_interval=10, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( py_metrics.AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments_eval), batched_py_metric.BatchedPyMetric( py_metrics.AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments_eval), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if model_ids is None: model_ids = [None] * num_parallel_environments else: assert len(model_ids) == num_parallel_environments,\ 'model ids provided, but length not equal to num_parallel_environments' if model_ids_eval is None: model_ids_eval = [None] * num_parallel_environments_eval else: assert len(model_ids_eval) == num_parallel_environments_eval,\ 'model ids eval provided, but length not equal to num_parallel_environments_eval' tf_py_env = [lambda model_id=model_ids[i]: env_load_fn(model_id, 'headless', gpu) for i in range(num_parallel_environments)] tf_env = tf_py_environment.TFPyEnvironment(parallel_py_environment.ParallelPyEnvironment(tf_py_env)) if eval_env_mode == 'gui': assert num_parallel_environments_eval == 1, 'only one GUI env is allowed' eval_py_env = [lambda model_id=model_ids_eval[i]: env_load_fn(model_id, eval_env_mode, gpu) for i in range(num_parallel_environments_eval)] eval_py_env = parallel_py_environment.ParallelPyEnvironment(eval_py_env) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) time_step_spec = tf_env.time_step_spec() observation_spec = tf_env.observation_spec() action_spec = tf_env.action_spec() print('observation_spec', observation_spec) print('action_spec', action_spec) glorot_uniform_initializer = tf.compat.v1.keras.initializers.glorot_uniform() preprocessing_layers = { 'depth_seg': tf.keras.Sequential(mlp_layers( conv_layer_params=conv_layer_params, fc_layer_params=encoder_fc_layers, kernel_initializer=glorot_uniform_initializer, )), 'sensor': tf.keras.Sequential(mlp_layers( conv_layer_params=None, fc_layer_params=encoder_fc_layers, kernel_initializer=glorot_uniform_initializer, )), } preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=actor_fc_layers, kernel_initializer=glorot_uniform_initializer ) value_net = value_network.ValueNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=value_fc_layers, kernel_initializer=glorot_uniform_initializer ) tf_agent = ppo_agent.PPOAgent( time_step_spec, action_spec, optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True sess = tf.compat.v1.Session(config=config) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) if eval_deterministic: eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) else: eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.collect_policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( buffer_size=100, batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( buffer_size=100, batch_size=num_parallel_environments), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration * num_parallel_environments).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) summary_ops = [] for train_metric in train_metrics: summary_ops.append(train_metric.tf_summaries( train_step=global_step, step_metrics=step_metrics)) with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries( train_step=global_step, step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with sess.as_default(): # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) if eval_only: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=0, callback=eval_metrics_callback, tf_summaries=False, log=True, ) episodes = eval_py_env.get_stored_episodes() episodes = [episode for sublist in episodes for episode in sublist][:num_eval_episodes] metrics = episode_utils.get_metrics(episodes) for key in sorted(metrics.keys()): print(key, ':', metrics[key]) save_path = os.path.join(eval_dir, 'episodes_eval.pkl') episode_utils.save(episodes, save_path) print('EVAL DONE') return common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_second_ph, step=global_step) global_step_val = sess.run(global_step) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(True): with tf.name_scope('Metrics/'): episodes = eval_py_env.get_stored_episodes() episodes = [episode for sublist in episodes for episode in sublist][:num_eval_episodes] metrics = episode_utils.get_metrics(episodes) for key in sorted(metrics.keys()): print(key, ':', metrics[key]) metric_op = tf.compat.v2.summary.scalar(name=key, data=metrics[key], step=global_step_val) sess.run(metric_op) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss, _ = sess.run([train_op, summary_ops]) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ( (global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run( steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info('%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) sess.close()
def train_eval( root_dir, env_name='SocialBot-GroceryGround-v0', env_load_fn=suite_socialbot.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(192, 64), value_fc_layers=(192, 64), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=8, num_parallel_environments=8, replay_buffer_capacity=2001, # Per-environment # Params for train num_epochs=16, learning_rate=1e-4, # Params for eval num_eval_episodes=10, eval_interval=500, # Params for summaries and logging log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): """A simple train and eval for GroceryGround.""" # Set summary writer and eval metrics root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Create envs and optimizer tf.compat.v1.set_random_seed(random_seed) eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) # Create actor and value network if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) # Create ppo agent tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() # Create metrics, replay_buffer and collect_driver environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() # Evaluate and train while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() trajectories = replay_buffer.gather_all() total_loss, _ = tf_agent.train(experience=trajectories) replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = {}, train_time = {}'.format( collect_time, train_time)) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val collect_time = 0 train_time = 0
def train(): num_iterations=1000000 # Params for networks. actor_fc_layers=(128, 64) actor_output_fc_layers=(64,) actor_lstm_size=(32,) critic_obs_fc_layers=None critic_action_fc_layers=None critic_joint_fc_layers=(128,) critic_output_fc_layers=(64,) critic_lstm_size=(32,) num_parallel_environments=1 # Params for collect initial_collect_episodes=1 collect_episodes_per_iteration=1 replay_buffer_capacity=1000000 # Params for target update target_update_tau=0.05 target_update_period=5 # Params for train train_steps_per_iteration=1 batch_size=256 critic_learning_rate=3e-4 train_sequence_length=20 actor_learning_rate=3e-4 alpha_learning_rate=3e-4 td_errors_loss_fn=tf.math.squared_difference gamma=0.99 reward_scale_factor=0.1 gradient_clipping=None use_tf_functions=True # Params for eval num_eval_episodes=30 eval_interval=10000 debug_summaries=False summarize_grads_and_vars=False global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, input_fc_layer_params=actor_fc_layers, lstm_size=actor_lstm_size, output_fc_layer_params=actor_output_fc_layers, continuous_projection_net=tanh_normal_projection_network .TanhNormalProjectionNetwork) critic_net = critic_rnn_network.CriticRnnNetwork( (observation_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, lstm_size=critic_lstm_size, output_fc_layer_params=critic_output_fc_layers, kernel_initializer='glorot_uniform', last_kernel_initializer='glorot_uniform') tf_agent = sac_agent.SacAgent( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] env_steps = tf_metrics.EnvironmentSteps(prefix='Train') eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy) initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), tf_env.action_spec()) collect_policy = tf_agent.collect_policy initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, initial_collect_policy, observers=replay_observer, num_episodes=initial_collect_episodes) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_observer, num_episodes=collect_episodes_per_iteration) if use_tf_functions: initial_collect_driver.run = common.function(initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) tf_agent.train = common.function(tf_agent.train) if env_steps.result() == 0 or replay_buffer.num_frames() == 0: logging.info( 'Initializing replay buffer by collecting experience for %d episodes ' 'with a random policy.', initial_collect_episodes) initial_collect_driver.run() time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() # Prepare replay buffer as dataset with invalid transitions filtered. def _filter_invalid_transition(trajectories, unused_arg1): # Reduce filter_fn over full trajectory sampled. The sequence is kept only # if all elements except for the last one pass the filter. This is to # allow training on terminal steps. return tf.reduce_all(~trajectories.is_boundary()[:-1]) dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=train_sequence_length+1).unbatch().filter( _filter_invalid_transition).batch(batch_size).prefetch(5) # Dataset generates trajectories with shape [Bx2x...] iterator = iter(dataset) def train_step(): experience, _ = next(iterator) return tf_agent.train(experience) if use_tf_functions: train_step = common.function(train_step) for _ in range(num_iterations): start_env_steps = env_steps.result() time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) episode_steps = env_steps.result() - start_env_steps # TODO(b/152648849) for _ in range(episode_steps): for _ in range(train_steps_per_iteration): train_step()
summaries_flush_secs = 1, use_tf_functions = True, debug_summaries = False, summarize_grads_and_vars = False initial_collect_steps = 100 importance_ratio_clipping = 0.2 # + global_step = tf.compat.v1.train.get_or_create_global_step() optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) if use_rnns: actor_net = (actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None)) value_net = (value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None)) else: actor_net = (actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh)) value_net = (value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers,
def create_ppo_agent(env, global_step, FLAGS): actor_fc_layers = (512, 256) value_fc_layers = (512, 256) lstm_fc_input = (1024, 512) lstm_size = (256, ) lstm_fc_output = (256, 256) minimap_preprocessing = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), strides=(2, 2), activation='relu'), tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(2, 2), activation='relu'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=256, activation='relu') ]) screen_preprocessing = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), strides=(2, 2), activation='relu'), tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(2, 2), activation='relu'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=256, activation='relu') ]) info_preprocessing = tf.keras.models.Sequential([ tf.keras.layers.Dense(units=128, activation='relu'), tf.keras.layers.Dense(units=128, activation='relu') ]) entities_preprocessing = tf.keras.models.Sequential([ tf.keras.layers.Conv1D(filters=4, kernel_size=4, activation='relu'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(units=256, activation='relu') ]) actor_preprocessing_layers = { 'minimap': minimap_preprocessing, 'screen': screen_preprocessing, 'info': info_preprocessing, 'entities': entities_preprocessing, } actor_preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) if FLAGS.use_lstms: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( env.observation_spec(), env.action_spec(), preprocessing_layers=actor_preprocessing_layers, preprocessing_combiner=actor_preprocessing_combiner, input_fc_layer_params=lstm_fc_input, output_fc_layer_params=lstm_fc_output, lstm_size=lstm_size) else: actor_net = actor_distribution_network.ActorDistributionNetwork( input_tensor_spec=env.observation_spec(), output_tensor_spec=env.action_spec(), preprocessing_layers=actor_preprocessing_layers, preprocessing_combiner=actor_preprocessing_combiner, fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_preprocessing_layers = { 'minimap': minimap_preprocessing, 'screen': screen_preprocessing, 'info': info_preprocessing, 'entities': entities_preprocessing, } value_preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) if FLAGS.use_lstms: value_net = value_rnn_network.ValueRnnNetwork( env.observation_spec(), preprocessing_layers=value_preprocessing_layers, preprocessing_combiner=value_preprocessing_combiner, input_fc_layer_params=lstm_fc_input, output_fc_layer_params=lstm_fc_output, lstm_size=lstm_size) else: value_net = value_network.ValueNetwork( env.observation_spec(), preprocessing_layers=value_preprocessing_layers, preprocessing_combiner=value_preprocessing_combiner, fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=FLAGS.learning_rate) # commented out values are the defaults tf_agent = my_ppo_agent.PPOAgent( time_step_spec=env.time_step_spec(), action_spec=env.action_spec(), optimizer=optimizer, actor_net=actor_net, value_net=value_net, importance_ratio_clipping=0.1, # lambda_value=0.95, discount_factor=0.95, entropy_regularization=0.003, # policy_l2_reg=0.0, # value_function_l2_reg=0.0, # shared_vars_l2_reg=0.0, # value_pred_loss_coef=0.5, num_epochs=FLAGS.num_epochs, use_gae=True, use_td_lambda_return=True, normalize_rewards=FLAGS.norm_rewards, reward_norm_clipping=0.0, normalize_observations=True, # log_prob_clipping=0.0, # KL from here... # To disable the fixed KL cutoff penalty, set the kl_cutoff_factor parameter to 0.0 kl_cutoff_factor=0.0, kl_cutoff_coef=0.0, # To disable the adaptive KL penalty, set the initial_adaptive_kl_beta parameter to 0.0 initial_adaptive_kl_beta=0.0, adaptive_kl_target=0.00, adaptive_kl_tolerance=0.0, # ...to here. # gradient_clipping=None, value_clipping=0.5, # check_numerics=False, # compute_value_and_advantage_in_train=True, # update_normalizers_in_train=True, # debug_summaries=False, # summarize_grads_and_vars=False, train_step_counter=global_step, # name='PPOClipAgent' ) tf_agent.initialize() return tf_agent
def train_eval( root_dir, random_seed=0, num_epochs=1000000, # Params for train normalize_observations=True, normalize_rewards=True, discount_factor=1.0, lr=1e-5, lr_schedule=None, num_policy_updates=20, initial_adaptive_kl_beta=0.0, kl_cutoff_factor=0, importance_ratio_clipping=0.2, value_pred_loss_coef=0.5, gradient_clipping=None, entropy_regularization=0.0, log_prob_clipping=0.0, # Params for log, eval, save eval_interval=100, save_interval=1000, checkpoint_interval=None, summary_interval=100, do_evaluation=True, # Params for data collection train_batch_size=10, eval_batch_size=100, collect_driver=None, eval_driver=None, replay_buffer_capacity=20000, # Policy and value networks ActorNet=actor_distribution_network.ActorDistributionNetwork, zero_means_kernel_initializer=False, init_action_stddev=0.35, actor_fc_layers=(), value_fc_layers=(), use_rnn=True, actor_lstm_size=(12, ), value_lstm_size=(12, ), **kwargs): """ A simple train and eval for PPO agent. Args: root_dir (str): directory for saving training and evalutaion data random_seed (int): seed for random number generator num_epochs (int): number of training epochs. At each epoch a batch of data is collected according to one stochastic policy, and then the policy is updated. normalize_observations (bool): flag for normalization of observations. Uses StreamingTensorNormalizer which normalizes based on the whole history of observations. normalize_rewards (bool): flag for normalization of rewards. Uses StreamingTensorNormalizer which normalizes based on the whole history of rewards. discount_factor (float): rewards discout factor, should be in (0,1] lr (float): learning rate for Adam optimizer lr_schedule (callable: int -> float, optional): function to schedule the learning rate annealing. Takes as argument the int epoch number and returns float value of the learning rate. num_policy_updates (int): number of policy gradient steps to do on each epoch of training. In PPO this is typically >1. initial_adaptive_kl_beta (float): see tf-agents PPO docs kl_cutoff_factor (float): see tf-agents PPO docs importance_ratio_clipping (float): clipping value for importance ratio. Should demotivate the policy from doing updates that significantly change the policy. Should be in (0,1] value_pred_loss_coef (float): weight coefficient for quadratic value estimation loss. gradient_clipping (float): gradient clipping coefficient. entropy_regularization (float): entropy regularization loss coefficient. log_prob_clipping (float): +/- value for clipping log probs to prevent inf / NaN values. Default: no clipping. eval_interval (int): interval between evaluations, counted in epochs. save_interval (int): interval between savings, counted in epochs. It updates the log file and saves the deterministic policy. checkpoint_interval (int): interval between saving checkpoints, counted in epochs. Overwrites the previous saved one. Defaults to None, in which case checkpoints are not saved. summary_interval (int): interval between summary writing, counted in epochs. tf-agents takes care of summary writing; results can be later displayed in tensorboard. do_evaluation (bool): flag to interleave training epochs with evaluation epochs. train_batch_size (int): training batch size, collected in parallel. eval_batch_size (int): batch size for evaluation of the policy. collect_driver (Driver): driver for training data collection eval_driver (Driver): driver for evaluation data collection replay_buffer_capacity (int): How many transition tuples the buffer can store. The buffer is emptied and re-populated at each epoch. ActorNet (network.DistributionNetwork): a distribution actor network to use for training. The default is ActorDistributionNetwork from tf-agents, but this can also be customized. zero_means_kernel_initializer (bool): flag to initialize the means projection network with zeros. If this flag is not set, it will use default tf-agent random initializer. init_action_stddev (float): initial stddev of the normal action dist. actor_fc_layers (tuple): sizes of fully connected layers in actor net. value_fc_layers (tuple): sizes of fully connected layers in value net. use_rnn (bool): whether to use LSTM units in the neural net. actor_lstm_size (tuple): sizes of LSTM layers in actor net. value_lstm_size (tuple): sizes of LSTM layers in value net. """ # -------------------------------------------------------------------- # -------------------------------------------------------------------- tf.compat.v1.set_random_seed(random_seed) # Setup directories within 'root_dir' if not os.path.isdir(root_dir): os.mkdir(root_dir) policy_dir = os.path.join(root_dir, 'policy') checkpoint_dir = os.path.join(root_dir, 'checkpoint') logfile = os.path.join(root_dir, 'log.hdf5') train_dir = os.path.join(root_dir, 'train_summaries') # Create tf summary writer train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir) train_summary_writer.set_as_default() summary_interval *= num_policy_updates global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): # Define action and observation specs observation_spec = collect_driver.observation_spec() action_spec = collect_driver.action_spec() # Preprocessing: flatten and concatenate observation components preprocessing_layers = { obs: tf.keras.layers.Flatten() for obs in observation_spec.keys() } preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) # Define actor network and value network if use_rnn: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( input_tensor_spec=observation_spec, output_tensor_spec=action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=None, lstm_size=actor_lstm_size, output_fc_layer_params=actor_fc_layers) value_net = value_rnn_network.ValueRnnNetwork( input_tensor_spec=observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=None, lstm_size=value_lstm_size, output_fc_layer_params=value_fc_layers) else: npn = actor_distribution_network._normal_projection_net normal_projection_net = lambda specs: npn( specs, zero_means_kernel_initializer=zero_means_kernel_initializer, init_action_stddev=init_action_stddev) actor_net = ActorNet( input_tensor_spec=observation_spec, output_tensor_spec=action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=actor_fc_layers, continuous_projection_net=normal_projection_net) value_net = value_network.ValueNetwork( input_tensor_spec=observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=value_fc_layers) # Create PPO agent optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr) tf_agent = ppo_agent.PPOAgent( time_step_spec=collect_driver.time_step_spec(), action_spec=action_spec, optimizer=optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_policy_updates, train_step_counter=global_step, discount_factor=discount_factor, normalize_observations=normalize_observations, normalize_rewards=normalize_rewards, initial_adaptive_kl_beta=initial_adaptive_kl_beta, kl_cutoff_factor=kl_cutoff_factor, importance_ratio_clipping=importance_ratio_clipping, gradient_clipping=gradient_clipping, value_pred_loss_coef=value_pred_loss_coef, entropy_regularization=entropy_regularization, log_prob_clipping=log_prob_clipping, debug_summaries=True) tf_agent.initialize() eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy # Create replay buffer and collection driver replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=tf_agent.collect_data_spec, batch_size=train_batch_size, max_length=replay_buffer_capacity) def train_step(): experience = replay_buffer.gather_all() return tf_agent.train(experience) tf_agent.train = common.function(tf_agent.train) avg_return_metric = tf_metrics.AverageReturnMetric( batch_size=eval_batch_size, buffer_size=eval_batch_size) collect_driver.setup(collect_policy, [replay_buffer.add_batch]) eval_driver.setup(eval_policy, [avg_return_metric]) # Create a checkpointer and load the saved agent train_checkpointer = common.Checkpointer(ckpt_dir=checkpoint_dir, max_to_keep=1, agent=tf_agent, policy=tf_agent.policy, replay_buffer=replay_buffer, global_step=global_step) train_checkpointer.initialize_or_restore() global_step = tf.compat.v1.train.get_global_step() # Saver for the deterministic policy saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) # Evaluate policy once before training if do_evaluation: eval_driver.run(0) avg_return = avg_return_metric.result().numpy() avg_return_metric.reset() log = { 'returns': [avg_return], 'epochs': [0], 'policy_steps': [0], 'experience_time': [0.0], 'train_time': [0.0] } print('-------------------') print('Epoch 0') print(' Policy steps: 0') print(' Experience time: 0.00 mins') print(' Policy train time: 0.00 mins') print(' Average return: %.5f' % avg_return) # Save initial random policy path = os.path.join(policy_dir, ('0').zfill(6)) saved_model.save(path) # Training loop train_timer = timer.Timer() experience_timer = timer.Timer() for epoch in range(1, num_epochs + 1): # Collect new experience experience_timer.start() collect_driver.run(epoch) experience_timer.stop() # Update the policy train_timer.start() if lr_schedule: optimizer._lr = lr_schedule(epoch) train_loss = train_step() replay_buffer.clear() train_timer.stop() if (epoch % eval_interval == 0) and do_evaluation: # Evaluate the policy eval_driver.run(epoch) avg_return = avg_return_metric.result().numpy() avg_return_metric.reset() # Print out and log all metrics print('-------------------') print('Epoch %d' % epoch) print(' Policy steps: %d' % (epoch * num_policy_updates)) print(' Experience time: %.2f mins' % (experience_timer.value() / 60)) print(' Policy train time: %.2f mins' % (train_timer.value() / 60)) print(' Average return: %.5f' % avg_return) log['epochs'].append(epoch) log['policy_steps'].append(epoch * num_policy_updates) log['returns'].append(avg_return) log['experience_time'].append(experience_timer.value()) log['train_time'].append(train_timer.value()) # Save updated log save_log(log, logfile, ('%d' % epoch).zfill(6)) if epoch % save_interval == 0: # Save deterministic policy path = os.path.join(policy_dir, ('%d' % epoch).zfill(6)) saved_model.save(path) if checkpoint_interval is not None and \ epoch % checkpoint_interval == 0: # Save training checkpoint train_checkpointer.save(global_step) collect_driver.finish_training() eval_driver.finish_training()
def testStatelessValueNetTrain(self, compute_value_and_advantage_in_train): counter = common.create_variable('test_train_counter') actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( self._time_step_spec.observation, self._action_spec, input_fc_layer_params=None, output_fc_layer_params=None, lstm_size=(20,)) value_net = value_network.ValueNetwork( self._time_step_spec.observation, fc_layer_params=None) agent = ppo_agent.PPOAgent( self._time_step_spec, self._action_spec, optimizer=tf.compat.v1.train.AdamOptimizer(), actor_net=actor_net, value_net=value_net, num_epochs=1, train_step_counter=counter, compute_value_and_advantage_in_train=compute_value_and_advantage_in_train ) observations = tf.constant([ [[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]], ], dtype=tf.float32) mid_time_step_val = ts.StepType.MID.tolist() time_steps = ts.TimeStep( step_type=tf.constant([[mid_time_step_val] * 3] * 2, dtype=tf.int32), reward=tf.constant([[1] * 3] * 2, dtype=tf.float32), discount=tf.constant([[1] * 3] * 2, dtype=tf.float32), observation=observations) actions = tf.constant([[[0], [1], [1]], [[0], [1], [1]]], dtype=tf.float32) action_distribution_parameters = { 'loc': tf.constant([[[0.0]] * 3] * 2, dtype=tf.float32), 'scale': tf.constant([[[1.0]] * 3] * 2, dtype=tf.float32), } value_preds = tf.constant([[9., 15., 21.], [9., 15., 21.]], dtype=tf.float32) policy_info = { 'dist_params': action_distribution_parameters, } if not compute_value_and_advantage_in_train: policy_info['value_prediction'] = value_preds experience = trajectory.Trajectory(time_steps.step_type, observations, actions, policy_info, time_steps.step_type, time_steps.reward, time_steps.discount) if not compute_value_and_advantage_in_train: experience = agent._preprocess(experience) if tf.executing_eagerly(): loss = lambda: agent.train(experience) else: loss = agent.train(experience) self.evaluate(tf.compat.v1.initialize_all_variables()) loss_type = self.evaluate(loss) loss_numpy = loss_type.loss # Assert that loss is not zero as we are training in a non-episodic env. self.assertNotEqual( loss_numpy, 0.0, msg=('Loss is exactly zero, looks like no training ' 'was performed due to incomplete episodes.'))
def train_eval( root_dir, env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(512, 256, 256, 30), value_fc_layers=(512, 256, 256, 25), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=NumEpisodes, num_parallel_environments=1, replay_buffer_capacity=10000, # Per-environment # Params for train num_epochs=25, learning_rate=5e-4, # Params for eval num_eval_episodes=5, eval_interval=500, # Params for summaries and logging log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train6') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) #eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) #tf_env = tf_py_environment.TFPyEnvironment( # parallel_py_environment.ParallelPyEnvironment( # [lambda: env_load_fn(env_name)] * num_parallel_environments)) env = xSpace() if isinstance(env, py_environment.PyEnvironment): eval_tf_env = tf_py_environment.TFPyEnvironment(env) tf_env = tf_py_environment.TFPyEnvironment(env) print("Py Env") elif isinstance(env, tf_environment.TFEnvironment): eval_tf_env = env tf_env = env print("TF Env") optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, lambda_value=0.98, discount_factor=0.995, #value_pred_loss_coef=0.005, use_gae=True, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step, normalize_observations=False) tf_agent.initialize() print("************ INITIALIZING **********************") environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy # this for tensorbaord replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() eval_tf_env.reset() if global_step_val % eval_interval == 0: #tf_env.ResetMattData() metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) #print("eager compute completed") eval_tf_env.reset() start_time = time.time() collect_driver.run() #print("collect completed") collect_time += time.time() - start_time print("collect_time:" + str(collect_time)) start_time = time.time() trajectories = replay_buffer.gather_all() #print("start train completed") #pdb.set_trace() #k=trajectories[5] #xMean=tf.reduce_mean(k) print('training...') total_loss, _ = tf_agent.train(experience=trajectories) print('training complete. total loss:' + str(total_loss)) #print("end train completed") replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = {}, train_time = {}'.format( collect_time, train_time)) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def construct_attention_networks(observation_spec, action_spec, use_rnns=True, actor_fc_layers=(200, 100), value_fc_layers=(200, 100), lstm_size=(128,), conv_filters=8, conv_kernel=3, scalar_fc=5, scalar_name='direction', scalar_dim=4): """Creates an actor and critic network designed for use with MultiGrid. A convolution layer processes the image and a dense layer processes the direction the agent is facing. These are fed into some fully connected layers and an LSTM. Args: observation_spec: A tf-agents observation spec. action_spec: A tf-agents action spec. use_rnns: If True, will construct RNN networks. actor_fc_layers: Dimension and number of fully connected layers in actor. value_fc_layers: Dimension and number of fully connected layers in critic. lstm_size: Number of cells in each LSTM layers. conv_filters: Number of convolution filters. conv_kernel: Size of the convolution kernel. scalar_fc: Number of neurons in the fully connected layer processing the scalar input. scalar_name: Name of the scalar input. scalar_dim: Highest possible value for the scalar input. Used to convert to one-hot representation. Returns: A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork for the critic. """ preprocessing_layers = { 'image': tf.keras.models.Sequential([ cast_and_scale(), tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding='same'), tf.keras.layers.ReLU(), ]), 'policy_state': tf.keras.layers.Lambda(lambda x: x) } if scalar_name in observation_spec: preprocessing_layers[scalar_name] = tf.keras.models.Sequential( [one_hot_layer(scalar_dim), tf.keras.layers.Dense(scalar_fc)]) if 'position' in observation_spec: preprocessing_layers['position'] = tf.keras.models.Sequential( [cast_and_scale(), tf.keras.layers.Dense(scalar_fc)]) preprocessing_nest = tf.nest.map_structure(lambda l: None, preprocessing_layers) flat_observation_spec = nest_utils.flatten_up_to( preprocessing_nest, observation_spec, ) image_index_flat = flat_observation_spec.index(observation_spec['image']) network_state_index_flat = flat_observation_spec.index( observation_spec['policy_state']) image_shape = observation_spec['image'].shape # N x H x W x D preprocessing_combiner = AttentionCombinerConv(image_index_flat, network_state_index_flat, image_shape) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = value_rnn_network.ValueRnnNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) return actor_net, value_net
def train_eval( root_dir, env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=None, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, lstm_size=(20, ), # Params for collect num_environment_steps=25000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-3, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, log_interval=50, summary_interval=50, summaries_flush_secs=1, use_tf_functions=True, debug_summaries=False, summarize_grads_and_vars=False): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name)) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) tf_agent = ppo_clip_agent.PPOClipAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) train_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) logging.info('collect_time = %.3f, train_time = %.3f', collect_time, train_time) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )
def construct_multigrid_networks(observation_spec, action_spec, use_rnns=True, actor_fc_layers=(200, 100), value_fc_layers=(200, 100), lstm_size=(128,), conv_filters=8, conv_kernel=3, scalar_fc=5, scalar_name='direction', scalar_dim=4, random_z=False, xy_dim=None): """Creates an actor and critic network designed for use with MultiGrid. A convolution layer processes the image and a dense layer processes the direction the agent is facing. These are fed into some fully connected layers and an LSTM. Args: observation_spec: A tf-agents observation spec. action_spec: A tf-agents action spec. use_rnns: If True, will construct RNN networks. actor_fc_layers: Dimension and number of fully connected layers in actor. value_fc_layers: Dimension and number of fully connected layers in critic. lstm_size: Number of cells in each LSTM layers. conv_filters: Number of convolution filters. conv_kernel: Size of the convolution kernel. scalar_fc: Number of neurons in the fully connected layer processing the scalar input. scalar_name: Name of the scalar input. scalar_dim: Highest possible value for the scalar input. Used to convert to one-hot representation. random_z: If True, will provide an additional layer to process a randomly generated float input vector. xy_dim: If not None, will provide two additional layers to process 'x' and 'y' inputs. The dimension provided is the maximum value of x and y, and is used to create one-hot representation. Returns: A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork for the critic. """ preprocessing_layers = { 'image': tf.keras.models.Sequential([ cast_and_scale(), tf.keras.layers.Conv2D(conv_filters, conv_kernel), tf.keras.layers.ReLU(), tf.keras.layers.Flatten() ]), } if scalar_name in observation_spec: preprocessing_layers[scalar_name] = tf.keras.models.Sequential( [one_hot_layer(scalar_dim), tf.keras.layers.Dense(scalar_fc)]) if 'position' in observation_spec: preprocessing_layers['position'] = tf.keras.models.Sequential( [cast_and_scale(), tf.keras.layers.Dense(scalar_fc)]) if random_z: preprocessing_layers['random_z'] = tf.keras.models.Sequential( [tf.keras.layers.Lambda(lambda x: x)]) # Identity layer if xy_dim is not None: preprocessing_layers['x'] = tf.keras.models.Sequential( [one_hot_layer(xy_dim)]) preprocessing_layers['y'] = tf.keras.models.Sequential( [one_hot_layer(xy_dim)]) preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None, lstm_size=lstm_size) value_net = value_rnn_network.ValueRnnNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( observation_spec, preprocessing_layers=preprocessing_layers, preprocessing_combiner=preprocessing_combiner, fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) return actor_net, value_net
def train_eval( root_dir, tf_master='', env_name='HalfCheetah-v2', env_load_fn=suite_mujoco.load, random_seed=0, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=10000000, collect_episodes_per_iteration=30, num_parallel_environments=30, replay_buffer_capacity=1001, # Per-environment # Params for train num_epochs=25, learning_rate=1e-4, # Params for eval num_eval_episodes=30, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=100, policy_checkpoint_interval=50, rb_checkpoint_interval=200, log_interval=50, summary_interval=50, summaries_flush_secs=1, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """A simple train and eval for PPO.""" if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ batched_py_metric.BatchedPyMetric( AverageReturnMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), batched_py_metric.BatchedPyMetric( AverageEpisodeLengthMetric, metric_args={'buffer_size': num_eval_episodes}, batch_size=num_parallel_environments), ] eval_summary_writer_flush_op = eval_summary_writer.flush() global_step = tf.compat.v1.train.get_or_create_global_step() with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): tf.compat.v1.set_random_seed(random_seed) eval_py_env = parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments) tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * num_parallel_environments)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy) environment_steps_metric = tf_metrics.EnvironmentSteps() environment_steps_count = environment_steps_metric.result() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric(), tf_metrics.AverageEpisodeLengthMetric(), ] # Add to replay buffer and other agent specific observers. replay_buffer_observer = [replay_buffer.add_batch] collect_policy = tf_agent.collect_policy collect_op = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=replay_buffer_observer + train_metrics, num_episodes=collect_episodes_per_iteration).run() trajectories = replay_buffer.gather_all() train_op, _ = tf_agent.train(experience=trajectories) with tf.control_dependencies([train_op]): clear_replay_op = replay_buffer.clear() with tf.control_dependencies([clear_replay_op]): train_op = tf.identity(train_op) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=tf_agent.policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) with eval_summary_writer.as_default(), \ tf.compat.v2.summary.record_if(True): for eval_metric in eval_metrics: eval_metric.tf_summaries(step_metrics=step_metrics) init_agent_op = tf_agent.initialize() with tf.compat.v1.Session(tf_master) as sess: # Initialize graph. train_checkpointer.initialize_or_restore(sess) rb_checkpointer.initialize_or_restore(sess) common.initialize_uninitialized_variables(sess) sess.run(init_agent_op) sess.run(train_summary_writer.init()) sess.run(eval_summary_writer.init()) collect_time = 0 train_time = 0 timed_at_step = sess.run(global_step) steps_per_second_ph = tf.compat.v1.placeholder( tf.float32, shape=(), name='steps_per_sec_ph') steps_per_second_summary = tf.contrib.summary.scalar( name='global_steps/sec', tensor=steps_per_second_ph) while sess.run(environment_steps_count) < num_environment_steps: global_step_val = sess.run(global_step) if global_step_val % eval_interval == 0: metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op) start_time = time.time() sess.run(collect_op) collect_time += time.time() - start_time start_time = time.time() total_loss = sess.run(train_op) train_time += time.time() - start_time global_step_val = sess.run(global_step) if global_step_val % log_interval == 0: logging.info('step = %d, loss = %f', global_step_val, total_loss) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('%.3f steps/sec', steps_per_sec) sess.run(steps_per_second_summary, feed_dict={steps_per_second_ph: steps_per_sec}) logging.info( '%s', 'collect_time = {}, train_time = {}'.format( collect_time, train_time)) timed_at_step = global_step_val collect_time = 0 train_time = 0 if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) # One final eval before exiting. metric_utils.compute_summaries( eval_metrics, eval_py_env, eval_py_policy, num_episodes=num_eval_episodes, global_step=global_step_val, callback=eval_metrics_callback, log=True, ) sess.run(eval_summary_writer_flush_op)
def train_eval( root_dir, env_name='CartPole-v0', env_load_fn=suite_gym.load, random_seed=None, max_ep_steps=1000, # TODO(b/127576522): rename to policy_fc_layers. actor_fc_layers=(200, 100), value_fc_layers=(200, 100), use_rnns=False, # Params for collect num_environment_steps=5000000, collect_episodes_per_iteration=1, num_parallel_environments=1, replay_buffer_capacity=10000, # Per-environment # Params for train num_epochs=25, learning_rate=1e-3, # Params for eval num_eval_episodes=10, num_random_episodes=1, eval_interval=500, # Params for summaries and logging train_checkpoint_interval=500, policy_checkpoint_interval=500, rb_checkpoint_interval=20000, log_interval=50, summary_interval=50, summaries_flush_secs=10, use_tf_functions=True, debug_summaries=False, eval_metrics_callback=None, random_metrics_callback=None, summarize_grads_and_vars=False): # Set up the directories to contain the log data and model saves # If data already exist in these folders, then we will try to load it later. if root_dir is None: raise AttributeError('train_eval requires a root_dir.') root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') eval_dir = os.path.join(root_dir, 'eval') random_dir = os.path.join(root_dir, 'random') saved_model_dir = os.path.join(root_dir, 'policy_saved_model') # Create writers for logging and specify the metrics to log for each train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] random_summary_writer = tf.compat.v2.summary.create_file_writer( random_dir, flush_millis=summaries_flush_secs * 1000) random_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] global_step = tf.compat.v1.train.get_or_create_global_step() # Set up the agent and train, recoding data at each summary_internal number of steps with tf.compat.v2.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) # Load the environments. Here, we used the same for evaluation and training. # However, they could be different. eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name, max_episode_steps=max_ep_steps)) # tf_env = tf_py_environment.TFPyEnvironment( # parallel_py_environment.ParallelPyEnvironment( # [lambda: env_load_fn(env_name, max_episode_steps=max_ep_steps)] * num_parallel_environments)) tf_env = tf_py_environment.TFPyEnvironment( suite_gym.load(env_name, max_episode_steps=max_ep_steps)) optimizer = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate) if use_rnns: actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork( tf_env.observation_spec(), tf_env.action_spec(), input_fc_layer_params=actor_fc_layers, output_fc_layer_params=None) value_net = value_rnn_network.ValueRnnNetwork( tf_env.observation_spec(), input_fc_layer_params=value_fc_layers, output_fc_layer_params=None) else: actor_net = actor_distribution_network.ActorDistributionNetwork( tf_env.observation_spec(), tf_env.action_spec(), fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.tanh) value_net = value_network.ValueNetwork( tf_env.observation_spec(), fc_layer_params=value_fc_layers, activation_fn=tf.keras.activations.tanh) tf_agent = ppo_agent.PPOAgent( tf_env.time_step_spec(), tf_env.action_spec(), optimizer, actor_net=actor_net, value_net=value_net, entropy_regularization=0.0, importance_ratio_clipping=0.2, normalize_observations=False, normalize_rewards=False, use_gae=True, kl_cutoff_factor=0.0, initial_adaptive_kl_beta=0.0, num_epochs=num_epochs, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=global_step) tf_agent.initialize() environment_steps_metric = tf_metrics.EnvironmentSteps() step_metrics = [ tf_metrics.NumberOfEpisodes(), environment_steps_metric, ] train_metrics = step_metrics + [ tf_metrics.AverageReturnMetric( batch_size=num_parallel_environments), tf_metrics.AverageEpisodeLengthMetric( batch_size=num_parallel_environments), ] eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( tf_agent.collect_data_spec, batch_size=num_parallel_environments, max_length=replay_buffer_capacity) train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( train_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer) saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step) train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() collect_driver = dynamic_episode_driver.DynamicEpisodeDriver( tf_env, collect_policy, observers=[replay_buffer.add_batch] + train_metrics, num_episodes=collect_episodes_per_iteration) def train_step(): trajectories = replay_buffer.gather_all() return tf_agent.train(experience=trajectories) if use_tf_functions: # TODO(b/123828980): Enable once the cause for slowdown was identified. collect_driver.run = common.function(collect_driver.run, autograph=False) tf_agent.train = common.function(tf_agent.train, autograph=False) train_step = common.function(train_step) random_policy = random_tf_policy.RandomTFPolicy( eval_tf_env.time_step_spec(), eval_tf_env.action_spec()) collect_time = 0 train_time = 0 timed_at_step = global_step.numpy() while environment_steps_metric.result() < num_environment_steps: global_step_val = global_step.numpy() if global_step_val % eval_interval == 0: metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', ) metric_utils.eager_compute( random_metrics, eval_tf_env, random_policy, num_episodes=num_random_episodes, train_step=global_step, summary_writer=random_summary_writer, summary_prefix='Metrics', ) start_time = time.time() collect_driver.run() collect_time += time.time() - start_time start_time = time.time() total_loss, _ = train_step() replay_buffer.clear() train_time += time.time() - start_time for train_metric in train_metrics: train_metric.tf_summaries(train_step=global_step, step_metrics=step_metrics) if global_step_val % log_interval == 0: logging.info('Step: {:>6d}\tLoss: {:>+20.4f}'.format( global_step_val, total_loss)) steps_per_sec = ((global_step_val - timed_at_step) / (collect_time + train_time)) logging.info('{:6.3f} steps/sec'.format(steps_per_sec)) logging.info( 'collect_time = {:.3f}, train_time = {:.3f}'.format( collect_time, train_time)) with tf.compat.v2.summary.record_if(True): tf.compat.v2.summary.scalar(name='global_steps_per_sec', data=steps_per_sec, step=global_step) if global_step_val % train_checkpoint_interval == 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=global_step_val) saved_model_path = os.path.join( saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9)) saved_model.save(saved_model_path) if global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step.numpy()) timed_at_step = global_step_val collect_time = 0 train_time = 0 # One final eval before exiting. metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='Metrics', )