示例#1
0
    def end_episode(self,
                    clear_buffers: bool = True,
                    timeout_ms: Optional[int] = None):
        """Flush all pending items and generate a new episode ID.

    Configurations that are conditioned to only be appied on episode end are
    applied (assuming all other conditions are fulfilled) and the items inserted
    before flush is called.

    Args:
      clear_buffers: Whether the history should be cleared or not. Buffers
        should only not be cleared when trajectories spanning multiple episodes
        are used.
      timeout_ms: (optional, default is no timeout) Maximum time to block for
        before unblocking and raising a `DeadlineExceededError` instead. Note
        that although the block is interrupted, the buffers and episode ID are
        reset all the same and the insertion of the items will proceed in the
        background thread.

    Raises:
      DeadlineExceededError: If operation did not complete before the timeout.
    """
        try:
            self._writer.EndEpisode(clear_buffers, timeout_ms)
        except RuntimeError as e:
            if 'Timeout exceeded' in str(e) and timeout_ms is not None:
                raise errors.DeadlineExceededError(
                    f'End episode call did not complete within provided timeout of '
                    f'{datetime.timedelta(milliseconds=timeout_ms)}')
            raise
示例#2
0
    def server_info(
            self,
            timeout: Optional[int] = None
    ) -> Dict[str, reverb_types.TableInfo]:
        """Get table metadata information.

    Args:
      timeout: Timeout in seconds to wait for server response. By default no
        deadline is set and call will block indefinetely until server responds.

    Returns:
      A dictionary mapping table names to their associated `TableInfo`
      instances, which contain metadata about the table.

    Raises:
      errors.DeadlineExceededError: If timeout provided and exceeded.
    """
        try:
            info_proto_strings = self._client.ServerInfo(timeout or 0)
        except RuntimeError as e:
            if 'Deadline Exceeded' in str(e) and timeout is not None:
                raise errors.DeadlineExceededError(
                    f'ServerInfo call did not complete within provided timeout of '
                    f'{timeout}s')
            raise

        table_infos = {}
        for proto_string in info_proto_strings:
            table_info = reverb_types.TableInfo.from_serialized_proto(
                proto_string)
            table_infos[table_info.name] = table_info
        return table_infos
示例#3
0
    def flush(self,
              block_until_num_items: int = 0,
              timeout_ms: Optional[int] = None):
        """Block until all but `block_until_num_items` confirmed by the server.

    There are two ways that an item could be "pending":

      1. Some of the data elements referenced by the item have not yet been
         finalized (and compressed) as a `ChunkData`.
      2. The item has been written to the gRPC stream but the response
         confirming the insertion has not yet been received.

    Type 1 pending items are transformed into type 2 when flush is called by
    forcing (premature) chunk finalization of the data elements referenced by
    the items. This will allow the background worker to write the data and items
    to the gRPC stream and turn them into type 2 pending items.

    The time it takes for type 2 pending items to be confirmed is primarily
    due to the state of the table rate limiter. After the items have been
    written to the gRPC stream then all we can do is wait (GIL is not held).

    Args:
      block_until_num_items: If > 0 then this many pending items will be allowed
        to remain as type 1. If the number of type 1 pending items is less than
        `block_until_num_items` then we simply wait until the total number of
        pending items is <= `block_until_num_items`.
      timeout_ms: (optional, default is no timeout) Maximum time to block for
        before unblocking and raising a `DeadlineExceededError` instead. Note
        that although the block is interrupted, the insertion of the items will
        proceed in the background.

    Raises:
      ValueError: If block_until_num_items < 0.
      DeadlineExceededError: If operation did not complete before the timeout.
    """
        if block_until_num_items < 0:
            raise ValueError(
                f'block_until_num_items must be >= 0, got {block_until_num_items}'
            )

        if timeout_ms is None:
            timeout_ms = -1

        try:
            self._writer.Flush(block_until_num_items, timeout_ms)
        except RuntimeError as e:
            if 'Timeout exceeded' in str(e) and timeout_ms is not None:
                raise errors.DeadlineExceededError(
                    f'Flush call did not complete within provided timeout of '
                    f'{datetime.timedelta(milliseconds=timeout_ms)}')
            raise
示例#4
0
    def server_info(
            self,
            timeout: Optional[int] = None
    ) -> Dict[str, reverb_types.TableInfo]:
        """Get table metadata information.

    Args:
      timeout: Timeout in seconds to wait for server response. By default no
        deadline is set and call will block indefinetely until server responds.

    Returns:
      A dictionary mapping table names to their associated `TableInfo`
      instances, which contain metadata about the table.

    Raises:
      errors.DeadlineExceededError: If timeout provided and exceeded.
    """
        try:
            info_proto_strings = self._client.ServerInfo(timeout or 0)
        except RuntimeError as e:
            if 'Deadline Exceeded' in str(e) and timeout is not None:
                raise errors.DeadlineExceededError(
                    f'ServerInfo call did not complete within provided timeout of '
                    f'{timeout}s')
            raise

        table_info = {}
        for proto_string in info_proto_strings:
            proto = schema_pb2.TableInfo.FromString(proto_string)
            if proto.HasField('signature'):
                signature = nested_structure_coder.StructureCoder(
                ).decode_proto(proto.signature)
            else:
                signature = None
            info_dict = dict((descr.name, getattr(proto, descr.name))
                             for descr in proto.DESCRIPTOR.fields)
            info_dict['signature'] = signature
            name = str(info_dict['name'])
            table_info[name] = reverb_types.TableInfo(**info_dict)
        return table_info
示例#5
0
  def test_exit_does_not_flush_on_reverb_error(self):
    # If there are no errors then flush should be called.
    with mock.patch.object(self.writer, 'flush') as flush_mock:
      with self.writer:
        pass

      flush_mock.assert_called_once()

    # It flush if unrelated errors are encountered
    with mock.patch.object(self.writer, 'flush') as flush_mock:
      with self.assertRaises(ValueError):
        with self.writer:
          raise ValueError('Test')

      flush_mock.assert_called_once()

    # But it should not flush if Reverb raises the error.
    with mock.patch.object(self.writer, 'flush') as flush_mock:
      with self.assertRaises(errors.ReverbError):
        with self.writer:
          raise errors.DeadlineExceededError('Test')

      flush_mock.assert_not_called()