Esempio n. 1
0
  def testBuilds(self):
    observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
                                                     1)
    observation = tensor_spec.sample_spec_nest(
        observation_spec, outer_dims=(1,))

    net = value_network.ValueNetwork(
        observation_spec, conv_layer_params=[(4, 2, 2)], fc_layer_params=(5,))

    value, _ = net(observation)
    self.evaluate(tf.compat.v1.global_variables_initializer())

    self.assertEqual([1], value.shape.as_list())

    self.assertEqual(6, len(net.variables))
    # Conv Net Kernel
    self.assertEqual((2, 2, 3, 4), net.variables[0].shape)
    # Conv Net bias
    self.assertEqual((4,), net.variables[1].shape)
    # Fc Kernel
    self.assertEqual((64, 5), net.variables[2].shape)
    # Fc Bias
    self.assertEqual((5,), net.variables[3].shape)
    # Value Shrink Kernel
    self.assertEqual((5, 1), net.variables[4].shape)
    # Value Shrink bias
    self.assertEqual((1,), net.variables[5].shape)
Esempio n. 2
0
    def __init__(
        self, batch_size,
        action_spec,
        time_step_spec,
        n_iterations,
        replay_buffer_max_length,
        learning_rate=1e-3,
        checkpoint_dir=None
    ):
        self.batch_size = batch_size
        self.time_step_spec = time_step_spec
        self.action_spec = action_spec
        observation_spec = self.time_step_spec.observation

        self.actor_net = HierachyActorNetwork(
            observation_spec,
            action_spec,
            n_iterations,
            n_options=4
        )
        value_net = value_network.ValueNetwork(
            observation_spec,
            fc_layer_params=(100,)
        )

        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.global_step = tf.compat.v1.train.get_or_create_global_step()

        self.agent = ppo_agent.PPOAgent(
            time_step_spec,
            self.action_spec,
            actor_net=self.actor_net,
            value_net=value_net,
            optimizer=optimizer,
            normalize_rewards=True,
            normalize_observations=False,
            train_step_counter=self.global_step
        )
        self.agent.initialize()
        self.agent.train = common.function(self.agent.train)

        self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=self.batch_size,
            max_length=replay_buffer_max_length
        )

        self.train_checkpointer = None
        if (checkpoint_dir):
            self.train_checkpointer = common.Checkpointer(
                ckpt_dir=checkpoint_dir,
                max_to_keep=1,
                agent=self.agent,
                policy=self.agent.policy,
                replay_buffer=self.replay_buffer,
                global_step=self.global_step
            )
            self.train_checkpointer.initialize_or_restore()

        self.policy_saver = policy_saver.PolicySaver(self.agent.policy)
Esempio n. 3
0
    def testUpdateAdaptiveKlBeta(self):
        if tf.executing_eagerly():
            self.skipTest('b/123777119')  # Secondary bug: ('b/123770194')
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self._time_step_spec.observation,
            self._action_spec,
            fc_layer_params=None)
        value_net = value_network.ValueNetwork(
            self._time_step_spec.observation, fc_layer_params=None)
        agent = ppo_agent.PPOAgent(
            self._time_step_spec,
            self._action_spec,
            tf.compat.v1.train.AdamOptimizer(),
            actor_net=actor_net,
            value_net=value_net,
            initial_adaptive_kl_beta=1.0,
            adaptive_kl_target=10.0,
            adaptive_kl_tolerance=0.5,
        )

        self.evaluate(tf.compat.v1.global_variables_initializer())

        # When KL is target kl, beta should not change.
        beta_0 = agent.update_adaptive_kl_beta(10.0)
        self.assertEqual(self.evaluate(beta_0), 1.0)

        # When KL is large, beta should increase.
        beta_1 = agent.update_adaptive_kl_beta(100.0)
        self.assertEqual(self.evaluate(beta_1), 1.5)

        # When KL is small, beta should decrease.
        beta_2 = agent.update_adaptive_kl_beta(1.0)
        self.assertEqual(self.evaluate(beta_2), 1.0)
Esempio n. 4
0
    def __init__(self):

        observation_tensor_spec = tf.TensorSpec(shape=[1], dtype=tf.float32)
        action_tensor_spec = tensor_spec.BoundedTensorSpec([1], tf.float32, -1,
                                                           1)

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(1, ),
            activation_fn=tf.nn.tanh,
            kernel_initializer=tf.keras.initializers.Orthogonal(seed=1),
            seed_stream_class=DeterministicSeedStream,
            seed=1)

        value_net = value_network.ValueNetwork(observation_tensor_spec,
                                               fc_layer_params=(1, ))

        super(PPOAgentActorDist, self).__init__(
            time_step_spec=ts.time_step_spec(observation_tensor_spec),
            action_spec=action_tensor_spec,
            actor_net=actor_net,
            value_net=value_net,
            # Ensures value_prediction, return and advantage are included as parts
            # of the training_data_spec.
            compute_value_and_advantage_in_train=True,
            update_normalizers_in_train=False,
            optimizer=tf.compat.v1.train.AdamOptimizer(),
        )
        # There is an artifical call on `_train` during the initialization which
        # ensures that the variables of the optimizer are initialized. This is
        # excluded from the call count.
        self.train_called_times = -1
        self.experiences = []
Esempio n. 5
0
    def GetAgent(self, env, params):
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            env.observation_spec(),
            env.action_spec(),
            fc_layer_params=tuple(self._ppo_params["ActorFcLayerParams", "",
                                                   [512, 256, 256]]))
        value_net = value_network.ValueNetwork(
            env.observation_spec(),
            fc_layer_params=tuple(self._ppo_params["CriticFcLayerParams", "",
                                                   [512, 256, 256]]))

        tf_agent = ppo_agent.PPOAgent(
            env.time_step_spec(),
            env.action_spec(),
            actor_net=actor_net,
            value_net=value_net,
            normalize_observations=self._ppo_params["NormalizeObservations",
                                                    "", False],
            normalize_rewards=self._ppo_params["NormalizeRewards", "", False],
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._ppo_params["LearningRate", "", 3e-4]),
            train_step_counter=self._ckpt.step,
            num_epochs=self._ppo_params["NumEpochs", "", 30],
            name=self._ppo_params["AgentName", "", "ppo_agent"],
            debug_summaries=self._ppo_params["DebugSummaries", "", False])
        tf_agent.initialize()
        return tf_agent
Esempio n. 6
0
    def test_tf_agents_on_policy_agent(self):
        learning_rate = 1e-3
        actor_fc_layers = (200, 100)
        value_fc_layers = (200, 100)
        env_name = "CartPole-v0"
        gym_env = gym.make(env_name)
        model_name = "ppo_tf_agent"
        train_env = environment_converter.gym_to_tf(gym_env)
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            train_env.observation_spec(),
            train_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )
        value_net = value_network.ValueNetwork(train_env.observation_spec(),
                                               fc_layer_params=value_fc_layers)

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)
        agent = ppo_agent.PPOAgent(
            train_env.time_step_spec(),
            train_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
        )
        agent.initialize()

        # Train
        train(agent, gym_env, 2000, 195, model_name, 200)
        trained_env = get_saved_environments()[0]
        trained_models = get_trained_model_names(trained_env)
        model_saved = model_name in trained_models
        shutil.rmtree(save_path)
        self.assertTrue(model_saved)
Esempio n. 7
0
  def testKlCutoffLoss(self, not_zero):
    kl_cutoff_coef = 30.0 * not_zero
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        self._time_step_spec.observation,
        self._action_spec,
        fc_layer_params=None)
    value_net = value_network.ValueNetwork(
        self._time_step_spec.observation, fc_layer_params=None)
    agent = ppo_agent.PPOAgent(
        self._time_step_spec,
        self._action_spec,
        tf.compat.v1.train.AdamOptimizer(),
        actor_net=actor_net,
        value_net=value_net,
        kl_cutoff_factor=5.0,
        adaptive_kl_target=0.1,
        kl_cutoff_coef=kl_cutoff_coef,
    )
    kl_divergence = tf.constant([[1.5, -0.5, 6.5, -1.5, -2.3]],
                                dtype=tf.float32)
    expected_kl_cutoff_loss = kl_cutoff_coef * (.24**2)  # (0.74 - 0.5) ^ 2

    loss = agent.kl_cutoff_loss(kl_divergence)
    self.evaluate(tf.compat.v1.initialize_all_variables())
    loss_ = self.evaluate(loss)
    self.assertAllClose([loss_], [expected_kl_cutoff_loss])
    def __init__(self, strategy=None):

        self._strategy = strategy

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tensor_spec.TensorSpec(shape=[], dtype=tf.float32),
            tensor_spec.BoundedTensorSpec([1], tf.float32, -1, 1),
            fc_layer_params=(1, ),
            activation_fn=tf.nn.tanh)
        value_net = value_network.ValueNetwork(tensor_spec.TensorSpec(
            shape=[], dtype=tf.float32),
                                               fc_layer_params=(1, ))

        super(FakePPOAgent, self).__init__(
            time_step_spec=ts.time_step_spec(
                tensor_spec.TensorSpec(shape=[], dtype=tf.float32)),
            action_spec=tensor_spec.BoundedTensorSpec([1], tf.float32, -1, 1),
            actor_net=actor_net,
            value_net=value_net,
            # Ensures value_prediction, return and normalized_advantage are included
            # as part of the training_data_spec.
            compute_value_and_advantage_in_train=False,
            update_normalizers_in_train=False,
        )
        self.train_called_times = tf.Variable(0, dtype=tf.int32)
        self.experiences = []
Esempio n. 9
0
  def testUpdateAdaptiveKlBeta(self):
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        self._time_step_spec.observation,
        self._action_spec,
        fc_layer_params=None)
    value_net = value_network.ValueNetwork(
        self._time_step_spec.observation, fc_layer_params=None)
    agent = ppo_agent.PPOAgent(
        self._time_step_spec,
        self._action_spec,
        tf.compat.v1.train.AdamOptimizer(),
        actor_net=actor_net,
        value_net=value_net,
        initial_adaptive_kl_beta=1.0,
        adaptive_kl_target=10.0,
        adaptive_kl_tolerance=0.5,
    )

    self.evaluate(tf.compat.v1.initialize_all_variables())

    # When KL is target kl, beta should not change.
    update_adaptive_kl_beta_fn = common.function(agent.update_adaptive_kl_beta)
    beta_0 = update_adaptive_kl_beta_fn([10.0])
    expected_beta_0 = 1.0
    self.assertEqual(expected_beta_0, self.evaluate(beta_0))

    # When KL is large, beta should increase.
    beta_1 = update_adaptive_kl_beta_fn([100.0])
    expected_beta_1 = 1.5
    self.assertEqual(expected_beta_1, self.evaluate(beta_1))

    # When KL is small, beta should decrease.
    beta_2 = update_adaptive_kl_beta_fn([1.0])
    expected_beta_2 = 1.0
    self.assertEqual(expected_beta_2, self.evaluate(beta_2))
Esempio n. 10
0
def get_actor_and_value_network(action_spec, observation_spec):
    preprocessing_layers = tfk.Sequential([
        tfk.layers.Lambda(lambda x: x - 0.5),  # Normalization
        tfk.layers.MaxPooling2D((5, 5), strides=(5, 5)),
        tfk.layers.Conv2D(256, (11, 3), (1, 1),
                          padding='valid',
                          activation='relu'),
        tfk.layers.Reshape((-1, 256)),
        tfk.layers.Conv1D(128, 1, activation='relu'),
        tfk.layers.Flatten()
    ])

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        preprocessing_layers=preprocessing_layers,
        fc_layer_params=(200, 100),
        activation_fn=tfk.activations.relu)
    value_net = value_network.ValueNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        fc_layer_params=(200, 100),
        activation_fn=tfk.activations.relu)

    return actor_net, value_net
Esempio n. 11
0
    def __init__(self):

        observation_tensor_spec = tf.TensorSpec(shape=[1], dtype=tf.float32)
        action_tensor_spec = tensor_spec.BoundedTensorSpec([2], tf.float32, -1,
                                                           1)

        actor_net = train_eval_lib.create_sequential_actor_net(
            fc_layer_units=(1, ), action_tensor_spec=action_tensor_spec)
        value_net = value_network.ValueNetwork(observation_tensor_spec,
                                               fc_layer_params=(1, ))

        super(FakePPOAgent, self).__init__(
            time_step_spec=ts.time_step_spec(observation_tensor_spec),
            action_spec=action_tensor_spec,
            actor_net=actor_net,
            value_net=value_net,
            # Ensures value_prediction, return and advantage are included as parts
            # of the training_data_spec.
            compute_value_and_advantage_in_train=False,
            update_normalizers_in_train=False,
        )
        # There is an artifical call on `_train` during the initialization which
        # ensures that the variables of the optimizer are initialized. This is
        # excluded from the call count.
        self.train_called_times = -1
        self.experiences = []
Esempio n. 12
0
    def test_same_policy_same_output(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'Skipping test: sequential networks not supported in TF1')
        observation_tensor_spec = tf.TensorSpec(shape=[1], dtype=tf.float32)
        action_tensor_spec = tensor_spec.BoundedTensorSpec((8, ), tf.float32,
                                                           -1, 1)

        value_net = value_network.ValueNetwork(observation_tensor_spec,
                                               fc_layer_params=(1, ))

        actor_net_lib = ppo_actor_network.PPOActorNetwork()
        actor_net_lib.seed_stream_class = DeterministicSeedStream
        actor_net_sequential = actor_net_lib.create_sequential_actor_net(
            fc_layer_units=(1, ),
            action_tensor_spec=action_tensor_spec,
            seed=1)
        actor_net_actor_dist = actor_distribution_network.ActorDistributionNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(1, ),
            activation_fn=tf.nn.tanh,
            kernel_initializer=tf.keras.initializers.Orthogonal(seed=1),
            seed_stream_class=DeterministicSeedStream,
            seed=1)

        tf.random.set_seed(111)
        seq_policy = ppo_policy.PPOPolicy(
            ts.time_step_spec(observation_tensor_spec),
            action_tensor_spec,
            actor_net_sequential,
            value_net,
            collect=True)
        tf.random.set_seed(111)
        actor_dist_policy = ppo_policy.PPOPolicy(
            ts.time_step_spec(observation_tensor_spec),
            action_tensor_spec,
            actor_net_actor_dist,
            value_net,
            collect=True)

        sample_timestep = ts.TimeStep(step_type=tf.constant([1, 1],
                                                            dtype=tf.int32),
                                      reward=tf.constant([1, 1],
                                                         dtype=tf.float32),
                                      discount=tf.constant([1, 1],
                                                           dtype=tf.float32),
                                      observation=tf.constant(
                                          [[1], [2]], dtype=tf.float32))
        seq_policy_step = seq_policy._distribution(sample_timestep,
                                                   policy_state=())
        act_dist_policy_step = actor_dist_policy._distribution(sample_timestep,
                                                               policy_state=())

        seq_scale = seq_policy_step.info['dist_params']['scale_diag']
        act_dist_scale = act_dist_policy_step.info['dist_params']['scale']
        self.assertAllEqual(seq_scale, act_dist_scale)
        self.assertAllEqual(seq_policy_step.info['dist_params']['loc'],
                            act_dist_policy_step.info['dist_params']['loc'])
Esempio n. 13
0
def get_networks(tf_env, actor_fc_layers, value_fc_layers):
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers)
    value_net = value_network.ValueNetwork(tf_env.observation_spec(),
                                           fc_layer_params=value_fc_layers)
    return actor_net, value_net
def load_ppo_agent(train_env,
                   actor_fc_layers,
                   value_fc_layers,
                   learning_rate,
                   num_epochs,
                   preprocessing_layers=None,
                   preprocessing_combiner=None):
    """
	Function which creates a tensorflow agent for a given environment with specified parameters, which uses the 
	proximal policy optimization (PPO) algorithm for training. 
	actor_fc_layers: tuple of integers, indicating the number of units in intermediate layers of the actor network. All layers are Keras Dense layers
	actor_fc_layers: same for value network
	preprocessing_layers: already-contructed layers of the preprocessing networks, which converts observations to tensors. Needed when the observation spec is either a list or dictionary
	preprocessing_combiner: combiner for the preprocessing networks, typically by concatenation. 
	learning_rate: learning rate, recommended value 0.001 or less
	num_epochs: number of training epochs which the agent executes per batch of collected episodes. 
	
	For more details on PPO, see the documentation of tf_agents: https://github.com/tensorflow/agents/tree/master/tf_agents
	or the paper: https://arxiv.org/abs/1707.06347
	"""

    optimizer = tf.compat.v1.train.AdamOptimizer(
        learning_rate=learning_rate
    )  #using Adam, a learning rule which uses only first-order gradients but incorporates momentum to become approximately second-order

    train_step_counter = tf.compat.v2.Variable(
        0)  #this creates a counter that starts at 0

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        train_env.observation_spec(),
        train_env.action_spec(),
        preprocessing_combiner=preprocessing_combiner,
        preprocessing_layers=preprocessing_layers,
        fc_layer_params=actor_fc_layers,
    )
    value_net = value_network.ValueNetwork(
        train_env.observation_spec(),
        preprocessing_combiner=preprocessing_combiner,
        preprocessing_layers=preprocessing_layers,
        fc_layer_params=value_fc_layers)

    tf_agent = ppo_agent.PPOAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        optimizer=optimizer,
        actor_net=actor_net,
        value_net=value_net,
        num_epochs=num_epochs,
        train_step_counter=train_step_counter,
        normalize_rewards=
        False,  #This is crucial to avoid the agent geting stuck
        normalize_observations=False,  #same
        discount_factor=1.0,
    )

    tf_agent.initialize(
    )  #This is necessary to create variables for the networks
    return tf_agent
Esempio n. 15
0
    def __init__(
        self,
        model: flexs.Model,
        rounds: int,
        sequences_batch_size: int,
        model_queries_per_batch: int,
        starting_sequence: str,
        alphabet: str,
        log_file: Optional[str] = None,
    ):
        """Create PPO explorer."""
        super().__init__(
            model,
            "PPO_Agent",
            rounds,
            sequences_batch_size,
            model_queries_per_batch,
            starting_sequence,
            log_file,
        )

        self.alphabet = alphabet

        # Initialize tf_environment
        env = PPOEnv(
            alphabet=self.alphabet,
            starting_seq=starting_sequence,
            model=self.model,
            max_num_steps=self.model_queries_per_batch,
        )
        self.tf_env = tf_py_environment.TFPyEnvironment(env)

        encoder_layer = tf.keras.layers.Lambda(lambda obs: obs["sequence"])
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self.tf_env.observation_spec(),
            self.tf_env.action_spec(),
            preprocessing_combiner=encoder_layer,
            fc_layer_params=[128],
        )
        value_net = value_network.ValueNetwork(
            self.tf_env.observation_spec(),
            preprocessing_combiner=encoder_layer,
            fc_layer_params=[128],
        )

        # Create the PPO agent
        self.agent = ppo_agent.PPOAgent(
            time_step_spec=self.tf_env.time_step_spec(),
            action_spec=self.tf_env.action_spec(),
            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=10,
            summarize_grads_and_vars=False,
        )
        self.agent.initialize()
Esempio n. 16
0
  def testHandlesExtraOuterDims(self):
    observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
                                                     1)
    observation = tensor_spec.sample_spec_nest(
        observation_spec, outer_dims=(3, 3, 2))

    net = value_network.ValueNetwork(
        observation_spec, conv_layer_params=[(4, 2, 2)], fc_layer_params=(5,))

    value, _ = net(observation)
    self.assertEqual([3, 3, 2], value.shape.as_list())
Esempio n. 17
0
    def testStatelessValueNetTrain(self, compute_value_and_advantage_in_train):
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            self._time_step_spec.observation,
            self._action_spec,
            input_fc_layer_params=None,
            output_fc_layer_params=None,
            lstm_size=(20, ))
        value_net = value_network.ValueNetwork(
            self._time_step_spec.observation, fc_layer_params=None)
        global_step = tf.compat.v1.train.get_or_create_global_step()
        agent = ppo_agent.PPOAgent(
            self._time_step_spec,
            self._action_spec,
            optimizer=tf.compat.v1.train.AdamOptimizer(),
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=1,
            train_step_counter=global_step,
            compute_value_and_advantage_in_train=
            compute_value_and_advantage_in_train)
        # Use a random env, policy, and replay buffer to collect training data.
        random_env = random_tf_environment.RandomTFEnvironment(
            self._time_step_spec, self._action_spec, batch_size=1)
        collection_policy = random_tf_policy.RandomTFPolicy(
            self._time_step_spec,
            self._action_spec,
            info_spec=agent.collect_policy.info_spec)
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collection_policy.trajectory_spec, batch_size=1, max_length=7)
        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            random_env,
            collection_policy,
            observers=[replay_buffer.add_batch],
            num_episodes=1)

        # In graph mode: finish building the graph so the optimizer
        # variables are created.
        if not tf.executing_eagerly():
            _, _ = agent.train(experience=replay_buffer.gather_all())

        # Initialize.
        self.evaluate(agent.initialize())
        self.evaluate(tf.compat.v1.global_variables_initializer())

        # Train one step.
        self.assertEqual(0, self.evaluate(global_step))
        self.evaluate(collect_driver.run())
        self.evaluate(agent.train(experience=replay_buffer.gather_all()))
        self.assertEqual(1, self.evaluate(global_step))
Esempio n. 18
0
 def testPolicy(self):
     value_net = value_network.ValueNetwork(
         self._time_step_spec.observation, fc_layer_params=None)
     agent = ppo_agent.PPOAgent(self._time_step_spec,
                                self._action_spec,
                                tf.train.AdamOptimizer(),
                                actor_net=DummyActorNet(self._action_spec),
                                value_net=value_net)
     observations = tf.constant([[1, 2]], dtype=tf.float32)
     time_steps = ts.restart(observations, batch_size=1)
     action_step = agent.policy().action(time_steps)
     actions = action_step.action
     self.assertEqual(actions.shape.as_list(), [1, 1])
     self.evaluate(tf.global_variables_initializer())
     _ = self.evaluate(actions)
Esempio n. 19
0
  def testKlPenaltyLoss(self):
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        self._time_step_spec.observation,
        self._action_spec,
        fc_layer_params=None)
    value_net = value_network.ValueNetwork(
        self._time_step_spec.observation, fc_layer_params=None)
    agent = ppo_agent.PPOAgent(
        self._time_step_spec,
        self._action_spec,
        tf.compat.v1.train.AdamOptimizer(),
        actor_net=actor_net,
        value_net=value_net,
        kl_cutoff_factor=5.0,
        adaptive_kl_target=0.1,
        kl_cutoff_coef=100,
    )

    agent.kl_cutoff_loss = mock.MagicMock(
        return_value=tf.constant(3.0, dtype=tf.float32))
    agent.adaptive_kl_loss = mock.MagicMock(
        return_value=tf.constant(4.0, dtype=tf.float32))

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    time_steps = ts.restart(observations, batch_size=2)
    action_distribution_parameters = {
        'loc': tf.constant([1.0, 1.0], dtype=tf.float32),
        'scale': tf.constant([1.0, 1.0], dtype=tf.float32),
    }
    current_policy_distribution, unused_network_state = DummyActorNet(
        self._obs_spec, self._action_spec)(time_steps.observation,
                                           time_steps.step_type, ())
    weights = tf.ones_like(time_steps.discount)

    expected_kl_penalty_loss = 7.0

    kl_penalty_loss = agent.kl_penalty_loss(time_steps,
                                            action_distribution_parameters,
                                            current_policy_distribution,
                                            weights)
    self.evaluate(tf.compat.v1.global_variables_initializer())
    kl_penalty_loss_ = self.evaluate(kl_penalty_loss)
    self.assertEqual(expected_kl_penalty_loss, kl_penalty_loss_)
Esempio n. 20
0
    def get_agent(self, env, params):
        """Returns a TensorFlow PPO-Agent
    
    Arguments:
        env {TFAPyEnvironment} -- Tensorflow-Agents PyEnvironment
        params {ParameterServer} -- ParameterServer from BARK
    
    Returns:
        agent -- tf-agent
    """

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            env.observation_spec(),
            env.action_spec(),
            fc_layer_params=tuple(
                self._params["ML"]["Agent"]["actor_fc_layer_params", "",
                                            [512, 256, 256]]))
        value_net = value_network.ValueNetwork(
            env.observation_spec(),
            fc_layer_params=tuple(
                self._params["ML"]["Agent"]["critic_fc_layer_params", "",
                                            [512, 256, 256]]))

        # agent
        tf_agent = ppo_agent.PPOAgent(
            env.time_step_spec(),
            env.action_spec(),
            actor_net=actor_net,
            value_net=value_net,
            normalize_observations=self._params["ML"]["Agent"][
                "normalize_observations", "", False],
            normalize_rewards=self._params["ML"]["Agent"]["normalize_rewards",
                                                          "", False],
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._params["ML"]["Agent"]["learning_rate", "",
                                                          3e-4]),
            train_step_counter=self._ckpt.step,
            num_epochs=self._params["ML"]["Agent"]["num_epochs", "", 30],
            name=self._params["ML"]["Agent"]["agent_name", "", "ppo_agent"],
            debug_summaries=self._params["ML"]["Agent"]["debug_summaries", "",
                                                        False])
        tf_agent.initialize()
        return tf_agent
Esempio n. 21
0
  def testHandlePreprocessingLayers(self):
    observation_spec = (tensor_spec.TensorSpec([1], tf.float32),
                        tensor_spec.TensorSpec([], tf.float32))
    observation = tensor_spec.sample_spec_nest(
        observation_spec, outer_dims=(3,))

    preprocessing_layers = (tf.keras.layers.Dense(4),
                            tf.keras.Sequential([
                                tf.keras.layers.Reshape((1,)),
                                tf.keras.layers.Dense(4)
                            ]))

    net = value_network.ValueNetwork(
        observation_spec,
        preprocessing_layers=preprocessing_layers,
        preprocessing_combiner=tf.keras.layers.Add())

    value, _ = net(observation)
    self.assertEqual([3], value.shape.as_list())
    self.assertGreater(len(net.trainable_variables), 4)
Esempio n. 22
0
    def testCopyUsesSameWrappedNetwork(self):
        # Create a wrapped network.
        wrapped_network = value_network.ValueNetwork(self._observation_spec,
                                                     fc_layer_params=(2, ))

        # Create and build a `splitter_network`.
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn,
            wrapped_network=wrapped_network,
            passthrough_mask=True,
            input_tensor_spec=self._observation_and_mask_spec)
        splitter_network.create_variables()

        # Crate a copy of the splitter network while redefining the wrapped network.
        copied_splitter_network = splitter_network.copy(
            wrapped_network=wrapped_network)

        # Check if the underlying wrapped network objects are different.
        self.assertIs(copied_splitter_network._wrapped_network,
                      splitter_network._wrapped_network)
Esempio n. 23
0
    def testCopyCreateNewInstanceOfNetworkIfNotRedefined(self):
        # Create a wrapped network.
        wrapped_network = value_network.ValueNetwork(self._observation_spec,
                                                     fc_layer_params=(2, ))

        # Create and build a `splitter_network`.
        splitter_network = mask_splitter_network.MaskSplitterNetwork(
            splitter_fn=self._splitter_fn,
            wrapped_network=wrapped_network,
            passthrough_mask=True,
            input_tensor_spec=self._observation_and_mask_spec)
        splitter_network.create_variables()

        # Copy and build the copied network.
        copied_splitter_network = splitter_network.copy()
        copied_splitter_network.create_variables()

        # Check if the underlying wrapped network objects are different.
        self.assertIsNot(copied_splitter_network._wrapped_network,
                         splitter_network._wrapped_network)
Esempio n. 24
0
    def testAdaptiveKlLoss(self):
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self._time_step_spec.observation,
            self._action_spec,
            fc_layer_params=None)
        value_net = value_network.ValueNetwork(
            self._time_step_spec.observation, fc_layer_params=None)
        agent = ppo_agent.PPOAgent(
            self._time_step_spec,
            self._action_spec,
            tf.compat.v1.train.AdamOptimizer(),
            actor_net=actor_net,
            value_net=value_net,
            initial_adaptive_kl_beta=1.0,
            adaptive_kl_target=10.0,
            adaptive_kl_tolerance=0.5,
        )

        # Force variable creation
        agent.policy.variables()
        self.evaluate(tf.compat.v1.initialize_all_variables())

        # Loss should not change if data kl is target kl.
        loss_1 = agent.adaptive_kl_loss([10.0])
        loss_2 = agent.adaptive_kl_loss([10.0])
        self.assertEqual(self.evaluate(loss_1), self.evaluate(loss_2))

        # If data kl is low, kl penalty should decrease between calls.
        loss_1 = self.evaluate(agent.adaptive_kl_loss([1.0]))
        adaptive_kl_beta_update_fn = common.function(
            agent.update_adaptive_kl_beta)
        self.evaluate(adaptive_kl_beta_update_fn([1.0]))
        loss_2 = self.evaluate(agent.adaptive_kl_loss([1.0]))
        self.assertGreater(loss_1, loss_2)

        # # # If data kl is low, kl penalty should increase between calls.
        loss_1 = self.evaluate(agent.adaptive_kl_loss([100.0]))
        self.evaluate(adaptive_kl_beta_update_fn([100.0]))
        loss_2 = self.evaluate(agent.adaptive_kl_loss([100.0]))
        self.assertLess(loss_1, loss_2)
Esempio n. 25
0
    def testAdaptiveKlLoss(self):
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self._time_step_spec.observation,
            self._action_spec,
            fc_layer_params=None)
        value_net = value_network.ValueNetwork(
            self._time_step_spec.observation, fc_layer_params=None)
        agent = ppo_agent.PPOAgent(
            self._time_step_spec,
            self._action_spec,
            tf.train.AdamOptimizer(),
            actor_net=actor_net,
            value_net=value_net,
            initial_adaptive_kl_beta=1.0,
            adaptive_kl_target=10.0,
            adaptive_kl_tolerance=0.5,
        )
        kl_divergence = tf.placeholder(shape=[1], dtype=tf.float32)
        loss = agent.adaptive_kl_loss(kl_divergence)
        update = agent.update_adaptive_kl_beta(kl_divergence)

        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())

            # Loss should not change if data kl is target kl.
            loss_1 = sess.run(loss, feed_dict={kl_divergence: [10.0]})
            loss_2 = sess.run(loss, feed_dict={kl_divergence: [10.0]})
            self.assertEqual(loss_1, loss_2)

            # If data kl is low, kl penalty should decrease between calls.
            loss_1 = sess.run(loss, feed_dict={kl_divergence: [1.0]})
            sess.run(update, feed_dict={kl_divergence: [1.0]})
            loss_2 = sess.run(loss, feed_dict={kl_divergence: [1.0]})
            self.assertGreater(loss_1, loss_2)

            # If data kl is low, kl penalty should increase between calls.
            loss_1 = sess.run(loss, feed_dict={kl_divergence: [100.0]})
            sess.run(update, feed_dict={kl_divergence: [100.0]})
            loss_2 = sess.run(loss, feed_dict={kl_divergence: [100.0]})
            self.assertLess(loss_1, loss_2)
Esempio n. 26
0
    def testUpdateAdaptiveKlBeta(self):
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self._time_step_spec.observation,
            self._action_spec,
            fc_layer_params=None)
        value_net = value_network.ValueNetwork(
            self._time_step_spec.observation, fc_layer_params=None)
        agent = ppo_agent.PPOAgent(
            self._time_step_spec,
            self._action_spec,
            tf.train.AdamOptimizer(),
            actor_net=actor_net,
            value_net=value_net,
            initial_adaptive_kl_beta=1.0,
            adaptive_kl_target=10.0,
            adaptive_kl_tolerance=0.5,
        )
        kl_divergence = tf.placeholder(shape=[1], dtype=tf.float32)
        updated_adaptive_kl_beta = agent.update_adaptive_kl_beta(kl_divergence)

        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())

            # When KL is target kl, beta should not change.
            beta_0 = sess.run(updated_adaptive_kl_beta,
                              feed_dict={kl_divergence: [10.0]})
            expected_beta_0 = 1.0
            self.assertEqual(expected_beta_0, beta_0)

            # When KL is large, beta should increase.
            beta_1 = sess.run(updated_adaptive_kl_beta,
                              feed_dict={kl_divergence: [100.0]})
            expected_beta_1 = 1.5
            self.assertEqual(expected_beta_1, beta_1)

            # When KL is small, beta should decrease.
            beta_2 = sess.run(updated_adaptive_kl_beta,
                              feed_dict={kl_divergence: [1.0]})
            expected_beta_2 = 1.0
            self.assertEqual(expected_beta_2, beta_2)
Esempio n. 27
0
    def testAdaptiveKlLoss(self):
        if tf.executing_eagerly():
            self.skipTest('b/123777119')  # Secondary bug: ('b/123770194')
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self._time_step_spec.observation,
            self._action_spec,
            fc_layer_params=None)
        value_net = value_network.ValueNetwork(
            self._time_step_spec.observation, fc_layer_params=None)
        agent = ppo_agent.PPOAgent(
            self._time_step_spec,
            self._action_spec,
            tf.compat.v1.train.AdamOptimizer(),
            actor_net=actor_net,
            value_net=value_net,
            initial_adaptive_kl_beta=1.0,
            adaptive_kl_target=10.0,
            adaptive_kl_tolerance=0.5,
        )

        self.evaluate(tf.compat.v1.global_variables_initializer())

        # Loss should not change if data kl is target kl.
        loss_1 = self.evaluate(agent.adaptive_kl_loss(10.0))
        loss_2 = self.evaluate(agent.adaptive_kl_loss(10.0))
        self.assertEqual(loss_1, loss_2)

        # If data kl is low, kl penalty should decrease between calls.
        loss_1 = self.evaluate(agent.adaptive_kl_loss(1.0))
        self.evaluate(agent.update_adaptive_kl_beta(1.0))
        loss_2 = self.evaluate(agent.adaptive_kl_loss(1.0))
        self.assertGreater(loss_1, loss_2)

        # If data kl is low, kl penalty should increase between calls.
        loss_1 = self.evaluate(agent.adaptive_kl_loss(100.0))
        self.evaluate(agent.update_adaptive_kl_beta(100.0))
        loss_2 = self.evaluate(agent.adaptive_kl_loss(100.0))
        self.assertLess(loss_1, loss_2)
Esempio n. 28
0
def make_networks(env, conv_params=[(16, 8, 4), (32, 3, 2)]):
    """Function for creating the neural networks for the PPO agent, namely the actor and value networks.

    Source for network params: https://www.arconsis.com/unternehmen/blog/reinforcement-learning-doom-with-tf-agents-and-ppo

    Arguments:
        1. env (tf env): A TensorFlow environment that the agent interacts with via the neural networks.
        2. conv_params (list): A list corresponding to convolutional layer parameters for each neural network.

    Returns:
        1. actor_net (ActorDistributionNetwork): A tf-agents Actor Distribution Network that is used for action selection
                                                 with the PPO agent.
        2. value_net (ValueNetork): A tf-agents Value Network that is used for value estimation with the PPO agent.
    """
    # Define actor network
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        env.observation_spec(),
        env.action_spec(),
        conv_layer_params=conv_params)
    # Define value network
    value_net = value_network.ValueNetwork(env.observation_spec(),
                                           conv_layer_params=conv_params)

    return actor_net, value_net
Esempio n. 29
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=50,
        rb_checkpoint_interval=200,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=step_metrics)

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss = sess.run(train_op)
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
Esempio n. 30
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        lstm_size=(20, ),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None,
                lstm_size=lstm_size)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )