Ejemplo n.º 1
0
  def testBuilds(self):
    observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
                                                     1)
    time_step_spec = ts.time_step_spec(observation_spec)
    time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(1,))

    action_spec = [
        tensor_spec.BoundedTensorSpec((2,), tf.float32, 2, 3),
        tensor_spec.BoundedTensorSpec((3,), tf.int32, 0, 3)
    ]

    net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        conv_layer_params=[(4, 2, 2)],
        input_fc_layer_params=(5,),
        output_fc_layer_params=(5,),
        lstm_size=(3,))

    action_distributions, network_state = net(
        time_step.observation,
        time_step.step_type,
        net.get_initial_state(batch_size=1))
    self.evaluate(tf.compat.v1.global_variables_initializer())
    self.assertEqual([1, 2], action_distributions[0].mode().shape.as_list())
    self.assertEqual([1, 3], action_distributions[1].mode().shape.as_list())

    self.assertEqual(14, len(net.variables))
    # Conv Net Kernel
    self.assertEqual((2, 2, 3, 4), net.variables[0].shape)
    # Conv Net bias
    self.assertEqual((4,), net.variables[1].shape)
    # Fc Kernel
    self.assertEqual((64, 5), net.variables[2].shape)
    # Fc Bias
    self.assertEqual((5,), net.variables[3].shape)
    # LSTM Cell Kernel
    self.assertEqual((5, 12), net.variables[4].shape)
    # LSTM Cell Recurrent Kernel
    self.assertEqual((3, 12), net.variables[5].shape)
    # LSTM Cell Bias
    self.assertEqual((12,), net.variables[6].shape)
    # Fc Kernel
    self.assertEqual((3, 5), net.variables[7].shape)
    # Fc Bias
    self.assertEqual((5,), net.variables[8].shape)
    # Normal Projection Kernel
    self.assertEqual((5, 2), net.variables[9].shape)
    # Normal Projection Bias
    self.assertEqual((2,), net.variables[10].shape)
    # Normal Projection STD Bias layer
    self.assertEqual((2,), net.variables[11].shape)
    # Categorical Projection Kernel
    self.assertEqual((5, 12), net.variables[12].shape)
    # Categorical Projection Bias
    self.assertEqual((12,), net.variables[13].shape)

    # Assert LSTM cell is created.
    self.assertEqual((1, 3), network_state[0].shape)
    self.assertEqual((1, 3), network_state[1].shape)
Ejemplo n.º 2
0
  def testRunsWithLstmStack(self):
    observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
                                                     1)
    time_step_spec = ts.time_step_spec(observation_spec)
    time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(1, 5))

    action_spec = [
        tensor_spec.BoundedTensorSpec((2,), tf.float32, 2, 3),
        tensor_spec.BoundedTensorSpec((3,), tf.int32, 0, 3)
    ]

    net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        conv_layer_params=[(4, 2, 2)],
        input_fc_layer_params=(5,),
        output_fc_layer_params=(5,),
        lstm_size=(3, 3))

    initial_state = actor_policy.ActorPolicy(time_step_spec, action_spec,
                                             net).get_initial_state(1)

    net_call = net(time_step.observation, time_step.step_type,
                   initial_state)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    self.evaluate(tf.nest.map_structure(lambda d: d.sample(), net_call[0]))
Ejemplo n.º 3
0
  def testHandlePreprocessingLayers(self):
    observation_spec = (tensor_spec.TensorSpec([1], tf.float32),
                        tensor_spec.TensorSpec([], tf.float32))
    time_step_spec = ts.time_step_spec(observation_spec)
    time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(3, 4))

    action_spec = [
        tensor_spec.BoundedTensorSpec((2,), tf.float32, 2, 3),
        tensor_spec.BoundedTensorSpec((3,), tf.int32, 0, 3)
    ]

    preprocessing_layers = (tf.keras.layers.Dense(4),
                            sequential_layer.SequentialLayer([
                                tf.keras.layers.Reshape((1,)),
                                tf.keras.layers.Dense(4)
                            ]))

    net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=tf.keras.layers.Add())

    initial_state = actor_policy.ActorPolicy(time_step_spec, action_spec,
                                             net).get_initial_state(3)

    action_distributions, _ = net(time_step.observation, time_step.step_type,
                                  initial_state)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    self.assertEqual([3, 4, 2], action_distributions[0].mode().shape.as_list())
    self.assertEqual([3, 4, 3], action_distributions[1].mode().shape.as_list())
    self.assertGreater(len(net.trainable_variables), 4)
Ejemplo n.º 4
0
  def testTrainWithRnn(self):
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        self._obs_spec,
        self._action_spec,
        input_fc_layer_params=None,
        output_fc_layer_params=None,
        conv_layer_params=None,
        lstm_size=(40,),
    )

    critic_net = critic_rnn_network.CriticRnnNetwork(
        (self._obs_spec, self._action_spec),
        observation_fc_layer_params=(16,),
        action_fc_layer_params=(16,),
        joint_fc_layer_params=(16,),
        lstm_size=(16,),
        output_fc_layer_params=None,
    )

    counter = common.create_variable('test_train_counter')

    optimizer_fn = tf.compat.v1.train.AdamOptimizer

    agent = sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=critic_net,
        actor_network=actor_net,
        actor_optimizer=optimizer_fn(1e-3),
        critic_optimizer=optimizer_fn(1e-3),
        alpha_optimizer=optimizer_fn(1e-3),
        train_step_counter=counter,
    )

    batch_size = 5
    observations = tf.constant(
        [[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32)
    actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32)
    time_steps = ts.TimeStep(
        step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32),
        reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        observation=[observations])

    experience = trajectory.Trajectory(
        time_steps.step_type, [observations], actions, (),
        time_steps.step_type, time_steps.reward, time_steps.discount)

    # Force variable creation.
    agent.policy.variables()
    if tf.executing_eagerly():
      loss = lambda: agent.train(experience)
    else:
      loss = agent.train(experience)

    self.evaluate(tf.compat.v1.initialize_all_variables())
    self.assertEqual(self.evaluate(counter), 0)
    self.evaluate(loss)
    self.assertEqual(self.evaluate(counter), 1)
Ejemplo n.º 5
0
    def testRNNTrain(self, compute_value_and_advantage_in_train):
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            self._time_step_spec.observation,
            self._action_spec,
            input_fc_layer_params=None,
            output_fc_layer_params=None,
            lstm_size=(20, ))
        value_net = value_rnn_network.ValueRnnNetwork(
            self._time_step_spec.observation,
            input_fc_layer_params=None,
            output_fc_layer_params=None,
            lstm_size=(10, ))
        global_step = tf.compat.v1.train.get_or_create_global_step()
        agent = ppo_agent.PPOAgent(
            self._time_step_spec,
            self._action_spec,
            optimizer=tf.compat.v1.train.AdamOptimizer(),
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=1,
            train_step_counter=global_step,
            compute_value_and_advantage_in_train=
            compute_value_and_advantage_in_train)
        # Use a random env, policy, and replay buffer to collect training data.
        random_env = random_tf_environment.RandomTFEnvironment(
            self._time_step_spec, self._action_spec, batch_size=1)
        collection_policy = random_tf_policy.RandomTFPolicy(
            self._time_step_spec,
            self._action_spec,
            info_spec=agent.collect_policy.info_spec)
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collection_policy.trajectory_spec, batch_size=1, max_length=7)
        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            random_env,
            collection_policy,
            observers=[replay_buffer.add_batch],
            num_episodes=1)

        # In graph mode: finish building the graph so the optimizer
        # variables are created.
        if not tf.executing_eagerly():
            _, _ = agent.train(experience=replay_buffer.gather_all())

        # Initialize.
        self.evaluate(agent.initialize())
        self.evaluate(tf.compat.v1.global_variables_initializer())

        # Train one step.
        self.assertEqual(0, self.evaluate(global_step))
        self.evaluate(collect_driver.run())
        self.evaluate(agent.train(experience=replay_buffer.gather_all()))
        self.assertEqual(1, self.evaluate(global_step))
Ejemplo n.º 6
0
    def testTrainWithRnn(self):
        with tf.compat.v2.summary.record_if(False):
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                self._obs_spec,
                self._action_spec,
                input_fc_layer_params=None,
                output_fc_layer_params=None,
                conv_layer_params=None,
                lstm_size=(40, ))

            counter = common.create_variable('test_train_counter')
            agent = reinforce_agent.ReinforceAgent(
                self._time_step_spec,
                self._action_spec,
                actor_network=actor_net,
                optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
                train_step_counter=counter)

            batch_size = 5
            observations = tf.constant([[[1, 2], [3, 4], [5, 6]]] * batch_size,
                                       dtype=tf.float32)
            time_steps = ts.TimeStep(
                step_type=tf.constant([[1, 1, 2]] * batch_size,
                                      dtype=tf.int32),
                reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
                discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
                observation=observations)
            actions = tf.constant([[[0], [1], [1]]] * batch_size,
                                  dtype=tf.float32)

            experience = trajectory.Trajectory(time_steps.step_type,
                                               observations, actions, (),
                                               time_steps.step_type,
                                               time_steps.reward,
                                               time_steps.discount)

            # Force variable creation.
            agent.policy.variables()

            if tf.executing_eagerly():
                loss = lambda: agent.train(experience)
            else:
                loss = agent.train(experience)

            self.evaluate(tf.compat.v1.initialize_all_variables())
            self.assertEqual(self.evaluate(counter), 0)
            self.evaluate(loss)
            self.assertEqual(self.evaluate(counter), 1)
Ejemplo n.º 7
0
  def testTrainWithRnnTransitions(self):
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        self._obs_spec,
        self._action_spec,
        input_fc_layer_params=None,
        output_fc_layer_params=None,
        conv_layer_params=None,
        lstm_size=(40,))

    counter = common.create_variable('test_train_counter')
    agent = reinforce_agent.ReinforceAgent(
        self._time_step_spec,
        self._action_spec,
        actor_network=actor_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
        train_step_counter=counter
    )

    batch_size = 5
    observations = tf.constant(
        [[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32)
    time_steps = ts.TimeStep(
        step_type=tf.constant([[1, 1, 1]] * batch_size, dtype=tf.int32),
        reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        observation=observations)
    actions = policy_step.PolicyStep(
        tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32))
    next_time_steps = ts.TimeStep(
        step_type=tf.constant([[1, 1, 2]] * batch_size, dtype=tf.int32),
        reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        observation=observations)

    experience = trajectory.Transition(time_steps, actions, next_time_steps)

    agent.initialize()
    agent.train(experience)
def load_agents_and_create_videos(root_dir,
                                  env_name='CartPole-v0',
                                  env_load_fn=suite_gym.load,
                                  random_seed=None,
                                  max_ep_steps=1000,
                                  # TODO(b/127576522): rename to policy_fc_layers.
                                  actor_fc_layers=(200, 100),
                                  value_fc_layers=(200, 100),
                                  use_rnns=False,
                                  # Params for collect
                                  num_environment_steps=5000000,
                                  collect_episodes_per_iteration=1,
                                  num_parallel_environments=1,
                                  replay_buffer_capacity=10000,    # Per-environment
                                  # Params for train
                                  num_epochs=25,
                                  learning_rate=1e-3,
                                  # Params for eval
                                  num_eval_episodes=10,
                                  num_random_episodes=1,
                                  eval_interval=500,
                                  # Params for summaries and logging
                                  train_checkpoint_interval=500,
                                  policy_checkpoint_interval=500,
                                  rb_checkpoint_interval=20000,
                                  log_interval=50,
                                  summary_interval=50,
                                  summaries_flush_secs=10,
                                  use_tf_functions=True,
                                  debug_summaries=False,
                                  eval_metrics_callback=None,
                                  random_metrics_callback=None,
                                  summarize_grads_and_vars=False):
    
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    random_dir = os.path.join(root_dir, 'random')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
            train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
    
    eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)]

    random_summary_writer = tf.compat.v2.summary.create_file_writer(
            random_dir, flush_millis=summaries_flush_secs * 1000)
    
    random_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)]


    global_step = tf.compat.v1.train.get_or_create_global_step()

    if random_seed is not None:
        tf.compat.v1.set_random_seed(random_seed)
    
    eval_py_env = env_load_fn(env_name, max_episode_steps=max_ep_steps)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
#     tf_env = tf_py_environment.TFPyEnvironment(
#             parallel_py_environment.ParallelPyEnvironment(
#                      [lambda: env_load_fn(env_name, max_episode_steps=max_ep_steps)] * num_parallel_environments))

    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name, max_episode_steps=max_ep_steps))

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    if use_rnns:
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
        
        value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
    
    else:
        actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
        
        value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

    tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            kl_cutoff_factor=0.0,
            initial_adaptive_kl_beta=0.0,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
    
    tf_agent.initialize()

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    
    step_metrics = [tf_metrics.NumberOfEpisodes(),
                    environment_steps_metric,]

    train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                    batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                    batch_size=num_parallel_environments),]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

    train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    
    policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=eval_policy,
            global_step=global_step)
    
    rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)
    
    saved_model = policy_saver.PolicySaver(
            eval_policy, train_step=global_step)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

    # if use_tf_functions:
    #     # To speed up collect use common.function.
    #     collect_driver.run = common.function(collect_driver.run)
    #     tf_agent.train = common.function(tf_agent.train)

    # initial_collect_policy = random_tf_policy.RandomTFPolicy(
    #         tf_env.time_step_spec(), tf_env.action_spec())

    random_policy = random_tf_policy.RandomTFPolicy(eval_tf_env.time_step_spec(),
                                                    eval_tf_env.action_spec())

    # Make movies of the trained agent and a random agent
    date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
    
    trained_filename = "trainedPPO_" + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename)

    random_filename = 'random_' + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
Ejemplo n.º 9
0
    def testTrainWithRnn(self, cql_alpha, num_cql_samples,
                         include_critic_entropy_term, use_lagrange_cql_alpha,
                         expected_loss):
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            self._obs_spec,
            self._action_spec,
            input_fc_layer_params=None,
            output_fc_layer_params=None,
            conv_layer_params=None,
            lstm_size=(40, ),
        )

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (self._obs_spec, self._action_spec),
            observation_fc_layer_params=(16, ),
            action_fc_layer_params=(16, ),
            joint_fc_layer_params=(16, ),
            lstm_size=(16, ),
            output_fc_layer_params=None,
        )

        counter = common.create_variable('test_train_counter')

        optimizer_fn = tf.compat.v1.train.AdamOptimizer

        agent = cql_sac_agent.CqlSacAgent(
            self._time_step_spec,
            self._action_spec,
            critic_network=critic_net,
            actor_network=actor_net,
            actor_optimizer=optimizer_fn(1e-3),
            critic_optimizer=optimizer_fn(1e-3),
            alpha_optimizer=optimizer_fn(1e-3),
            cql_alpha=cql_alpha,
            num_cql_samples=num_cql_samples,
            include_critic_entropy_term=include_critic_entropy_term,
            use_lagrange_cql_alpha=use_lagrange_cql_alpha,
            random_seed=self._random_seed,
            train_step_counter=counter,
        )

        batch_size = 5
        observations = tf.constant([[[1, 2], [3, 4], [5, 6]]] * batch_size,
                                   dtype=tf.float32)
        actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32)
        time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * batch_size,
                                                       dtype=tf.int32),
                                 reward=tf.constant([[1] * 3] * batch_size,
                                                    dtype=tf.float32),
                                 discount=tf.constant([[1] * 3] * batch_size,
                                                      dtype=tf.float32),
                                 observation=observations)

        experience = trajectory.Trajectory(time_steps.step_type, observations,
                                           actions, (), time_steps.step_type,
                                           time_steps.reward,
                                           time_steps.discount)

        # Force variable creation.
        agent.policy.variables()

        if not tf.executing_eagerly():
            # Get experience first to make sure optimizer variables are created and
            # can be initialized.
            experience = agent.train(experience)
            with self.cached_session() as sess:
                common.initialize_uninitialized_variables(sess)
            self.assertEqual(self.evaluate(counter), 0)
            self.evaluate(experience)
            self.assertEqual(self.evaluate(counter), 1)
        else:
            self.assertEqual(self.evaluate(counter), 0)
            loss = self.evaluate(agent.train(experience))
            self.assertAllClose(loss.loss, expected_loss)
            self.assertEqual(self.evaluate(counter), 1)
Ejemplo n.º 10
0
def construct_multigrid_networks(
    observation_spec,
    action_spec,
    use_rnns=True,
    actor_fc_layers=(200, 100),
    value_fc_layers=(200, 100),
    lstm_size=(128, ),
    conv_filters=8,
    conv_kernel=3,
    scalar_fc=5,
    scalar_name="direction",
    scalar_dim=4,
    use_stacks=False,
):
    """Creates an actor and critic network designed for use with MultiGrid.

  A convolution layer processes the image and a dense layer processes the
  direction the agent is facing. These are fed into some fully connected layers
  and an LSTM.

  Args:
    observation_spec: A tf-agents observation spec.
    action_spec: A tf-agents action spec.
    use_rnns: If True, will construct RNN networks.
    actor_fc_layers: Dimension and number of fully connected layers in actor.
    value_fc_layers: Dimension and number of fully connected layers in critic.
    lstm_size: Number of cells in each LSTM layers.
    conv_filters: Number of convolution filters.
    conv_kernel: Size of the convolution kernel.
    scalar_fc: Number of neurons in the fully connected layer processing the
      scalar input.
    scalar_name: Name of the scalar input.
    scalar_dim: Highest possible value for the scalar input. Used to convert to
      one-hot representation.
    use_stacks: Use ResNet stacks (compresses the image).

  Returns:
    A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork
    for the critic.
  """

    preprocessing_layers = {
        "policy_state": tf.keras.layers.Lambda(lambda x: x)
    }
    if use_stacks:
        preprocessing_layers["image"] = tf.keras.models.Sequential([
            multigrid_networks.cast_and_scale(),
            _Stack(conv_filters // 2, 2),
            _Stack(conv_filters, 2),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Flatten()
        ])
    else:
        preprocessing_layers["image"] = tf.keras.models.Sequential([
            multigrid_networks.cast_and_scale(),
            tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding="same"),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Flatten()
        ])
    if scalar_name in observation_spec:
        preprocessing_layers[scalar_name] = tf.keras.models.Sequential([
            multigrid_networks.one_hot_layer(scalar_dim),
            tf.keras.layers.Dense(scalar_fc)
        ])
    if "position" in observation_spec:
        preprocessing_layers["position"] = tf.keras.models.Sequential([
            multigrid_networks.cast_and_scale(),
            tf.keras.layers.Dense(scalar_fc)
        ])

    preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    custom_objects = {"_Stack": _Stack}
    with tf.keras.utils.custom_object_scope(custom_objects):
        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                observation_spec,
                action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None,
                lstm_size=lstm_size)
            value_net = value_rnn_network.ValueRnnNetwork(
                observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                observation_spec,
                action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

    return actor_net, value_net
Ejemplo n.º 11
0
def train_eval(
        root_dir,
        env_name=None,
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        inference_fc_layers=(200, 100),
        use_rnns=None,
        dim_z=4,
        categorical=True,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        entropy_regularization=None,
        kl_posteriors_penalty=None,
        mock_inference=None,
        mock_reward=None,
        l2_distance=None,
        rl_steps=None,
        inference_steps=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)

        def _env_load_fn(env_name):
            diayn_wrapper = (
                lambda x: diayn_gym_env.DiaynGymEnv(x, dim_z, categorical))
            return env_load_fn(
                env_name,
                gym_env_wrappers=[diayn_wrapper],
            )

        eval_tf_env = tf_py_environment.TFPyEnvironment(_env_load_fn(env_name))
        if num_parallel_environments == 1:
            py_env = _env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: _env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)

        augmented_time_step_spec = tf_env.time_step_spec()
        augmented_observation_spec = augmented_time_step_spec.observation
        observation_spec = augmented_observation_spec['observation']
        z_spec = augmented_observation_spec['z']
        reward_spec = augmented_time_step_spec.reward
        action_spec = tf_env.action_spec()
        time_step_spec = ts.time_step_spec(observation_spec)
        infer_from_com = False
        if env_name == "AntRandGoalEval-v1":
            infer_from_com = True
        if infer_from_com:
            input_inference_spec = tspec.BoundedTensorSpec(
                shape=[2],
                dtype=tf.float64,
                minimum=-1.79769313e+308,
                maximum=1.79769313e+308,
                name='body_com')
        else:
            input_inference_spec = observation_spec

        if tensor_spec.is_discrete(z_spec):
            _preprocessing_combiner = OneHotConcatenateLayer(dim_z)
        else:
            _preprocessing_combiner = DictConcatenateLayer()

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                augmented_observation_spec,
                action_spec,
                preprocessing_combiner=_preprocessing_combiner,
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                augmented_observation_spec,
                preprocessing_combiner=_preprocessing_combiner,
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                augmented_observation_spec,
                action_spec,
                preprocessing_combiner=_preprocessing_combiner,
                fc_layer_params=actor_fc_layers,
                name="actor_net")
            value_net = value_network.ValueNetwork(
                augmented_observation_spec,
                preprocessing_combiner=_preprocessing_combiner,
                fc_layer_params=value_fc_layers,
                name="critic_net")
        inference_net = actor_distribution_network.ActorDistributionNetwork(
            input_tensor_spec=input_inference_spec,
            output_tensor_spec=z_spec,
            fc_layer_params=inference_fc_layers,
            continuous_projection_net=normal_projection_net,
            name="inference_net")

        tf_agent = ppo_diayn_agent.PPODiaynAgent(
            augmented_time_step_spec,
            action_spec,
            z_spec,
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            inference_net=inference_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            entropy_regularization=entropy_regularization,
            kl_posteriors_penalty=kl_posteriors_penalty,
            mock_inference=mock_inference,
            mock_reward=mock_reward,
            infer_from_com=infer_from_com,
            l2_distance=l2_distance,
            rl_steps=rl_steps,
            inference_steps=inference_steps)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        actor_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'diayn_actor'),
                                                 actor_net=actor_net,
                                                 global_step=global_step)
        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'diayn_policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'diayn_replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        inference_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'diayn_inference'),
            inference_net=inference_net,
            global_step=global_step)

        actor_checkpointer.initialize_or_restore()
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        inference_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        # option_length = 200
        # if env_name == "Plane-v1":
        #     option_length = 10
        # dataset = replay_buffer.as_dataset(
        #         num_parallel_calls=3, sample_batch_size=num_parallel_environments,
        #         num_steps=option_length)
        # iterator_dataset = iter(dataset)

        def train_step():
            trajectories = replay_buffer.gather_all()
            #   trajectories, _ = next(iterator_dataset)
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)
                    inference_checkpointer.save(global_step=global_step_val)
                    actor_checkpointer.save(global_step=global_step_val)
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Ejemplo n.º 12
0
def train_eval(
        root_dir,
        # env_name='HalfCheetah-v2',
        # env_load_fn=suite_mujoco.load,
        env_load_fn=None,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=int(1e7),
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,  # use_tf_functions=False,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)

        # eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        # tf_env = tf_py_environment.TFPyEnvironment(
        #     parallel_py_environment.ParallelPyEnvironment(
        #         [lambda: env_load_fn(env_name)] * num_parallel_environments))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.wrap_env(RectEnv()))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: suite_gym.wrap_env(RectEnv())] *
                num_parallel_environments))

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        preprocessing_layers = {
            'target':
            tf.keras.models.Sequential([
                # tf.keras.applications.MobileNetV2(
                #     input_shape=(64, 64, 1), include_top=False, weights=None),
                # tf.keras.layers.Conv2D(1, 6),
                easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)),
                tf.keras.layers.Flatten()
            ]),
            'canvas':
            tf.keras.models.Sequential([
                # tf.keras.applications.MobileNetV2(
                #     input_shape=(64, 64, 1), include_top=False, weights=None),
                # tf.keras.layers.Conv2D(1, 6),
                easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)),
                tf.keras.layers.Flatten()
            ]),
            'coord':
            tf.keras.models.Sequential([
                tf.keras.layers.Dense(64),
                tf.keras.layers.Dense(64),
                tf.keras.layers.Flatten()
            ])
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            max_to_keep=5,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  max_to_keep=5,
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Ejemplo n.º 13
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=1,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=0,
        eval_interval=1000000,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=200,
        rb_checkpoint_interval=300,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1000,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    global interrupted
    dir_key = root_dir
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    # eval_summary_writer = tf.compat.v2.summary.create_file_writer(
    #     eval_dir, flush_millis=summaries_flush_secs * 1000)
    # eval_metrics = [
    #     tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
    #     tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    # ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        train_env = T4TFEnv(metrics_key=dir_key)
        eval_env = T4TFEnv(fake=True)
        tf_env = tf_py_environment.TFPyEnvironment(train_env)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            #             fc_layer_params=actor_fc_layers,
            conv_layer_params=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
            lstm_size=(8, ),
            normal_projection_net=normal_projection_net)
        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_conv_layer_params=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
            lstm_size=(8, ),
            #             observation_fc_layer_params=critic_obs_fc_layers,
            #             action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)
        print(actor_net)
        print(critic_net)
        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        # results = metric_utils.eager_compute(
        #     eval_metrics,
        #     eval_tf_env,
        #     eval_policy,
        #     num_episodes=num_eval_episodes,
        #     train_step=global_step,
        #     summary_writer=eval_summary_writer,
        #     summary_prefix='Metrics',
        # )
        # if eval_metrics_callback is not None:
        #     eval_metrics_callback(results, global_step.numpy())
        # metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=1,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            if interrupted:
                train_env.interrupted = True

            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                if interrupted:
                    train_env.interrupted = True

                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:

                # actor_net.save(os.path.join(root_dir, 'actor'), save_format='tf')
                # critic_net.save(os.path.join(root_dir, 'critic'), save_format='tf')

                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)

                obs_shape = time_step.observation.shape
                tf.compat.v2.summary.image(
                    name='input_image',
                    data=np.reshape(time_step.observation,
                                    (1, obs_shape[1], obs_shape[2], 3)),
                    step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()

            print('global step: %d' % global_step_val)
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(
    ),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40, ),  # lstm size for actor
    critic_lstm_size=(40, ),  # lstm size for critic
    actor_output_fc_layers=(100, ),  # lstm output
    critic_output_fc_layers=(100, ),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1
):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
    # Setup GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    # Get train and eval directories
    root_dir = os.path.expanduser(root_dir)
    root_dir = os.path.join(root_dir, env_name, experiment_name)

    # Get summary writers
    summary_writer = tf.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    # Eval metrics
    eval_metrics = [
        tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Whether to record for summary
    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create Carla environment
        if agent_name == 'latent_sac':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names +
                                                 mask_names,
                                                 action_repeat=action_repeat)
        elif agent_name == 'dqn':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 discrete=True,
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)
        else:
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)

        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
        fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

        # Specs
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        ## Make tf agent
        if agent_name == 'latent_sac':
            # Get model network for latent sac
            if model_network_ctor_type == 'hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
            elif model_network_ctor_type == 'non-hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
            else:
                raise NotImplementedError
            model_net = model_network_ctor(input_names,
                                           input_names + mask_names)

            # Get the latent spec
            latent_size = model_net.latent_size
            latent_observation_spec = tensor_spec.TensorSpec((latent_size, ),
                                                             dtype=tf.float32)
            latent_time_step_spec = ts.time_step_spec(
                observation_spec=latent_observation_spec)

            # Get actor and critic net
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                latent_observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=normal_projection_net)
            critic_net = critic_network.CriticNetwork(
                (latent_observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers)

            # Build the inner SAC agent based on latent space
            inner_agent = sac_agent.SacAgent(
                latent_time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
            inner_agent.initialize()

            # Build the latent sac agent
            tf_agent = latent_sac_agent.LatentSACAgent(
                time_step_spec,
                action_spec,
                inner_agent=inner_agent,
                model_network=model_net,
                model_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=model_learning_rate),
                model_batch_size=model_batch_size,
                num_images_per_summary=num_images_per_summary,
                sequence_length=sequence_length,
                gradient_clipping=gradient_clipping,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                fps=fps)

        else:
            # Set up preprosessing layers for dictionary observation inputs
            preprocessing_layers = collections.OrderedDict()
            for name in input_names:
                preprocessing_layers[name] = Preprocessing_Layer(32, 256)
            if len(input_names) < 2:
                preprocessing_combiner = None

            if agent_name == 'dqn':
                q_rnn_net = q_rnn_network.QRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_rnn_net,
                    epsilon_greedy=epsilon_greedy,
                    n_step_update=1,
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=q_learning_rate),
                    td_errors_loss_fn=common.element_wise_squared_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            elif agent_name == 'ddpg' or agent_name == 'td3':
                actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                if agent_name == 'ddpg':
                    tf_agent = ddpg_agent.DdpgAgent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        ou_stddev=ou_stddev,
                        ou_damping=ou_damping,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars)
                elif agent_name == 'td3':
                    tf_agent = td3_agent.Td3Agent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        exploration_noise_std=exploration_noise_std,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        actor_update_period=actor_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars,
                        train_step_counter=global_step)

            elif agent_name == 'sac':
                actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers,
                    continuous_projection_net=normal_projection_net)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = sac_agent.SacAgent(
                    time_step_spec,
                    action_spec,
                    actor_network=actor_distribution_rnn_net,
                    critic_network=critic_rnn_net,
                    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=actor_learning_rate),
                    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=critic_learning_rate),
                    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=alpha_learning_rate),
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    td_errors_loss_fn=tf.math.
                    squared_difference,  # make critic loss dimension compatible
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            else:
                raise NotImplementedError

        tf_agent.initialize()

        # Get replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,  # No parallel environments
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        # Train metrics
        env_steps = tf_metrics.EnvironmentSteps()
        average_return = tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        # Get policies
        # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_policy = tf_agent.policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        collect_policy = tf_agent.collect_policy

        # Checkpointers
        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        # Collect driver
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        # Optimize the performance by using tf functions
        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        if agent_name == 'latent_sac':
            compute_summaries(eval_metrics,
                              eval_tf_env,
                              eval_policy,
                              train_step=global_step,
                              summary_writer=summary_writer,
                              num_episodes=1,
                              num_episodes_to_render=1,
                              model_net=model_net,
                              fps=10,
                              image_keys=input_names + mask_names)
        else:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=1,
                train_step=env_steps.result(),
                summary_writer=summary_writer,
                summary_prefix='Eval',
            )
            metric_utils.log_metrics(eval_metrics)

        # Dataset generates trajectories with shape [Bxslx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        # Get train step
        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        train_step = common.function(train_step)

        if agent_name == 'latent_sac':

            def train_model_step():
                experience, _ = next(iterator)
                return tf_agent.train_model(experience)

            train_model_step = common.function(train_model_step)

        # Training initializations
        time_step = None
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Start training
        for iteration in range(num_iterations):
            start_time = time.time()

            if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
                train_model_step()
            else:
                # Run collect
                time_step, _ = collect_driver.run(time_step=time_step)

                # Train an iteration
                for _ in range(train_steps_per_iteration):
                    train_step()

            time_acc += time.time() - start_time

            # Log training information
            if global_step.numpy() % log_interval == 0:
                logging.info('env steps = %d, average return = %f',
                             env_steps.result(), average_return.result())
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / time_acc
                logging.info('%.3f env steps/sec', env_steps_per_sec)
                tf.summary.scalar(name='env_steps_per_sec',
                                  data=env_steps_per_sec,
                                  step=env_steps.result())
                time_acc = 0
                env_steps_before = env_steps.result().numpy()

            # Get training metrics
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())

            # Evaluation
            if global_step.numpy() % eval_interval == 0:
                # Log evaluation metrics
                if agent_name == 'latent_sac':
                    compute_summaries(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        train_step=global_step,
                        summary_writer=summary_writer,
                        num_episodes=num_eval_episodes,
                        num_episodes_to_render=num_images_per_summary,
                        model_net=model_net,
                        fps=10,
                        image_keys=input_names + mask_names)
                else:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    metric_utils.log_metrics(eval_metrics)

            # Save checkpoints
            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 15
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_allowlist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        critic_learning_rate=3e-4,
        train_sequence_length=20,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_allowlist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_allowlist=[observations_allowlist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d episodes '
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            # Reduce filter_fn over full trajectory sampled. The sequence is kept only
            # if all elements except for the last one pass the filter. This is to
            # allow training on terminal steps.
            return tf.reduce_all(~trajectories.is_boundary()[:-1])

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=train_sequence_length + 1).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            # TODO(b/152648849)
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Ejemplo n.º 16
0
def train_eval(
        root_dir,
        gpu='1',
        env_load_fn=None,
        model_ids=None,
        eval_env_mode='headless',
        conv_layer_params=None,
        encoder_fc_layers=[256],
        actor_fc_layers=[256, 256],
        value_fc_layers=[256, 256],
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
        num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        eval_only=False,
        eval_deterministic=False,
        num_parallel_environments_eval=1,
        model_ids_eval=None,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        rb_checkpoint_interval=500,
        log_interval=10,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if model_ids is None:
            model_ids = [None] * num_parallel_environments
        else:
            assert len(model_ids) == num_parallel_environments,\
                'model ids provided, but length not equal to num_parallel_environments'

        if model_ids_eval is None:
            model_ids_eval = [None] * num_parallel_environments_eval
        else:
            assert len(model_ids_eval) == num_parallel_environments_eval,\
                'model ids eval provided, but length not equal to num_parallel_environments_eval'

        tf_py_env = [lambda model_id=model_ids[i]: env_load_fn(model_id, 'headless', gpu)
                     for i in range(num_parallel_environments)]
        tf_env = tf_py_environment.TFPyEnvironment(parallel_py_environment.ParallelPyEnvironment(tf_py_env))

        if eval_env_mode == 'gui':
            assert num_parallel_environments_eval == 1, 'only one GUI env is allowed'
        eval_py_env = [lambda model_id=model_ids_eval[i]: env_load_fn(model_id, eval_env_mode, gpu)
                       for i in range(num_parallel_environments_eval)]
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(eval_py_env)

        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = tf_env.observation_spec()
        action_spec = tf_env.action_spec()
        print('observation_spec', observation_spec)
        print('action_spec', action_spec)

        glorot_uniform_initializer = tf.compat.v1.keras.initializers.glorot_uniform()
        preprocessing_layers = {
            'depth_seg': tf.keras.Sequential(mlp_layers(
                conv_layer_params=conv_layer_params,
                fc_layer_params=encoder_fc_layers,
                kernel_initializer=glorot_uniform_initializer,
            )),
            'sensor': tf.keras.Sequential(mlp_layers(
                conv_layer_params=None,
                fc_layer_params=encoder_fc_layers,
                kernel_initializer=glorot_uniform_initializer,
            )),
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                observation_spec,
                action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                observation_spec,
                action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=actor_fc_layers,
                kernel_initializer=glorot_uniform_initializer
            )
            value_net = value_network.ValueNetwork(
                observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=value_fc_layers,
                kernel_initializer=glorot_uniform_initializer
            )

        tf_agent = ppo_agent.PPOAgent(
            time_step_spec,
            action_spec,
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        if eval_deterministic:
            eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)
        else:
            eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.collect_policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                buffer_size=100,
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=100,
                batch_size=num_parallel_environments),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration * num_parallel_environments).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(train_metric.tf_summaries(
                train_step=global_step, step_metrics=step_metrics))

        with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(
                    train_step=global_step, step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with sess.as_default():
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            if eval_only:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=False,
                    log=True,
                )
                episodes = eval_py_env.get_stored_episodes()
                episodes = [episode for sublist in episodes for episode in sublist][:num_eval_episodes]
                metrics = episode_utils.get_metrics(episodes)
                for key in sorted(metrics.keys()):
                    print(key, ':', metrics[key])

                save_path = os.path.join(eval_dir, 'episodes_eval.pkl')
                episode_utils.save(episodes, save_path)
                print('EVAL DONE')
                return

            common.initialize_uninitialized_variables(sess)
            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec', data=steps_per_second_ph,
                step=global_step)

            global_step_val = sess.run(global_step)
            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(True):
                        with tf.name_scope('Metrics/'):
                            episodes = eval_py_env.get_stored_episodes()
                            episodes = [episode for sublist in episodes for episode in sublist][:num_eval_episodes]
                            metrics = episode_utils.get_metrics(episodes)
                            for key in sorted(metrics.keys()):
                                print(key, ':', metrics[key])
                                metric_op = tf.compat.v2.summary.scalar(name=key,
                                                                        data=metrics[key],
                                                                        step=global_step_val)
                                sess.run(metric_op)
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss, _ = sess.run([train_op, summary_ops])
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val, total_loss)
                    steps_per_sec = (
                            (global_step_val - timed_at_step) / (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(
                        steps_per_second_summary,
                        feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info('%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)

        sess.close()
Ejemplo n.º 17
0
def train_eval(
        root_dir,
        env_name='SocialBot-GroceryGround-v0',
        env_load_fn=suite_socialbot.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(192, 64),
        value_fc_layers=(192, 64),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=8,
        num_parallel_environments=8,
        replay_buffer_capacity=2001,  # Per-environment
        # Params for train
    num_epochs=16,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=500,
        # Params for summaries and logging
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for GroceryGround."""

    # Set summary writer and eval metrics
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create envs and optimizer
        tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)
        # Create actor and value network
        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        # Create ppo agent
        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Create metrics, replay_buffer and collect_driver
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)
        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        # Evaluate and train
        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            trajectories = replay_buffer.gather_all()
            total_loss, _ = tf_agent.train(experience=trajectories)
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0
Ejemplo n.º 18
0
def train():
    num_iterations=1000000
    # Params for networks.
    actor_fc_layers=(128, 64)
    actor_output_fc_layers=(64,)
    actor_lstm_size=(32,)
    critic_obs_fc_layers=None
    critic_action_fc_layers=None
    critic_joint_fc_layers=(128,)
    critic_output_fc_layers=(64,)
    critic_lstm_size=(32,)
    num_parallel_environments=1
    # Params for collect
    initial_collect_episodes=1
    collect_episodes_per_iteration=1
    replay_buffer_capacity=1000000
    # Params for target update
    target_update_tau=0.05
    target_update_period=5
    # Params for train
    train_steps_per_iteration=1
    batch_size=256
    critic_learning_rate=3e-4
    train_sequence_length=20
    actor_learning_rate=3e-4
    alpha_learning_rate=3e-4
    td_errors_loss_fn=tf.math.squared_difference
    gamma=0.99
    reward_scale_factor=0.1
    gradient_clipping=None
    use_tf_functions=True
    # Params for eval
    num_eval_episodes=30
    eval_interval=10000

    debug_summaries=False
    summarize_grads_and_vars=False

    global_step = tf.compat.v1.train.get_or_create_global_step()
    
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()
    
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        input_fc_layer_params=actor_fc_layers,
        lstm_size=actor_lstm_size,
        output_fc_layer_params=actor_output_fc_layers,
        continuous_projection_net=tanh_normal_projection_network
        .TanhNormalProjectionNetwork)

    critic_net = critic_rnn_network.CriticRnnNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        lstm_size=critic_lstm_size,
        output_fc_layer_params=critic_output_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')
    
    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()
    
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]
    
    env_steps = tf_metrics.EnvironmentSteps(prefix='Train')

    
    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    collect_policy = tf_agent.collect_policy
    
    
    initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_episodes=initial_collect_episodes)

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=replay_observer,
        num_episodes=collect_episodes_per_iteration)

    if use_tf_functions:
        initial_collect_driver.run = common.function(initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)
        
    if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
        logging.info(
          'Initializing replay buffer by collecting experience for %d episodes '
          'with a random policy.', initial_collect_episodes)
        initial_collect_driver.run()


    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)
    
    time_acc = 0
    env_steps_before = env_steps.result().numpy()

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      # Reduce filter_fn over full trajectory sampled. The sequence is kept only
      # if all elements except for the last one pass the filter. This is to
      # allow training on terminal steps.
      return tf.reduce_all(~trajectories.is_boundary()[:-1])
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=train_sequence_length+1).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
        experience, _ = next(iterator)
        return tf_agent.train(experience)

    if use_tf_functions:
        train_step = common.function(train_step)

    for _ in range(num_iterations):
        start_env_steps = env_steps.result()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
        episode_steps = env_steps.result() - start_env_steps
        # TODO(b/152648849)
        for _ in range(episode_steps):
            for _ in range(train_steps_per_iteration):
                train_step()
Ejemplo n.º 19
0
summaries_flush_secs = 1,
use_tf_functions = True,
debug_summaries = False,
summarize_grads_and_vars = False
initial_collect_steps = 100

importance_ratio_clipping = 0.2

# +
global_step = tf.compat.v1.train.get_or_create_global_step()
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

if use_rnns:
    actor_net = (actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        input_fc_layer_params=actor_fc_layers,
        output_fc_layer_params=None))
    value_net = (value_rnn_network.ValueRnnNetwork(
        tf_env.observation_spec(),
        input_fc_layer_params=value_fc_layers,
        output_fc_layer_params=None))
else:
    actor_net = (actor_distribution_network.ActorDistributionNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
        activation_fn=tf.keras.activations.tanh))
    value_net = (value_network.ValueNetwork(
        tf_env.observation_spec(),
        fc_layer_params=value_fc_layers,
Ejemplo n.º 20
0
def create_ppo_agent(env, global_step, FLAGS):

    actor_fc_layers = (512, 256)
    value_fc_layers = (512, 256)

    lstm_fc_input = (1024, 512)
    lstm_size = (256, )
    lstm_fc_output = (256, 256)

    minimap_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=16,
                               kernel_size=(5, 5),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Conv2D(filters=32,
                               kernel_size=(3, 3),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=256, activation='relu')
    ])

    screen_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=16,
                               kernel_size=(5, 5),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Conv2D(filters=32,
                               kernel_size=(3, 3),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=256, activation='relu')
    ])

    info_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Dense(units=128, activation='relu'),
        tf.keras.layers.Dense(units=128, activation='relu')
    ])

    entities_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Conv1D(filters=4, kernel_size=4, activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=256, activation='relu')
    ])

    actor_preprocessing_layers = {
        'minimap': minimap_preprocessing,
        'screen': screen_preprocessing,
        'info': info_preprocessing,
        'entities': entities_preprocessing,
    }
    actor_preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    if FLAGS.use_lstms:
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            env.observation_spec(),
            env.action_spec(),
            preprocessing_layers=actor_preprocessing_layers,
            preprocessing_combiner=actor_preprocessing_combiner,
            input_fc_layer_params=lstm_fc_input,
            output_fc_layer_params=lstm_fc_output,
            lstm_size=lstm_size)
    else:
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            input_tensor_spec=env.observation_spec(),
            output_tensor_spec=env.action_spec(),
            preprocessing_layers=actor_preprocessing_layers,
            preprocessing_combiner=actor_preprocessing_combiner,
            fc_layer_params=actor_fc_layers,
            activation_fn=tf.keras.activations.tanh)

    value_preprocessing_layers = {
        'minimap': minimap_preprocessing,
        'screen': screen_preprocessing,
        'info': info_preprocessing,
        'entities': entities_preprocessing,
    }
    value_preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    if FLAGS.use_lstms:
        value_net = value_rnn_network.ValueRnnNetwork(
            env.observation_spec(),
            preprocessing_layers=value_preprocessing_layers,
            preprocessing_combiner=value_preprocessing_combiner,
            input_fc_layer_params=lstm_fc_input,
            output_fc_layer_params=lstm_fc_output,
            lstm_size=lstm_size)
    else:
        value_net = value_network.ValueNetwork(
            env.observation_spec(),
            preprocessing_layers=value_preprocessing_layers,
            preprocessing_combiner=value_preprocessing_combiner,
            fc_layer_params=value_fc_layers,
            activation_fn=tf.keras.activations.tanh)

    optimizer = tf.compat.v1.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate)

    # commented out values are the defaults
    tf_agent = my_ppo_agent.PPOAgent(
        time_step_spec=env.time_step_spec(),
        action_spec=env.action_spec(),
        optimizer=optimizer,
        actor_net=actor_net,
        value_net=value_net,
        importance_ratio_clipping=0.1,
        # lambda_value=0.95,
        discount_factor=0.95,
        entropy_regularization=0.003,
        # policy_l2_reg=0.0,
        # value_function_l2_reg=0.0,
        # shared_vars_l2_reg=0.0,
        # value_pred_loss_coef=0.5,
        num_epochs=FLAGS.num_epochs,
        use_gae=True,
        use_td_lambda_return=True,
        normalize_rewards=FLAGS.norm_rewards,
        reward_norm_clipping=0.0,
        normalize_observations=True,
        # log_prob_clipping=0.0,
        # KL from here...
        # To disable the fixed KL cutoff penalty, set the kl_cutoff_factor parameter to 0.0
        kl_cutoff_factor=0.0,
        kl_cutoff_coef=0.0,
        # To disable the adaptive KL penalty, set the initial_adaptive_kl_beta parameter to 0.0
        initial_adaptive_kl_beta=0.0,
        adaptive_kl_target=0.00,
        adaptive_kl_tolerance=0.0,  # ...to here.
        # gradient_clipping=None,
        value_clipping=0.5,
        # check_numerics=False,
        # compute_value_and_advantage_in_train=True,
        # update_normalizers_in_train=True,
        # debug_summaries=False,
        # summarize_grads_and_vars=False,
        train_step_counter=global_step,
        # name='PPOClipAgent'
    )

    tf_agent.initialize()
    return tf_agent
Ejemplo n.º 21
0
def train_eval(
        root_dir,
        random_seed=0,
        num_epochs=1000000,
        # Params for train
        normalize_observations=True,
        normalize_rewards=True,
        discount_factor=1.0,
        lr=1e-5,
        lr_schedule=None,
        num_policy_updates=20,
        initial_adaptive_kl_beta=0.0,
        kl_cutoff_factor=0,
        importance_ratio_clipping=0.2,
        value_pred_loss_coef=0.5,
        gradient_clipping=None,
        entropy_regularization=0.0,
        log_prob_clipping=0.0,
        # Params for log, eval, save
        eval_interval=100,
        save_interval=1000,
        checkpoint_interval=None,
        summary_interval=100,
        do_evaluation=True,
        # Params for data collection
        train_batch_size=10,
        eval_batch_size=100,
        collect_driver=None,
        eval_driver=None,
        replay_buffer_capacity=20000,
        # Policy and value networks
        ActorNet=actor_distribution_network.ActorDistributionNetwork,
        zero_means_kernel_initializer=False,
        init_action_stddev=0.35,
        actor_fc_layers=(),
        value_fc_layers=(),
        use_rnn=True,
        actor_lstm_size=(12, ),
        value_lstm_size=(12, ),
        **kwargs):
    """ A simple train and eval for PPO agent. 
    
    Args:
        root_dir (str): directory for saving training and evalutaion data
        random_seed (int): seed for random number generator
        num_epochs (int): number of training epochs. At each epoch a batch
            of data is collected according to one stochastic policy, and then
            the policy is updated.
        normalize_observations (bool): flag for normalization of observations.
            Uses StreamingTensorNormalizer which normalizes based on the whole
            history of observations.
        normalize_rewards (bool): flag for normalization of rewards.
            Uses StreamingTensorNormalizer which normalizes based on the whole
            history of rewards.
        discount_factor (float): rewards discout factor, should be in (0,1]
        lr (float): learning rate for Adam optimizer
        lr_schedule (callable: int -> float, optional): function to schedule 
            the learning rate annealing. Takes as argument the int epoch
            number and returns float value of the learning rate. 
        num_policy_updates (int): number of policy gradient steps to do on each
            epoch of training. In PPO this is typically >1.
        initial_adaptive_kl_beta (float): see tf-agents PPO docs 
        kl_cutoff_factor (float): see tf-agents PPO docs 
        importance_ratio_clipping (float): clipping value for importance ratio.
            Should demotivate the policy from doing updates that significantly
            change the policy. Should be in (0,1]
        value_pred_loss_coef (float): weight coefficient for quadratic value
            estimation loss.
        gradient_clipping (float): gradient clipping coefficient.
        entropy_regularization (float): entropy regularization loss coefficient.
        log_prob_clipping (float): +/- value for clipping log probs to prevent 
            inf / NaN values.  Default: no clipping.
        eval_interval (int): interval between evaluations, counted in epochs.
        save_interval (int): interval between savings, counted in epochs. It
            updates the log file and saves the deterministic policy.
        checkpoint_interval (int): interval between saving checkpoints, counted 
            in epochs. Overwrites the previous saved one. Defaults to None, 
            in which case checkpoints are not saved.
        summary_interval (int): interval between summary writing, counted in 
            epochs. tf-agents takes care of summary writing; results can be
            later displayed in tensorboard.
        do_evaluation (bool): flag to interleave training epochs with 
            evaluation epochs.
        train_batch_size (int): training batch size, collected in parallel.
        eval_batch_size (int): batch size for evaluation of the policy.
        collect_driver (Driver): driver for training data collection
        eval_driver (Driver): driver for evaluation data collection
        replay_buffer_capacity (int): How many transition tuples the buffer 
            can store. The buffer is emptied and re-populated at each epoch.
        ActorNet (network.DistributionNetwork): a distribution actor network 
            to use for training. The default is ActorDistributionNetwork from
            tf-agents, but this can also be customized.
        zero_means_kernel_initializer (bool): flag to initialize the means
            projection network with zeros. If this flag is not set, it will
            use default tf-agent random initializer.
        init_action_stddev (float): initial stddev of the normal action dist.
        actor_fc_layers (tuple): sizes of fully connected layers in actor net.
        value_fc_layers (tuple): sizes of fully connected layers in value net.
        use_rnn (bool): whether to use LSTM units in the neural net.
        actor_lstm_size (tuple): sizes of LSTM layers in actor net.
        value_lstm_size (tuple): sizes of LSTM layers in value net.
    """
    # --------------------------------------------------------------------
    # --------------------------------------------------------------------
    tf.compat.v1.set_random_seed(random_seed)

    # Setup directories within 'root_dir'
    if not os.path.isdir(root_dir): os.mkdir(root_dir)
    policy_dir = os.path.join(root_dir, 'policy')
    checkpoint_dir = os.path.join(root_dir, 'checkpoint')
    logfile = os.path.join(root_dir, 'log.hdf5')
    train_dir = os.path.join(root_dir, 'train_summaries')

    # Create tf summary writer
    train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir)
    train_summary_writer.set_as_default()
    summary_interval *= num_policy_updates
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        # Define action and observation specs
        observation_spec = collect_driver.observation_spec()
        action_spec = collect_driver.action_spec()

        # Preprocessing: flatten and concatenate observation components
        preprocessing_layers = {
            obs: tf.keras.layers.Flatten()
            for obs in observation_spec.keys()
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        # Define actor network and value network
        if use_rnn:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                input_tensor_spec=observation_spec,
                output_tensor_spec=action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=None,
                lstm_size=actor_lstm_size,
                output_fc_layer_params=actor_fc_layers)

            value_net = value_rnn_network.ValueRnnNetwork(
                input_tensor_spec=observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                input_fc_layer_params=None,
                lstm_size=value_lstm_size,
                output_fc_layer_params=value_fc_layers)
        else:
            npn = actor_distribution_network._normal_projection_net
            normal_projection_net = lambda specs: npn(
                specs,
                zero_means_kernel_initializer=zero_means_kernel_initializer,
                init_action_stddev=init_action_stddev)

            actor_net = ActorNet(
                input_tensor_spec=observation_spec,
                output_tensor_spec=action_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=normal_projection_net)

            value_net = value_network.ValueNetwork(
                input_tensor_spec=observation_spec,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner,
                fc_layer_params=value_fc_layers)

        # Create PPO agent
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr)
        tf_agent = ppo_agent.PPOAgent(
            time_step_spec=collect_driver.time_step_spec(),
            action_spec=action_spec,
            optimizer=optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_policy_updates,
            train_step_counter=global_step,
            discount_factor=discount_factor,
            normalize_observations=normalize_observations,
            normalize_rewards=normalize_rewards,
            initial_adaptive_kl_beta=initial_adaptive_kl_beta,
            kl_cutoff_factor=kl_cutoff_factor,
            importance_ratio_clipping=importance_ratio_clipping,
            gradient_clipping=gradient_clipping,
            value_pred_loss_coef=value_pred_loss_coef,
            entropy_regularization=entropy_regularization,
            log_prob_clipping=log_prob_clipping,
            debug_summaries=True)

        tf_agent.initialize()
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        # Create replay buffer and collection driver
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=train_batch_size,
            max_length=replay_buffer_capacity)

        def train_step():
            experience = replay_buffer.gather_all()
            return tf_agent.train(experience)

        tf_agent.train = common.function(tf_agent.train)

        avg_return_metric = tf_metrics.AverageReturnMetric(
            batch_size=eval_batch_size, buffer_size=eval_batch_size)

        collect_driver.setup(collect_policy, [replay_buffer.add_batch])
        eval_driver.setup(eval_policy, [avg_return_metric])

        # Create a checkpointer and load the saved agent
        train_checkpointer = common.Checkpointer(ckpt_dir=checkpoint_dir,
                                                 max_to_keep=1,
                                                 agent=tf_agent,
                                                 policy=tf_agent.policy,
                                                 replay_buffer=replay_buffer,
                                                 global_step=global_step)

        train_checkpointer.initialize_or_restore()
        global_step = tf.compat.v1.train.get_global_step()

        # Saver for the deterministic policy
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        # Evaluate policy once before training
        if do_evaluation:
            eval_driver.run(0)
            avg_return = avg_return_metric.result().numpy()
            avg_return_metric.reset()
            log = {
                'returns': [avg_return],
                'epochs': [0],
                'policy_steps': [0],
                'experience_time': [0.0],
                'train_time': [0.0]
            }
            print('-------------------')
            print('Epoch 0')
            print('  Policy steps: 0')
            print('  Experience time: 0.00 mins')
            print('  Policy train time: 0.00 mins')
            print('  Average return: %.5f' % avg_return)

        # Save initial random policy
        path = os.path.join(policy_dir, ('0').zfill(6))
        saved_model.save(path)

        # Training loop
        train_timer = timer.Timer()
        experience_timer = timer.Timer()
        for epoch in range(1, num_epochs + 1):
            # Collect new experience
            experience_timer.start()
            collect_driver.run(epoch)
            experience_timer.stop()
            # Update the policy
            train_timer.start()
            if lr_schedule: optimizer._lr = lr_schedule(epoch)
            train_loss = train_step()
            replay_buffer.clear()
            train_timer.stop()

            if (epoch % eval_interval == 0) and do_evaluation:
                # Evaluate the policy
                eval_driver.run(epoch)
                avg_return = avg_return_metric.result().numpy()
                avg_return_metric.reset()

                # Print out and log all metrics
                print('-------------------')
                print('Epoch %d' % epoch)
                print('  Policy steps: %d' % (epoch * num_policy_updates))
                print('  Experience time: %.2f mins' %
                      (experience_timer.value() / 60))
                print('  Policy train time: %.2f mins' %
                      (train_timer.value() / 60))
                print('  Average return: %.5f' % avg_return)
                log['epochs'].append(epoch)
                log['policy_steps'].append(epoch * num_policy_updates)
                log['returns'].append(avg_return)
                log['experience_time'].append(experience_timer.value())
                log['train_time'].append(train_timer.value())
                # Save updated log
                save_log(log, logfile, ('%d' % epoch).zfill(6))

            if epoch % save_interval == 0:
                # Save deterministic policy
                path = os.path.join(policy_dir, ('%d' % epoch).zfill(6))
                saved_model.save(path)

            if checkpoint_interval is not None and \
                epoch % checkpoint_interval == 0:
                # Save training checkpoint
                train_checkpointer.save(global_step)
        collect_driver.finish_training()
        eval_driver.finish_training()
Ejemplo n.º 22
0
  def testStatelessValueNetTrain(self, compute_value_and_advantage_in_train):
    counter = common.create_variable('test_train_counter')
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        self._time_step_spec.observation,
        self._action_spec,
        input_fc_layer_params=None,
        output_fc_layer_params=None,
        lstm_size=(20,))
    value_net = value_network.ValueNetwork(
        self._time_step_spec.observation, fc_layer_params=None)
    agent = ppo_agent.PPOAgent(
        self._time_step_spec,
        self._action_spec,
        optimizer=tf.compat.v1.train.AdamOptimizer(),
        actor_net=actor_net,
        value_net=value_net,
        num_epochs=1,
        train_step_counter=counter,
        compute_value_and_advantage_in_train=compute_value_and_advantage_in_train
    )
    observations = tf.constant([
        [[1, 2], [3, 4], [5, 6]],
        [[1, 2], [3, 4], [5, 6]],
    ],
                               dtype=tf.float32)

    mid_time_step_val = ts.StepType.MID.tolist()
    time_steps = ts.TimeStep(
        step_type=tf.constant([[mid_time_step_val] * 3] * 2, dtype=tf.int32),
        reward=tf.constant([[1] * 3] * 2, dtype=tf.float32),
        discount=tf.constant([[1] * 3] * 2, dtype=tf.float32),
        observation=observations)
    actions = tf.constant([[[0], [1], [1]], [[0], [1], [1]]], dtype=tf.float32)

    action_distribution_parameters = {
        'loc': tf.constant([[[0.0]] * 3] * 2, dtype=tf.float32),
        'scale': tf.constant([[[1.0]] * 3] * 2, dtype=tf.float32),
    }
    value_preds = tf.constant([[9., 15., 21.], [9., 15., 21.]],
                              dtype=tf.float32)

    policy_info = {
        'dist_params': action_distribution_parameters,
    }
    if not compute_value_and_advantage_in_train:
      policy_info['value_prediction'] = value_preds
    experience = trajectory.Trajectory(time_steps.step_type, observations,
                                       actions, policy_info,
                                       time_steps.step_type, time_steps.reward,
                                       time_steps.discount)
    if not compute_value_and_advantage_in_train:
      experience = agent._preprocess(experience)

    if tf.executing_eagerly():
      loss = lambda: agent.train(experience)
    else:
      loss = agent.train(experience)

    self.evaluate(tf.compat.v1.initialize_all_variables())

    loss_type = self.evaluate(loss)
    loss_numpy = loss_type.loss
    # Assert that loss is not zero as we are training in a non-episodic env.
    self.assertNotEqual(
        loss_numpy,
        0.0,
        msg=('Loss is exactly zero, looks like no training '
             'was performed due to incomplete episodes.'))
Ejemplo n.º 23
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(512, 256, 256, 30),
        value_fc_layers=(512, 256, 256, 25),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=NumEpisodes,
        num_parallel_environments=1,
        replay_buffer_capacity=10000,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=5e-4,
        # Params for eval
        num_eval_episodes=5,
        eval_interval=500,
        # Params for summaries and logging
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train6')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        #eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        #tf_env = tf_py_environment.TFPyEnvironment(
        #    parallel_py_environment.ParallelPyEnvironment(
        #        [lambda: env_load_fn(env_name)] * num_parallel_environments))
        env = xSpace()
        if isinstance(env, py_environment.PyEnvironment):
            eval_tf_env = tf_py_environment.TFPyEnvironment(env)
            tf_env = tf_py_environment.TFPyEnvironment(env)
            print("Py Env")
        elif isinstance(env, tf_environment.TFEnvironment):
            eval_tf_env = env
            tf_env = env
            print("TF Env")
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            lambda_value=0.98,
            discount_factor=0.995,
            #value_pred_loss_coef=0.005,
            use_gae=True,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            normalize_observations=False)
        tf_agent.initialize()
        print("************ INITIALIZING **********************")
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy
        # this for tensorbaord
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            eval_tf_env.reset()
            if global_step_val % eval_interval == 0:
                #tf_env.ResetMattData()
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
            #print("eager compute completed")
            eval_tf_env.reset()
            start_time = time.time()
            collect_driver.run()
            #print("collect completed")
            collect_time += time.time() - start_time
            print("collect_time:" + str(collect_time))
            start_time = time.time()
            trajectories = replay_buffer.gather_all()
            #print("start train completed")
            #pdb.set_trace()
            #k=trajectories[5]
            #xMean=tf.reduce_mean(k)

            print('training...')
            total_loss, _ = tf_agent.train(experience=trajectories)
            print('training complete. total loss:' + str(total_loss))
            #print("end train completed")
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Ejemplo n.º 24
0
def construct_attention_networks(observation_spec,
                                 action_spec,
                                 use_rnns=True,
                                 actor_fc_layers=(200, 100),
                                 value_fc_layers=(200, 100),
                                 lstm_size=(128,),
                                 conv_filters=8,
                                 conv_kernel=3,
                                 scalar_fc=5,
                                 scalar_name='direction',
                                 scalar_dim=4):
  """Creates an actor and critic network designed for use with MultiGrid.

  A convolution layer processes the image and a dense layer processes the
  direction the agent is facing. These are fed into some fully connected layers
  and an LSTM.

  Args:
    observation_spec: A tf-agents observation spec.
    action_spec: A tf-agents action spec.
    use_rnns: If True, will construct RNN networks.
    actor_fc_layers: Dimension and number of fully connected layers in actor.
    value_fc_layers: Dimension and number of fully connected layers in critic.
    lstm_size: Number of cells in each LSTM layers.
    conv_filters: Number of convolution filters.
    conv_kernel: Size of the convolution kernel.
    scalar_fc: Number of neurons in the fully connected layer processing the
      scalar input.
    scalar_name: Name of the scalar input.
    scalar_dim: Highest possible value for the scalar input. Used to convert to
      one-hot representation.

  Returns:
    A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork
    for the critic.
  """
  preprocessing_layers = {
      'image':
          tf.keras.models.Sequential([
              cast_and_scale(),
              tf.keras.layers.Conv2D(conv_filters, conv_kernel, padding='same'),
              tf.keras.layers.ReLU(),
          ]),
      'policy_state':
          tf.keras.layers.Lambda(lambda x: x)
  }
  if scalar_name in observation_spec:
    preprocessing_layers[scalar_name] = tf.keras.models.Sequential(
        [one_hot_layer(scalar_dim),
         tf.keras.layers.Dense(scalar_fc)])
  if 'position' in observation_spec:
    preprocessing_layers['position'] = tf.keras.models.Sequential(
        [cast_and_scale(), tf.keras.layers.Dense(scalar_fc)])

  preprocessing_nest = tf.nest.map_structure(lambda l: None,
                                             preprocessing_layers)
  flat_observation_spec = nest_utils.flatten_up_to(
      preprocessing_nest,
      observation_spec,
  )
  image_index_flat = flat_observation_spec.index(observation_spec['image'])
  network_state_index_flat = flat_observation_spec.index(
      observation_spec['policy_state'])
  image_shape = observation_spec['image'].shape  # N x H x W x D
  preprocessing_combiner = AttentionCombinerConv(image_index_flat,
                                                 network_state_index_flat,
                                                 image_shape)

  if use_rnns:
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        input_fc_layer_params=actor_fc_layers,
        output_fc_layer_params=None,
        lstm_size=lstm_size)
    value_net = value_rnn_network.ValueRnnNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        input_fc_layer_params=value_fc_layers,
        output_fc_layer_params=None)
  else:
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        fc_layer_params=actor_fc_layers,
        activation_fn=tf.keras.activations.tanh)
    value_net = value_network.ValueNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        fc_layer_params=value_fc_layers,
        activation_fn=tf.keras.activations.tanh)

  return actor_net, value_net
Ejemplo n.º 25
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        lstm_size=(20, ),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None,
                lstm_size=lstm_size)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Ejemplo n.º 26
0
def construct_multigrid_networks(observation_spec,
                                 action_spec,
                                 use_rnns=True,
                                 actor_fc_layers=(200, 100),
                                 value_fc_layers=(200, 100),
                                 lstm_size=(128,),
                                 conv_filters=8,
                                 conv_kernel=3,
                                 scalar_fc=5,
                                 scalar_name='direction',
                                 scalar_dim=4,
                                 random_z=False,
                                 xy_dim=None):
  """Creates an actor and critic network designed for use with MultiGrid.

  A convolution layer processes the image and a dense layer processes the
  direction the agent is facing. These are fed into some fully connected layers
  and an LSTM.

  Args:
    observation_spec: A tf-agents observation spec.
    action_spec: A tf-agents action spec.
    use_rnns: If True, will construct RNN networks.
    actor_fc_layers: Dimension and number of fully connected layers in actor.
    value_fc_layers: Dimension and number of fully connected layers in critic.
    lstm_size: Number of cells in each LSTM layers.
    conv_filters: Number of convolution filters.
    conv_kernel: Size of the convolution kernel.
    scalar_fc: Number of neurons in the fully connected layer processing the
      scalar input.
    scalar_name: Name of the scalar input.
    scalar_dim: Highest possible value for the scalar input. Used to convert to
      one-hot representation.
    random_z: If True, will provide an additional layer to process a randomly
      generated float input vector.
    xy_dim: If not None, will provide two additional layers to process 'x' and
      'y' inputs. The dimension provided is the maximum value of x and y, and
      is used to create one-hot representation.

  Returns:
    A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork
    for the critic.
  """
  preprocessing_layers = {
      'image':
          tf.keras.models.Sequential([
              cast_and_scale(),
              tf.keras.layers.Conv2D(conv_filters, conv_kernel),
              tf.keras.layers.ReLU(),
              tf.keras.layers.Flatten()
          ]),
  }
  if scalar_name in observation_spec:
    preprocessing_layers[scalar_name] = tf.keras.models.Sequential(
        [one_hot_layer(scalar_dim),
         tf.keras.layers.Dense(scalar_fc)])
  if 'position' in observation_spec:
    preprocessing_layers['position'] = tf.keras.models.Sequential(
        [cast_and_scale(), tf.keras.layers.Dense(scalar_fc)])

  if random_z:
    preprocessing_layers['random_z'] = tf.keras.models.Sequential(
        [tf.keras.layers.Lambda(lambda x: x)])  # Identity layer
  if xy_dim is not None:
    preprocessing_layers['x'] = tf.keras.models.Sequential(
        [one_hot_layer(xy_dim)])
    preprocessing_layers['y'] = tf.keras.models.Sequential(
        [one_hot_layer(xy_dim)])

  preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

  if use_rnns:
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        input_fc_layer_params=actor_fc_layers,
        output_fc_layer_params=None,
        lstm_size=lstm_size)
    value_net = value_rnn_network.ValueRnnNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        input_fc_layer_params=value_fc_layers,
        output_fc_layer_params=None)
  else:
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        fc_layer_params=actor_fc_layers,
        activation_fn=tf.keras.activations.tanh)
    value_net = value_network.ValueNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=preprocessing_combiner,
        fc_layer_params=value_fc_layers,
        activation_fn=tf.keras.activations.tanh)

  return actor_net, value_net
Ejemplo n.º 27
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=50,
        rb_checkpoint_interval=200,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=step_metrics)

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss = sess.run(train_op)
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
Ejemplo n.º 28
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        env_load_fn=suite_gym.load,
        random_seed=None,
        max_ep_steps=1000,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=5000000,
        collect_episodes_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=10000,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=10,
        num_random_episodes=1,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        rb_checkpoint_interval=20000,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=10,
        use_tf_functions=True,
        debug_summaries=False,
        eval_metrics_callback=None,
        random_metrics_callback=None,
        summarize_grads_and_vars=False):

    # Set up the directories to contain the log data and model saves
    # If data already exist in these folders, then we will try to load it later.
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    random_dir = os.path.join(root_dir, 'random')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    # Create writers for logging and specify the metrics to log for each
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    random_summary_writer = tf.compat.v2.summary.create_file_writer(
        random_dir, flush_millis=summaries_flush_secs * 1000)

    random_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Set up the agent and train, recoding data at each summary_internal number of steps
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        # Load the environments. Here, we used the same for evaluation and training.
        # However, they could be different.
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(env_name, max_episode_steps=max_ep_steps))
        # tf_env = tf_py_environment.TFPyEnvironment(
        #         parallel_py_environment.ParallelPyEnvironment(
        #                 [lambda: env_load_fn(env_name, max_episode_steps=max_ep_steps)] * num_parallel_environments))

        tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name, max_episode_steps=max_ep_steps))

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)

            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)

        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)

            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            kl_cutoff_factor=0.0,
            initial_adaptive_kl_beta=0.0,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()

        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)

        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        random_policy = random_tf_policy.RandomTFPolicy(
            eval_tf_env.time_step_spec(), eval_tf_env.action_spec())

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()

            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

                metric_utils.eager_compute(
                    random_metrics,
                    eval_tf_env,
                    random_policy,
                    num_episodes=num_random_episodes,
                    train_step=global_step,
                    summary_writer=random_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('Step: {:>6d}\tLoss: {:>+20.4f}'.format(
                    global_step_val, total_loss))

                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))

                logging.info('{:6.3f} steps/sec'.format(steps_per_sec))
                logging.info(
                    'collect_time = {:.3f}, train_time = {:.3f}'.format(
                        collect_time, train_time))

                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                saved_model_path = os.path.join(
                    saved_model_dir,
                    'policy_' + ('%d' % global_step_val).zfill(9))
                saved_model.save(saved_model_path)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            timed_at_step = global_step_val
            collect_time = 0
            train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )