def sample(self, table: str, data_dtypes, name: Optional[str] = None) -> replay_sample.ReplaySample: """Samples an item from the replay. This only allows sampling items with a data field. Args: table: Probability table to sample from. data_dtypes: Dtypes of the data output. Can be nested. name: Optional name for the Client operations. Returns: A ReplaySample with data nested according to data_dtypes. See ReplaySample for more details. """ with tf.name_scope(name, f'{self._name}_sample', ['sample']) as scope: key, probability, table_size, priority, data = gen_client_ops.reverb_client_sample( self._handle, table, tree.flatten(data_dtypes), name=scope) return replay_sample.ReplaySample( replay_sample.SampleInfo(key=key, probability=probability, table_size=table_size, priority=priority), tree.unflatten_as(data_dtypes, data))
def sample_trajectory(client: client_lib.Client, table: str, structure: Any) -> replay_sample.ReplaySample: """Temporary helper method for sampling a trajectory. Note! This function is only intended to make it easier for alpha testers to experiment with the new API. It will be removed before this file is made public. Args: client: Client connected to the server to sample from. table: Name of the table to sample from. structure: Structure to unpack flat data as. Returns: ReplaySample with trajectory unpacked as `structure` in `data`-field. """ sampler = client._client.NewSampler(table, 1, 1, 1) # pylint: disable=protected-access sample = sampler.GetNextSample() return replay_sample.ReplaySample( info=replay_sample.SampleInfo(key=int(sample[0][0]), probability=float(sample[1][0]), table_size=int(sample[2][0]), priority=float(sample[3][0])), data=tree.unflatten_as(structure, sample[4:]))
def sample( self, table: str, num_samples=1 ) -> Generator[List[replay_sample.ReplaySample], None, None]: """Samples `num_samples` items from table `table` of the Server. NOTE: This method should NOT be used for real training. TFClient (see tf_client.py) has far superior performance and should always be preferred. Note: If data was written using `insert` (e.g when inserting complete trajectories) then the returned "sequence" will be a list of length 1 containing the trajectory as a single item. If `num_samples` is greater than the number of items in `table`, (or a rate limiter is used to control sampling), then the returned generator will block when an item past the sampling limit is requested. It will unblock when sufficient additional items have been added to `table`. Example: ```python server = Server(..., tables=[queue("queue", ...)]) client = Client(...) # Don't insert anything into "queue" generator = client.sample("queue") generator.next() # Blocks until another thread/process writes to queue. ``` Args: table: Name of the priority table to sample from. num_samples: (default to 1) The number of samples to fetch. Yields: Lists of timesteps (lists of instances of `ReplaySample`). If data was inserted into the table via `insert`, then each element of the generator is a length 1 list containing a `ReplaySample`. If data was inserted via a writer, then each element is a list whose length is the sampled trajectory's length. """ sampler = self._client.NewSampler(table, num_samples, 1) for _ in range(num_samples): sequence = [] last = False while not last: step, last = sampler.GetNextTimestep() key = int(step[0]) probability = float(step[1]) table_size = int(step[2]) priority = float(step[3]) data = step[4:] sequence.append( replay_sample.ReplaySample(info=replay_sample.SampleInfo( key, probability, table_size, priority), data=data)) yield sequence
def test_sample_variable_length_trajectory(self): with self._client.trajectory_writer(10) as writer: for i in range(10): writer.append([np.ones([3, 3], np.int32) * i]) writer.create_item(TABLE, 1.0, { 'last': writer.history[0][-1], 'all': writer.history[0][:], }) dataset = trajectory_dataset.TrajectoryDataset( tf.constant(self._client.server_address), table=tf.constant(TABLE), dtypes={ 'last': tf.int32, 'all': tf.int32, }, shapes={ 'last': tf.TensorShape([3, 3]), 'all': tf.TensorShape([None, 3, 3]), }, max_in_flight_samples_per_worker=1, flexible_batch_size=1) # Continue sample until we have observed all the trajectories. seen_lengths = set() while len(seen_lengths) < 10: sample = self._sample_from(dataset, 1)[0] # The structure should always be the same. tree.assert_same_structure( sample, replay_sample.ReplaySample( info=replay_sample.SampleInfo( key=1, probability=1.0, table_size=10, priority=0.5, ), data={ 'last': None, 'all': None })) seen_lengths.add(sample.data['all'].shape[0]) self.assertEqual(seen_lengths, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
def test_sample_fixed_length_trajectory(self): self._populate_replay() dataset = trajectory_dataset.TrajectoryDataset( tf.constant(self._client.server_address), table=tf.constant(TABLE), dtypes=DTYPES, shapes=SHAPES, max_in_flight_samples_per_worker=1, flexible_batch_size=1) tree.assert_same_structure( self._sample_from(dataset, 1)[0], replay_sample.ReplaySample(info=replay_sample.SampleInfo( key=1, probability=1.0, table_size=10, priority=0.5, ), data=SHAPES))
def __init__(self, server_address: Union[str, tf.Tensor], table: Union[str, tf.Tensor], dtypes: Any, shapes: Any, max_in_flight_samples_per_worker: int, num_workers_per_iterator: int = -1, max_samples_per_stream: int = -1, rate_limiter_timeout_ms: int = -1, flexible_batch_size: int = -1): """Constructs a new TimestepDataset. Args: server_address: Address of gRPC ReverbService. table: Probability table to sample from. dtypes: Dtypes of the data output. Can be nested. shapes: Shapes of the data output. Can be nested. max_in_flight_samples_per_worker: The number of samples requested in each batch of samples. Higher values give higher throughput but too big values can result in skewed sampling distributions as large number of samples are fetched from single snapshot of the replay (followed by a period of lower activity as the samples are consumed). A good rule of thumb is to set this value to 2-3x times the batch size used. num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number of worker threads to create per dataset iterator. When the selected table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be used to avoid races causing invalid ordering of items. For all other samplers, this value should be roughly equal to the number of threads available on the CPU. max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum number of samples to fetch from a stream before a new call is made. Keeping this number low ensures that the data is fetched uniformly from all server. rate_limiter_timeout_ms: (Defaults to -1: infinite). Timeout (in milliseconds) to wait on the rate limiter when sampling from the table. If `rate_limiter_timeout_ms >= 0`, this is the timeout passed to `Table::Sample` describing how long to wait for the rate limiter to allow sampling. The first time that a request times out (across any of the workers), the Dataset iterator is closed and the sequence is considered finished. flexible_batch_size: (Defaults to -1: auto selected) The maximum number of items to sampled from `Table` with single call. Values > 1 enables `Table::SampleFlexibleBatch` to return more than one item (but no more than `flexible_batch_size`) in a single call without releasing the table lock iff the rate limiter allows it. NOTE! It is unlikely that you need to tune this value yourself. The auto selected value should almost always be preferred. Larger `flexible_batch_size` values result a bias towards sampling over inserts. In highly overloaded systems this results in higher sample QPS and lower insert QPS compared to lower `flexible_batch_size` values. Raises: ValueError: If `dtypes` and `shapes` don't share the same structure. ValueError: If `max_in_flight_samples_per_worker` is not a positive integer. ValueError: If `num_workers_per_iterator` is not a positive integer or -1. ValueError: If `max_samples_per_stream` is not a positive integer or -1. ValueError: If `rate_limiter_timeout_ms < -1`. ValueError: If `flexible_batch_size` is not a positive integer or -1. """ tree.assert_same_structure(dtypes, shapes, False) if max_in_flight_samples_per_worker < 1: raise ValueError( 'max_in_flight_samples_per_worker (%d) must be a positive integer' % max_in_flight_samples_per_worker) if num_workers_per_iterator < 1 and num_workers_per_iterator != -1: raise ValueError( 'num_workers_per_iterator (%d) must be a positive integer or -1' % num_workers_per_iterator) if max_samples_per_stream < 1 and max_samples_per_stream != -1: raise ValueError( 'max_samples_per_stream (%d) must be a positive integer or -1' % max_samples_per_stream) if rate_limiter_timeout_ms < -1: raise ValueError('rate_limiter_timeout_ms (%d) must be an integer >= -1' % rate_limiter_timeout_ms) if flexible_batch_size < 1 and flexible_batch_size != -1: raise ValueError( 'flexible_batch_size (%d) must be a positive integer or -1' % flexible_batch_size) # Add the info fields (all scalars). dtypes = replay_sample.ReplaySample( info=replay_sample.SampleInfo.tf_dtypes(), data=dtypes) shapes = replay_sample.ReplaySample( info=replay_sample.SampleInfo( key=tf.TensorShape([]), probability=tf.TensorShape([]), table_size=tf.TensorShape([]), priority=tf.TensorShape([])), data=shapes) # The tf.data API doesn't fully support lists so we convert all uses of # lists into tuples. dtypes = _convert_lists_to_tuples(dtypes) shapes = _convert_lists_to_tuples(shapes) self._server_address = server_address self._table = table self._dtypes = dtypes self._shapes = shapes self._max_in_flight_samples_per_worker = max_in_flight_samples_per_worker self._num_workers_per_iterator = num_workers_per_iterator self._max_samples_per_stream = max_samples_per_stream self._rate_limiter_timeout_ms = rate_limiter_timeout_ms self._flexible_batch_size = flexible_batch_size if _is_tf1_runtime(): # Disabling to avoid errors given the different tf.data.Dataset init args # between v1 and v2 APIs. # pytype: disable=wrong-arg-count super().__init__() else: # DatasetV2 requires the dataset as a variant tensor during init. super().__init__(self._as_variant_tensor())
def sample( self, table: str, num_samples: int = 1, *, emit_timesteps: bool = True, unpack_as_table_signature: bool = False, ) -> Generator[Union[List[replay_sample.ReplaySample], replay_sample.ReplaySample], None, None]: """Samples `num_samples` items from table `table` of the Server. NOTE: This method should NOT be used for real training. TrajectoryDataset and TimestepDataset should always be preferred over this method. Note: If data was written using `insert` (e.g when inserting complete trajectories) then the returned "sequence" will be a list of length 1 containing the trajectory as a single item. If `num_samples` is greater than the number of items in `table`, (or a rate limiter is used to control sampling), then the returned generator will block when an item past the sampling limit is requested. It will unblock when sufficient additional items have been added to `table`. Example: ```python server = Server(..., tables=[queue("queue", ...)]) client = Client(...) # Don't insert anything into "queue" generator = client.sample("queue") generator.next() # Blocks until another thread/process writes to queue. ``` Args: table: Name of the priority table to sample from. num_samples: (default to 1) The number of samples to fetch. emit_timesteps: If True then trajectories are returned as a list of `ReplaySample`, each representing a single step within the trajectory. unpack_as_table_signature: If True then the sampled data is unpacked according to the structure of the table signature. If the table does not have a signature then flat data is returned. Yields: If `emit_timesteps` is `True`: Lists of timesteps (lists of instances of `ReplaySample`). If data was inserted into the table via `insert`, then each element of the generator is a length 1 list containing a `ReplaySample`. If data was inserted via a writer, then each element is a list whose length is the sampled trajectory's length. If emit_timesteps is False: An instance of `ReplaySample` where the data is unpacked according to the signature of the table. If the table does not have any signature then the data is flat, i.e each element is a leaf node of the full trajectory. Raises: ValueError: If `emit_timestep` is True but the trajectory cannot be decomposed into timesteps. """ buffer_size = 1 if unpack_as_table_signature: signature = self._get_signature_for_table(table) else: signature = None if signature: unflatten = lambda x: tree.unflatten_as(signature, x) else: unflatten = lambda x: x sampler = self._client.NewSampler(table, num_samples, buffer_size) for _ in range(num_samples): sample = sampler.GetNextTrajectory() info = replay_sample.SampleInfo( key=int(sample[0]), probability=float(sample[1]), table_size=int(sample[2]), priority=float(sample[3]), times_sampled=int(sample[4])) data = sample[len(info):] if emit_timesteps: if len(set([len(col) for col in data])) != 1: raise ValueError( 'Can\'t split non timestep trajectory into timesteps.') timesteps = [] for i in range(data[0].shape[0]): timestep = replay_sample.ReplaySample( info=info, data=unflatten([np.asarray(col[i], col.dtype) for col in data])) timesteps.append(timestep) yield timesteps else: yield replay_sample.ReplaySample(info, unflatten(data))
def __init__(self, server_address: str, table: str, dtypes: Any, shapes: Any, max_in_flight_samples_per_worker: int, num_workers_per_iterator: int = -1, max_samples_per_stream: int = -1, sequence_length: Optional[int] = None, emit_timesteps: bool = True, rate_limiter_timeout_ms: int = -1): """Constructs a new ReplayDataset. Args: server_address: Address of gRPC ReverbService. table: Probability table to sample from. dtypes: Dtypes of the data output. Can be nested. shapes: Shapes of the data output. Can be nested. max_in_flight_samples_per_worker: The number of samples requested in each batch of samples. Higher values give higher throughput but too big values can result in skewed sampling distributions as large number of samples are fetched from single snapshot of the replay (followed by a period of lower activity as the samples are consumed). A good rule of thumb is to set this value to 2-3x times the batch size used. num_workers_per_iterator: (Defaults to -1, i.e auto selected) The number of worker threads to create per dataset iterator. When the selected table uses a FIFO sampler (i.e a queue) then exactly 1 worker must be used to avoid races causing invalid ordering of items. For all other samplers, this value should be roughly equal to the number of threads available on the CPU. max_samples_per_stream: (Defaults to -1, i.e auto selected) The maximum number of samples to fetch from a stream before a new call is made. Keeping this number low ensures that the data is fetched uniformly from all server. sequence_length: (Defaults to None, i.e unknown) The number of timesteps that each sample consists of. If set then the length of samples received from the server will be validated against this number. emit_timesteps: (Defaults to True) If set, timesteps instead of full sequences are returned from the dataset. Returning sequences instead of timesteps can be more efficient as the memcopies caused by the splitting and batching of tensor can be avoided. Note that if set to False then then all `shapes` must have dim[0] equal to `sequence_length`. rate_limiter_timeout_ms: (Defaults to -1: infinite). Timeout (in milliseconds) to wait on the rate limiter when sampling from the table. If `rate_limiter_timeout_ms >= 0`, this is the timeout passed to `Table::Sample` describing how long to wait for the rate limiter to allow sampling. The first time that a request times out (across any of the workers), the Dataset iterator is closed and the sequence is considered finished. Raises: ValueError: If `dtypes` and `shapes` don't share the same structure. ValueError: If `max_in_flight_samples_per_worker` is not a positive integer. ValueError: If `num_workers_per_iterator` is not a positive integer or -1. ValueError: If `max_samples_per_stream` is not a positive integer or -1. ValueError: If `sequence_length` is not a positive integer or None. ValueError: If `emit_timesteps is False` and not all items in `shapes` has `sequence_length` as its leading dimension. ValueError: If `rate_limiter_timeout_ms < -1`. """ tree.assert_same_structure(dtypes, shapes, False) if max_in_flight_samples_per_worker < 1: raise ValueError( 'max_in_flight_samples_per_worker (%d) must be a positive integer' % max_in_flight_samples_per_worker) if num_workers_per_iterator < 1 and num_workers_per_iterator != -1: raise ValueError( 'num_workers_per_iterator (%d) must be a positive integer or -1' % num_workers_per_iterator) if max_samples_per_stream < 1 and max_samples_per_stream != -1: raise ValueError( 'max_samples_per_stream (%d) must be a positive integer or -1' % max_samples_per_stream) if sequence_length is not None and sequence_length < 1: raise ValueError( 'sequence_length (%s) must be None or a positive integer' % sequence_length) if rate_limiter_timeout_ms < -1: raise ValueError( 'rate_limiter_timeout_ms (%d) must be an integer >= -1' % rate_limiter_timeout_ms) # Add the info fields. dtypes = replay_sample.ReplaySample( replay_sample.SampleInfo.tf_dtypes(), dtypes) shapes = replay_sample.ReplaySample( replay_sample.SampleInfo( tf.TensorShape( [sequence_length] if not emit_timesteps else []), tf.TensorShape( [sequence_length] if not emit_timesteps else []), tf.TensorShape( [sequence_length] if not emit_timesteps else []), tf.TensorShape( [sequence_length] if not emit_timesteps else [])), shapes) # If sequences are to be emitted then all shapes must specify use # sequence_length as their batch dimension. if not emit_timesteps: def _validate_batch_dim(path: str, shape: tf.TensorShape): if (not shape.ndims or tf.compat.dimension_value(shape[0]) != sequence_length): raise ValueError( 'All items in shapes must use sequence_range (%s) as the leading ' 'dimension, but "%s" has shape %s' % (sequence_length, path[0], shape)) tree.map_structure_with_path(_validate_batch_dim, shapes.data) # The tf.data API doesn't fully support lists so we convert all uses of # lists into tuples. dtypes = _convert_lists_to_tuples(dtypes) shapes = _convert_lists_to_tuples(shapes) self._server_address = server_address self._table = table self._dtypes = dtypes self._shapes = shapes self._sequence_length = sequence_length self._emit_timesteps = emit_timesteps self._max_in_flight_samples_per_worker = max_in_flight_samples_per_worker self._num_workers_per_iterator = num_workers_per_iterator self._max_samples_per_stream = max_samples_per_stream self._rate_limiter_timeout_ms = rate_limiter_timeout_ms if _is_tf1_runtime(): # Disabling to avoid errors given the different tf.data.Dataset init args # between v1 and v2 APIs. # pytype: disable=wrong-arg-count super().__init__() else: # DatasetV2 requires the dataset as a variant tensor during init. super().__init__(self._as_variant_tensor())