Ejemplo n.º 1
0
  def make_replay_tables(
      self,
      environment_spec: specs.EnvironmentSpec,
  ) -> List[reverb.Table]:
    """Create tables to insert data into."""
    if self._config.samples_per_insert is None:
      # We will take a samples_per_insert ratio of None to mean that there is
      # no limit, i.e. this only implies a min size limit.
      limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size)

    else:
      # Create enough of an error buffer to give a 10% tolerance in rate.
      samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert
      error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
      limiter = reverb.rate_limiters.SampleToInsertRatio(
          min_size_to_sample=self._config.min_replay_size,
          samples_per_insert=self._config.samples_per_insert,
          error_buffer=error_buffer)

    replay_table = reverb.Table(
        name=self._config.replay_table_name,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=self._config.max_replay_size,
        rate_limiter=limiter,
        signature=reverb_adders.NStepTransitionAdder.signature(
            environment_spec))

    return [replay_table]
Ejemplo n.º 2
0
  def replay(self) -> List[reverb.Table]:
    """The replay storage."""
    network = self._network_factory(self._environment_spec.actions)
    extra_spec = {
        'core_state': network.initial_state(1),
    }
    # Remove batch dimensions.
    extra_spec = tf2_utils.squeeze_batch_dim(extra_spec)
    if self._samples_per_insert:
      limiter = reverb.rate_limiters.SampleToInsertRatio(
          min_size_to_sample=self._min_replay_size,
          samples_per_insert=self._samples_per_insert,
          error_buffer=self._batch_size)
    else:
      limiter = reverb.rate_limiters.MinSize(self._min_replay_size)
    table = reverb.Table(
        name=adders.DEFAULT_PRIORITY_TABLE,
        sampler=reverb.selectors.Prioritized(self._priority_exponent),
        remover=reverb.selectors.Fifo(),
        max_size=self._max_replay_size,
        rate_limiter=limiter,
        signature=adders.SequenceAdder.signature(
            self._environment_spec,
            extra_spec,
            sequence_length=self._burn_in_length + self._trace_length + 1))

    return [table]
Ejemplo n.º 3
0
 def make_replay_tables(
     self,
     environment_spec: specs.EnvironmentSpec,
 ) -> List[reverb.Table]:
     """Create tables to insert data into."""
     #'''
     if self._config.samples_per_insert:
         samples_per_insert_tolerance = (
             self._config.samples_per_insert_tolerance_rate *
             self._config.samples_per_insert)
         error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
         limiter = reverb.rate_limiters.SampleToInsertRatio(
             min_size_to_sample=self._config.min_replay_size,
             samples_per_insert=self._config.samples_per_insert,
             error_buffer=error_buffer)
     else:
         limiter = reverb.rate_limiters.MinSize(
             self._config.min_replay_size)
     return [
         reverb.Table(name=self._config.replay_table_name,
                      sampler=reverb.selectors.Prioritized(
                          self._config.priority_exponent),
                      remover=reverb.selectors.Fifo(),
                      max_size=self._config.max_replay_size,
                      rate_limiter=limiter,
                      signature=adders_reverb.SequenceAdder.signature(
                          environment_spec, self._extra_spec))
     ]
Ejemplo n.º 4
0
 def make_replay_tables(
     self,
     environment_spec: specs.EnvironmentSpec,
     policy: dqn_actor.EpsilonPolicy,
 ) -> List[reverb.Table]:
     """Creates reverb tables for the algorithm."""
     del policy
     samples_per_insert_tolerance = (
         self._config.samples_per_insert_tolerance_rate *
         self._config.samples_per_insert)
     error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
     limiter = rate_limiters.SampleToInsertRatio(
         min_size_to_sample=self._config.min_replay_size,
         samples_per_insert=self._config.samples_per_insert,
         error_buffer=error_buffer)
     return [
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Prioritized(
                 self._config.priority_exponent),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=limiter,
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec))
     ]
Ejemplo n.º 5
0
    def make_replay_tables(
        self,
        environment_spec: specs.EnvironmentSpec,
        policy: r2d2_actor.R2D2Policy,
    ) -> List[reverb.Table]:
        """Create tables to insert data into."""
        dummy_actor_state = policy.init(jax.random.PRNGKey(0))
        extras_spec = policy.get_extras(dummy_actor_state)

        if self._config.samples_per_insert:
            samples_per_insert_tolerance = (
                self._config.samples_per_insert_tolerance_rate *
                self._config.samples_per_insert)
            error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
            limiter = reverb.rate_limiters.SampleToInsertRatio(
                min_size_to_sample=self._config.min_replay_size,
                samples_per_insert=self._config.samples_per_insert,
                error_buffer=error_buffer)
        else:
            limiter = reverb.rate_limiters.MinSize(
                self._config.min_replay_size)
        return [
            reverb.Table(name=self._config.replay_table_name,
                         sampler=reverb.selectors.Prioritized(
                             self._config.priority_exponent),
                         remover=reverb.selectors.Fifo(),
                         max_size=self._config.max_replay_size,
                         rate_limiter=limiter,
                         signature=adders_reverb.SequenceAdder.signature(
                             environment_spec, extras_spec))
        ]
Ejemplo n.º 6
0
  def setUp(self):
    super(ReverbReplayBufferTest, self).setUp()

    # Prepare the environment (and the corresponding specs).
    self._env = test_envs.EpisodeCountingEnv(steps_per_episode=3)
    tensor_time_step_spec = tf.nest.map_structure(tensor_spec.from_spec,
                                                  self._env.time_step_spec())
    tensor_action_spec = tensor_spec.from_spec(self._env.action_spec())
    self._data_spec = trajectory.Trajectory(
        step_type=tensor_time_step_spec.step_type,
        observation=tensor_time_step_spec.observation,
        action=tensor_action_spec,
        policy_info=(),
        next_step_type=tensor_time_step_spec.step_type,
        reward=tensor_time_step_spec.reward,
        discount=tensor_time_step_spec.discount,
    )
    table_spec = tf.nest.map_structure(
        lambda s: tf.TensorSpec(dtype=s.dtype, shape=(None,) + s.shape),
        self._data_spec)
    self._array_data_spec = tensor_spec.to_nest_array_spec(self._data_spec)

    # Initialize and start a Reverb server (and set up a client to it).
    self._table_name = 'test_table'
    uniform_table = reverb.Table(
        self._table_name,
        max_size=100,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=table_spec,
    )
    self._server = reverb.Server([uniform_table])
    self._py_client = reverb.Client('localhost:{}'.format(self._server.port))
Ejemplo n.º 7
0
    def make_replay_tables(
        self,
        environment_spec: specs.EnvironmentSpec,
        sequence_length: int,
    ) -> List[reverb.Table]:
        """Create tables to insert data into."""
        if self._config.samples_per_insert is None:
            # We will take a samples_per_insert ratio of None to mean that there is
            # no limit, i.e. this only implies a min size limit.
            limiter = reverb.rate_limiters.MinSize(
                self._config.min_replay_size)

        else:
            error_buffer = max(1, self._config.samples_per_insert)
            limiter = reverb.rate_limiters.SampleToInsertRatio(
                min_size_to_sample=self._config.min_replay_size,
                samples_per_insert=self._config.samples_per_insert,
                error_buffer=error_buffer)

        extras_spec = {'log_prob': tf.ones(shape=(), dtype=tf.float32)}
        replay_table = reverb.Table(
            name=self._config.replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=self._config.max_replay_size,
            rate_limiter=limiter,
            signature=reverb_adders.SequenceAdder.signature(
                environment_spec,
                extras_spec=extras_spec,
                sequence_length=sequence_length + 1))

        return [replay_table]
Ejemplo n.º 8
0
  def test_prioritized_table(self):
    table_name = 'test_prioritized_table'
    queue_table = reverb.Table(
        table_name,
        sampler=reverb.selectors.Prioritized(1.0),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        max_size=3)
    reverb_server = reverb.Server([queue_table])
    data_spec = tensor_spec.TensorSpec((), dtype=tf.int64)
    replay = reverb_replay_buffer.ReverbReplayBuffer(
        data_spec,
        table_name,
        sequence_length=1,
        local_server=reverb_server,
        dataset_buffer_size=1)

    with replay.py_client.trajectory_writer(1) as writer:
      for i in range(3):
        writer.append(i)
        writer.create_item(table_name, trajectory=writer.history[-1:],
                           priority=i)

    dataset = replay.as_dataset(
        sample_batch_size=1, num_steps=None, num_parallel_calls=1)

    iterator = iter(dataset)
    counts = [0] * 3
    for i in range(1000):
      item_0 = next(iterator)[0].numpy()  # This is a matrix shaped 1x1.
      counts[int(item_0)] += 1

    self.assertEqual(counts[0], 0)  # priority 0
    self.assertGreater(counts[1], 250)  # priority 1
    self.assertGreater(counts[2], 600)  # priority 2
Ejemplo n.º 9
0
    def __init__(self, data_spec, batch_size=1, n_steps=2):
        self.data_spec = data_spec
        self.batch_size = batch_size
        self.n_steps = n_steps
        self.replay_buffer_capacity = 100000

        self.name = 'PER'
        self.server = reverb.Server([
            reverb.Table(name=self.name,
                         max_size=self.replay_buffer_capacity,
                         sampler=reverb.selectors.Prioritized(0.8),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
        ])
        self.buffer = reverb_replay_buffer.ReverbReplayBuffer(
            data_spec=self.data_spec,
            sequence_length=self.n_steps,
            table_name=self.name,
            local_server=self.server)
        self.writer = reverb_utils.ReverbAddTrajectoryObserver(
            self.buffer.py_client,
            table_name=self.name,
            sequence_length=self.n_steps,
            stride_length=1,
            priority=1,
        )
        self.states = None
Ejemplo n.º 10
0
 def make_replay_tables(
     self,
     environment_spec: specs.EnvironmentSpec,
     policy: actor_core_lib.FeedForwardPolicy,
 ) -> List[reverb.Table]:
     """Create tables to insert data into."""
     del policy
     # Create the rate limiter.
     if self._config.samples_per_insert:
         samples_per_insert_tolerance = (
             self._config.samples_per_insert_tolerance_rate *
             self._config.samples_per_insert)
         error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
         limiter = rate_limiters.SampleToInsertRatio(
             min_size_to_sample=self._config.min_replay_size,
             samples_per_insert=self._config.samples_per_insert,
             error_buffer=error_buffer)
     else:
         limiter = rate_limiters.MinSize(self._config.min_replay_size)
     return [
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Uniform(),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=limiter,
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec))
     ]
Ejemplo n.º 11
0
def _create_reverb_server(table_name: str) -> reverb.Server:
    table = reverb.Table(table_name,
                         max_size=_replay_buffer_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    return reverb.Server([table])
Ejemplo n.º 12
0
def priority_tables_fn():
    return [
        reverb.Table(name=_TABLE_NAME,
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     max_size=100,
                     rate_limiter=rate_limiters.MinSize(100))
    ]
Ejemplo n.º 13
0
def create_reverb_server_for_replay_buffer_and_variable_container(
    collect_policy, train_step, replay_buffer_capacity, port):
  """Sets up one reverb server for replay buffer and variable container."""
  # Create the signature for the variable container holding the policy weights.
  variables = {
      reverb_variable_container.POLICY_KEY: collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container_signature = tf.nest.map_structure(
      lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype),
      variables)

  # Create the signature for the replay buffer holding observed experience.
  replay_buffer_signature = tensor_spec.from_spec(
      collect_policy.collect_data_spec)
  # Prefix a time axis for trajectories.
  replay_buffer_signature = tf.nest.map_structure(
      lambda s: tf.TensorSpec((None,) + s.shape, s.dtype, s.name),
      replay_buffer_signature)

  # Crete and start the replay buffer and variable container server.
  server = reverb.Server(
      tables=[
          reverb.Table(  # Replay buffer storing experience.
              name=reverb_replay_buffer.DEFAULT_TABLE,
              sampler=reverb.selectors.Uniform(),
              remover=reverb.selectors.Fifo(),
              # TODO(b/159073060): Set rate limiter for SAC properly.
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=replay_buffer_capacity,
              max_times_sampled=0,
              signature=replay_buffer_signature,
          ),
          reverb.Table(  # Variable container storing policy parameters.
              name=reverb_variable_container.DEFAULT_TABLE,
              sampler=reverb.selectors.Uniform(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=1,
              max_times_sampled=0,
              signature=variable_container_signature,
          ),
      ],
      port=port)
  return server
Ejemplo n.º 14
0
def _reverb_server():
    return reverb.Server(tables=[
        reverb.Table('test_table',
                     reverb.selectors.Uniform(),
                     reverb.selectors.Fifo(),
                     max_size=100,
                     rate_limiter=reverb.rate_limiters.MinSize(95))
    ],
                         port=None)
Ejemplo n.º 15
0
    def make_replay_tables(
        self,
        environment_spec: specs.MAEnvironmentSpec,
    ) -> List[reverb.Table]:
        """Create tables to insert data into.

        Args:
            environment_spec (specs.MAEnvironmentSpec): description of the action and
                observation spaces etc. for each agent in the system.

        Raises:
            NotImplementedError: unknown executor type.

        Returns:
            List[reverb.Table]: a list of data tables for inserting data.
        """

        # Select adder
        if issubclass(self._executor_fn, executors.FeedForwardExecutor):
            # Check if we should use fingerprints
            if self._replay_stabiliser_fn is not None:
                self._extra_specs.update({"fingerprint": np.array([1.0, 1.0])})
            adder_sig = reverb_adders.ParallelNStepTransitionAdder.signature(
                environment_spec, self._extra_specs
            )
        elif issubclass(self._executor_fn, executors.RecurrentExecutor):
            adder_sig = reverb_adders.ParallelSequenceAdder.signature(
                environment_spec, self._extra_specs
            )
        else:
            raise NotImplementedError("Unknown executor type: ", self._executor_fn)

        if self._config.samples_per_insert is None:
            # We will take a samples_per_insert ratio of None to mean that there is
            # no limit, i.e. this only implies a min size limit.
            limiter = reverb.rate_limiters.MinSize(self._config.min_replay_size)

        else:
            # Create enough of an error buffer to give a 10% tolerance in rate.
            samples_per_insert_tolerance = 0.1 * self._config.samples_per_insert
            error_buffer = self._config.min_replay_size * samples_per_insert_tolerance
            limiter = reverb.rate_limiters.SampleToInsertRatio(
                min_size_to_sample=self._config.min_replay_size,
                samples_per_insert=self._config.samples_per_insert,
                error_buffer=error_buffer,
            )

        replay_table = reverb.Table(
            name=self._config.replay_table_name,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=self._config.max_replay_size,
            rate_limiter=limiter,
            signature=adder_sig,
        )

        return [replay_table]
Ejemplo n.º 16
0
    def test_dataset_with_variable_sequence_length_truncates(self):
        spec = tf.TensorSpec((), tf.int64)
        table_spec = tf.TensorSpec((None, ), tf.int64)
        table = reverb.Table(
            name=self._table_name,
            sampler=reverb.selectors.Fifo(),
            remover=reverb.selectors.Fifo(),
            max_times_sampled=1,
            max_size=100,
            rate_limiter=reverb.rate_limiters.MinSize(1),
            signature=table_spec,
        )
        server = reverb.Server([table])
        py_client = reverb.Client('localhost:{}'.format(server.port))

        # Insert two episodes: one of length 3 and one of length 5
        with py_client.trajectory_writer(10) as writer:
            writer.append(1)
            writer.append(2)
            writer.append(3)
            writer.create_item(self._table_name,
                               trajectory=writer.history[-3:],
                               priority=5)

        with py_client.trajectory_writer(10) as writer:
            writer.append(10)
            writer.append(20)
            writer.append(30)
            writer.append(40)
            writer.append(50)
            writer.create_item(self._table_name,
                               trajectory=writer.history[-5:],
                               priority=5)

        replay = reverb_replay_buffer.ReverbReplayBuffer(
            spec,
            self._table_name,
            local_server=server,
            sequence_length=None,
            rate_limiter_timeout_ms=100)
        ds = replay.as_dataset(single_deterministic_pass=True, num_steps=2)
        it = iter(ds)

        # Expect [1, 2]
        data, _ = next(it)
        self.assertAllEqual(data, [1, 2])

        # Expect [10, 20]
        data, _ = next(it)
        self.assertAllEqual(data, [10, 20])

        # Expect [30, 40]
        data, _ = next(it)
        self.assertAllEqual(data, [30, 40])

        with self.assertRaises(StopIteration):
            next(it)
Ejemplo n.º 17
0
    def __init__(self):

        # Initialize the reverb server
        buffer_size = tf.cast(Params.BUFFER_SIZE, tf.int64)
        batch_size = tf.cast(Params.MINIBATCH_SIZE, tf.int64)
        reverb_server = reverb.Server(tables=[
            reverb.Table(
                name=Params.BUFFER_TYPE,
                sampler=reverb.selectors.Prioritized(
                    priority_exponent=Params.BUFFER_PRIORITY_ALPHA),
                remover=reverb.selectors.Fifo(),
                max_size=buffer_size,
                rate_limiter=reverb.rate_limiters.MinSize(batch_size),
            ),
            reverb.Table(
                name=Params.BUFFER_TYPE + "_max",
                sampler=reverb.selectors.MaxHeap(),
                remover=reverb.selectors.Fifo(),
                max_size=buffer_size,
                rate_limiter=reverb.rate_limiters.MinSize(tf.constant(1)),
            ),
            reverb.Table(
                name=Params.BUFFER_TYPE + "_min",
                sampler=reverb.selectors.MinHeap(),
                remover=reverb.selectors.Fifo(),
                max_size=buffer_size,
                rate_limiter=reverb.rate_limiters.MinSize(tf.constant(1)),
            ),
        ], )

        super().__init__(name="ReverbPrioritizedReplayBuffer",
                         reverb_server=reverb_server)

        # Init client for updating priorities
        self.client = self.get_client()

        # Insert dummy trajectory
        self.client.insert([
            tf.zeros(spec.shape, dtype=spec.dtype)
            for spec in Params.BUFFER_DATA_SPEC
        ],
                           tables=Params.BUFFER_PRIORITY_TABLE_NAMES,
                           priorities=tf.constant([1., 1., 1.],
                                                  dtype=tf.float64))
Ejemplo n.º 18
0
 def test_make_replay_table_preserves_table_info(self):
     limiter = reverb.rate_limiters.SampleToInsertRatio(
         samples_per_insert=1, min_size_to_sample=2, error_buffer=(0, 10))
     table = reverb.Table(name='test',
                          sampler=reverb.selectors.Uniform(),
                          remover=reverb.selectors.Fifo(),
                          max_size=10,
                          rate_limiter=limiter)
     table_from_info = reverb_utils.make_replay_table_from_info(table.info)
     self.assertEqual(table_from_info.info, table.info)
Ejemplo n.º 19
0
    def setUpClass(cls):
        super().setUpClass()

        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=1000,
            rate_limiter=reverb.rate_limiters.MinSize(1),
        )
        cls.server = reverb.Server([replay_table])
Ejemplo n.º 20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    eval_table_size = _CONFIG.value.num_eval_points

    # TODO(joshgreaves): Choose an appropriate rate_limiter, max_size.
    server = reverb.Server(tables=[
        reverb.Table(name='successor_table',
                     sampler=reverb.selectors.Uniform(),
                     remover=reverb.selectors.Fifo(),
                     max_size=1_000_000,
                     rate_limiter=reverb.rate_limiters.MinSize(20_000)),
        reverb.Table(
            name='eval_table',
            sampler=reverb.selectors.Fifo(),
            remover=reverb.selectors.Fifo(),
            max_size=eval_table_size,
            rate_limiter=reverb.rate_limiters.MinSize(eval_table_size)),
    ],
                           port=FLAGS.port)
    server.wait()
Ejemplo n.º 21
0
def make_replay_tables(environment_spec: specs.EnvironmentSpec
                      ) -> Sequence[reverb.Table]:
  """Create tables to insert data into."""
  return [
      reverb.Table(
          name='default',
          sampler=reverb.selectors.Uniform(),
          remover=reverb.selectors.Fifo(),
          max_size=1000000,
          rate_limiter=rate_limiters.MinSize(1),
          signature=adders_reverb.NStepTransitionAdder.signature(
              environment_spec))
  ]
Ejemplo n.º 22
0
 def __init__(self, min_size: int = 64, max_size: int = 40000):
     self._min_size = min_size
     self._table_name = 'priority_table'
     self._server = reverb.Server(
         tables=[
             reverb.Table(
                 name=self._table_name,
                 sampler=reverb.selectors.Prioritized(
                     priority_exponent=0.8),
                 remover=reverb.selectors.Fifo(),
                 max_size=int(max_size),
                 rate_limiter=reverb.rate_limiters.MinSize(min_size)),
         ],
         # Sets the port to None to make the server pick one automatically.
         port=None)
Ejemplo n.º 23
0
def tables_from_proto(
    configs: Sequence[checkpoint_pb2.PriorityTableCheckpoint]
) -> Sequence[reverb.Table]:
  """Convert protobuf to reverb.Table."""
  tables = []
  for config in configs:
    tables.append(
        reverb.Table(
            name=config.table_name,
            sampler=selector_from_proto(config.sampler),
            remover=selector_from_proto(config.remover),
            max_size=config.max_size,
            rate_limiter=rate_limiter_from_proto(config.rate_limiter),
            max_times_sampled=config.max_times_sampled,
        ))
  return tables
Ejemplo n.º 24
0
    def test_deterministic_dataset_from_heap_sampler_remover(self):

        uniform_sampler_min_heap_remover_table = reverb.Table(
            name=self._table_name,
            sampler=reverb.selectors.MaxHeap(),
            remover=reverb.selectors.MinHeap(),
            max_size=100,
            max_times_sampled=0,
            rate_limiter=reverb.rate_limiters.MinSize(1))
        server = reverb.Server([uniform_sampler_min_heap_remover_table])
        replay = reverb_replay_buffer.ReverbReplayBuffer(self._data_spec,
                                                         self._table_name,
                                                         local_server=server,
                                                         sequence_length=None)
        replay.as_dataset(single_deterministic_pass=True)
        server.stop()
Ejemplo n.º 25
0
    def setUp(self):
        super().setUp()
        self._mock_client = mock.MagicMock()
        self._mock_writer = mock.MagicMock()
        self._mock_client.trajectory_writer = self._mock_writer
        self._mock_writer.return_value = self._mock_writer

        self._table_name = 'uniform_table'
        self._table = reverb.Table(
            self._table_name,
            max_size=100,
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._reverb_server = reverb.Server([self._table], port=None)
        self._reverb_client = self._reverb_server.localhost_client()
Ejemplo n.º 26
0
  def make_replay_tables(
      self,
      environment_spec: specs.EnvironmentSpec,
      policy: impala_networks.IMPALANetworks,
  ) -> List[reverb.Table]:
    """The queue; use XData or INFO log."""
    del policy
    num_actions = environment_spec.actions.num_values
    extra_spec = {
        'core_state': self._core_state_spec,
        'logits': jnp.ones(shape=(num_actions,), dtype=jnp.float32)
    }
    signature = reverb_adders.SequenceAdder.signature(
        environment_spec,
        extra_spec,
        sequence_length=self._config.sequence_length)

    # Maybe create rate limiter.
    # Setting the samples_per_insert ratio less than the default of 1.0, allows
    # the agent to drop data for the benefit of using data from most up-to-date
    # policies to compute its learner updates.
    samples_per_insert = self._config.samples_per_insert
    if samples_per_insert:
      if samples_per_insert > 1.0 or samples_per_insert <= 0.0:
        raise ValueError(
            'Impala requires a samples_per_insert ratio in the range (0, 1],'
            f' but received {samples_per_insert}.')
      limiter = reverb.rate_limiters.SampleToInsertRatio(
          samples_per_insert=samples_per_insert,
          min_size_to_sample=1,
          error_buffer=self._config.batch_size)
    else:
      limiter = reverb.rate_limiters.MinSize(1)

    table_extensions = []
    if self._table_extension is not None:
      table_extensions = [self._table_extension()]
    queue = reverb.Table(
        name=self._config.replay_table_name,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        max_size=self._config.max_queue_size,
        max_times_sampled=1,
        rate_limiter=limiter,
        extensions=table_extensions,
        signature=signature)
    return [queue]
Ejemplo n.º 27
0
 def make_replay_tables(
         self,
         environment_spec: specs.EnvironmentSpec) -> List[reverb.Table]:
     replay_tables = self._rl_agent.make_replay_tables(environment_spec)
     if self._config.share_iterator:
         return replay_tables
     replay_tables.append(
         reverb.Table(
             name=self._config.replay_table_name,
             sampler=reverb.selectors.Uniform(),
             remover=reverb.selectors.Fifo(),
             max_size=self._config.max_replay_size,
             rate_limiter=rate_limiters.MinSize(
                 self._config.min_replay_size),
             signature=adders_reverb.NStepTransitionAdder.signature(
                 environment_spec)))
     return replay_tables
Ejemplo n.º 28
0
  def test_capacity_set(self):
    table_name = 'test_table'
    capacity = 100

    uniform_table = reverb.Table(
        table_name,
        max_size=capacity,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(3))
    server = reverb.Server([uniform_table])
    data_spec = tensor_spec.TensorSpec((), tf.float32)
    replay = reverb_replay_buffer.ReverbReplayBuffer(
        data_spec, table_name, local_server=server, sequence_length=None)

    self.assertEqual(capacity, replay.capacity)
    server.stop()
Ejemplo n.º 29
0
 def replay(self):
     """The replay storage."""
     if self._samples_per_insert:
         limiter = reverb.rate_limiters.SampleToInsertRatio(
             min_size_to_sample=self._min_replay_size,
             samples_per_insert=self._samples_per_insert,
             error_buffer=self._batch_size)
     else:
         limiter = reverb.rate_limiters.MinSize(self._min_replay_size)
     replay_table = reverb.Table(
         name=adders.DEFAULT_PRIORITY_TABLE,
         sampler=reverb.selectors.Prioritized(self._priority_exponent),
         remover=reverb.selectors.Fifo(),
         max_size=self._max_replay_size,
         rate_limiter=limiter,
         signature=adders.NStepTransitionAdder.signature(self._env_spec))
     return [replay_table]
Ejemplo n.º 30
0
def main():
    mp.set_start_method("spawn", force=True)
    config_path = path.join(path.dirname(__file__), "config.yaml")
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    # init perwez
    pwz_proc, pwz_config = perwez.start_server()
    # init reverb_server
    reverb.Server(
        tables=[
            reverb.Table(
                name=TABLE_NAME,
                sampler=reverb.selectors.Prioritized(0.6),
                remover=reverb.selectors.Fifo(),
                max_size=config["replay_buffer"]["capacity"],
                rate_limiter=reverb.rate_limiters.MinSize(1000),
            )
        ],
        port=PORT,
    )

    # worker subprocesses
    worker_processes = []
    num_workers = config["common"]["num_workers"]
    for idx in range(num_workers):
        p = mp.Process(
            name=f"apex-worker-{idx}",
            target=worker_main,
            args=(pwz_config["url"], config, idx),
            daemon=True,
        )
        p.start()
        worker_processes.append(p)

    # trainer process should be the main process
    try:
        trainer_main_tf_dataset(pwz_config["url"], config)
    finally:
        print("exiting...")
        for p in worker_processes:
            p.terminate()
            p.join()
        pwz_proc.terminate()
        pwz_proc.join()