Esempio n. 1
0
    def sample(self,
               table: str,
               data_dtypes,
               name: Optional[str] = None) -> replay_sample.ReplaySample:
        """Samples an item from the replay.

    This only allows sampling items with a data field.

    Args:
      table: Probability table to sample from.
      data_dtypes: Dtypes of the data output. Can be nested.
      name: Optional name for the Client operations.

    Returns:
      A ReplaySample with data nested according to data_dtypes. See ReplaySample
      for more details.
    """
        with tf.name_scope(name, f'{self._name}_sample', ['sample']) as scope:
            key, probability, table_size, priority, data = gen_client_ops.reverb_client_sample(
                self._handle, table, tree.flatten(data_dtypes), name=scope)
            return replay_sample.ReplaySample(
                replay_sample.SampleInfo(key=key,
                                         probability=probability,
                                         table_size=table_size,
                                         priority=priority),
                tree.unflatten_as(data_dtypes, data))
Esempio n. 2
0
def sample_trajectory(client: client_lib.Client, table: str,
                      structure: Any) -> replay_sample.ReplaySample:
    """Temporary helper method for sampling a trajectory.

  Note! This function is only intended to make it easier for alpha testers to
  experiment with the new API. It will be removed before this file is made
  public.

  Args:
    client: Client connected to the server to sample from.
    table: Name of the table to sample from.
    structure: Structure to unpack flat data as.

  Returns:
    ReplaySample with trajectory unpacked as `structure` in `data`-field.
  """

    sampler = client._client.NewSampler(table, 1, 1, 1)  # pylint: disable=protected-access
    sample = sampler.GetNextSample()
    return replay_sample.ReplaySample(
        info=replay_sample.SampleInfo(key=int(sample[0][0]),
                                      probability=float(sample[1][0]),
                                      table_size=int(sample[2][0]),
                                      priority=float(sample[3][0])),
        data=tree.unflatten_as(structure, sample[4:]))
Esempio n. 3
0
    def sample(
        self,
        table: str,
        num_samples=1
    ) -> Generator[List[replay_sample.ReplaySample], None, None]:
        """Samples `num_samples` items from table `table` of the Server.

    NOTE: This method should NOT be used for real training. TFClient (see
    tf_client.py) has far superior performance and should always be preferred.

    Note: If data was written using `insert` (e.g when inserting complete
    trajectories) then the returned "sequence" will be a list of length 1
    containing the trajectory as a single item.

    If `num_samples` is greater than the number of items in `table`, (or
    a rate limiter is used to control sampling), then the returned generator
    will block when an item past the sampling limit is requested.  It will
    unblock when sufficient additional items have been added to `table`.

    Example:
    ```python
    server = Server(..., tables=[queue("queue", ...)])
    client = Client(...)
    # Don't insert anything into "queue"
    generator = client.sample("queue")
    generator.next()  # Blocks until another thread/process writes to queue.
    ```

    Args:
      table: Name of the priority table to sample from.
      num_samples: (default to 1) The number of samples to fetch.

    Yields:
      Lists of timesteps (lists of instances of `ReplaySample`).
      If data was inserted into the table via `insert`, then each element
      of the generator is a length 1 list containing a `ReplaySample`.
      If data was inserted via a writer, then each element is a list whose
      length is the sampled trajectory's length.
    """
        sampler = self._client.NewSampler(table, num_samples, 1)

        for _ in range(num_samples):
            sequence = []
            last = False

            while not last:
                step, last = sampler.GetNextTimestep()
                key = int(step[0])
                probability = float(step[1])
                table_size = int(step[2])
                priority = float(step[3])
                data = step[4:]
                sequence.append(
                    replay_sample.ReplaySample(info=replay_sample.SampleInfo(
                        key, probability, table_size, priority),
                                               data=data))

            yield sequence
Esempio n. 4
0
  def test_sample_variable_length_trajectory(self):
    with self._client.trajectory_writer(10) as writer:
      for i in range(10):
        writer.append([np.ones([3, 3], np.int32) * i])
        writer.create_item(TABLE, 1.0, {
            'last': writer.history[0][-1],
            'all': writer.history[0][:],
        })

    dataset = trajectory_dataset.TrajectoryDataset(
        tf.constant(self._client.server_address),
        table=tf.constant(TABLE),
        dtypes={
            'last': tf.int32,
            'all': tf.int32,
        },
        shapes={
            'last': tf.TensorShape([3, 3]),
            'all': tf.TensorShape([None, 3, 3]),
        },
        max_in_flight_samples_per_worker=1,
        flexible_batch_size=1)

    # Continue sample until we have observed all the trajectories.
    seen_lengths = set()
    while len(seen_lengths) < 10:
      sample = self._sample_from(dataset, 1)[0]

      # The structure should always be the same.
      tree.assert_same_structure(
          sample,
          replay_sample.ReplaySample(
              info=replay_sample.SampleInfo(
                  key=1,
                  probability=1.0,
                  table_size=10,
                  priority=0.5,
              ),
              data={
                  'last': None,
                  'all': None
              }))

      seen_lengths.add(sample.data['all'].shape[0])

    self.assertEqual(seen_lengths, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
    def test_sample_fixed_length_trajectory(self):
        self._populate_replay()

        dataset = trajectory_dataset.TrajectoryDataset(
            tf.constant(self._client.server_address),
            table=tf.constant(TABLE),
            dtypes=DTYPES,
            shapes=SHAPES,
            max_in_flight_samples_per_worker=1,
            flexible_batch_size=1)

        tree.assert_same_structure(
            self._sample_from(dataset, 1)[0],
            replay_sample.ReplaySample(info=replay_sample.SampleInfo(
                key=1,
                probability=1.0,
                table_size=10,
                priority=0.5,
            ),
                                       data=SHAPES))
Esempio n. 6
0
  def __init__(self,
               server_address: Union[str, tf.Tensor],
               table: Union[str, tf.Tensor],
               dtypes: Any,
               shapes: Any,
               max_in_flight_samples_per_worker: int,
               num_workers_per_iterator: int = -1,
               max_samples_per_stream: int = -1,
               rate_limiter_timeout_ms: int = -1,
               flexible_batch_size: int = -1):
    """Constructs a new TimestepDataset.

    Args:
      server_address: Address of gRPC ReverbService.
      table: Probability table to sample from.
      dtypes: Dtypes of the data output. Can be nested.
      shapes: Shapes of the data output. Can be nested.
      max_in_flight_samples_per_worker: The number of samples requested in each
        batch of samples. Higher values give higher throughput but too big
        values can result in skewed sampling distributions as large number of
        samples are fetched from single snapshot of the replay (followed by a
        period of lower activity as the samples are consumed). A good rule of
        thumb is to set this value to 2-3x times the batch size used.
      num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number
        of worker threads to create per dataset iterator. When the selected
        table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be
        used to avoid races causing invalid ordering of items. For all other
        samplers, this value should be roughly equal to the number of threads
        available on the CPU.
      max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum
        number of samples to fetch from a stream before a new call is made.
        Keeping this number low ensures that the data is fetched uniformly from
        all server.
      rate_limiter_timeout_ms: (Defaults to -1: infinite).  Timeout (in
        milliseconds) to wait on the rate limiter when sampling from the table.
        If `rate_limiter_timeout_ms >= 0`, this is the timeout passed to
        `Table::Sample` describing how long to wait for the rate limiter to
          allow sampling. The first time that a request times out (across any of
          the workers), the Dataset iterator is closed and the sequence is
          considered finished.
      flexible_batch_size: (Defaults to -1: auto selected) The maximum number of
        items to sampled from `Table` with single call. Values > 1 enables
        `Table::SampleFlexibleBatch` to return more than one item (but no more
          than `flexible_batch_size`) in a single call without releasing the
          table lock iff the rate limiter allows it. NOTE! It is unlikely that
          you need to tune this value yourself. The auto selected value should
          almost always be preferred. Larger `flexible_batch_size` values result
          a bias towards sampling over inserts. In highly overloaded systems
          this results in higher sample QPS and lower insert QPS compared to
          lower `flexible_batch_size` values.

    Raises:
      ValueError: If `dtypes` and `shapes` don't share the same structure.
      ValueError: If `max_in_flight_samples_per_worker` is not a
        positive integer.
      ValueError: If `num_workers_per_iterator` is not a positive integer or -1.
      ValueError: If `max_samples_per_stream` is not a positive integer or -1.
      ValueError: If `rate_limiter_timeout_ms < -1`.
      ValueError: If `flexible_batch_size` is not a positive integer or -1.
    """
    tree.assert_same_structure(dtypes, shapes, False)
    if max_in_flight_samples_per_worker < 1:
      raise ValueError(
          'max_in_flight_samples_per_worker (%d) must be a positive integer' %
          max_in_flight_samples_per_worker)
    if num_workers_per_iterator < 1 and num_workers_per_iterator != -1:
      raise ValueError(
          'num_workers_per_iterator (%d) must be a positive integer or -1' %
          num_workers_per_iterator)
    if max_samples_per_stream < 1 and max_samples_per_stream != -1:
      raise ValueError(
          'max_samples_per_stream (%d) must be a positive integer or -1' %
          max_samples_per_stream)
    if rate_limiter_timeout_ms < -1:
      raise ValueError('rate_limiter_timeout_ms (%d) must be an integer >= -1' %
                       rate_limiter_timeout_ms)
    if flexible_batch_size < 1 and flexible_batch_size != -1:
      raise ValueError(
          'flexible_batch_size (%d) must be a positive integer or -1' %
          flexible_batch_size)

    # Add the info fields (all scalars).
    dtypes = replay_sample.ReplaySample(
        info=replay_sample.SampleInfo.tf_dtypes(), data=dtypes)
    shapes = replay_sample.ReplaySample(
        info=replay_sample.SampleInfo(
            key=tf.TensorShape([]),
            probability=tf.TensorShape([]),
            table_size=tf.TensorShape([]),
            priority=tf.TensorShape([])),
        data=shapes)

    # The tf.data API doesn't fully support lists so we convert all uses of
    # lists into tuples.
    dtypes = _convert_lists_to_tuples(dtypes)
    shapes = _convert_lists_to_tuples(shapes)

    self._server_address = server_address
    self._table = table
    self._dtypes = dtypes
    self._shapes = shapes
    self._max_in_flight_samples_per_worker = max_in_flight_samples_per_worker
    self._num_workers_per_iterator = num_workers_per_iterator
    self._max_samples_per_stream = max_samples_per_stream
    self._rate_limiter_timeout_ms = rate_limiter_timeout_ms
    self._flexible_batch_size = flexible_batch_size

    if _is_tf1_runtime():
      # Disabling to avoid errors given the different tf.data.Dataset init args
      # between v1 and v2 APIs.
      # pytype: disable=wrong-arg-count
      super().__init__()
    else:
      # DatasetV2 requires the dataset as a variant tensor during init.
      super().__init__(self._as_variant_tensor())
Esempio n. 7
0
  def sample(
      self,
      table: str,
      num_samples: int = 1,
      *,
      emit_timesteps: bool = True,
      unpack_as_table_signature: bool = False,
  ) -> Generator[Union[List[replay_sample.ReplaySample],
                       replay_sample.ReplaySample], None, None]:
    """Samples `num_samples` items from table `table` of the Server.

    NOTE: This method should NOT be used for real training. TrajectoryDataset
    and TimestepDataset should always be preferred over this method.

    Note: If data was written using `insert` (e.g when inserting complete
    trajectories) then the returned "sequence" will be a list of length 1
    containing the trajectory as a single item.

    If `num_samples` is greater than the number of items in `table`, (or
    a rate limiter is used to control sampling), then the returned generator
    will block when an item past the sampling limit is requested.  It will
    unblock when sufficient additional items have been added to `table`.

    Example:

    ```python

    server = Server(..., tables=[queue("queue", ...)])
    client = Client(...)

    # Don't insert anything into "queue"
    generator = client.sample("queue")
    generator.next()  # Blocks until another thread/process writes to queue.

    ```

    Args:
      table: Name of the priority table to sample from.
      num_samples: (default to 1) The number of samples to fetch.
      emit_timesteps: If True then trajectories are returned as a list of
        `ReplaySample`, each representing a single step within the trajectory.
      unpack_as_table_signature: If True then the sampled data is unpacked
        according to the structure of the table signature. If the table does
        not have a signature then flat data is returned.

    Yields:
      If `emit_timesteps` is `True`:

        Lists of timesteps (lists of instances of `ReplaySample`).
        If data was inserted into the table via `insert`, then each element
        of the generator is a length 1 list containing a `ReplaySample`.
        If data was inserted via a writer, then each element is a list whose
        length is the sampled trajectory's length.

      If emit_timesteps is False:

        An instance of `ReplaySample` where the data is unpacked according to
        the signature of the table. If the table does not have any signature
        then the data is flat, i.e each element is a leaf node of the full
        trajectory.

    Raises:
      ValueError: If `emit_timestep` is True but the trajectory cannot be
        decomposed into timesteps.
    """
    buffer_size = 1

    if unpack_as_table_signature:
      signature = self._get_signature_for_table(table)
    else:
      signature = None

    if signature:
      unflatten = lambda x: tree.unflatten_as(signature, x)
    else:
      unflatten = lambda x: x

    sampler = self._client.NewSampler(table, num_samples, buffer_size)

    for _ in range(num_samples):
      sample = sampler.GetNextTrajectory()

      info = replay_sample.SampleInfo(
          key=int(sample[0]),
          probability=float(sample[1]),
          table_size=int(sample[2]),
          priority=float(sample[3]),
          times_sampled=int(sample[4]))
      data = sample[len(info):]

      if emit_timesteps:
        if len(set([len(col) for col in data])) != 1:
          raise ValueError(
              'Can\'t split non timestep trajectory into timesteps.')

        timesteps = []
        for i in range(data[0].shape[0]):
          timestep = replay_sample.ReplaySample(
              info=info,
              data=unflatten([np.asarray(col[i], col.dtype) for col in data]))
          timesteps.append(timestep)

        yield timesteps
      else:
        yield replay_sample.ReplaySample(info, unflatten(data))
Esempio n. 8
0
    def __init__(self,
                 server_address: str,
                 table: str,
                 dtypes: Any,
                 shapes: Any,
                 max_in_flight_samples_per_worker: int,
                 num_workers_per_iterator: int = -1,
                 max_samples_per_stream: int = -1,
                 sequence_length: Optional[int] = None,
                 emit_timesteps: bool = True,
                 rate_limiter_timeout_ms: int = -1):
        """Constructs a new ReplayDataset.

    Args:
      server_address: Address of gRPC ReverbService.
      table: Probability table to sample from.
      dtypes: Dtypes of the data output. Can be nested.
      shapes: Shapes of the data output. Can be nested.
      max_in_flight_samples_per_worker: The number of samples requested in each
        batch of samples. Higher values give higher throughput but too big
        values can result in skewed sampling distributions as large number of
        samples are fetched from single snapshot of the replay (followed by a
        period of lower activity as the samples are consumed). A good rule of
        thumb is to set this value to 2-3x times the batch size used.
      num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number
        of worker threads to create per dataset iterator. When the selected
        table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be
        used to avoid races causing invalid ordering of items. For all other
        samplers, this value should be roughly equal to the number of threads
        available on the CPU.
      max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum
        number of samples to fetch from a stream before a new call is made.
        Keeping this number low ensures that the data is fetched uniformly from
        all server.
      sequence_length: (Defaults to None, i.e unknown) The number of timesteps
        that each sample consists of. If set then the length of samples received
        from the server will be validated against this number.
      emit_timesteps: (Defaults to True) If set, timesteps instead of full
        sequences are returned from the dataset. Returning sequences instead of
        timesteps can be more efficient as the memcopies caused by the splitting
        and batching of tensor can be avoided. Note that if set to False then
        then all `shapes` must have dim[0] equal to `sequence_length`.
      rate_limiter_timeout_ms: (Defaults to -1: infinite).  Timeout
        (in milliseconds) to wait on the rate limiter when sampling from the
        table. If `rate_limiter_timeout_ms >= 0`, this is the timeout passed to
        `Table::Sample` describing how long to wait for the rate limiter to
        allow sampling. The first time that a request times out (across any of
        the workers), the Dataset iterator is closed and the sequence is
        considered finished.

    Raises:
      ValueError: If `dtypes` and `shapes` don't share the same structure.
      ValueError: If `max_in_flight_samples_per_worker` is not a
        positive integer.
      ValueError: If `num_workers_per_iterator` is not a positive integer or -1.
      ValueError: If `max_samples_per_stream` is not a positive integer or -1.
      ValueError: If `sequence_length` is not a positive integer or None.
      ValueError: If `emit_timesteps is False` and not all items in `shapes` has
        `sequence_length` as its leading dimension.
      ValueError: If `rate_limiter_timeout_ms < -1`.
    """
        tree.assert_same_structure(dtypes, shapes, False)
        if max_in_flight_samples_per_worker < 1:
            raise ValueError(
                'max_in_flight_samples_per_worker (%d) must be a positive integer'
                % max_in_flight_samples_per_worker)
        if num_workers_per_iterator < 1 and num_workers_per_iterator != -1:
            raise ValueError(
                'num_workers_per_iterator (%d) must be a positive integer or -1'
                % num_workers_per_iterator)
        if max_samples_per_stream < 1 and max_samples_per_stream != -1:
            raise ValueError(
                'max_samples_per_stream (%d) must be a positive integer or -1'
                % max_samples_per_stream)
        if sequence_length is not None and sequence_length < 1:
            raise ValueError(
                'sequence_length (%s) must be None or a positive integer' %
                sequence_length)
        if rate_limiter_timeout_ms < -1:
            raise ValueError(
                'rate_limiter_timeout_ms (%d) must be an integer >= -1' %
                rate_limiter_timeout_ms)

        # Add the info fields.
        dtypes = replay_sample.ReplaySample(
            replay_sample.SampleInfo.tf_dtypes(), dtypes)
        shapes = replay_sample.ReplaySample(
            replay_sample.SampleInfo(
                tf.TensorShape(
                    [sequence_length] if not emit_timesteps else []),
                tf.TensorShape(
                    [sequence_length] if not emit_timesteps else []),
                tf.TensorShape(
                    [sequence_length] if not emit_timesteps else []),
                tf.TensorShape(
                    [sequence_length] if not emit_timesteps else [])), shapes)

        # If sequences are to be emitted then all shapes must specify use
        # sequence_length as their batch dimension.
        if not emit_timesteps:

            def _validate_batch_dim(path: str, shape: tf.TensorShape):
                if (not shape.ndims or tf.compat.dimension_value(shape[0]) !=
                        sequence_length):
                    raise ValueError(
                        'All items in shapes must use sequence_range (%s) as the leading '
                        'dimension, but "%s" has shape %s' %
                        (sequence_length, path[0], shape))

            tree.map_structure_with_path(_validate_batch_dim, shapes.data)

        # The tf.data API doesn't fully support lists so we convert all uses of
        # lists into tuples.
        dtypes = _convert_lists_to_tuples(dtypes)
        shapes = _convert_lists_to_tuples(shapes)

        self._server_address = server_address
        self._table = table
        self._dtypes = dtypes
        self._shapes = shapes
        self._sequence_length = sequence_length
        self._emit_timesteps = emit_timesteps
        self._max_in_flight_samples_per_worker = max_in_flight_samples_per_worker
        self._num_workers_per_iterator = num_workers_per_iterator
        self._max_samples_per_stream = max_samples_per_stream
        self._rate_limiter_timeout_ms = rate_limiter_timeout_ms

        if _is_tf1_runtime():
            # Disabling to avoid errors given the different tf.data.Dataset init args
            # between v1 and v2 APIs.
            # pytype: disable=wrong-arg-count
            super().__init__()
        else:
            # DatasetV2 requires the dataset as a variant tensor during init.
            super().__init__(self._as_variant_tensor())