def test_deterministic_dataset_from_heap_sampler_remover(self):

    uniform_sampler_min_heap_remover_table = reverb.Table(
        name=self._table_name,
        sampler=reverb.selectors.MaxHeap(),
        remover=reverb.selectors.MinHeap(),
        max_size=100,
        max_times_sampled=0,
        rate_limiter=reverb.rate_limiters.MinSize(1))
    server = reverb.Server([uniform_sampler_min_heap_remover_table])
    replay = reverb_replay_buffer.ReverbReplayBuffer(
        self._data_spec,
        self._table_name,
        local_server=server,
        sequence_length=None)
    replay.as_dataset(single_deterministic_pass=True)
    server.stop()
  def test_capacity_set(self):
    table_name = 'test_table'
    capacity = 100

    uniform_table = reverb.Table(
        table_name,
        max_size=capacity,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(3))
    server = reverb.Server([uniform_table])
    data_spec = tensor_spec.TensorSpec((), tf.float32)
    replay = reverb_replay_buffer.ReverbReplayBuffer(
        data_spec, table_name, local_server=server, sequence_length=None)

    self.assertEqual(capacity, replay.capacity)
    server.stop()
Exemple #3
0
    def test_experimental_distribute_datasets_from_function(self, strategy):
        sequence_length = 3
        batch_size = 10
        self._insert_random_data(self._env,
                                 num_steps=sequence_length,
                                 sequence_length=sequence_length)

        replay = reverb_replay_buffer.ReverbReplayBuffer(
            self._data_spec,
            self._table_name,
            sequence_length=sequence_length,
            local_server=self._server)

        num_replicas = strategy.num_replicas_in_sync
        with strategy.scope():
            dataset = strategy.experimental_distribute_datasets_from_function(
                lambda _: replay.as_dataset(batch_size // num_replicas))
            iterator = iter(dataset)

        @common.function()
        def train_step():
            with strategy.scope():
                sample, _ = next(iterator)
                _, step = sample.observation
                loss = strategy.run(lambda x: tf.reduce_mean(x, axis=-1),
                                    args=(step, ))
                return strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       loss,
                                       axis=0)

        # Test running eagerly
        for _ in range(5):
            with strategy.scope():
                loss = train_step()
                self.assertEqual(batch_size, loss)

        # Test with wrapping into a tf.function
        train_step_fn = common.function(train_step)
        for _ in range(5):
            with strategy.scope():
                loss = train_step_fn()
                self.assertEqual(batch_size, loss)
  def test_dataset_with_preprocess(self):

    def validate_data_observer(traj):
      if not array_spec.check_arrays_nest(traj, self._array_data_spec):
        raise ValueError('Trajectory incompatible with array_data_spec')

    def preprocess(traj):
      episode, step = traj.observation
      return traj.replace(observation=(episode, step + 1))

    # Observe 10 steps from the env.  This isn't the num_steps we're testing.
    self._insert_random_data(
        self._env,
        num_steps=10,
        additional_observers=[validate_data_observer],
        sequence_length=4)

    replay = reverb_replay_buffer.ReverbReplayBuffer(
        self._data_spec,
        self._table_name,
        local_server=self._server,
        sequence_length=4)

    dataset = replay.as_dataset(num_steps=2)
    for sample, _ in dataset.take(5):
      episode, step = sample.observation
      self.assertEqual(episode[0], episode[1])
      self.assertEqual(step[0] + 1, step[1])
      # From even to odd steps
      self.assertEqual(0, step[0].numpy() % 2)
      self.assertEqual(1, step[1].numpy() % 2)

    dataset = replay.as_dataset(
        num_steps=2, sample_batch_size=1, sequence_preprocess_fn=preprocess)
    for sample, _ in dataset.take(5):
      episode, step = sample.observation
      self.assertEqual(episode[0, 0], episode[0, 1])
      self.assertEqual(step[0, 0] + 1, step[0, 1])
      # Makes sure the preprocess has happened.
      # From odd to even steps
      self.assertEqual(1, step[0, 0].numpy() % 2)
      self.assertEqual(0, step[0, 1].numpy() % 2)
Exemple #5
0
    def test_single_episode_dataset(self):
        sequence_length = 3
        self._insert_random_data(self._env,
                                 num_steps=sequence_length,
                                 sequence_length=sequence_length)

        replay = reverb_replay_buffer.ReverbReplayBuffer(
            self._data_spec,
            self._table_name,
            sequence_length=None,
            local_server=self._server)

        # Make sure observations are off by 1 given we are counting transitions in
        # the env observations.
        dataset = replay.as_dataset()
        for sample, _ in dataset.take(5):
            episode, step = sample.observation
            self.assertEqual((sequence_length, ), episode.shape)
            self.assertEqual((sequence_length, ), step.shape)
            self.assertAllEqual([0] * sequence_length, episode - episode[:1])
            self.assertAllEqual(list(range(sequence_length)), step - step[:1])
    def test_uniform_table(self):
        table_name = 'test_uniform_table'
        queue_table = reverb.Table(
            table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=1000,
            rate_limiter=reverb.rate_limiters.MinSize(3))
        reverb_server = reverb.Server([queue_table])
        data_spec = tensor_spec.TensorSpec((), dtype=tf.int64)
        replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            table_name,
            local_server=reverb_server,
            sequence_length=1,
            dataset_buffer_size=1)

        with replay.py_client.trajectory_writer(
                num_keep_alive_refs=1) as writer:
            for i in range(3):
                writer.append(i)
                trajectory = writer.history[-1:]
                writer.create_item(table_name,
                                   trajectory=trajectory,
                                   priority=1)

        dataset = replay.as_dataset(sample_batch_size=1,
                                    num_steps=None,
                                    num_parallel_calls=1)

        iterator = iter(dataset)
        counts = [0] * 3
        for i in range(1000):
            item_0 = next(iterator)[0].numpy()  # This is a matrix shaped 1x1.
            counts[int(item_0)] += 1

        # Comparing against 200 to avoid flakyness
        self.assertGreater(counts[0], 200)
        self.assertGreater(counts[1], 200)
        self.assertGreater(counts[2], 200)
  def test_dataset_samples_sequential(self, sequence_length):

    def validate_data_observer(traj):
      if not array_spec.check_arrays_nest(traj, self._array_data_spec):
        raise ValueError('Trajectory incompatible with array_data_spec')

    # Observe 20 steps from the env.  This isn't the num_steps we're testing.
    self._insert_random_data(
        self._env, num_steps=20,
        additional_observers=[validate_data_observer],
        sequence_length=sequence_length or 4)

    replay = reverb_replay_buffer.ReverbReplayBuffer(
        self._data_spec, self._table_name, local_server=self._server,
        sequence_length=sequence_length)

    # Make sure observations belong to the same episode and their step are off
    # by 1.
    for sample, _ in replay.as_dataset(num_steps=2).take(100):
      episode, step = sample.observation
      self.assertEqual(episode[0], episode[1])
      self.assertEqual(step[0] + 1, step[1])
    def test_prioritized_table_max_sample(self):
        table_name = 'test_prioritized_table'
        table = reverb.Table(table_name,
                             sampler=reverb.selectors.Prioritized(1.0),
                             remover=reverb.selectors.Fifo(),
                             max_times_sampled=10,
                             rate_limiter=reverb.rate_limiters.MinSize(1),
                             max_size=3)
        reverb_server = reverb.Server([table])
        data_spec = tensor_spec.TensorSpec((), dtype=tf.int64)
        replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            table_name,
            sequence_length=1,
            local_server=reverb_server,
            dataset_buffer_size=1)

        with replay.py_client.trajectory_writer(1) as writer:
            for i in range(3):
                writer.append(i)
                writer.create_item(table_name,
                                   trajectory=writer.history[-1:],
                                   priority=i)

        dataset = replay.as_dataset(sample_batch_size=3, num_parallel_calls=3)

        self.assertTrue(table.can_sample(3))
        iterator = iter(dataset)
        counts = [0] * 3
        for i in range(10):
            item_0 = next(iterator)[0].numpy()  # This is a matrix shaped 1x3.
            for item in item_0:
                counts[int(item)] += 1
        self.assertFalse(table.can_sample(3))

        # Same number of counts due to limit on max_times_sampled
        self.assertEqual(counts[0], 10)  # priority 0
        self.assertEqual(counts[1], 10)  # priority 1
        self.assertEqual(counts[2], 10)  # priority 2
    def test_prioritized_table(self):
        table_name = 'test_prioritized_table'
        queue_table = reverb.Table(
            table_name,
            sampler=reverb.selectors.Prioritized(1.0),
            remover=reverb.selectors.Fifo(),
            rate_limiter=reverb.rate_limiters.MinSize(1),
            max_size=3)
        reverb_server = reverb.Server([queue_table])
        data_spec = tensor_spec.TensorSpec((), dtype=tf.int64)
        replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            table_name,
            sequence_length=1,
            local_server=reverb_server,
            dataset_buffer_size=1)

        with replay.py_client.writer(max_sequence_length=1) as writer:
            for i in range(3):
                writer.append(i)
                writer.create_item(table=table_name,
                                   num_timesteps=1,
                                   priority=i)

        dataset = replay.as_dataset(sample_batch_size=1,
                                    num_steps=None,
                                    num_parallel_calls=None)

        iterator = iter(dataset)
        counts = [0] * 3
        for i in range(1000):
            item_0 = next(iterator)[0].numpy()  # This is a matrix shaped 1x1.
            counts[int(item_0)] += 1

        self.assertEqual(counts[0], 0)  # priority 0
        self.assertGreater(counts[1], 250)  # priority 1
        self.assertGreater(counts[2], 600)  # priority 2
Exemple #10
0
    def test_batched_episodes_dataset(self, sequence_length):
        # Observe batch_size * sequence_length steps to have at least 3 episodes
        batch_size = 3
        env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length)
        self._insert_random_data(env,
                                 num_steps=batch_size * sequence_length,
                                 sequence_length=sequence_length)

        replay = reverb_replay_buffer.ReverbReplayBuffer(
            self._data_spec,
            self._table_name,
            sequence_length=None,
            local_server=self._server)

        dataset = replay.as_dataset(batch_size)
        for sample, _ in dataset.take(5):
            episode, step = sample.observation
            self.assertEqual((batch_size, sequence_length), episode.shape)
            self.assertEqual((batch_size, sequence_length), step.shape)
            for n in range(sequence_length):
                # All elements in the same batch should belong to the same episode.
                self.assertAllEqual(episode[:, 0], episode[:, n])
                # All elements in the same batch should have consecutive steps.
                self.assertAllEqual(step[:, 0] + n, step[:, n])
Exemple #11
0
def train(
    root_dir: Text,
    environment_name: Text,
    strategy: tf.distribute.Strategy,
    replay_buffer_server_address: Text,
    variable_container_server_address: Text,
    suite_load_fn: Callable[[Text],
                            py_environment.PyEnvironment] = suite_mujoco.load,
    # Training params
    learning_rate: float = 3e-4,
    batch_size: int = 256,
    num_iterations: int = 2000000,
    learner_iterations_per_call: int = 1) -> None:
  """Trains a DQN agent."""
  # Get the specs from the environment.
  logging.info('Training SAC with learning rate: %f', learning_rate)
  env = suite_load_fn(environment_name)
  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(env))

  # Create the agent.
  with strategy.scope():
    train_step = train_utils.create_train_step()
    agent = _create_agent(
        train_step=train_step,
        observation_tensor_spec=observation_tensor_spec,
        action_tensor_spec=action_tensor_spec,
        time_step_tensor_spec=time_step_tensor_spec,
        learning_rate=learning_rate)

  # Create the policy saver which saves the initial model now, then it
  # periodically checkpoints the policy weigths.
  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  save_model_trigger = triggers.PolicySavedModelTrigger(
      saved_model_dir, agent, train_step, interval=1000)

  # Create the variable container.
  variables = {
      reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.push(variables)

  # Create the replay buffer.
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      server_address=replay_buffer_server_address)

  # Initialize the dataset.
  def experience_dataset_fn():
    with strategy.scope():
      return reverb_replay.as_dataset(
          sample_batch_size=batch_size, num_steps=2).prefetch(3)

  # Create the learner.
  learning_triggers = [
      save_model_trigger,
      triggers.StepPerSecondLogTrigger(train_step, interval=1000)
  ]
  sac_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers,
      strategy=strategy)

  # Run the training loop.
  while train_step.numpy() < num_iterations:
    sac_learner.run(iterations=learner_iterations_per_call)
    variable_container.push(variables)
Exemple #12
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        # Training params
        initial_collect_steps=1000,
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Agent params
        epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""
    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec())
    action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec())

    train_step = train_utils.create_train_step()
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

    # Define a helper function to create Dense layers configured with the right
    # activation and kernel initializer.
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0, mode='fan_in', distribution='truncated_normal'))

    # QNetwork consists of a sequence of Dense layers followed by a dense layer
    # with `num_actions` units to generate one q_value per available action as
    # it's output.
    dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
    q_values_layer = tf.keras.layers.Dense(
        num_actions,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03,
                                                               maxval=0.03),
        bias_initializer=tf.keras.initializers.Constant(-0.2))
    q_net = sequential.Sequential(dense_layers + [q_values_layer])

    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=1,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
def get_reverb_buffer(data_spec,
                      sequence_length=None,
                      table_name='uniform_table',
                      table=None,
                      reverb_server_address=None,
                      port=None,
                      replay_capacity=1000,
                      min_size_limiter_size=1):
    """Returns an instance of Reverb replay buffer and observer to add items.

  Either creates a local reverb server or uses a remote reverb server at
  reverb_sever_address (if set).

  If reverb_server_address is None, creates a local server with a uniform
  table underneath.

  Args:
    data_spec: spec of the data elements to be stored in the replay buffer
    sequence_length: integer specifying sequence_lenghts used to write
      to the given table.
    table_name: Name of the table to create.
    table: Optional table for the backing local server. If None, automatically
      creates a uniform sampling table.
    reverb_server_address: Address of the remote reverb server, if None a local
      server is created.
    port: Port to launch the server in.
    replay_capacity: Optinal (for default uniform sampling table
      only, i.e if table=None) capacity of the uniform sampling table for the
      local replay server.
    min_size_limiter_size: Optional (for default uniform sampling
      table only, i.e if table=None) minimum number of items
      required in the RB before sampling can begin, used for local server only.
  Returns:
    Reverb replay buffer instance

    Note: the if local server is created, it is not returned. It can be
      retrieved by calling local_server() on the returned replay buffer.
  """
    table_signature = tensor_spec.add_outer_dim(data_spec, sequence_length)

    if reverb_server_address is None:
        if table is None:
            table = _create_uniform_table(
                table_name,
                table_signature,
                table_capacity=replay_capacity,
                min_size_limiter_size=min_size_limiter_size)

        reverb_server = reverb.Server([table], port=port)
        reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            sequence_length=sequence_length,
            table_name=table_name,
            local_server=reverb_server)
    else:
        reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec,
            sequence_length=sequence_length,
            table_name=table_name,
            server_address=reverb_server_address)

    return reverb_replay
Exemple #14
0
def train_agent(iterations, modeldir, logdir, policydir):
    """Train and convert the model using TF Agents."""

    # TODO: add code to instantiate the training and evaluation environments

    # TODO: add code to create a reinforcement learning agent that is going to be trained

    tf_agent.initialize()

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    tf_policy_saver = policy_saver.PolicySaver(collect_policy)

    # Use reverb as replay buffer
    replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
    table = reverb.Table(
        REPLAY_BUFFER_TABLE_NAME,
        max_size=REPLAY_BUFFER_CAPACITY,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature,
    )  # specify signature here for validation at insertion time

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        tf_agent.collect_data_spec,
        sequence_length=None,
        table_name=REPLAY_BUFFER_TABLE_NAME,
        local_server=reverb_server,
    )

    replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver(
        replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME,
        REPLAY_BUFFER_CAPACITY)

    # Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy,
                                              NUM_EVAL_EPISODES)

    summary_writer = tf.summary.create_file_writer(logdir)

    for i in range(iterations):
        # TODO: add code to collect game episodes and train the agent

        logger = tf.get_logger()
        if i % EVAL_INTERVAL == 0:
            avg_return, avg_episode_length = compute_avg_return_and_steps(
                eval_env, eval_policy, NUM_EVAL_EPISODES)
            with summary_writer.as_default():
                tf.summary.scalar("Average return", avg_return, step=i)
                tf.summary.scalar("Average episode length",
                                  avg_episode_length,
                                  step=i)
                summary_writer.flush()
            logger.info(
                "iteration = {0}: Average Return = {1}, Average Episode Length = {2}"
                .format(i, avg_return, avg_episode_length))

    summary_writer.close()

    tf_policy_saver.save(policydir)
 def test_size_empty(self):
   replay = reverb_replay_buffer.ReverbReplayBuffer(
       self._data_spec, self._table_name, local_server=self._server,
       sequence_length=None)
   self.assertEqual(replay.num_frames(), 0)
Exemple #16
0
def train_eval(
    root_dir,
    env_name='Pong-v0',
    # Training params
    update_frequency=4,  # Number of collect steps per policy update
    initial_collect_steps=50000,  # 50k collect steps
    num_iterations=50000000,  # 50M collect steps
    # Taken from Rainbow as it's not specified in Mnih,15.
    max_episode_frames_collect=50000,  # env frames observed by the agent
    max_episode_frames_eval=108000,  # env frames observed by the agent
    # Agent params
    epsilon_greedy=0.1,
    epsilon_decay_period=250000,  # 1M collect steps / update_frequency
    batch_size=32,
    learning_rate=0.00025,
    n_step_update=1,
    gamma=0.99,
    target_update_tau=1.0,
    target_update_period=2500,  # 10k collect steps / update_frequency
    reward_scale_factor=1.0,
    # Replay params
    reverb_port=None,
    replay_capacity=1000000,
    # Others
    policy_save_interval=250000,
    eval_interval=1000,
    eval_episodes=30,
    debug_summaries=True):
  """Trains and evaluates DQN."""

  collect_env = suite_atari.load(
      env_name,
      max_episode_steps=max_episode_frames_collect,
      gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
  eval_env = suite_atari.load(
      env_name,
      max_episode_steps=max_episode_frames_eval,
      gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)

  unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(collect_env))

  train_step = train_utils.create_train_step()

  num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
  epsilon = tf.compat.v1.train.polynomial_decay(
      1.0,
      train_step,
      epsilon_decay_period,
      end_learning_rate=epsilon_greedy)
  agent = dqn_agent.DqnAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      q_network=create_q_network(num_actions),
      epsilon_greedy=epsilon,
      n_step_update=n_step_update,
      target_update_tau=target_update_tau,
      target_update_period=target_update_period,
      optimizer=tf.compat.v1.train.RMSPropOptimizer(
          learning_rate=learning_rate,
          decay=0.95,
          momentum=0.95,
          epsilon=0.01,
          centered=True),
      td_errors_loss_fn=common.element_wise_huber_loss,
      gamma=gamma,
      reward_scale_factor=reward_scale_factor,
      train_step_counter=train_step,
      debug_summaries=debug_summaries)

  table_name = 'uniform_table'
  table = reverb.Table(
      table_name,
      max_size=replay_capacity,
      sampler=reverb.selectors.Uniform(),
      remover=reverb.selectors.Fifo(),
      rate_limiter=reverb.rate_limiters.MinSize(1))
  reverb_server = reverb.Server([table], port=reverb_port)
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=table_name,
      local_server=reverb_server)
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb_replay.py_client, table_name,
      sequence_length=2,
      stride_length=1)

  dataset = reverb_replay.as_dataset(
      sample_batch_size=batch_size, num_steps=2).prefetch(3)
  experience_dataset_fn = lambda: dataset

  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  env_step_metric = py_metrics.EnvironmentSteps()

  learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}),
      triggers.StepPerSecondLogTrigger(train_step, interval=100),
  ]

  dqn_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers)

  # If we haven't trained yet make sure we collect some random samples first to
  # fill up the Replay Buffer with some experience.
  random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                  collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  tf_collect_policy = agent.collect_policy
  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                      use_tf_function=True)

  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=update_frequency,
      observers=[rb_observer, env_step_metric],
      metrics=actor.collect_metrics(10),
      reference_metrics=[env_step_metric],
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
  )

  tf_greedy_policy = agent.policy
  greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                     use_tf_function=True)

  eval_actor = actor.Actor(
      eval_env,
      greedy_policy,
      train_step,
      episodes_per_run=eval_episodes,
      metrics=actor.eval_metrics(eval_episodes),
      reference_metrics=[env_step_metric],
      summary_dir=os.path.join(root_dir, 'eval'),
  )

  if eval_interval:
    logging.info('Evaluating.')
    eval_actor.run_and_log()

  logging.info('Training.')
  for _ in range(num_iterations):
    collect_actor.run()
    dqn_learner.run(iterations=1)

    if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
      logging.info('Evaluating.')
      eval_actor.run_and_log()

  rb_observer.close()
  reverb_server.stop()
def train(
    root_dir,
    strategy,
    replay_buffer_server_address,
    variable_container_server_address,
    create_agent_fn,
    create_env_fn,
    # Training params
    learning_rate=3e-4,
    batch_size=256,
    num_iterations=32000,
    learner_iterations_per_call=100):
  """Trains a DQN agent."""
  # Get the specs from the environment.
  logging.info('Training SAC with learning rate: %f', learning_rate)
  env = create_env_fn()
  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(env))

  # Create the agent.
  with strategy.scope():
    train_step = train_utils.create_train_step()
    agent = create_agent_fn(train_step, observation_tensor_spec,
                            action_tensor_spec, time_step_tensor_spec,
                            learning_rate)
    agent.initialize()

  # Create the policy saver which saves the initial model now, then it
  # periodically checkpoints the policy weigths.
  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  save_model_trigger = triggers.PolicySavedModelTrigger(
      saved_model_dir, agent, train_step, interval=1000)

  # Create the variable container.
  variables = {
      reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.push(variables)

  # Create the replay buffer.
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      server_address=replay_buffer_server_address)

  # Initialize the dataset.
  def experience_dataset_fn():
    with strategy.scope():
      return reverb_replay.as_dataset(
          sample_batch_size=batch_size, num_steps=2).prefetch(3)

  # Create the learner.
  learning_triggers = [
      save_model_trigger,
      triggers.StepPerSecondLogTrigger(train_step, interval=1000)
  ]
  sac_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers,
      strategy=strategy)

  # Run the training loop.
  # TODO(b/162440911) change the loop use train_step to handle preemptions
  for _ in range(num_iterations):
    sac_learner.run(iterations=learner_iterations_per_call)
    variable_container.push(variables)
Exemple #18
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    # Training params
    num_iterations=1600,
    actor_fc_layers=(64, 64),
    value_fc_layers=(64, 64),
    learning_rate=3e-4,
    collect_sequence_length=2048,
    minibatch_size=64,
    num_epochs=10,
    # Agent params
    importance_ratio_clipping=0.2,
    lambda_value=0.95,
    discount_factor=0.99,
    entropy_regularization=0.,
    value_pred_loss_coef=0.5,
    use_gae=True,
    use_td_lambda_return=True,
    gradient_clipping=0.5,
    value_clipping=None,
    # Replay params
    reverb_port=None,
    replay_capacity=10000,
    # Others
    policy_save_interval=5000,
    summary_interval=1000,
    eval_interval=10000,
    eval_episodes=100,
    debug_summaries=False,
    summarize_grads_and_vars=False):
  """Trains and evaluates PPO (Importance Ratio Clipping).

  Args:
    root_dir: Main directory path where checkpoints, saved_models, and summaries
      will be written to.
    env_name: Name for the Mujoco environment to load.
    num_iterations: The number of iterations to perform collection and training.
    actor_fc_layers: List of fully_connected parameters for the actor network,
      where each item is the number of units in the layer.
    value_fc_layers: : List of fully_connected parameters for the value network,
      where each item is the number of units in the layer.
    learning_rate: Learning rate used on the Adam optimizer.
    collect_sequence_length: Number of steps to take in each collect run.
    minibatch_size: Number of elements in each mini batch. If `None`, the entire
      collected sequence will be treated as one batch.
    num_epochs: Number of iterations to repeat over all collected data per data
      collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for
      Roboschool and 3 for Atari.
    importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For
      more detail, see explanation at the top of the doc.
    lambda_value: Lambda parameter for TD-lambda computation.
    discount_factor: Discount factor for return computation. Default to `0.99`
      which is the value used for all environments from (Schulman, 2017).
    entropy_regularization: Coefficient for entropy regularization loss term.
      Default to `0.0` because no entropy bonus was used in (Schulman, 2017).
    value_pred_loss_coef: Multiplier for value prediction loss to balance with
      policy gradient loss. Default to `0.5`, which was used for all
      environments in the OpenAI baseline implementation. This parameters is
      irrelevant unless you are sharing part of actor_net and value_net. In that
      case, you would want to tune this coeeficient, whose value depends on the
      network architecture of your choice.
    use_gae: If True (default False), uses generalized advantage estimation for
      computing per-timestep advantage. Else, just subtracts value predictions
      from empirical return.
    use_td_lambda_return: If True (default False), uses td_lambda_return for
      training value function; here: `td_lambda_return = gae_advantage +
        value_predictions`. `use_gae` must be set to `True` as well to enable TD
        -lambda returns. If `use_td_lambda_return` is set to True while
        `use_gae` is False, the empirical return will be used and a warning will
        be logged.
    gradient_clipping: Norm length to clip gradients.
    value_clipping: Difference between new and old value predictions are clipped
      to this threshold. Value clipping could be helpful when training
      very deep networks. Default: no clipping.
    reverb_port: Port for reverb server, if None, use a randomly chosen unused
      port.
    replay_capacity: The maximum number of elements for the replay buffer. Items
      will be wasted if this is smalled than collect_sequence_length.
    policy_save_interval: How often, in train_steps, the policy will be saved.
    summary_interval: How often to write data into Tensorboard.
    eval_interval: How often to run evaluation, in train_steps.
    eval_episodes: Number of episodes to evaluate over.
    debug_summaries: Boolean for whether to gather debug summaries.
    summarize_grads_and_vars: If true, gradient summaries will be written.
  """
  collect_env = suite_mujoco.load(env_name)
  eval_env = suite_mujoco.load(env_name)
  num_environments = 1

  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(collect_env))
  # TODO(b/172267869): Remove this conversion once TensorNormalizer stops
  # converting float64 inputs to float32.
  observation_tensor_spec = tf.TensorSpec(
      dtype=tf.float32, shape=observation_tensor_spec.shape)

  train_step = train_utils.create_train_step()
  actor_net_builder = ppo_actor_network.PPOActorNetwork()
  actor_net = actor_net_builder.create_sequential_actor_net(
      actor_fc_layers, action_tensor_spec)
  value_net = value_network.ValueNetwork(
      observation_tensor_spec,
      fc_layer_params=value_fc_layers,
      kernel_initializer=tf.keras.initializers.Orthogonal())

  current_iteration = tf.Variable(0, dtype=tf.int64)
  def learning_rate_fn():
    # Linearly decay the learning rate.
    return learning_rate * (1 - current_iteration / num_iterations)

  agent = ppo_clip_agent.PPOClipAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      optimizer=tf.keras.optimizers.Adam(
          learning_rate=learning_rate_fn, epsilon=1e-5),
      actor_net=actor_net,
      value_net=value_net,
      importance_ratio_clipping=importance_ratio_clipping,
      lambda_value=lambda_value,
      discount_factor=discount_factor,
      entropy_regularization=entropy_regularization,
      value_pred_loss_coef=value_pred_loss_coef,
      # This is a legacy argument for the number of times we repeat the data
      # inside of the train function, incompatible with mini batch learning.
      # We set the epoch number from the replay buffer and tf.Data instead.
      num_epochs=1,
      use_gae=use_gae,
      use_td_lambda_return=use_td_lambda_return,
      gradient_clipping=gradient_clipping,
      value_clipping=value_clipping,
      # TODO(b/150244758): Default compute_value_and_advantage_in_train to False
      # after Reverb open source.
      compute_value_and_advantage_in_train=False,
      # Skips updating normalizers in the agent, as it's handled in the learner.
      update_normalizers_in_train=False,
      debug_summaries=debug_summaries,
      summarize_grads_and_vars=summarize_grads_and_vars,
      train_step_counter=train_step)
  agent.initialize()

  reverb_server = reverb.Server(
      [
          reverb.Table(  # Replay buffer storing experience for training.
              name='training_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          ),
          reverb.Table(  # Replay buffer storing experience for normalization.
              name='normalization_table',
              sampler=reverb.selectors.Fifo(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_capacity,
              max_times_sampled=1,
          )
      ],
      port=reverb_port)

  # Create the replay buffer.
  reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='training_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)
  reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=collect_sequence_length,
      table_name='normalization_table',
      server_address='localhost:{}'.format(reverb_server.port),
      # The only collected sequence is used to populate the batches.
      max_cycle_length=1,
      rate_limiter_timeout_ms=1000)

  rb_observer = reverb_utils.ReverbTrajectorySequenceObserver(
      reverb_replay_train.py_client, ['training_table', 'normalization_table'],
      sequence_length=collect_sequence_length,
      stride_length=collect_sequence_length)

  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  collect_env_step_metric = py_metrics.EnvironmentSteps()
  learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={
              triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric
          }),
      triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval),
  ]

  def training_dataset_fn():
    return reverb_replay_train.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  def normalization_dataset_fn():
    return reverb_replay_normalization.as_dataset(
        sample_batch_size=num_environments,
        sequence_preprocess_fn=agent.preprocess_sequence)

  agent_learner = ppo_learner.PPOLearner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn=training_dataset_fn,
      normalization_dataset_fn=normalization_dataset_fn,
      num_samples=1,
      num_epochs=num_epochs,
      minibatch_size=minibatch_size,
      shuffle_buffer_size=collect_sequence_length,
      triggers=learning_triggers)

  tf_collect_policy = agent.collect_policy
  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
      tf_collect_policy, use_tf_function=True)

  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=collect_sequence_length,
      observers=[rb_observer],
      metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric],
      reference_metrics=[collect_env_step_metric],
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
      summary_interval=summary_interval)

  eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
      agent.policy, use_tf_function=True)

  if eval_interval:
    logging.info('Intial evaluation.')
    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        metrics=actor.eval_metrics(eval_episodes),
        reference_metrics=[collect_env_step_metric],
        summary_dir=os.path.join(root_dir, 'eval'),
        episodes_per_run=eval_episodes)

    eval_actor.run_and_log()

  logging.info('Training on %s', env_name)
  last_eval_step = 0
  for i in range(num_iterations):
    collect_actor.run()
    rb_observer.flush()
    agent_learner.run()
    reverb_replay_train.clear()
    reverb_replay_normalization.clear()
    current_iteration.assign_add(1)

    # Eval only if `eval_interval` has been set. Then, eval if the current train
    # step is equal or greater than the `last_eval_step` + `eval_interval` or if
    # this is the last iteration. This logic exists because agent_learner.run()
    # does not return after every train step.
    if (eval_interval and
        (agent_learner.train_step_numpy >= eval_interval + last_eval_step
         or i == num_iterations - 1)):
      logging.info('Evaluating.')
      eval_actor.run_and_log()
      last_eval_step = agent_learner.train_step_numpy

  rb_observer.close()
  reverb_server.stop()
Exemple #19
0
def train_eval(
        root_dir,
        env_name,
        # Training params
        train_sequence_length,
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_iterations=100000,
        # RNN params.
        q_network_fn=q_lstm_network,  # defaults to q_lstm_network.
        # Agent params
    epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""

    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
    q_net = q_network_fn(num_actions=num_actions)

    sequence_length = train_sequence_length + 1
    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        # n-step updates aren't supported with RNNs yet.
        n_step_update=1,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=sequence_length,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=sequence_length,
        stride_length=1,
        pad_end_of_episodes=True)

    def experience_dataset_fn():
        return reverb_replay.as_dataset(sample_batch_size=batch_size,
                                        num_steps=sequence_length)

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=collect_steps_per_iteration,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Exemple #20
0
table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
table = reverb.Table(table_name,
                     max_size=replay_buffer_capacity,
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     rate_limiter=reverb.rate_limiters.MinSize(1),
                     signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    table_name=table_name,
    sequence_length=None,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddEpisodeObserver(replay_buffer.py_client,
                                                    table_name,
                                                    replay_buffer_capacity)


def collect_episode(environment, policy, num_episodes):

    driver = py_driver.PyDriver(environment,
                                py_tf_eager_policy.PyTFEagerPolicy(
                                    policy, use_tf_function=True),
                                [rb_observer],
                                max_episodes=num_episodes)
def train_agent(iterations, modeldir, logdir, policydir):
    """Train and convert the model using TF Agents."""

    train_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2)
    eval_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2)

    train_env = tf_py_environment.TFPyEnvironment(train_py_env)
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    # Alternatively you could use ActorDistributionNetwork as actor_net
    actor_net = tfa.networks.Sequential(
        [
            tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE],
                                          [BOARD_SIZE**2]),
            tf.keras.layers.Dense(FC_LAYER_PARAMS, activation='relu'),
            tf.keras.layers.Dense(BOARD_SIZE**2),
            tf.keras.layers.Lambda(
                lambda t: tfp.distributions.Categorical(logits=t)),
        ],
        input_spec=train_py_env.observation_spec())

    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

    train_step_counter = tf.Variable(0)

    tf_agent = reinforce_agent.ReinforceAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        actor_network=actor_net,
        optimizer=optimizer,
        normalize_returns=True,
        train_step_counter=train_step_counter)

    tf_agent.initialize()

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    tf_policy_saver = policy_saver.PolicySaver(collect_policy)

    # Use reverb as replay buffer
    replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
    table = reverb.Table(
        REPLAY_BUFFER_TABLE_NAME,
        max_size=REPLAY_BUFFER_CAPACITY,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature
    )  # specify signature here for validation at insertion time

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        tf_agent.collect_data_spec,
        sequence_length=None,
        table_name=REPLAY_BUFFER_TABLE_NAME,
        local_server=reverb_server)

    replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver(
        replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME,
        REPLAY_BUFFER_CAPACITY)

    # Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return_and_steps(eval_env, tf_agent.policy,
                                              NUM_EVAL_EPISODES)

    summary_writer = tf.summary.create_file_writer(logdir)

    for i in range(iterations):
        # Collect a few episodes using collect_policy and save to the replay buffer.
        collect_episode(train_py_env, collect_policy,
                        COLLECT_EPISODES_PER_ITERATION, replay_buffer_observer)

        # Use data from the buffer and update the agent's network.
        iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
        trajectories, _ = next(iterator)
        tf_agent.train(experience=trajectories)
        replay_buffer.clear()

        logger = tf.get_logger()
        if i % EVAL_INTERVAL == 0:
            avg_return, avg_episode_length = compute_avg_return_and_steps(
                eval_env, eval_policy, NUM_EVAL_EPISODES)
            with summary_writer.as_default():
                tf.summary.scalar('Average return', avg_return, step=i)
                tf.summary.scalar('Average episode length',
                                  avg_episode_length,
                                  step=i)
                summary_writer.flush()
            logger.info(
                'iteration = {0}: Average Return = {1}, Average Episode Length = {2}'
                .format(i, avg_return, avg_episode_length))

    summary_writer.close()

    tf_policy_saver.save(policydir)
    # Convert to tflite model
    converter = tf.lite.TFLiteConverter.from_saved_model(
        policydir, signature_keys=['action'])
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
        tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
    ]
    tflite_policy = converter.convert()
    with open(os.path.join(modeldir, 'planestrike_tf_agents.tflite'),
              'wb') as f:
        f.write(tflite_policy)
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        # Defaults to not checkpointing saved policy. If you wish to enable this,
        # please note the caveat explained in README.md.
        policy_save_interval=-1,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    critic_net = critic_network.CriticNetwork(
        (observation_tensor_spec, action_tensor_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    agent = sac_agent.SacAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step)
    agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                       num_steps=2).prefetch(50)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Exemple #23
0
def train_eval(
        root_dir,
        strategy: tf.distribute.Strategy,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        policy_save_interval=10000,
        replay_buffer_save_interval=100000,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    _, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    actor_net = create_sequential_actor_network(
        actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec)

    critic_net = create_sequential_critic_network(
        obs_fc_layer_units=critic_obs_fc_layers,
        action_fc_layer_units=critic_action_fc_layers,
        joint_fc_layer_units=critic_joint_fc_layers)

    with strategy.scope():
        train_step = train_utils.create_train_step()
        agent = sac_agent.SacAgent(
            time_step_tensor_spec,
            action_tensor_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.keras.optimizers.Adam(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.keras.optimizers.Adam(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.keras.optimizers.Adam(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=tf.math.squared_difference,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=None,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step)
        agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR,
                                         learner.REPLAY_BUFFER_CHECKPOINT_DIR)
    reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer(
        path=reverb_checkpoint_dir)
    reverb_server = reverb.Server([table],
                                  port=reverb_port,
                                  checkpointer=reverb_checkpointer)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    def experience_dataset_fn():
        return reverb_replay.as_dataset(sample_batch_size=batch_size,
                                        num_steps=2).prefetch(50)

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.ReverbCheckpointTrigger(
            train_step,
            interval=replay_buffer_save_interval,
            reverb_client=reverb_replay.py_client),
        # TODO(b/165023684): Add SIGTERM handler to checkpoint before preemption.
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers,
                                    strategy=strategy)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()