Esempio n. 1
0
  def GetAgent(self, env, params):
    

    def init_gnn(name):
      """
      Returns a new `GNNWrapper`instance with the given `name`.
      We need this function to be able to prefix the variable
      names with the names of the parent actor or critic network,
      by passing in this function and initializing the instance in 
      the parent network.
      """
      return GNNWrapper(
        params=self._gnn_sac_params["GNN"], 
        graph_dims=self._observer.graph_dimensions,
        name=name)

    # actor network
    actor_net = GNNActorNetwork(
      input_tensor_spec=env.observation_spec(),
      output_tensor_spec=env.action_spec(),
      gnn=init_gnn,
      fc_layer_params=self._gnn_sac_params[
        "ActorFcLayerParams", "", [128, 64]]
    )

    # critic network
    critic_net = GNNCriticNetwork(
      (env.observation_spec(), env.action_spec()),
      gnn=init_gnn,
      observation_fc_layer_params=self._gnn_sac_params[
        "CriticObservationFcLayerParams", "", [128]],
      action_fc_layer_params=self._gnn_sac_params[
        "CriticActionFcLayerParams", "", None],
      joint_fc_layer_params=self._gnn_sac_params[
        "CriticJointFcLayerParams", "", [128, 128]]
    )
    
    # agent
    tf_agent = sac_agent.SacAgent(
      env.time_step_spec(),
      env.action_spec(),
      actor_network=actor_net,
      critic_network=critic_net,
      actor_optimizer=tf.compat.v1.train.AdamOptimizer(
          learning_rate=self._gnn_sac_params["ActorLearningRate", "", 3e-4]),
      critic_optimizer=tf.compat.v1.train.AdamOptimizer(
          learning_rate=self._gnn_sac_params["CriticLearningRate", "", 3e-4]),
      alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
          learning_rate=self._gnn_sac_params["AlphaLearningRate", "", 3e-4]),
      target_update_tau=self._gnn_sac_params["TargetUpdateTau", "", 0.05],
      target_update_period=self._gnn_sac_params["TargetUpdatePeriod", "", 3],
      td_errors_loss_fn=tf.math.squared_difference,
      gamma=self._gnn_sac_params["Gamma", "", 0.995],
      reward_scale_factor=self._gnn_sac_params["RewardScaleFactor", "", 1.],
      train_step_counter=self._ckpt.step,
      name=self._gnn_sac_params["AgentName", "", "gnn_sac_agent"],
      debug_summaries=self._gnn_sac_params["DebugSummaries", "", False])
    
    tf_agent.initialize()
    return tf_agent
Esempio n. 2
0
    def testAgentTransitionTrain(self):
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self._obs_spec,
            self._action_spec,
            fc_layer_params=(10, ),
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        agent = sac_agent.SacAgent(
            self._time_step_spec,
            self._action_spec,
            critic_network=DummyCriticNet(),
            actor_network=actor_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(0.001))

        time_step_spec = self._time_step_spec._replace(
            reward=tensor_spec.BoundedTensorSpec(
                [], tf.float32, minimum=0.0, maximum=1.0, name='reward'))

        transition_spec = trajectory.Transition(
            time_step=time_step_spec,
            action_step=policy_step.PolicyStep(action=self._action_spec,
                                               state=(),
                                               info=()),
            next_time_step=time_step_spec)

        sample_trajectory_experience = tensor_spec.sample_spec_nest(
            transition_spec, outer_dims=(3, ))
        agent.train(sample_trajectory_experience)
Esempio n. 3
0
    def testAgentTrajectoryTrain(self):
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            self._obs_spec,
            self._action_spec,
            fc_layer_params=(10, ),
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        agent = sac_agent.SacAgent(
            self._time_step_spec,
            self._action_spec,
            critic_network=DummyCriticNet(),
            actor_network=actor_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(0.001))

        trajectory_spec = trajectory.Trajectory(
            step_type=self._time_step_spec.step_type,
            observation=self._time_step_spec.observation,
            action=self._action_spec,
            policy_info=(),
            next_step_type=self._time_step_spec.step_type,
            reward=tensor_spec.BoundedTensorSpec([],
                                                 tf.float32,
                                                 minimum=0.0,
                                                 maximum=1.0,
                                                 name='reward'),
            discount=self._time_step_spec.discount)

        sample_trajectory_experience = tensor_spec.sample_spec_nest(
            trajectory_spec, outer_dims=(3, 2))
        agent.train(sample_trajectory_experience)
Esempio n. 4
0
    def testCriticRegLoss(self):
        agent = sac_agent.SacAgent(self._time_step_spec,
                                   self._action_spec,
                                   critic_network=DummyCriticNet(0.5),
                                   actor_network=None,
                                   actor_optimizer=None,
                                   critic_optimizer=None,
                                   alpha_optimizer=None,
                                   actor_policy_ctor=DummyActorPolicy)

        observations = tf.zeros((2, 2), dtype=tf.float32)
        time_steps = ts.restart(observations, batch_size=2)
        actions = tf.zeros((2, 1), dtype=tf.float32)

        rewards = tf.zeros((2, ), dtype=tf.float32)
        discounts = tf.zeros((2, ), dtype=tf.float32)
        next_observations = tf.zeros((2, 2), dtype=tf.float32)
        next_time_steps = ts.transition(next_observations, rewards, discounts)

        # Expected loss only regularization loss.
        expected_loss = 2.0

        loss = agent.critic_loss(time_steps,
                                 actions,
                                 next_time_steps,
                                 td_errors_loss_fn=tf.math.squared_difference)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        loss_ = self.evaluate(loss)
        self.assertAllClose(loss_, expected_loss)
Esempio n. 5
0
  def testTrainWithRnn(self):
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        self._obs_spec,
        self._action_spec,
        input_fc_layer_params=None,
        output_fc_layer_params=None,
        conv_layer_params=None,
        lstm_size=(40,),
    )

    critic_net = critic_rnn_network.CriticRnnNetwork(
        (self._obs_spec, self._action_spec),
        observation_fc_layer_params=(16,),
        action_fc_layer_params=(16,),
        joint_fc_layer_params=(16,),
        lstm_size=(16,),
        output_fc_layer_params=None,
    )

    counter = common.create_variable('test_train_counter')

    optimizer_fn = tf.compat.v1.train.AdamOptimizer

    agent = sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=critic_net,
        actor_network=actor_net,
        actor_optimizer=optimizer_fn(1e-3),
        critic_optimizer=optimizer_fn(1e-3),
        alpha_optimizer=optimizer_fn(1e-3),
        train_step_counter=counter,
    )

    batch_size = 5
    observations = tf.constant(
        [[[1, 2], [3, 4], [5, 6]]] * batch_size, dtype=tf.float32)
    actions = tf.constant([[[0], [1], [1]]] * batch_size, dtype=tf.float32)
    time_steps = ts.TimeStep(
        step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32),
        reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
        observation=[observations])

    experience = trajectory.Trajectory(
        time_steps.step_type, [observations], actions, (),
        time_steps.step_type, time_steps.reward, time_steps.discount)

    # Force variable creation.
    agent.policy.variables()
    if tf.executing_eagerly():
      loss = lambda: agent.train(experience)
    else:
      loss = agent.train(experience)

    self.evaluate(tf.compat.v1.initialize_all_variables())
    self.assertEqual(self.evaluate(counter), 0)
    self.evaluate(loss)
    self.assertEqual(self.evaluate(counter), 1)
Esempio n. 6
0
    def GetAgent(self, env, params):
        def _normal_projection_net(action_spec, init_means_output_factor=0.1):
            return normal_projection_network.NormalProjectionNetwork(
                action_spec,
                mean_transform=None,
                state_dependent_std=True,
                init_means_output_factor=init_means_output_factor,
                std_transform=sac_agent.std_clip_transform,
                scale_distribution=True)

        # actor network
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            env.observation_spec(),
            env.action_spec(),
            fc_layer_params=tuple(
                self._params["ML"]["BehaviorSACAgent"]["ActorFcLayerParams",
                                                       "", [512, 256, 256]]),
            continuous_projection_net=_normal_projection_net)

        # critic network
        critic_net = critic_network.CriticNetwork(
            (env.observation_spec(), env.action_spec()),
            observation_fc_layer_params=None,
            action_fc_layer_params=None,
            joint_fc_layer_params=tuple(self._params["ML"]["BehaviorSACAgent"][
                "CriticJointFcLayerParams", "", [512, 256, 256]]))

        # agent
        tf_agent = sac_agent.SacAgent(
            env.time_step_spec(),
            env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._params["ML"]["BehaviorSACAgent"][
                    "ActorLearningRate", "", 3e-4]),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._params["ML"]["BehaviorSACAgent"][
                    "CriticLearningRate", "", 3e-4]),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._params["ML"]["BehaviorSACAgent"][
                    "AlphaLearningRate", "", 3e-4]),
            target_update_tau=self._params["ML"]["BehaviorSACAgent"][
                "TargetUpdateTau", "", 0.05],
            target_update_period=self._params["ML"]["BehaviorSACAgent"][
                "TargetUpdatePeriod", "", 3],
            td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
            gamma=self._params["ML"]["BehaviorSACAgent"]["Gamma", "", 0.995],
            reward_scale_factor=self._params["ML"]["BehaviorSACAgent"][
                "RewardScaleFactor", "", 1.],
            train_step_counter=self._ckpt.step,
            name=self._params["ML"]["BehaviorSACAgent"]["AgentName", "",
                                                        "sac_agent"],
            debug_summaries=self._params["ML"]["BehaviorSACAgent"][
                "DebugSummaries", "", False])

        tf_agent.initialize()
        return tf_agent
Esempio n. 7
0
 def testCreateAgent(self):
     sac_agent.SacAgent(self._time_step_spec,
                        self._action_spec,
                        critic_network=DummyCriticNet(),
                        actor_network=None,
                        actor_optimizer=None,
                        critic_optimizer=None,
                        alpha_optimizer=None,
                        actor_policy_ctor=DummyActorPolicy)
Esempio n. 8
0
    def GetAgent(self, env, params):
        self._params["ML"]["GraphDims"] = self._observer.graph_dimensions

        # actor network
        actor_net = GNNActorNetwork(
            input_tensor_spec=env.observation_spec(),
            output_tensor_spec=env.action_spec(),
            gnn=self._init_gnn,
            fc_layer_params=self._gnn_sac_params["ActorFcLayerParams", "",
                                                 [256, 256]],
            params=params)

        # critic network
        critic_net = GNNCriticNetwork(
            (env.observation_spec(), env.action_spec()),
            gnn=self._init_gnn,
            observation_fc_layer_params=self._gnn_sac_params[
                "CriticObservationFcLayerParams", "", [256]],
            action_fc_layer_params=self._gnn_sac_params[
                "CriticActionFcLayerParams", "", None],
            joint_fc_layer_params=self._gnn_sac_params[
                "CriticJointFcLayerParams", "", [256, 256]],
            params=params)

        # agent
        tf_agent = sac_agent.SacAgent(
            env.time_step_spec(),
            env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._gnn_sac_params["ActorLearningRate", "",
                                                   3e-4]),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._gnn_sac_params["CriticLearningRate", "",
                                                   3e-4]),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._gnn_sac_params["AlphaLearningRate", "",
                                                   0.]),
            target_update_tau=self._gnn_sac_params["TargetUpdateTau", "", 1.],
            target_update_period=self._gnn_sac_params["TargetUpdatePeriod", "",
                                                      1],
            td_errors_loss_fn=tf.math.squared_difference,
            gamma=self._gnn_sac_params["Gamma", "", 0.995],
            reward_scale_factor=self._gnn_sac_params["RewardScaleFactor", "",
                                                     1.],
            train_step_counter=self._ckpt.step,
            name=self._gnn_sac_params["AgentName", "", "gnn_sac_agent"],
            debug_summaries=self._gnn_sac_params["DebugSummaries", "", True])
        tf_agent.initialize()
        return tf_agent
Esempio n. 9
0
    def testSharedLayer(self):
        shared_layer = tf.keras.layers.Dense(
            1,
            kernel_initializer=tf.compat.v1.initializers.constant([0]),
            bias_initializer=tf.compat.v1.initializers.constant([0]),
            name='shared')

        critic_net_1 = DummyCriticNet(shared_layer=shared_layer)
        critic_net_2 = DummyCriticNet(shared_layer=shared_layer)

        target_shared_layer = tf.keras.layers.Dense(
            1,
            kernel_initializer=tf.compat.v1.initializers.constant([0]),
            bias_initializer=tf.compat.v1.initializers.constant([0]),
            name='shared_target')

        target_critic_net_1 = DummyCriticNet(shared_layer=target_shared_layer)
        target_critic_net_2 = DummyCriticNet(shared_layer=target_shared_layer)

        agent = sac_agent.SacAgent(self._time_step_spec,
                                   self._action_spec,
                                   critic_network=critic_net_1,
                                   critic_network_2=critic_net_2,
                                   target_critic_network=target_critic_net_1,
                                   target_critic_network_2=target_critic_net_2,
                                   actor_network=None,
                                   actor_optimizer=None,
                                   critic_optimizer=None,
                                   alpha_optimizer=None,
                                   target_entropy=3.0,
                                   initial_log_alpha=4.0,
                                   target_update_tau=0.5,
                                   actor_policy_ctor=DummyActorPolicy)

        self.evaluate([
            tf.compat.v1.global_variables_initializer(),
            tf.compat.v1.local_variables_initializer()
        ])

        self.evaluate(agent.initialize())

        for v in shared_layer.variables:
            self.evaluate(v.assign(v * 0 + 1))

        self.evaluate(agent._update_target())

        self.assertEqual(1.0, self.evaluate(shared_layer.variables[0][0][0]))
        self.assertEqual(0.5,
                         self.evaluate(target_shared_layer.variables[0][0][0]))
Esempio n. 10
0
  def testCreateAgent(self, create_critic_net_fn, skip_in_tf1):
    if skip_in_tf1 and not common.has_eager_been_enabled():
      self.skipTest('Skipping test: sequential networks not supported in TF1')

    critic_network = create_critic_net_fn()

    sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=critic_network,
        actor_network=None,
        actor_optimizer=None,
        critic_optimizer=None,
        alpha_optimizer=None,
        actor_policy_ctor=DummyActorPolicy)
Esempio n. 11
0
def load_policy(agent_class, tf_env):
    load_dir = FLAGS.load_dir
    assert load_dir and osp.exists(
        load_dir
    ), 'need to provide valid load_dir to load policy, got: {}'.format(
        load_dir)
    global_step = tf.compat.v1.train.get_or_create_global_step()
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=(256, 256),
        continuous_projection_net=normal_projection_net)

    critic_net = critic_network.CriticNetwork((observation_spec, action_spec),
                                              joint_fc_layer_params=(256, 256))

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4),
        target_update_tau=0.005,
        target_update_period=1,
        td_errors_loss_fn=tf.keras.losses.mse,
        gamma=0,
        reward_scale_factor=1.,
        gradient_clipping=1.,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        train_step_counter=global_step)

    train_checkpointer = common.Checkpointer(ckpt_dir=load_dir,
                                             agent=tf_agent,
                                             global_step=global_step)
    status = train_checkpointer.initialize_or_restore()
    status.expect_partial()
    logging.info('Loaded from checkpoint: %s, trained %s steps',
                 train_checkpointer._manager.latest_checkpoint,
                 global_step.numpy())
    return tf_agent.policy
Esempio n. 12
0
    def GetAgent(self, env, params):
        gnn_sac_params = self._params["ML"]["BehaviorGraphSACAgent"]

        # actor network
        actor_net = GNNActorNetwork(
            input_tensor_spec=env.observation_spec(),
            output_tensor_spec=env.action_spec(),
            gnn=GNNWrapper(params=gnn_sac_params["GNN"],
                           graph_dims=self._observer.graph_dimensions),
            fc_layer_params=gnn_sac_params["ActorFcLayerParams", "",
                                           [128, 64]])

        # critic network
        critic_net = GNNCriticNetwork(
            (env.observation_spec(), env.action_spec()),
            gnn=GNNWrapper(params=gnn_sac_params["GNN"],
                           graph_dims=self._observer.graph_dimensions),
            observation_fc_layer_params=gnn_sac_params[
                "CriticObservationFcLayerParams", "", [128]],
            action_fc_layer_params=gnn_sac_params["CriticActionFcLayerParams",
                                                  "", None],
            joint_fc_layer_params=gnn_sac_params["CriticJointFcLayerParams",
                                                 "", [128, 128]])

        # agent
        tf_agent = sac_agent.SacAgent(
            env.time_step_spec(),
            env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=gnn_sac_params["ActorLearningRate", "", 3e-4]),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=gnn_sac_params["CriticLearningRate", "", 3e-4]),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=gnn_sac_params["AlphaLearningRate", "", 3e-4]),
            target_update_tau=gnn_sac_params["TargetUpdateTau", "", 0.05],
            target_update_period=gnn_sac_params["TargetUpdatePeriod", "", 3],
            td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
            gamma=gnn_sac_params["Gamma", "", 0.995],
            reward_scale_factor=gnn_sac_params["RewardScaleFactor", "", 1.],
            train_step_counter=self._ckpt.step,
            name=gnn_sac_params["AgentName", "", "gnn_sac_agent"],
            debug_summaries=gnn_sac_params["DebugSummaries", "", False])

        tf_agent.initialize()
        return tf_agent
Esempio n. 13
0
  def testLoss(self, mock_actions_and_log_probs, mock_apply_gradients):
    # Mock _actions_and_log_probs so that _train() and _loss() run on the same
    # sampled values.
    actions = tf.constant([[0.2], [0.5], [-0.8]])
    log_pi = tf.constant([-1.1, -0.8, -0.5])
    mock_actions_and_log_probs.return_value = (actions, log_pi)

    # Skip applying gradients since mocking _actions_and_log_probs.
    del mock_apply_gradients

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        self._obs_spec,
        self._action_spec,
        fc_layer_params=(10,),
        continuous_projection_net=tanh_normal_projection_network
        .TanhNormalProjectionNetwork)

    agent = sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=DummyCriticNet(),
        actor_network=actor_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(0.001),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(0.001))

    observations = tf.constant(
        [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
        dtype=tf.float32)
    actions = tf.constant([[[0], [1]], [[2], [3]], [[4], [5]]],
                          dtype=tf.float32)
    time_steps = ts.TimeStep(
        step_type=tf.constant([[1, 1]] * 3, dtype=tf.int32),
        reward=tf.constant([[1, 1]] * 3, dtype=tf.float32),
        discount=tf.constant([[1, 1]] * 3, dtype=tf.float32),
        observation=observations)

    experience = trajectory.Trajectory(
        time_steps.step_type, observations, actions, (),
        time_steps.step_type, time_steps.reward, time_steps.discount)

    test_util.test_loss_and_train_output(
        test=self,
        expect_equal_loss_values=True,
        agent=agent,
        experience=experience)
Esempio n. 14
0
def create_sac_agent(train_env, reward_scale_factor):
    return sac_agent.SacAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        actor_network=create_actor_network(train_env),
        critic_network=create_critic_network(train_env),
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=3e-4),
        target_update_tau=0.005,
        target_update_period=1,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        train_step_counter=tf.compat.v1.train.get_or_create_global_step(),
    )
Esempio n. 15
0
    def testPolicy(self):
        agent = sac_agent.SacAgent(self._time_step_spec,
                                   self._action_spec,
                                   critic_network=DummyCriticNet(),
                                   actor_network=None,
                                   actor_optimizer=None,
                                   critic_optimizer=None,
                                   alpha_optimizer=None,
                                   actor_policy_ctor=DummyActorPolicy)

        observations = tf.constant([[1, 2]], dtype=tf.float32)
        time_steps = ts.restart(observations)
        action_step = agent.policy.action(time_steps)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        action_ = self.evaluate(action_step.action)
        self.assertLessEqual(action_, self._action_spec.maximum)
        self.assertGreaterEqual(action_, self._action_spec.minimum)
Esempio n. 16
0
    def testActorLoss(self):
        agent = sac_agent.SacAgent(self._time_step_spec,
                                   self._action_spec,
                                   critic_network=DummyCriticNet(),
                                   actor_network=None,
                                   actor_optimizer=None,
                                   critic_optimizer=None,
                                   alpha_optimizer=None,
                                   actor_policy_ctor=DummyActorPolicy)

        observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
        time_steps = ts.restart(observations, batch_size=2)

        expected_loss = (2 * 10 - (2 + 1) - (4 + 1)) / 2
        loss = agent.actor_loss(time_steps)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        loss_ = self.evaluate(loss)
        self.assertAllClose(loss_, expected_loss)
Esempio n. 17
0
 def create_sac_agent(self, actor, critic, actor_alpha, critic_alpha,
                      alpha_alpha, gamma):
     train_step_counter = tf.Variable(0)
     return sac_agent.SacAgent(
         spec.get_time_step_spec(),
         spec.get_action_spec(),
         actor_network=actor,
         critic_network=critic,
         actor_optimizer=tf.compat.v1.train.AdamOptimizer(
             learning_rate=actor_alpha),
         critic_optimizer=tf.compat.v1.train.AdamOptimizer(
             learning_rate=critic_alpha),
         alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
             learning_rate=alpha_alpha),
         target_update_tau=0.05,
         target_update_period=5,
         td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
         gamma=gamma,
         train_step_counter=train_step_counter)
Esempio n. 18
0
 def sac_agent(self):
     return sac_agent.SacAgent(
         self.train_env.time_step_spec(),
         self.action_spec,
         actor_network=self.actor_net,
         critic_network=self.critic_net,
         actor_optimizer=tf.compat.v1.train.AdamOptimizer(
             learning_rate=self.actor_lr),
         critic_optimizer=tf.compat.v1.train.AdamOptimizer(
             learning_rate=self.critic_lr),
         alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
             learning_rate=self.alpha_lr),
         target_update_tau=self.target_update_tau,
         target_update_period=self.target_update_period,
         td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
         gamma=self.gamma,
         reward_scale_factor=self.reward_scale,
         gradient_clipping=self.gradient_clipping,
         train_step_counter=self.global_step)
Esempio n. 19
0
  def testCriticLoss(self, create_critic_net_fn, skip_in_tf1):
    if skip_in_tf1 and not common.has_eager_been_enabled():
      self.skipTest('Skipping test: sequential networks not supported in TF1')

    critic_network = create_critic_net_fn()
    agent = sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=critic_network,
        actor_network=None,
        actor_optimizer=None,
        critic_optimizer=None,
        alpha_optimizer=None,
        actor_policy_ctor=DummyActorPolicy)

    observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
    time_steps = ts.restart(observations, batch_size=2)
    actions = tf.constant([[5], [6]], dtype=tf.float32)

    rewards = tf.constant([10, 20], dtype=tf.float32)
    discounts = tf.constant([0.9, 0.9], dtype=tf.float32)
    next_observations = tf.constant([[5, 6], [7, 8]], dtype=tf.float32)
    next_time_steps = ts.transition(next_observations, rewards, discounts)

    td_targets = [7.3, 19.1]
    pred_td_targets = [7., 10.]

    self.evaluate(tf.compat.v1.global_variables_initializer())

    # Expected critic loss has factor of 2, for the two TD3 critics.
    expected_loss = self.evaluate(2 * tf.compat.v1.losses.mean_squared_error(
        tf.constant(td_targets), tf.constant(pred_td_targets)))

    loss = agent.critic_loss(
        time_steps,
        actions,
        next_time_steps,
        td_errors_loss_fn=tf.math.squared_difference)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    loss_ = self.evaluate(loss)
    self.assertAllClose(loss_, expected_loss)
Esempio n. 20
0
    def testAlphaLoss(self):
        agent = sac_agent.SacAgent(self._time_step_spec,
                                   self._action_spec,
                                   critic_network=DummyCriticNet(),
                                   actor_network=None,
                                   actor_optimizer=None,
                                   critic_optimizer=None,
                                   alpha_optimizer=None,
                                   squash_actions=False,
                                   target_entropy=3.0,
                                   initial_log_alpha=4.0,
                                   actor_policy_ctor=DummyActorPolicy)
        observations = [tf.constant([[1, 2], [3, 4]], dtype=tf.float32)]
        time_steps = ts.restart(observations, batch_size=2)

        expected_loss = 4.0 * (-10 - 3)
        loss = agent.alpha_loss(time_steps)

        self.evaluate(tf.global_variables_initializer())
        loss_ = self.evaluate(loss)
        self.assertAllClose(loss_, expected_loss)
Esempio n. 21
0
def _create_agent(train_step: tf.Variable,
                  observation_tensor_spec: types.NestedTensorSpec,
                  action_tensor_spec: types.NestedTensorSpec,
                  time_step_tensor_spec: ts.TimeStep,
                  learning_rate: float) -> tf_agent.TFAgent:
  """Creates an agent."""
  critic_net = critic_network.CriticNetwork(
      (observation_tensor_spec, action_tensor_spec),
      observation_fc_layer_params=None,
      action_fc_layer_params=None,
      joint_fc_layer_params=(256, 256),
      kernel_initializer='glorot_uniform',
      last_kernel_initializer='glorot_uniform')

  actor_net = actor_distribution_network.ActorDistributionNetwork(
      observation_tensor_spec,
      action_tensor_spec,
      fc_layer_params=(256, 256),
      continuous_projection_net=tanh_normal_projection_network
      .TanhNormalProjectionNetwork)

  return sac_agent.SacAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      actor_network=actor_net,
      critic_network=critic_net,
      actor_optimizer=tf.compat.v1.train.AdamOptimizer(
          learning_rate=learning_rate),
      critic_optimizer=tf.compat.v1.train.AdamOptimizer(
          learning_rate=learning_rate),
      alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
          learning_rate=learning_rate),
      target_update_tau=0.005,
      target_update_period=1,
      td_errors_loss_fn=tf.math.squared_difference,
      gamma=0.99,
      reward_scale_factor=0.1,
      gradient_clipping=None,
      train_step_counter=train_step)
Esempio n. 22
0
  def testCriticLoss(self):
    agent = sac_agent.SacAgent(
        self._time_step_spec,
        self._action_spec,
        critic_network=DummyCriticNet(),
        actor_network=None,
        actor_optimizer=None,
        critic_optimizer=None,
        alpha_optimizer=None,
        squash_actions=False,
        actor_policy_ctor=DummyActorPolicy)

    observations = [tf.constant([[1, 2], [3, 4]], dtype=tf.float32)]
    time_steps = ts.restart(observations)
    actions = tf.constant([[5], [6]], dtype=tf.float32)

    rewards = tf.constant([10, 20], dtype=tf.float32)
    discounts = tf.constant([0.9, 0.9], dtype=tf.float32)
    next_observations = [tf.constant([[5, 6], [7, 8]], dtype=tf.float32)]
    next_time_steps = ts.transition(next_observations, rewards, discounts)

    td_targets = [7.3, 19.1]
    pred_td_targets = [7., 10.]

    self.evaluate(tf.compat.v1.global_variables_initializer())

    # Expected critic loss has factor of 2, for the two TD3 critics.
    expected_loss = self.evaluate(2 * tf.compat.v1.losses.mean_squared_error(
        tf.constant(td_targets), tf.constant(pred_td_targets)))

    loss = agent.critic_loss(
        time_steps,
        actions,
        next_time_steps,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    loss_ = self.evaluate(loss)
    self.assertAllClose(loss_, expected_loss)
Esempio n. 23
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
        tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.', initial_collect_steps)
    initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience)
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(),
                     train_loss.loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss
Esempio n. 24
0
        fc_layer_params=HyperParms.actor_fc_layer_params,
        continuous_projection_net=(
            tanh_normal_projection_network.TanhNormalProjectionNetwork))

with objStrategy.scope():
    train_step = train_utils.create_train_step()

    tf_agent = sac_agent.SacAgent(
        specTimeStep,
        specAction,
        actor_network=nnActor,
        critic_network=nnCritic,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=HyperParms.actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=HyperParms.critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=HyperParms.alpha_learning_rate),
        target_update_tau=HyperParms.target_update_tau,
        target_update_period=HyperParms.target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=HyperParms.gamma,
        reward_scale_factor=HyperParms.reward_scale_factor,
        train_step_counter=train_step)

    tf_agent.initialize()

print(f" --  REPLAY BUFFER  ({now()})  -- ")
rate_limiter = reverb.rate_limiters.SampleToInsertRatio(samples_per_insert=3.0,
                                                        min_size_to_sample=3,
                                                        error_buffer=3.0)
Esempio n. 25
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_py_env = suite_mujoco.load(env_name)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(
                    batch_size * 5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    callback=eval_metrics_callback,
                    log=True,
                )
                sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    total_loss, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_flush_op)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(
    ),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40, ),  # lstm size for actor
    critic_lstm_size=(40, ),  # lstm size for critic
    actor_output_fc_layers=(100, ),  # lstm output
    critic_output_fc_layers=(100, ),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1
):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
    # Setup GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    # Get train and eval directories
    root_dir = os.path.expanduser(root_dir)
    root_dir = os.path.join(root_dir, env_name, experiment_name)

    # Get summary writers
    summary_writer = tf.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    # Eval metrics
    eval_metrics = [
        tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Whether to record for summary
    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create Carla environment
        if agent_name == 'latent_sac':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names +
                                                 mask_names,
                                                 action_repeat=action_repeat)
        elif agent_name == 'dqn':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 discrete=True,
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)
        else:
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)

        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
        fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

        # Specs
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        ## Make tf agent
        if agent_name == 'latent_sac':
            # Get model network for latent sac
            if model_network_ctor_type == 'hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
            elif model_network_ctor_type == 'non-hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
            else:
                raise NotImplementedError
            model_net = model_network_ctor(input_names,
                                           input_names + mask_names)

            # Get the latent spec
            latent_size = model_net.latent_size
            latent_observation_spec = tensor_spec.TensorSpec((latent_size, ),
                                                             dtype=tf.float32)
            latent_time_step_spec = ts.time_step_spec(
                observation_spec=latent_observation_spec)

            # Get actor and critic net
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                latent_observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=normal_projection_net)
            critic_net = critic_network.CriticNetwork(
                (latent_observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers)

            # Build the inner SAC agent based on latent space
            inner_agent = sac_agent.SacAgent(
                latent_time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
            inner_agent.initialize()

            # Build the latent sac agent
            tf_agent = latent_sac_agent.LatentSACAgent(
                time_step_spec,
                action_spec,
                inner_agent=inner_agent,
                model_network=model_net,
                model_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=model_learning_rate),
                model_batch_size=model_batch_size,
                num_images_per_summary=num_images_per_summary,
                sequence_length=sequence_length,
                gradient_clipping=gradient_clipping,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                fps=fps)

        else:
            # Set up preprosessing layers for dictionary observation inputs
            preprocessing_layers = collections.OrderedDict()
            for name in input_names:
                preprocessing_layers[name] = Preprocessing_Layer(32, 256)
            if len(input_names) < 2:
                preprocessing_combiner = None

            if agent_name == 'dqn':
                q_rnn_net = q_rnn_network.QRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_rnn_net,
                    epsilon_greedy=epsilon_greedy,
                    n_step_update=1,
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=q_learning_rate),
                    td_errors_loss_fn=common.element_wise_squared_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            elif agent_name == 'ddpg' or agent_name == 'td3':
                actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                if agent_name == 'ddpg':
                    tf_agent = ddpg_agent.DdpgAgent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        ou_stddev=ou_stddev,
                        ou_damping=ou_damping,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars)
                elif agent_name == 'td3':
                    tf_agent = td3_agent.Td3Agent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        exploration_noise_std=exploration_noise_std,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        actor_update_period=actor_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars,
                        train_step_counter=global_step)

            elif agent_name == 'sac':
                actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers,
                    continuous_projection_net=normal_projection_net)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = sac_agent.SacAgent(
                    time_step_spec,
                    action_spec,
                    actor_network=actor_distribution_rnn_net,
                    critic_network=critic_rnn_net,
                    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=actor_learning_rate),
                    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=critic_learning_rate),
                    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=alpha_learning_rate),
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    td_errors_loss_fn=tf.math.
                    squared_difference,  # make critic loss dimension compatible
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            else:
                raise NotImplementedError

        tf_agent.initialize()

        # Get replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,  # No parallel environments
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        # Train metrics
        env_steps = tf_metrics.EnvironmentSteps()
        average_return = tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        # Get policies
        # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_policy = tf_agent.policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        collect_policy = tf_agent.collect_policy

        # Checkpointers
        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        # Collect driver
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        # Optimize the performance by using tf functions
        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        if agent_name == 'latent_sac':
            compute_summaries(eval_metrics,
                              eval_tf_env,
                              eval_policy,
                              train_step=global_step,
                              summary_writer=summary_writer,
                              num_episodes=1,
                              num_episodes_to_render=1,
                              model_net=model_net,
                              fps=10,
                              image_keys=input_names + mask_names)
        else:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=1,
                train_step=env_steps.result(),
                summary_writer=summary_writer,
                summary_prefix='Eval',
            )
            metric_utils.log_metrics(eval_metrics)

        # Dataset generates trajectories with shape [Bxslx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        # Get train step
        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        train_step = common.function(train_step)

        if agent_name == 'latent_sac':

            def train_model_step():
                experience, _ = next(iterator)
                return tf_agent.train_model(experience)

            train_model_step = common.function(train_model_step)

        # Training initializations
        time_step = None
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Start training
        for iteration in range(num_iterations):
            start_time = time.time()

            if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
                train_model_step()
            else:
                # Run collect
                time_step, _ = collect_driver.run(time_step=time_step)

                # Train an iteration
                for _ in range(train_steps_per_iteration):
                    train_step()

            time_acc += time.time() - start_time

            # Log training information
            if global_step.numpy() % log_interval == 0:
                logging.info('env steps = %d, average return = %f',
                             env_steps.result(), average_return.result())
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / time_acc
                logging.info('%.3f env steps/sec', env_steps_per_sec)
                tf.summary.scalar(name='env_steps_per_sec',
                                  data=env_steps_per_sec,
                                  step=env_steps.result())
                time_acc = 0
                env_steps_before = env_steps.result().numpy()

            # Get training metrics
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())

            # Evaluation
            if global_step.numpy() % eval_interval == 0:
                # Log evaluation metrics
                if agent_name == 'latent_sac':
                    compute_summaries(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        train_step=global_step,
                        summary_writer=summary_writer,
                        num_episodes=num_eval_episodes,
                        num_episodes_to_render=num_images_per_summary,
                        model_net=model_net,
                        fps=10,
                        image_keys=input_names + mask_names)
                else:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    metric_utils.log_metrics(eval_metrics)

            # Save checkpoints
            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
Esempio n. 27
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_allowlist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        critic_learning_rate=3e-4,
        train_sequence_length=20,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_allowlist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_allowlist=[observations_allowlist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d episodes '
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            # Reduce filter_fn over full trajectory sampled. The sequence is kept only
            # if all elements except for the last one pass the filter. This is to
            # allow training on terminal steps.
            return tf.reduce_all(~trajectories.is_boundary()[:-1])

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=train_sequence_length + 1).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            # TODO(b/152648849)
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Esempio n. 28
0
    print('Actor Network Created.')

    # create SAC Agent
    # https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/SacAgent
    global_step = tf.compat.v1.train.get_or_create_global_step()
    if shouldContinueFromLastCheckpoint:
        global_step = tf.compat.v1.train.get_global_step()
    # with strategy.scope():
    #     train_step = train_utils.create_train_step()
    tf_agent = sac_agent.SacAgent(
        env.time_step_spec(),
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=actorLearningRate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=criticLearningRate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=alphaLearningRate),
        target_update_tau=target_update_tau,
        gamma=gamma,
        gradient_clipping=gradientClipping,
        train_step_counter=global_step,
    )
    tf_agent.initialize()
    print('SAC Agent Created.')


    # policies
    evaluate_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    collect_policy = tf_agent.collect_policy

    # metrics and evaluation
Esempio n. 29
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        # Defaults to not checkpointing saved policy. If you wish to enable this,
        # please note the caveat explained in README.md.
        policy_save_interval=-1,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    critic_net = critic_network.CriticNetwork(
        (observation_tensor_spec, action_tensor_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    agent = sac_agent.SacAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step)
    agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                       num_steps=2).prefetch(50)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Esempio n. 30
0
actor_net = actor_distribution_network.ActorDistributionNetwork(
    observation_spec,
    action_spec,
    fc_layer_params=actor_fc_layer_params,
    continuous_projection_net=normal_projection_net)

global_step = tf.compat.v1.train.get_or_create_global_step()
tf_agent = sac_agent.SacAgent(
    train_env.time_step_spec(),
    action_spec,
    actor_network=actor_net,
    critic_network=critic_net,
    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=actor_learning_rate),
    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=critic_learning_rate),
    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=alpha_learning_rate),
    target_update_tau=target_update_tau,
    target_update_period=target_update_period,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=gamma,
    reward_scale_factor=reward_scale_factor,
    gradient_clipping=gradient_clipping,
    train_step_counter=global_step)
tf_agent.initialize()

eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
collect_policy = tf_agent.collect_policy


def compute_avg_return(environment, policy, num_episodes=5):