Пример #1
0
    def test_eval_job(self):
        # Create test context.
        summary_dir = self.create_tempdir().full_path
        environment = test_envs.CountingEnv(steps_per_episode=4)
        action_tensor_spec = tensor_spec.from_spec(environment.action_spec())
        time_step_tensor_spec = tensor_spec.from_spec(
            environment.time_step_spec())
        policy = py_tf_eager_policy.PyTFEagerPolicy(
            random_tf_policy.RandomTFPolicy(time_step_tensor_spec,
                                            action_tensor_spec))
        mock_variable_container = mock.create_autospec(
            reverb_variable_container.ReverbVariableContainer)

        with mock.patch.object(
                tf.summary, 'scalar',
                autospec=True) as mock_scalar_summary, mock.patch.object(
                    train_utils, 'wait_for_predicate', autospec=True):
            # Run the function tested.
            eval_job.evaluate(summary_dir=summary_dir,
                              policy=policy,
                              environment_name=None,
                              suite_load_fn=lambda _: environment,
                              variable_container=mock_variable_container,
                              is_running=_NTimesReturnTrue(n=2))

            # Check if the expected calls happened.
            # As an input, an eval job is expected to fetch data from the variable
            # container.
            mock_variable_container.assert_has_calls(
                [mock.call.update(mock.ANY)])

            # As an output, an eval job is expected to write at least the average
            # return corresponding to the first step.
            mock_scalar_summary.assert_any_call(
                name='eval_actor/AverageReturn', data=mock.ANY, step=mock.ANY)
    def testPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec)

        actor_net = actor_network.ActorNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec,
                                             action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        # Env will validate action types automaticall since we provided the
        # action_spec.
        env = random_py_environment.RandomPyEnvironment(
            observation_spec, action_spec)

        time_step = env.reset()

        for _ in range(100):
            action_step = py_policy.action(time_step)
            time_step = env.step(action_step.action)
    def testBatchedPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        actor_net = actor_network.ActorNetwork(
            self._observation_tensor_spec,
            self._action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec,
                                             self._action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy,
                                                       batch_time_steps=False)

        env_ctr = lambda: random_py_environment.RandomPyEnvironment(  # pylint: disable=g-long-lambda
            self._observation_spec, self._action_spec)

        env = batched_py_environment.BatchedPyEnvironment(
            [env_ctr() for _ in range(3)])
        time_step = env.reset()

        for _ in range(20):
            action_step = py_policy.action(time_step)
            time_step = env.step(action_step.action)
    def testRandomTFPolicyCompatibility(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        tf_policy = random_tf_policy.RandomTFPolicy(
            self._time_step_tensor_spec,
            self._action_tensor_spec,
            info_spec=self._info_tensor_spec)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        time_step = self._env.reset()

        def _check_action_step(action_step):
            self.assertIsInstance(action_step.action, np.ndarray)
            self.assertEqual(action_step.action.shape, (1, ))
            self.assertBetween(action_step.action[0], 2.0, 3.0)

            self.assertIsInstance(action_step.info['a'], np.ndarray)
            self.assertEqual(action_step.info['a'].shape, (1, ))
            self.assertBetween(action_step.info['a'][0], 0.0, 1.0)

            self.assertIsInstance(action_step.info['b'], np.ndarray)
            self.assertEqual(action_step.info['b'].shape, (1, ))
            self.assertBetween(action_step.info['b'][0], 100.0, 101.0)

        for _ in range(100):
            action_step = py_policy.action(time_step)
            _check_action_step(action_step)
            time_step = self._env.step(action_step.action)
Пример #5
0
def collect_episode(environment, policy, num_episodes):

    driver = py_driver.PyDriver(environment,
                                py_tf_eager_policy.PyTFEagerPolicy(
                                    policy, use_tf_function=True),
                                [rb_observer],
                                max_episodes=num_episodes)
    initial_time_step = environment.reset()
    driver.run(initial_time_step)
Пример #6
0
def collect_episode(environment, policy, num_episodes, replay_buffer_observer):
    """Collect game episode trajectories."""
    initial_time_step = environment.reset()

    driver = py_driver.PyDriver(environment,
                                py_tf_eager_policy.PyTFEagerPolicy(
                                    policy, use_tf_function=True),
                                [replay_buffer_observer],
                                max_episodes=num_episodes)
    initial_time_step = environment.reset()
    driver.run(initial_time_step)
Пример #7
0
def build_actor(root_dir, env, agent, rb_observer, train_step):
    """Builds the Actor."""
    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)
    temp_dir = root_dir + 'actor'
    test_actor = actor.Actor(env,
                             collect_policy,
                             train_step,
                             steps_per_run=1,
                             metrics=actor.collect_metrics(10),
                             summary_dir=temp_dir,
                             observers=[rb_observer])

    return test_actor
Пример #8
0
  def test_eval_job_constant_eval(self):
    """Tests eval every step for 2 steps.

    This test's `variable_container` passes the same train step twice to test
    that `is_train_step_the_same_or_behind` is working as expected. If were not
    working, the number of train steps processed will be incorrect (2x higher).
    """
    summary_dir = self.create_tempdir().full_path
    environment = test_envs.CountingEnv(steps_per_episode=4)
    action_tensor_spec = tensor_spec.from_spec(environment.action_spec())
    time_step_tensor_spec = tensor_spec.from_spec(environment.time_step_spec())
    policy = py_tf_eager_policy.PyTFEagerPolicy(
        random_tf_policy.RandomTFPolicy(time_step_tensor_spec,
                                        action_tensor_spec))
    mock_variable_container = mock.create_autospec(
        reverb_variable_container.ReverbVariableContainer)

    class VCUpdateIncrementEveryOtherTrainStep(object):
      """Side effect that updates train_step on every other call."""

      def __init__(self):
        self.fake_train_step = -1
        self.call_count = 0

      def __call__(self, variables):
        if self.call_count % 2:
          self.fake_train_step += 1
          variables[reverb_variable_container.TRAIN_STEP_KEY].assign(
              self.fake_train_step)
        self.call_count += 1

    fake_update = VCUpdateIncrementEveryOtherTrainStep()
    mock_variable_container.update.side_effect = fake_update

    with mock.patch.object(
        tf.summary, 'scalar', autospec=True) as mock_scalar_summary:
      eval_job.evaluate(
          summary_dir=summary_dir,
          policy=policy,
          environment_name=None,
          suite_load_fn=lambda _: environment,
          variable_container=mock_variable_container,
          eval_interval=1,
          is_running=_NTimesReturnTrue(n=2))

      summary_count = self.count_summary_scalar_tags_in_call_list(
          mock_scalar_summary, 'Metrics/eval_actor/AverageReturn')
      self.assertEqual(summary_count, 2)
Пример #9
0
  def test_eval_job(self):
    """Tests the eval job doing an eval every 5 steps for 10 train steps."""
    summary_dir = self.create_tempdir().full_path
    environment = test_envs.CountingEnv(steps_per_episode=4)
    action_tensor_spec = tensor_spec.from_spec(environment.action_spec())
    time_step_tensor_spec = tensor_spec.from_spec(environment.time_step_spec())
    policy = py_tf_eager_policy.PyTFEagerPolicy(
        random_tf_policy.RandomTFPolicy(time_step_tensor_spec,
                                        action_tensor_spec))

    class VCUpdateIncrementTrainStep(object):
      """Side effect that updates train_step."""

      def __init__(self):
        self.fake_train_step = -1

      def __call__(self, variables):
        self.fake_train_step += 1
        variables[reverb_variable_container.TRAIN_STEP_KEY].assign(
            self.fake_train_step)

    mock_variable_container = mock.create_autospec(
        reverb_variable_container.ReverbVariableContainer)
    fake_update = VCUpdateIncrementTrainStep()
    mock_variable_container.update.side_effect = fake_update

    with mock.patch.object(
        tf.summary, 'scalar', autospec=True) as mock_scalar_summary:
      # Run the function tested.
      # 11 loops to do 10 steps becaue the eval occurs on the loop after the
      # train_step is found.
      eval_job.evaluate(
          summary_dir=summary_dir,
          policy=policy,
          environment_name=None,
          suite_load_fn=lambda _: environment,
          variable_container=mock_variable_container,
          eval_interval=5,
          is_running=_NTimesReturnTrue(n=11))

      summary_count = self.count_summary_scalar_tags_in_call_list(
          mock_scalar_summary, 'Metrics/eval_actor/AverageReturn')
      self.assertEqual(summary_count, 3)
Пример #10
0
    def testActorRun(self):
        rb_port = portpicker.pick_unused_port(portserver_address='localhost')

        env, agent, train_step, replay_buffer, rb_observer = (
            self._build_components(rb_port))

        tf_collect_policy = agent.collect_policy
        collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
            tf_collect_policy, use_tf_function=True)
        test_actor = actor.Actor(env,
                                 collect_policy,
                                 train_step,
                                 steps_per_run=1,
                                 observers=[rb_observer])

        self.assertEqual(replay_buffer.num_frames(), 0)
        for _ in range(10):
            test_actor.run()
        self.assertGreater(replay_buffer.num_frames(), 0)
    def testRandomTFPolicyCompatibility(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)
        info_spec = {
            'a': array_spec.BoundedArraySpec([1], np.float32, 0, 1),
            'b': array_spec.BoundedArraySpec([1], np.float32, 100, 101)
        }

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        info_tensor_spec = tensor_spec.from_spec(info_spec)
        time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec)

        tf_policy = random_tf_policy.RandomTFPolicy(time_step_tensor_spec,
                                                    action_tensor_spec,
                                                    info_spec=info_tensor_spec)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        env = random_py_environment.RandomPyEnvironment(
            observation_spec, action_spec)
        time_step = env.reset()

        def _check_action_step(action_step):
            self.assertIsInstance(action_step.action, np.ndarray)
            self.assertEqual(action_step.action.shape, (1, ))
            self.assertBetween(action_step.action[0], 2.0, 3.0)

            self.assertIsInstance(action_step.info['a'], np.ndarray)
            self.assertEqual(action_step.info['a'].shape, (1, ))
            self.assertBetween(action_step.info['a'][0], 0.0, 1.0)

            self.assertIsInstance(action_step.info['b'], np.ndarray)
            self.assertEqual(action_step.info['b'].shape, (1, ))
            self.assertBetween(action_step.info['b'][0], 100.0, 101.0)

        for _ in range(100):
            action_step = py_policy.action(time_step)
            _check_action_step(action_step)
            time_step = env.step(action_step.action)
    def testPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        actor_net = actor_network.ActorNetwork(
            self._observation_tensor_spec,
            self._action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec,
                                             self._action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        time_step = self._env.reset()

        for _ in range(100):
            action_step = py_policy.action(time_step)
            time_step = self._env.step(action_step.action)
    def testActionWithSeed(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        tf_policy = random_tf_policy.RandomTFPolicy(
            self._time_step_tensor_spec,
            self._action_tensor_spec,
            info_spec=self._info_tensor_spec)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        time_step = self._env.reset()
        tf.random.set_seed(100)
        action_step_1 = py_policy.action(time_step, seed=100)
        time_step = self._env.reset()
        tf.random.set_seed(100)
        action_step_2 = py_policy.action(time_step, seed=100)
        time_step = self._env.reset()
        tf.random.set_seed(200)
        action_step_3 = py_policy.action(time_step, seed=200)
        self.assertEqual(action_step_1.action[0], action_step_2.action[0])
        self.assertNotEqual(action_step_1.action[0], action_step_3.action[0])
Пример #14
0
    def testCollectLocalPyActorRun(self):
        rb_port = portpicker.pick_unused_port(portserver_address='localhost')

        env, agent, train_step, replay_buffer, rb_observer = (
            self._build_components(rb_port))

        temp_dir = self.create_tempdir().full_path
        tf_collect_policy = agent.collect_policy
        collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
            tf_collect_policy, use_tf_function=True)
        test_actor = actor.Actor(env,
                                 collect_policy,
                                 train_step,
                                 steps_per_run=1,
                                 metrics=actor.collect_metrics(buffer_size=1),
                                 summary_dir=temp_dir,
                                 observers=[rb_observer])

        self.assertEqual(replay_buffer.num_frames(), 0)
        for _ in range(10):
            test_actor.run()
        self.assertGreater(replay_buffer.num_frames(), 0)
Пример #15
0
  def testRandomTFPolicyCompatibility(self):
    if not common.has_eager_been_enabled():
      self.skipTest('Only supported in eager.')

    observation_spec = array_spec.ArraySpec([2], np.float32)
    action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)

    observation_tensor_spec = tensor_spec.from_spec(observation_spec)
    action_tensor_spec = tensor_spec.from_spec(action_spec)
    time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec)

    tf_policy = random_tf_policy.RandomTFPolicy(time_step_tensor_spec,
                                                action_tensor_spec)

    py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
    env = random_py_environment.RandomPyEnvironment(observation_spec,
                                                    action_spec)
    time_step = env.reset()

    for _ in range(100):
      action_step = py_policy.action(time_step)
      time_step = env.step(action_step.action)
Пример #16
0
    def testRefereneMetricsNotInObservers(self):
        rb_port = portpicker.pick_unused_port(portserver_address='localhost')

        env, agent, train_step, _, rb_observer = (
            self._build_components(rb_port))

        temp_dir = self.create_tempdir().full_path
        tf_collect_policy = agent.collect_policy
        collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
            tf_collect_policy, use_tf_function=True)
        metrics = metrics = actor.collect_metrics(buffer_size=1)
        step_metric = py_metrics.EnvironmentSteps()
        test_actor = actor.Actor(env,
                                 collect_policy,
                                 train_step,
                                 steps_per_run=1,
                                 metrics=metrics,
                                 reference_metrics=[step_metric],
                                 summary_dir=temp_dir,
                                 observers=[rb_observer])

        self.assertNotIn(step_metric, test_actor._observers)
Пример #17
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    # Training params
    num_iterations=1600,
    actor_fc_layers=(64, 64),
    value_fc_layers=(64, 64),
    learning_rate=3e-4,
    collect_sequence_length=2048,
    minibatch_size=64,
    num_epochs=10,
    # Agent params
    importance_ratio_clipping=0.2,
    lambda_value=0.95,
    discount_factor=0.99,
    entropy_regularization=0.,
    value_pred_loss_coef=0.5,
    use_gae=True,
    use_td_lambda_return=True,
    gradient_clipping=0.5,
    value_clipping=None,
    # Replay params
    reverb_port=None,
    replay_capacity=10000,
    # Others
    policy_save_interval=5000,
    summary_interval=1000,
    eval_interval=10000,
    eval_episodes=100,
    debug_summaries=False,
    summarize_grads_and_vars=False):
  """Trains and evaluates PPO (Importance Ratio Clipping).

  Args:
    root_dir: Main directory path where checkpoints, saved_models, and summaries
      will be written to.
    env_name: Name for the Mujoco environment to load.
    num_iterations: The number of iterations to perform collection and training.
    actor_fc_layers: List of fully_connected parameters for the actor network,
      where each item is the number of units in the layer.
    value_fc_layers: : List of fully_connected parameters for the value network,
      where each item is the number of units in the layer.
    learning_rate: Learning rate used on the Adam optimizer.
    collect_sequence_length: Number of steps to take in each collect run.
    minibatch_size: Number of elements in each mini batch. If `None`, the entire
      collected sequence will be treated as one batch.
    num_epochs: Number of iterations to repeat over all collected data per data
      collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for
      Roboschool and 3 for Atari.
    importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For
      more detail, see explanation at the top of the doc.
    lambda_value: Lambda parameter for TD-lambda computation.
    discount_factor: Discount factor for return computation. Default to `0.99`
      which is the value used for all environments from (Schulman, 2017).
    entropy_regularization: Coefficient for entropy regularization loss term.
      Default to `0.0` because no entropy bonus was used in (Schulman, 2017).
    value_pred_loss_coef: Multiplier for value prediction loss to balance with
      policy gradient loss. Default to `0.5`, which was used for all
      environments in the OpenAI baseline implementation. This parameters is
      irrelevant unless you are sharing part of actor_net and value_net. In that
      case, you would want to tune this coeeficient, whose value depends on the
      network architecture of your choice.
    use_gae: If True (default False), uses generalized advantage estimation for
      computing per-timestep advantage. Else, just subtracts value predictions
      from empirical return.
    use_td_lambda_return: If True (default False), uses td_lambda_return for
      training value function; here: `td_lambda_return = gae_advantage +
        value_predictions`. `use_gae` must be set to `True` as well to enable TD
        -lambda returns. If `use_td_lambda_return` is set to True while
        `use_gae` is False, the empirical return will be used and a warning will
        be logged.
    gradient_clipping: Norm length to clip gradients.
    value_clipping: Difference between new and old value predictions are clipped
      to this threshold. Value clipping could be helpful when training
      very deep networks. Default: no clipping.
    reverb_port: Port for reverb server, if None, use a randomly chosen unused
      port.
    replay_capacity: The maximum number of elements for the replay buffer. Items
      will be wasted if this is smalled than collect_sequence_length.
    policy_save_interval: How often, in train_steps, the policy will be saved.
    summary_interval: How often to write data into Tensorboard.
    eval_interval: How often to run evaluation, in train_steps.
    eval_episodes: Number of episodes to evaluate over.
    debug_summaries: Boolean for whether to gather debug summaries.
    summarize_grads_and_vars: If true, gradient summaries will be written.
  """
  collect_env = suite_mujoco.load(env_name)
  eval_env = suite_mujoco.load(env_name)
  num_environments = 1

  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(collect_env))
  # TODO(b/172267869): Remove this conversion once TensorNormalizer stops
  # converting float64 inputs to float32.
  observation_tensor_spec = tf.TensorSpec(
      dtype=tf.float32, shape=observation_tensor_spec.shape)

  train_step = train_utils.create_train_step()
  actor_net_builder = ppo_actor_network.PPOActorNetwork()
  actor_net = actor_net_builder.create_sequential_actor_net(
      actor_fc_layers, action_tensor_spec)
  value_net = value_network.ValueNetwork(
      observation_tensor_spec,
      fc_layer_params=value_fc_layers,
      kernel_initializer=tf.keras.initializers.Orthogonal())

  current_iteration = tf.Variable(0, dtype=tf.int64)
  def learning_rate_fn():
    # Linearly decay the learning rate.
    return learning_rate * (1 - current_iteration / num_iterations)

  agent = ppo_clip_agent.PPOClipAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      optimizer=tf.keras.optimizers.Adam(
          learning_rate=learning_rate_fn, epsilon=1e-5),
      actor_net=actor_net,
      value_net=value_net,
      importance_ratio_clipping=importance_ratio_clipping,
      lambda_value=lambda_value,
      discount_factor=discount_factor,
      entropy_regularization=entropy_regularization,
      value_pred_loss_coef=value_pred_loss_coef,
      # This is a legacy argument for the number of times we repeat the data
      # inside of the train function, incompatible with mini batch learning.
      # We set the epoch number from the replay buffer and tf.Data instead.
      num_epochs=1,
      use_gae=use_gae,
      use_td_lambda_return=use_td_lambda_return,
      gradient_clipping=gradient_clipping,
      value_clipping=value_clipping,
      # TODO(b/150244758): Default compute_value_and_advantage_in_train to False
      # after Reverb open source.
      compute_value_and_advantage_in_train=False,
      # Skips updating normalizers in the agent, as it's handled in the learner.
      update_normalizers_in_train=False,
      debug_summaries=debug_summaries,
      summarize_grads_and_vars=summarize_grads_and_vars,
      train_step_counter=train_step)
  agent.initialize()

  reverb_server = reverb.Server(
      [
          reverb.Table(  # Replay buffer storing experience for training.
              name='training_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          ),
          reverb.Table(  # Replay buffer storing experience for normalization.
              name='normalization_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          )
      ],
      port=reverb_port)

  # Create the replay buffer.
  reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='training_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)
  reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='normalization_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)

  rb_observer = reverb_utils.ReverbTrajectorySequenceObserver(
      reverb_replay_train.py_client, ['training_table', 'normalization_table'],
      sequence_length=collect_sequence_length,
      stride_length=collect_sequence_length)

  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  collect_env_step_metric = py_metrics.EnvironmentSteps()
  learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={
              triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric
          }),
      triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval),
  ]

  def training_dataset_fn():
    return reverb_replay_train.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  def normalization_dataset_fn():
    return reverb_replay_normalization.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  agent_learner = ppo_learner.PPOLearner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn=training_dataset_fn,
      normalization_dataset_fn=normalization_dataset_fn,
      num_samples=1,
      num_epochs=num_epochs,
      minibatch_size=minibatch_size,
      shuffle_buffer_size=collect_sequence_length,
      triggers=learning_triggers)

  tf_collect_policy = agent.collect_policy
  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
      tf_collect_policy, use_tf_function=True)

  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=collect_sequence_length,
      observers=[rb_observer],
      metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric],
      reference_metrics=[collect_env_step_metric],
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
      summary_interval=summary_interval)

  eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
      agent.policy, use_tf_function=True)

  if eval_interval:
    logging.info('Intial evaluation.')
    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        metrics=actor.eval_metrics(eval_episodes),
        reference_metrics=[collect_env_step_metric],
        summary_dir=os.path.join(root_dir, 'eval'),
        episodes_per_run=eval_episodes)

    eval_actor.run_and_log()

  logging.info('Training on %s', env_name)
  last_eval_step = 0
  for i in range(num_iterations):
    collect_actor.run()
    rb_observer.flush()
    agent_learner.run()
    reverb_replay_train.clear()
    reverb_replay_normalization.clear()
    current_iteration.assign_add(1)

    # Eval only if `eval_interval` has been set. Then, eval if the current train
    # step is equal or greater than the `last_eval_step` + `eval_interval` or if
    # this is the last iteration. This logic exists because agent_learner.run()
    # does not return after every train step.
    if (eval_interval and
        (agent_learner.train_step_numpy >= eval_interval + last_eval_step
         or i == num_iterations - 1)):
      logging.info('Evaluating.')
      eval_actor.run_and_log()
      last_eval_step = agent_learner.train_step_numpy

  rb_observer.close()
  reverb_server.stop()
Пример #18
0
def train_eval(
        root_dir,
        offline_dir=None,
        random_seed=None,
        env_name='sawyer_push',
        eval_env_name=None,
        env_load_fn=get_env,
        max_episode_steps=1000,
        eval_episode_steps=1000,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        reset_goal_frequency=1000,  # virtual episode size for reset-free training
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        # reset-free parameters
        use_minimum=True,
        reset_lagrange_learning_rate=3e-4,
        value_threshold=None,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        # Td3 parameters
        actor_update_period=1,
        exploration_noise_std=0.1,
        target_policy_noise=0.1,
        target_policy_noise_clip=0.1,
        dqda_clipping=None,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        # video recording for the environment
        video_record_interval=10000,
        num_videos=0,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):

    start_time = time.time()

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    video_dir = os.path.join(eval_dir, 'videos')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        if FLAGS.use_reset_goals in [-1]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.GoalTerminalResetWrapper,
                num_success_states=FLAGS.num_success_states,
                full_reset_frequency=max_episode_steps), )
        elif FLAGS.use_reset_goals in [0, 1]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.ResetFreeWrapper,
                reset_goal_frequency=reset_goal_frequency,
                variable_horizon_for_reset=FLAGS.variable_reset_horizon,
                num_success_states=FLAGS.num_success_states,
                full_reset_frequency=max_episode_steps), )
        elif FLAGS.use_reset_goals in [2]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.CustomOracleResetWrapper,
                partial_reset_frequency=reset_goal_frequency,
                episodes_before_full_reset=max_episode_steps //
                reset_goal_frequency), )
        elif FLAGS.use_reset_goals in [3, 4]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.GoalTerminalResetFreeWrapper,
                reset_goal_frequency=reset_goal_frequency,
                num_success_states=FLAGS.num_success_states,
                full_reset_frequency=max_episode_steps), )
        elif FLAGS.use_reset_goals in [5, 7]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.CustomOracleResetGoalTerminalWrapper,
                partial_reset_frequency=reset_goal_frequency,
                episodes_before_full_reset=max_episode_steps //
                reset_goal_frequency), )
        elif FLAGS.use_reset_goals in [6]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.VariableGoalTerminalResetWrapper,
                full_reset_frequency=max_episode_steps), )

        if env_name == 'playpen_reduced':
            train_env_load_fn = functools.partial(
                env_load_fn, reset_at_goal=FLAGS.reset_at_goal)
        else:
            train_env_load_fn = env_load_fn

        env, env_train_metrics, env_eval_metrics, aux_info = train_env_load_fn(
            name=env_name,
            max_episode_steps=None,
            gym_env_wrappers=gym_env_wrappers)

        tf_env = tf_py_environment.TFPyEnvironment(env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(name=eval_env_name,
                        max_episode_steps=eval_episode_steps)[0])

        eval_metrics += env_eval_metrics

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        if FLAGS.agent_type == 'sac':
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=functools.partial(
                    tanh_normal_projection_network.TanhNormalProjectionNetwork,
                    std_transform=std_clip_transform))
            critic_net = critic_network.CriticNetwork(
                (observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers,
                kernel_initializer='glorot_uniform',
                last_kernel_initializer='glorot_uniform',
            )

            critic_net_no_entropy = None
            critic_no_entropy_optimizer = None
            if FLAGS.use_no_entropy_q:
                critic_net_no_entropy = critic_network.CriticNetwork(
                    (observation_spec, action_spec),
                    observation_fc_layer_params=critic_obs_fc_layers,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    kernel_initializer='glorot_uniform',
                    last_kernel_initializer='glorot_uniform',
                    name='CriticNetworkNoEntropy1')
                critic_no_entropy_optimizer = tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate)

            tf_agent = SacAgent(
                time_step_spec,
                action_spec,
                num_action_samples=FLAGS.num_action_samples,
                actor_network=actor_net,
                critic_network=critic_net,
                critic_network_no_entropy=critic_net_no_entropy,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                critic_no_entropy_optimizer=critic_no_entropy_optimizer,
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)

        elif FLAGS.agent_type == 'td3':
            actor_net = actor_network.ActorNetwork(
                tf_env.time_step_spec().observation,
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
            )
            critic_net = critic_network.CriticNetwork(
                (observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers,
                kernel_initializer='glorot_uniform',
                last_kernel_initializer='glorot_uniform')

            tf_agent = Td3Agent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                exploration_noise_std=exploration_noise_std,
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                actor_update_period=actor_update_period,
                dqda_clipping=dqda_clipping,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                target_policy_noise=target_policy_noise,
                target_policy_noise_clip=target_policy_noise_clip,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
            )

        tf_agent.initialize()

        if FLAGS.use_reset_goals > 0:
            if FLAGS.use_reset_goals in [4, 5, 6]:
                reset_goal_generator = ScheduledResetGoal(
                    goal_dim=aux_info['reset_state_shape'][0],
                    num_success_for_switch=FLAGS.num_success_for_switch,
                    num_chunks=FLAGS.num_chunks,
                    name='ScheduledResetGoalGenerator')
            else:
                # distance to initial state distribution
                initial_state_distance = state_distribution_distance.L2Distance(
                    initial_state_shape=aux_info['reset_state_shape'])
                initial_state_distance.update(tf.constant(
                    aux_info['reset_states'], dtype=tf.float32),
                                              update_type='complete')

                if use_tf_functions:
                    initial_state_distance.distance = common.function(
                        initial_state_distance.distance)
                    tf_agent.compute_value = common.function(
                        tf_agent.compute_value)

                # initialize reset / practice goal proposer
                if reset_lagrange_learning_rate > 0:
                    reset_goal_generator = ResetGoalGenerator(
                        goal_dim=aux_info['reset_state_shape'][0],
                        compute_value_fn=tf_agent.compute_value,
                        distance_fn=initial_state_distance,
                        use_minimum=use_minimum,
                        value_threshold=value_threshold,
                        lagrange_variable_max=FLAGS.lagrange_max,
                        optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=reset_lagrange_learning_rate),
                        name='reset_goal_generator')
                else:
                    reset_goal_generator = FixedResetGoal(
                        distance_fn=initial_state_distance)

            # if use_tf_functions:
            #   reset_goal_generator.get_reset_goal = common.function(
            #       reset_goal_generator.get_reset_goal)

            # modify the reset-free wrapper to use the reset goal generator
            tf_env.pyenv.envs[0].set_reset_goal_fn(
                reset_goal_generator.get_reset_goal)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        if FLAGS.relabel_goals:
            cur_episode_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                data_spec=tf_agent.collect_data_spec,
                batch_size=1,
                scope='CurEpisodeReplayBuffer',
                max_length=int(2 *
                               min(reset_goal_frequency, max_episode_steps)))

            # NOTE: the buffer is replaced because cannot have two buffers.add_batch
            replay_observer = [cur_episode_buffer.add_batch]

        # initialize metrics and observers
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        train_metrics += env_train_metrics

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy(
            tf_agent.policy, use_tf_function=True)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)
        if use_tf_functions:
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if offline_dir is not None:
            offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                data_spec=tf_agent.collect_data_spec,
                batch_size=1,
                max_length=int(1e5))  # this has to be 100_000
            offline_checkpointer = common.Checkpointer(
                ckpt_dir=offline_dir,
                max_to_keep=1,
                replay_buffer=offline_data)
            offline_checkpointer.initialize_or_restore()

            # set the reset candidates to be all the data in offline buffer
            if (FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0
                ) or FLAGS.use_reset_goals in [4, 5, 6, 7]:
                tf_env.pyenv.envs[0].set_reset_candidates(
                    nest_utils.unbatch_nested_tensors(
                        offline_data.gather_all()))

        if replay_buffer.num_frames() == 0:
            if offline_dir is not None:
                copy_replay_buffer(offline_data, replay_buffer)
                print(replay_buffer.num_frames())

                # multiply offline data
                if FLAGS.relabel_offline_data:
                    data_multiplier(replay_buffer,
                                    tf_env.pyenv.envs[0].env.compute_reward)
                    print('after data multiplication:',
                          replay_buffer.num_frames())

            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=1)
            if use_tf_functions:
                initial_collect_driver.run = common.function(
                    initial_collect_driver.run)

            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps with '
                'a random policy.', initial_collect_steps)

            time_step = None
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)

            for iter_idx in range(initial_collect_steps):
                time_step, policy_state = initial_collect_driver.run(
                    time_step=time_step, policy_state=policy_state)

                if time_step.is_last() and FLAGS.relabel_goals:
                    reward_fn = tf_env.pyenv.envs[0].env.compute_reward
                    relabel_function(cur_episode_buffer, time_step, reward_fn,
                                     replay_buffer)
                    cur_episode_buffer.clear()

                if FLAGS.use_reset_goals > 0 and time_step.is_last(
                ) and FLAGS.num_reset_candidates > 0:
                    tf_env.pyenv.envs[0].set_reset_candidates(
                        replay_buffer.get_next(
                            sample_batch_size=FLAGS.num_reset_candidates)[0])

        else:
            time_step = None
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        # manual data save for plotting utils
        np_custom_save(os.path.join(eval_dir, 'eval_interval.npy'),
                       eval_interval)
        try:
            average_eval_return = np_custom_load(
                os.path.join(eval_dir, 'average_eval_return.npy')).tolist()
            average_eval_success = np_custom_load(
                os.path.join(eval_dir, 'average_eval_success.npy')).tolist()
            average_eval_final_success = np_custom_load(
                os.path.join(eval_dir,
                             'average_eval_final_success.npy')).tolist()
        except:  # pylint: disable=bare-except
            average_eval_return = []
            average_eval_success = []
            average_eval_final_success = []

        print('initialization_time:', time.time() - start_time)
        for iter_idx in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )

            if time_step.is_last() and FLAGS.relabel_goals:
                reward_fn = tf_env.pyenv.envs[0].env.compute_reward
                relabel_function(cur_episode_buffer, time_step, reward_fn,
                                 replay_buffer)
                cur_episode_buffer.clear()

            # reset goal generator updates
            if FLAGS.use_reset_goals > 0 and iter_idx % (
                    FLAGS.reset_goal_frequency *
                    collect_steps_per_iteration) == 0:
                if FLAGS.num_reset_candidates > 0:
                    tf_env.pyenv.envs[0].set_reset_candidates(
                        replay_buffer.get_next(
                            sample_batch_size=FLAGS.num_reset_candidates)[0])
                if reset_lagrange_learning_rate > 0:
                    reset_goal_generator.update_lagrange_multipliers()

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                if 'Heatmap' in train_metric.name:
                    if global_step_val % summary_interval == 0:
                        train_metric.tf_summaries(
                            train_step=global_step,
                            step_metrics=train_metrics[:2])
                else:
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:2])

            if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0:
                reset_states, values, initial_state_distance_vals, lagrangian = reset_goal_generator.update_summaries(
                    step_counter=global_step)
                for vf_viz_metric in aux_info['value_fn_viz_metrics']:
                    vf_viz_metric.tf_summaries(reset_states,
                                               values,
                                               train_step=global_step,
                                               step_metrics=train_metrics[:2])

                if FLAGS.debug_value_fn_for_reset:
                    num_test_lagrange = 20
                    hyp_lagranges = [
                        1.0 * increment / num_test_lagrange
                        for increment in range(num_test_lagrange + 1)
                    ]

                    door_pos = reset_states[
                        np.argmin(initial_state_distance_vals.numpy() -
                                  lagrangian.numpy() * values.numpy())][3:5]
                    print('cur lagrange: %.2f, cur reset goal: (%.2f, %.2f)' %
                          (lagrangian.numpy(), door_pos[0], door_pos[1]))
                    for lagrange in hyp_lagranges:
                        door_pos = reset_states[
                            np.argmin(initial_state_distance_vals.numpy() -
                                      lagrange * values.numpy())][3:5]
                        print(
                            'test lagrange: %.2f, cur reset goal: (%.2f, %.2f)'
                            % (lagrange, door_pos[0], door_pos[1]))
                    print('\n')

            if global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

                # numpy saves for plotting
                if 'AverageReturn' in results.keys():
                    average_eval_return.append(
                        results['AverageReturn'].numpy())
                if 'EvalSuccessfulAtAnyStep' in results.keys():
                    average_eval_success.append(
                        results['EvalSuccessfulAtAnyStep'].numpy())
                if 'EvalSuccessfulEpisodes' in results.keys():
                    average_eval_final_success.append(
                        results['EvalSuccessfulEpisodes'].numpy())
                elif 'EvalSuccessfulAtLastStep' in results.keys():
                    average_eval_final_success.append(
                        results['EvalSuccessfulAtLastStep'].numpy())

                if average_eval_return:
                    np_custom_save(
                        os.path.join(eval_dir, 'average_eval_return.npy'),
                        average_eval_return)
                if average_eval_success:
                    np_custom_save(
                        os.path.join(eval_dir, 'average_eval_success.npy'),
                        average_eval_success)
                if average_eval_final_success:
                    np_custom_save(
                        os.path.join(eval_dir,
                                     'average_eval_final_success.npy'),
                        average_eval_final_success)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)

            if global_step_val % video_record_interval == 0:
                for video_idx in range(num_videos):
                    video_name = os.path.join(
                        video_dir, str(global_step_val),
                        'video_' + str(video_idx) + '.mp4')
                    record_video(
                        lambda: env_load_fn(  # pylint: disable=g-long-lambda
                            name=env_name,
                            max_episode_steps=max_episode_steps)[0],
                        video_name,
                        eval_py_policy,
                        max_episode_length=eval_episode_steps)

        return train_loss
Пример #19
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        # Training params
        initial_collect_steps=1000,
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Agent params
        epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""
    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec())
    action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec())

    train_step = train_utils.create_train_step()
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

    # Define a helper function to create Dense layers configured with the right
    # activation and kernel initializer.
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0, mode='fan_in', distribution='truncated_normal'))

    # QNetwork consists of a sequence of Dense layers followed by a dense layer
    # with `num_actions` units to generate one q_value per available action as
    # it's output.
    dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
    q_values_layer = tf.keras.layers.Dense(
        num_actions,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03,
                                                               maxval=0.03),
        bias_initializer=tf.keras.initializers.Constant(-0.2))
    q_net = sequential.Sequential(dense_layers + [q_values_layer])

    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=1,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Пример #20
0
def train_eval(
        root_dir,
        env_name,
        # Training params
        train_sequence_length,
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_iterations=100000,
        # RNN params.
        q_network_fn=q_lstm_network,  # defaults to q_lstm_network.
        # Agent params
    epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""

    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
    q_net = q_network_fn(num_actions=num_actions)

    sequence_length = train_sequence_length + 1
    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        # n-step updates aren't supported with RNNs yet.
        n_step_update=1,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=sequence_length,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=sequence_length,
        stride_length=1,
        pad_end_of_episodes=True)

    def experience_dataset_fn():
        return reverb_replay.as_dataset(sample_batch_size=batch_size,
                                        num_steps=sequence_length)

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=collect_steps_per_iteration,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Пример #21
0
def train_eval(
    root_dir,
    env_name='Pong-v0',
    # Training params
    update_frequency=4,  # Number of collect steps per policy update
    initial_collect_steps=50000,  # 50k collect steps
    num_iterations=50000000,  # 50M collect steps
    # Taken from Rainbow as it's not specified in Mnih,15.
    max_episode_frames_collect=50000,  # env frames observed by the agent
    max_episode_frames_eval=108000,  # env frames observed by the agent
    # Agent params
    epsilon_greedy=0.1,
    epsilon_decay_period=250000,  # 1M collect steps / update_frequency
    batch_size=32,
    learning_rate=0.00025,
    n_step_update=1,
    gamma=0.99,
    target_update_tau=1.0,
    target_update_period=2500,  # 10k collect steps / update_frequency
    reward_scale_factor=1.0,
    # Replay params
    reverb_port=None,
    replay_capacity=1000000,
    # Others
    policy_save_interval=250000,
    eval_interval=1000,
    eval_episodes=30,
    debug_summaries=True):
  """Trains and evaluates DQN."""

  collect_env = suite_atari.load(
      env_name,
      max_episode_steps=max_episode_frames_collect,
      gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
  eval_env = suite_atari.load(
      env_name,
      max_episode_steps=max_episode_frames_eval,
      gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)

  unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(collect_env))

  train_step = train_utils.create_train_step()

  num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
  epsilon = tf.compat.v1.train.polynomial_decay(
      1.0,
      train_step,
      epsilon_decay_period,
      end_learning_rate=epsilon_greedy)
  agent = dqn_agent.DqnAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      q_network=create_q_network(num_actions),
      epsilon_greedy=epsilon,
      n_step_update=n_step_update,
      target_update_tau=target_update_tau,
      target_update_period=target_update_period,
      optimizer=tf.compat.v1.train.RMSPropOptimizer(
          learning_rate=learning_rate,
          decay=0.95,
          momentum=0.95,
          epsilon=0.01,
          centered=True),
      td_errors_loss_fn=common.element_wise_huber_loss,
      gamma=gamma,
      reward_scale_factor=reward_scale_factor,
      train_step_counter=train_step,
      debug_summaries=debug_summaries)

  table_name = 'uniform_table'
  table = reverb.Table(
      table_name,
      max_size=replay_capacity,
      sampler=reverb.selectors.Uniform(),
      remover=reverb.selectors.Fifo(),
      rate_limiter=reverb.rate_limiters.MinSize(1))
  reverb_server = reverb.Server([table], port=reverb_port)
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=table_name,
      local_server=reverb_server)
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb_replay.py_client, table_name,
      sequence_length=2,
      stride_length=1)

  dataset = reverb_replay.as_dataset(
      sample_batch_size=batch_size, num_steps=2).prefetch(3)
  experience_dataset_fn = lambda: dataset

  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  env_step_metric = py_metrics.EnvironmentSteps()

  learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}),
      triggers.StepPerSecondLogTrigger(train_step, interval=100),
  ]

  dqn_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers)

  # If we haven't trained yet make sure we collect some random samples first to
  # fill up the Replay Buffer with some experience.
  random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                  collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  tf_collect_policy = agent.collect_policy
  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                      use_tf_function=True)

  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=update_frequency,
      observers=[rb_observer, env_step_metric],
      metrics=actor.collect_metrics(10),
      reference_metrics=[env_step_metric],
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
  )

  tf_greedy_policy = agent.policy
  greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                     use_tf_function=True)

  eval_actor = actor.Actor(
      eval_env,
      greedy_policy,
      train_step,
      episodes_per_run=eval_episodes,
      metrics=actor.eval_metrics(eval_episodes),
      reference_metrics=[env_step_metric],
      summary_dir=os.path.join(root_dir, 'eval'),
  )

  if eval_interval:
    logging.info('Evaluating.')
    eval_actor.run_and_log()

  logging.info('Training.')
  for _ in range(num_iterations):
    collect_actor.run()
    dqn_learner.run(iterations=1)

    if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
      logging.info('Evaluating.')
      eval_actor.run_and_log()

  rb_observer.close()
  reverb_server.stop()
Пример #22
0
reverb_server = reverb.Server([table])

reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

dataset = reverb_replay.as_dataset(sample_batch_size=HyperParms.batch_size,
                                   num_steps=2).prefetch(50)
experience_dataset_fn = lambda: dataset

print(f" --  POLICIES  ({now()})  -- ")
tf_eval_policy = tf_agent.policy
eval_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_eval_policy,
                                                 use_tf_function=True)
tf_collect_policy = tf_agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                    use_tf_function=True)
random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                collect_env.action_spec())

print(f" --  ACTORS  ({now()})  -- ")
rb_observer = reverb_utils.ReverbAddTrajectoryObserver(reverb_replay.py_client,
                                                       table_name,
                                                       sequence_length=2,
                                                       stride_length=1)

initial_collect_actor = actor.Actor(
    collect_env,
    random_policy,
Пример #23
0
def train_eval(
        root_dir,
        # Dataset params
        env_name,
        data_dir=None,
        load_pretrained=False,
        pretrained_model_dir=None,
        img_pad=4,
        frame_shape=(84, 84, 3),
        frame_stack=3,
        num_augmentations=2,  # K and M in DrQ
        # Training params
    contrastive_loss_weight=1.0,
        contrastive_loss_temperature=0.5,
        image_encoder_representation=True,
        initial_collect_steps=1000,
        num_train_steps=3000000,
        actor_fc_layers=(1024, 1024),
        critic_joint_fc_layers=(1024, 1024),
        # Agent params
        batch_size=256,
        actor_learning_rate=1e-3,
        critic_learning_rate=1e-3,
        alpha_learning_rate=1e-3,
        encoder_learning_rate=1e-3,
        actor_update_freq=2,
        gamma=0.99,
        target_update_tau=0.01,
        target_update_period=2,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        checkpoint_interval=10000,
        policy_save_interval=5000,
        eval_interval=10000,
        summary_interval=250,
        debug_summaries=False,
        eval_episodes_per_run=10,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    collect_env = env_utils.load_dm_env_for_training(env_name,
                                                     frame_shape,
                                                     frame_stack=frame_stack)
    eval_env = env_utils.load_dm_env_for_eval(env_name,
                                              frame_shape,
                                              frame_stack=frame_stack)

    logging.info('Data directory: %s', data_dir)
    logging.info('Num train steps: %d', num_train_steps)
    logging.info('Contrastive loss coeff: %.2f', contrastive_loss_weight)
    logging.info('Contrastive loss temperature: %.4f',
                 contrastive_loss_temperature)
    logging.info('load_pretrained: %s', 'yes' if load_pretrained else 'no')
    logging.info('encoder representation: %s',
                 'yes' if image_encoder_representation else 'no')

    load_episode_data = (contrastive_loss_weight > 0)
    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()
    image_encoder = networks.ImageEncoder(observation_tensor_spec)

    actor_net = model_utils.Actor(
        observation_tensor_spec,
        action_tensor_spec,
        image_encoder=image_encoder,
        fc_layers=actor_fc_layers,
        image_encoder_representation=image_encoder_representation)

    critic_net = networks.Critic((observation_tensor_spec, action_tensor_spec),
                                 image_encoder=image_encoder,
                                 joint_fc_layers=critic_joint_fc_layers)
    critic_net_2 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=image_encoder,
        joint_fc_layers=critic_joint_fc_layers)

    target_image_encoder = networks.ImageEncoder(observation_tensor_spec)
    target_critic_net_1 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=target_image_encoder)
    target_critic_net_2 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=target_image_encoder)

    agent = pse_drq_agent.DrQSacModifiedAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        critic_network_2=critic_net_2,
        target_critic_network=target_critic_net_1,
        target_critic_network_2=target_critic_net_2,
        actor_update_frequency=actor_update_freq,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        contrastive_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=encoder_learning_rate),
        contrastive_loss_weight=contrastive_loss_weight,
        contrastive_loss_temperature=contrastive_loss_temperature,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        use_log_alpha_in_alpha_loss=False,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step,
        num_augmentations=num_augmentations)
    agent.initialize()

    # Setup the replay buffer.
    reverb_replay, rb_observer = (
        replay_buffer_utils.get_reverb_buffer_and_observer(
            agent.collect_data_spec,
            sequence_length=2,
            replay_capacity=replay_capacity,
            port=reverb_port))

    # pylint: disable=g-long-lambda
    if num_augmentations == 0:
        image_aug = lambda traj, meta: (dict(
            experience=traj, augmented_obs=[], augmented_next_obs=[]), meta)
    else:
        image_aug = lambda traj, meta: pse_drq_agent.image_aug(
            traj, meta, img_pad, num_augmentations)
    augmented_dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                                 num_steps=2).unbatch().map(
                                                     image_aug,
                                                     num_parallel_calls=3)
    augmented_iterator = iter(augmented_dataset)

    trajs = augmented_dataset.batch(batch_size).prefetch(50)
    if load_episode_data:
        # Load full episodes and zip them
        episodes = dataset_utils.load_episodes(
            os.path.join(data_dir, 'episodes2'), img_pad)
        episode_iterator = iter(episodes)
        dataset = tf.data.Dataset.zip((trajs, episodes)).prefetch(10)
    else:
        dataset = trajs
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    learning_triggers = [
        triggers.PolicySavedModelTrigger(saved_model_dir,
                                         agent,
                                         train_step,
                                         interval=policy_save_interval),
        triggers.StepPerSecondLogTrigger(train_step,
                                         interval=summary_interval),
    ]

    agent_learner = model_utils.Learner(
        root_dir,
        train_step,
        agent,
        experience_dataset_fn=experience_dataset_fn,
        triggers=learning_triggers,
        checkpoint_interval=checkpoint_interval,
        summary_interval=summary_interval,
        load_episode_data=load_episode_data,
        use_kwargs_in_agent_train=True,
        # Turn off the initialization of the optimizer variables since, the agent
        # expects different batching for the `training_data_spec` and
        # `train_argspec` which can't be handled in general by the initialization
        # logic in the learner.
        run_optimizer_variable_init=False)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    train_dir = os.path.join(root_dir, learner.TRAIN_DIR)

    # Code for loading pretrained policy.
    if load_pretrained:
        # Note that num_train_steps is same as the max_train_step we want to
        # load the pretrained policy for our experiments
        pretrained_policy = model_utils.load_pretrained_policy(
            pretrained_model_dir, num_train_steps)
        initial_collect_policy = pretrained_policy

        agent.policy.update_partial(pretrained_policy)
        agent.collect_policy.update_partial(pretrained_policy)
        logging.info('Restored pretrained policy.')
    else:
        initial_collect_policy = random_py_policy.RandomPyPolicy(
            collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        initial_collect_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                observers=[rb_observer],
                                metrics=actor.collect_metrics(buffer_size=10),
                                summary_dir=train_dir,
                                summary_interval=summary_interval,
                                name='CollectActor')

    # If restarting with train_step > 0, the replay buffer will be empty
    # except for random experience. Populate the buffer with some on-policy
    # experience.
    if load_pretrained or (agent_learner.train_step_numpy > 0):
        for _ in range(batch_size * 50):
            collect_actor.run()

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(eval_env,
                             greedy_policy,
                             train_step,
                             episodes_per_run=eval_episodes_per_run,
                             metrics=actor.eval_metrics(buffer_size=10),
                             summary_dir=os.path.join(root_dir, 'eval'),
                             summary_interval=-1,
                             name='EvalTrainActor')

    if eval_interval:
        logging.info('Evaluating.')
        img_summary(
            next(augmented_iterator)[0], eval_actor.summary_writer, train_step)
        if load_episode_data:
            contrastive_img_summary(next(episode_iterator), agent,
                                    eval_actor.summary_writer, train_step)
        eval_actor.run_and_log()

    logging.info('Saving operative gin config file.')
    gin_path = os.path.join(train_dir, 'train_operative_gin_config.txt')
    with tf.io.gfile.GFile(gin_path, mode='w') as f:
        f.write(gin.operative_config_str())

    logging.info('Training Staring at: %r', train_step.numpy())
    while train_step < num_train_steps:
        collect_actor.run()
        agent_learner.run(iterations=1)
        if (not eval_interval) and (train_step % 10000 == 0):
            img_summary(
                next(augmented_iterator)[0],
                agent_learner.train_summary_writer, train_step)
        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            img_summary(
                next(augmented_iterator)[0], eval_actor.summary_writer,
                train_step)
            if load_episode_data:
                contrastive_img_summary(next(episode_iterator), agent,
                                        eval_actor.summary_writer, train_step)
            eval_actor.run_and_log()
Пример #24
0
def train_eval(
        root_dir,
        dataset_path,
        env_name,
        # Training params
        tpu=False,
        use_gpu=False,
        num_gradient_updates=1000000,
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256, 256),
        # Agent params
        batch_size=256,
        bc_steps=0,
        actor_learning_rate=3e-5,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        reward_scale_factor=1.0,
        cql_alpha_learning_rate=3e-4,
        cql_alpha=5.0,
        cql_tau=10.0,
        num_cql_samples=10,
        reward_noise_variance=0.0,
        include_critic_entropy_term=False,
        use_lagrange_cql_alpha=True,
        log_cql_alpha_clipping=None,
        softmax_temperature=1.0,
        # Data params
        reward_shift=0.0,
        action_clipping=None,
        use_trajectories=False,
        data_shuffle_buffer_size_per_record=1,
        data_shuffle_buffer_size=100,
        data_num_shards=1,
        data_block_length=10,
        data_parallel_reads=None,
        data_parallel_calls=10,
        data_prefetch=10,
        data_cycle_length=10,
        # Others
        policy_save_interval=10000,
        eval_interval=10000,
        summary_interval=1000,
        learner_iterations_per_call=1,
        eval_episodes=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        seed=None):
    """Trains and evaluates CQL-SAC."""
    logging.info('Training CQL-SAC on: %s', env_name)
    tf.random.set_seed(seed)
    np.random.seed(seed)

    # Load environment.
    env = load_d4rl(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    strategy = strategy_utils.get_strategy(tpu, use_gpu)

    if not dataset_path.endswith('.tfrecord'):
        dataset_path = os.path.join(dataset_path, env_name,
                                    '%s*.tfrecord' % env_name)
    logging.info('Loading dataset from %s', dataset_path)
    dataset_paths = tf.io.gfile.glob(dataset_path)

    # Create dataset.
    with strategy.scope():
        dataset = create_tf_record_dataset(
            dataset_paths,
            batch_size,
            shuffle_buffer_size_per_record=data_shuffle_buffer_size_per_record,
            shuffle_buffer_size=data_shuffle_buffer_size,
            num_shards=data_num_shards,
            cycle_length=data_cycle_length,
            block_length=data_block_length,
            num_parallel_reads=data_parallel_reads,
            num_parallel_calls=data_parallel_calls,
            num_prefetch=data_prefetch,
            strategy=strategy,
            reward_shift=reward_shift,
            action_clipping=action_clipping,
            use_trajectories=use_trajectories)

    # Create agent.
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()
    with strategy.scope():
        train_step = train_utils.create_train_step()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        agent = cql_sac_agent.CqlSacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.keras.optimizers.Adam(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.keras.optimizers.Adam(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.keras.optimizers.Adam(
                learning_rate=alpha_learning_rate),
            cql_alpha=cql_alpha,
            num_cql_samples=num_cql_samples,
            include_critic_entropy_term=include_critic_entropy_term,
            use_lagrange_cql_alpha=use_lagrange_cql_alpha,
            cql_alpha_learning_rate=cql_alpha_learning_rate,
            target_update_tau=5e-3,
            target_update_period=1,
            random_seed=seed,
            cql_tau=cql_tau,
            reward_noise_variance=reward_noise_variance,
            num_bc_steps=bc_steps,
            td_errors_loss_fn=tf.math.squared_difference,
            gamma=0.99,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=None,
            log_cql_alpha_clipping=log_cql_alpha_clipping,
            softmax_temperature=softmax_temperature,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step)
        agent.initialize()

    # Create learner.
    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    collect_env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(saved_model_dir,
                                         agent,
                                         train_step,
                                         interval=policy_save_interval,
                                         metadata_metrics={
                                             triggers.ENV_STEP_METADATA_KEY:
                                             collect_env_step_metric
                                         }),
        triggers.StepPerSecondLogTrigger(train_step, interval=100)
    ]
    cql_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn=lambda: dataset,
                                  triggers=learning_triggers,
                                  summary_interval=summary_interval,
                                  strategy=strategy)

    # Create actor for evaluation.
    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)
    eval_actor = actor.Actor(env,
                             eval_greedy_policy,
                             train_step,
                             metrics=actor.eval_metrics(eval_episodes),
                             summary_dir=os.path.join(root_dir, 'eval'),
                             episodes_per_run=eval_episodes)

    # Run.
    dummy_trajectory = trajectory.mid((), (), (), 0., 1.)
    num_learner_iterations = int(num_gradient_updates /
                                 learner_iterations_per_call)
    for _ in range(num_learner_iterations):
        # Mimic collecting environment steps since we loaded a static dataset.
        for _ in range(learner_iterations_per_call):
            collect_env_step_metric(dummy_trajectory)

        cql_learner.run(iterations=learner_iterations_per_call)
        if eval_interval and train_step.numpy() % eval_interval == 0:
            eval_actor.run_and_log()
Пример #25
0
def train_eval(
    root_dir,
    random_seed=None,
    env_name='sawyer_push',
    eval_env_name=None,
    env_load_fn=get_env,
    max_episode_steps=1000,
    eval_episode_steps=1000,
    # The SAC paper reported:
    # Hopper and Cartpole results up to 1000000 iters,
    # Humanoid results up to 10000000 iters,
    # Other mujoco tasks up to 3000000 iters.
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
    # HalfCheetah and Ant take 10000 initial collection steps.
    # Other mujoco tasks take 1000.
    # Different choices roughly keep the initial episodes about the same.
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    reset_goal_frequency=1000,  # virtual episode size for reset-free training
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    # reset-free parameters
    use_minimum=True,
    reset_lagrange_learning_rate=3e-4,
    value_threshold=None,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    # Td3 parameters
    actor_update_period=1,
    exploration_noise_std=0.1,
    target_policy_noise=0.1,
    target_policy_noise_clip=0.1,
    dqda_clipping=None,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    # video recording for the environment
    video_record_interval=10000,
    num_videos=0,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  start_time = time.time()

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  video_dir = os.path.join(eval_dir, 'videos')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)
    env, env_train_metrics, env_eval_metrics, aux_info = env_load_fn(
        name=env_name,
        max_episode_steps=None,
        gym_env_wrappers=(functools.partial(
            reset_free_wrapper.ResetFreeWrapper,
            reset_goal_frequency=reset_goal_frequency,
            full_reset_frequency=max_episode_steps),))

    tf_env = tf_py_environment.TFPyEnvironment(env)
    eval_env_name = eval_env_name or env_name
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        env_load_fn(name=eval_env_name,
                    max_episode_steps=eval_episode_steps)[0])

    eval_metrics += env_eval_metrics

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    if FLAGS.agent_type == 'sac':
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=functools.partial(
              tanh_normal_projection_network.TanhNormalProjectionNetwork,
              std_transform=std_clip_transform),
          name='forward_actor')
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform',
          name='forward_critic')

      tf_agent = SacAgent(
          time_step_spec,
          action_spec,
          num_action_samples=FLAGS.num_action_samples,
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          name='forward_agent')

      actor_net_rev = actor_distribution_network.ActorDistributionNetwork(
          observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=functools.partial(
              tanh_normal_projection_network.TanhNormalProjectionNetwork,
              std_transform=std_clip_transform),
          name='reverse_actor')

      critic_net_rev = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform',
          name='reverse_critic')

      tf_agent_rev = SacAgent(
          time_step_spec,
          action_spec,
          num_action_samples=FLAGS.num_action_samples,
          actor_network=actor_net_rev,
          critic_network=critic_net_rev,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          name='reverse_agent')

    elif FLAGS.agent_type == 'td3':
      actor_net = actor_network.ActorNetwork(
          tf_env.time_step_spec().observation,
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers,
      )
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform')

      tf_agent = Td3Agent(
          tf_env.time_step_spec(),
          tf_env.action_spec(),
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          exploration_noise_std=exploration_noise_std,
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          actor_update_period=actor_update_period,
          dqda_clipping=dqda_clipping,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          target_policy_noise=target_policy_noise,
          target_policy_noise_clip=target_policy_noise_clip,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
      )

    tf_agent.initialize()
    tf_agent_rev.initialize()

    if FLAGS.use_reset_goals:
      # distance to initial state distribution
      initial_state_distance = state_distribution_distance.L2Distance(
          initial_state_shape=aux_info['reset_state_shape'])
      initial_state_distance.update(
          tf.constant(aux_info['reset_states'], dtype=tf.float32),
          update_type='complete')

      if use_tf_functions:
        initial_state_distance.distance = common.function(
            initial_state_distance.distance)
        tf_agent.compute_value = common.function(tf_agent.compute_value)

      # initialize reset / practice goal proposer
      if reset_lagrange_learning_rate > 0:
        reset_goal_generator = ResetGoalGenerator(
            goal_dim=aux_info['reset_state_shape'][0],
            num_reset_candidates=FLAGS.num_reset_candidates,
            compute_value_fn=tf_agent.compute_value,
            distance_fn=initial_state_distance,
            use_minimum=use_minimum,
            value_threshold=value_threshold,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=reset_lagrange_learning_rate),
            name='reset_goal_generator')
      else:
        reset_goal_generator = FixedResetGoal(
            distance_fn=initial_state_distance)

      # if use_tf_functions:
      #   reset_goal_generator.get_reset_goal = common.function(
      #       reset_goal_generator.get_reset_goal)

      # modify the reset-free wrapper to use the reset goal generator
      tf_env.pyenv.envs[0].set_reset_goal_fn(
          reset_goal_generator.get_reset_goal)

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    replay_buffer_rev = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent_rev.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer_rev = [replay_buffer_rev.add_batch]

    # initialize metrics and observers
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    ]
    train_metrics += env_train_metrics
    train_metrics_rev = [
        tf_metrics.NumberOfEpisodes(name='NumberOfEpisodesRev'),
        tf_metrics.EnvironmentSteps(name='EnvironmentStepsRev'),
        tf_metrics.AverageReturnMetric(
            name='AverageReturnRev',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthRev',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]
    train_metrics_rev += aux_info['train_metrics_rev']

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_agent.policy, use_tf_function=True)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    initial_collect_policy_rev = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy
    collect_policy_rev = tf_agent_rev.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'forward'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'forward', 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)
    # reverse policy savers
    train_checkpointer_rev = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'reverse'),
        agent=tf_agent_rev,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics_rev,
                                          'train_metrics_rev'))
    rb_checkpointer_rev = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer_rev'),
        max_to_keep=1,
        replay_buffer=replay_buffer_rev)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()
    train_checkpointer_rev.initialize_or_restore()
    rb_checkpointer_rev.initialize_or_restore()

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)
    collect_driver_rev = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy_rev,
        observers=replay_observer_rev + train_metrics_rev,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      collect_driver.run = common.function(collect_driver.run)
      collect_driver_rev.run = common.function(collect_driver_rev.run)
      tf_agent.train = common.function(tf_agent.train)
      tf_agent_rev.train = common.function(tf_agent_rev.train)

    if replay_buffer.num_frames() == 0:
      initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy,
          observers=replay_observer + train_metrics,
          num_steps=1)
      initial_collect_driver_rev = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy_rev,
          observers=replay_observer_rev + train_metrics_rev,
          num_steps=1)
      # does not work for some reason
      if use_tf_functions:
        initial_collect_driver.run = common.function(initial_collect_driver.run)
        initial_collect_driver_rev.run = common.function(
            initial_collect_driver_rev.run)

      # Collect initial replay data.
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps with '
          'a random policy.', initial_collect_steps)
      for iter_idx_initial in range(initial_collect_steps):
        if tf_env.pyenv.envs[0]._forward_or_reset_goal:
          initial_collect_driver.run()
        else:
          initial_collect_driver_rev.run()
        if FLAGS.use_reset_goals and iter_idx_initial % FLAGS.reset_goal_frequency == 0:
          if replay_buffer_rev.num_frames():
            reset_candidates_from_forward_buffer = replay_buffer.get_next(
                sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
            reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next(
                sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
            flat_forward_tensors = tf.nest.flatten(
                reset_candidates_from_forward_buffer)
            flat_reverse_tensors = tf.nest.flatten(
                reset_candidates_from_reverse_buffer)
            concatenated_tensors = [
                tf.concat([x, y], axis=0)
                for x, y in zip(flat_forward_tensors, flat_reverse_tensors)
            ]
            reset_candidates = tf.nest.pack_sequence_as(
                reset_candidates_from_forward_buffer, concatenated_tensors)
            tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)
          else:
            reset_candidates = replay_buffer.get_next(
                sample_batch_size=FLAGS.num_reset_candidates)[0]
            tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]

    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size, num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    dataset_rev = replay_buffer_rev.as_dataset(
        sample_batch_size=batch_size, num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator_rev = iter(dataset_rev)

    def train_step_rev():
      experience_rev, _ = next(iterator_rev)
      return tf_agent_rev.train(experience_rev)

    if use_tf_functions:
      train_step = common.function(train_step)
      train_step_rev = common.function(train_step_rev)

    # manual data save for plotting utils
    np_on_cns_save(os.path.join(eval_dir, 'eval_interval.npy'), eval_interval)
    try:
      average_eval_return = np_on_cns_load(
          os.path.join(eval_dir, 'average_eval_return.npy')).tolist()
      average_eval_success = np_on_cns_load(
          os.path.join(eval_dir, 'average_eval_success.npy')).tolist()
    except:
      average_eval_return = []
      average_eval_success = []

    print('initialization_time:', time.time() - start_time)
    for iter_idx in range(num_iterations):
      start_time = time.time()
      if tf_env.pyenv.envs[0]._forward_or_reset_goal:
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
      else:
        time_step, policy_state = collect_driver_rev.run(
            time_step=time_step,
            policy_state=policy_state,
        )

      # reset goal generator updates
      if FLAGS.use_reset_goals and iter_idx % (
          FLAGS.reset_goal_frequency * collect_steps_per_iteration) == 0:
        reset_candidates_from_forward_buffer = replay_buffer.get_next(
            sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
        reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next(
            sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
        flat_forward_tensors = tf.nest.flatten(
            reset_candidates_from_forward_buffer)
        flat_reverse_tensors = tf.nest.flatten(
            reset_candidates_from_reverse_buffer)
        concatenated_tensors = [
            tf.concat([x, y], axis=0)
            for x, y in zip(flat_forward_tensors, flat_reverse_tensors)
        ]
        reset_candidates = tf.nest.pack_sequence_as(
            reset_candidates_from_forward_buffer, concatenated_tensors)
        tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)
        if reset_lagrange_learning_rate > 0:
          reset_goal_generator.update_lagrange_multipliers()

      for _ in range(train_steps_per_iteration):
        train_loss_rev = train_step_rev()
        train_loss = train_step()

      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, train_loss.loss)
        logging.info('step = %d, loss_rev = %f', global_step_val,
                     train_loss_rev.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        if 'Heatmap' in train_metric.name:
          if global_step_val % summary_interval == 0:
            train_metric.tf_summaries(
                train_step=global_step, step_metrics=train_metrics[:2])
        else:
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics[:2])

      for train_metric in train_metrics_rev:
        if 'Heatmap' in train_metric.name:
          if global_step_val % summary_interval == 0:
            train_metric.tf_summaries(
                train_step=global_step, step_metrics=train_metrics_rev[:2])
        else:
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics_rev[:2])

      if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals:
        reset_goal_generator.update_summaries(step_counter=global_step)

      if global_step_val % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step_val)
        metric_utils.log_metrics(eval_metrics)

        # numpy saves for plotting
        average_eval_return.append(results['AverageReturn'].numpy())
        average_eval_success.append(results['EvalSuccessfulEpisodes'].numpy())
        np_on_cns_save(
            os.path.join(eval_dir, 'average_eval_return.npy'),
            average_eval_return)
        np_on_cns_save(
            os.path.join(eval_dir, 'average_eval_success.npy'),
            average_eval_success)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)
        train_checkpointer_rev.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
        rb_checkpointer_rev.save(global_step=global_step_val)

      if global_step_val % video_record_interval == 0:
        for video_idx in range(num_videos):
          video_name = os.path.join(video_dir, str(global_step_val),
                                    'video_' + str(video_idx) + '.mp4')
          record_video(
              lambda: env_load_fn(  # pylint: disable=g-long-lambda
                  name=env_name,
                  max_episode_steps=max_episode_steps)[0],
              video_name,
              eval_py_policy,
              max_episode_length=eval_episode_steps)

    return train_loss
Пример #26
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        # Defaults to not checkpointing saved policy. If you wish to enable this,
        # please note the caveat explained in README.md.
        policy_save_interval=-1,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    critic_net = critic_network.CriticNetwork(
        (observation_tensor_spec, action_tensor_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    agent = sac_agent.SacAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step)
    agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                       num_steps=2).prefetch(50)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Пример #27
0
def train_eval(
        root_dir,
        strategy: tf.distribute.Strategy,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        policy_save_interval=10000,
        replay_buffer_save_interval=100000,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    _, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    actor_net = create_sequential_actor_network(
        actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec)

    critic_net = create_sequential_critic_network(
        obs_fc_layer_units=critic_obs_fc_layers,
        action_fc_layer_units=critic_action_fc_layers,
        joint_fc_layer_units=critic_joint_fc_layers)

    with strategy.scope():
        train_step = train_utils.create_train_step()
        agent = sac_agent.SacAgent(
            time_step_tensor_spec,
            action_tensor_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.keras.optimizers.Adam(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.keras.optimizers.Adam(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.keras.optimizers.Adam(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=tf.math.squared_difference,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=None,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step)
        agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR,
                                         learner.REPLAY_BUFFER_CHECKPOINT_DIR)
    reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer(
        path=reverb_checkpoint_dir)
    reverb_server = reverb.Server([table],
                                  port=reverb_port,
                                  checkpointer=reverb_checkpointer)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    def experience_dataset_fn():
        return reverb_replay.as_dataset(sample_batch_size=batch_size,
                                        num_steps=2).prefetch(50)

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.ReverbCheckpointTrigger(
            train_step,
            interval=replay_buffer_save_interval,
            reverb_client=reverb_replay.py_client),
        # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption.
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers,
                                    strategy=strategy)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()