Beispiel #1
0
 def test_server_info_timeout(self):
   # Setup a client that doesn't actually connect to anything.
   dummy_client = client.Client(f'localhost:{self.server.port + 1}')
   with self.assertRaises(
       errors.DeadlineExceededError,
       msg='ServerInfo call did not complete within provided timeout of 1s'):
     dummy_client.server_info(timeout=1)
  def test_numpy_squeeze(self):
    # No data will ever be sent to the server so it doesn't matter that we use
    # an invalid address.
    client = client_lib.Client('localhost:1234')
    writer = trajectory_writer.TrajectoryWriter(client, 5, 10)

    for i in range(10):
      writer.append({'a': i})
      self.assertEqual(writer.history['a'][-1].numpy(), i)
Beispiel #3
0
  def in_process_client(self):
    """Gets a local in process client.

    This bypasses proto serialization and network overhead.

    Returns:
      Client. Must not be used after this ReplayServer has been stopped!
    """
    return client.Client(f'[::1]:{self._port}', self._server.InProcessClient())
Beispiel #4
0
 def test_server_info_timeout(self):
   try:
     # Setup a client that doesn't actually connect to anything.
     dummy_port = portpicker.pick_unused_port()
     dummy_client = client.Client(f'localhost:{dummy_port}')
     with self.assertRaises(
         errors.DeadlineExceededError,
         msg='ServerInfo call did not complete within provided timeout of 1s'):
       dummy_client.server_info(timeout=1)
   finally:
     portpicker.return_port(dummy_port)
Beispiel #5
0
  def test_numpy_squeeze(self):
    # No data will ever be sent to the server so it doesn't matter that we use
    # an invalid address.
    client = client_lib.Client('localhost:1234')
    writer = trajectory_writer.TrajectoryWriter(
        client,
        max_chunk_length=5,
        num_keep_alive_refs=10,
        get_signature_timeout_ms=None)

    for i in range(10):
      writer.append({'a': i})
      self.assertEqual(writer.history['a'][-1].numpy(), i)
Beispiel #6
0
 def setUpClass(cls):
     super().setUpClass()
     cls.server = server.Server(tables=[
         server.Table(
             name=TABLE_NAME,
             sampler=item_selectors.Prioritized(1),
             remover=item_selectors.Fifo(),
             max_size=1000,
             rate_limiter=rate_limiters.MinSize(3),
             signature=tf.TensorSpec(dtype=tf.int64, shape=()),
         ),
     ],
                                port=None)
     cls.client = client.Client(f'localhost:{cls.server.port}')
  def test_episode_steps_reset_on_end_episode(self, clear_buffers: bool):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')

    # Create a writer and check that the counter starts at 0.
    writer = client.trajectory_writer(num_keep_alive_refs=1)
    self.assertEqual(writer.episode_steps, 0)

    # Append a step and check that the counter is incremented.
    writer.append([1])
    self.assertEqual(writer.episode_steps, 1)

    # End the episode and check the counter is reset.
    writer.end_episode(clear_buffers=clear_buffers)
    self.assertEqual(writer.episode_steps, 0)
  def test_numpy(self):
    # No data will ever be sent to the server so it doesn't matter that we use
    # an invalid address.
    client = client_lib.Client('localhost:1234')
    writer = trajectory_writer.TrajectoryWriter(client, 5, 10)

    for i in range(10):
      writer.append({'a': i, 'b': np.ones([3, 3], np.float) * i})

      np.testing.assert_array_equal(writer.history['a'][:].numpy(),
                                    np.arange(i + 1, dtype=np.int64))

      np.testing.assert_array_equal(
          writer.history['b'][:].numpy(),
          np.stack([np.ones([3, 3], np.float) * x for x in range(i + 1)]))
  def test_timeout_on_end_episode(self):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')

    writer = client.trajectory_writer(num_keep_alive_refs=1)
    writer.append([1])

    # Table has space for one item, up to 2 more items can be queued in
    # table worker queues.
    # Since there isn't space for all 4 items end_episode should time out.
    with self.assertRaises(errors.DeadlineExceededError):
      for _ in range(4):
        writer.create_item('queue', 1.0, writer.history[0][:])
        writer.end_episode(clear_buffers=False, timeout_ms=1)

    writer.close()
    server.stop()
  def test_episode_steps(self):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')
    writer = client.trajectory_writer(num_keep_alive_refs=1)

    for _ in range(10):
      # Every episode, including the first, should start at zero.
      self.assertEqual(writer.episode_steps, 0)

      for i in range(1, 21):
        writer.append({'x': 3, 'y': 2})

        # Step count should increment with each append call.
        self.assertEqual(writer.episode_steps, i)

      # Ending the episode should reset the step count to zero.
      writer.end_episode()
  def test_episode_steps_partial_step(self):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')
    writer = client.trajectory_writer(num_keep_alive_refs=1)

    for _ in range(3):
      # Every episode, including the first, should start at zero.
      self.assertEqual(writer.episode_steps, 0)

      for i in range(1, 4):
        writer.append({'x': 3}, partial_step=True)

        # Step count should not increment on partial append calls.
        self.assertEqual(writer.episode_steps, i - 1)

        writer.append({'y': 2})

        # Step count should increment after the unqualified append call.
        self.assertEqual(writer.episode_steps, i)

      # Ending the episode should reset the step count to zero.
      writer.end_episode()
Beispiel #12
0
 def localhost_client(self) -> client.Client:
     """Creates a client connect to the localhost channel."""
     return client.Client(f'localhost:{self._port}')
Beispiel #13
0
    def from_table_signature(
        cls,
        server_address: str,
        table: str,
        max_in_flight_samples_per_worker: int,
        num_workers_per_iterator: int = -1,
        max_samples_per_stream: int = -1,
        sequence_length: Optional[int] = None,
        emit_timesteps: bool = True,
        rate_limiter_timeout_ms: int = -1,
        get_signature_timeout_secs: Optional[int] = None,
    ):
        """Constructs a ReplayDataset using the table's signature to infer specs.

    Note: The signature must be provided to `Table` at construction. See
    `Table.__init__` (./server.py) for more details.

    Args:
      server_address: Address of gRPC ReverbService.
      table: Table to read the signature and sample from.
      max_in_flight_samples_per_worker: See __init__ for details.
      num_workers_per_iterator: See __init__ for details.
      max_samples_per_stream: See __init__ for details.
      sequence_length: See __init__ for details.
      emit_timesteps: See __init__ for details.
      rate_limiter_timeout_ms: See __init__ for details.
      get_signature_timeout_secs: Timeout in seconds to wait for server to
        respond when fetching the table signature. By default no timeout is set
        and the call will block indefinetely if the server does not respond.

    Returns:
      ReplayDataset using the specs defined by the table signature to build
        `shapes` and `dtypes`.

    Raises:
      ValueError: If `table` does not exist on server at `server_address`.
      ValueError: If `table` does not have a signature.
      errors.DeadlineExceededError: If `get_signature_timeout_secs` provided and
        exceeded.
      ValueError: See __init__.
    """
        client = reverb_client.Client(server_address)
        info = client.server_info(get_signature_timeout_secs)
        if table not in info:
            raise ValueError(
                f'Server at {server_address} does not contain any table named '
                f'{table}. Found: {", ".join(sorted(info.keys()))}.')

        if not info[table].signature:
            raise ValueError(
                f'Table {table} at {server_address} does not have a signature.'
            )

        shapes = tree.map_structure(lambda x: x.shape, info[table].signature)
        dtypes = tree.map_structure(lambda x: x.dtype, info[table].signature)

        if not emit_timesteps:
            batch_dim = tf.TensorShape([sequence_length])
            shapes = tree.map_structure(batch_dim.concatenate, shapes)

        return cls(
            server_address=server_address,
            table=table,
            shapes=shapes,
            dtypes=dtypes,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=num_workers_per_iterator,
            max_samples_per_stream=max_samples_per_stream,
            sequence_length=sequence_length,
            emit_timesteps=emit_timesteps,
            rate_limiter_timeout_ms=rate_limiter_timeout_ms)
 def setUp(self):
     super().setUp()
     self.client = client_lib.Client(f'localhost:{self._server.port}')
Beispiel #15
0
  def from_table_signature(cls,
                           server_address: str,
                           table: str,
                           max_in_flight_samples_per_worker: int,
                           num_workers_per_iterator: int = -1,
                           max_samples_per_stream: int = -1,
                           rate_limiter_timeout_ms: int = -1,
                           get_signature_timeout_secs: Optional[int] = None,
                           flexible_batch_size: int = -1):
    """Constructs a TimestepDataset using the table's signature to infer specs.

    Note: The target `Table` must specify a signature that represents a single
      timestep (as opposed to an entire trajectory). See `Table.__init__`
      (./server.py) for more details.

    Args:
      server_address: Address of gRPC ReverbService.
      table: Table to read the signature and sample from.
      max_in_flight_samples_per_worker: See __init__ for details.
      num_workers_per_iterator: See __init__ for details.
      max_samples_per_stream: See __init__ for details.
      rate_limiter_timeout_ms: See __init__ for details.
      get_signature_timeout_secs: Timeout in seconds to wait for server to
        respond when fetching the table signature. By default no timeout is set
        and the call will block indefinitely if the server does not respond.
      flexible_batch_size: See __init__ for details.

    Returns:
      TimestepDataset using the specs defined by the table signature to build
        `shapes` and `dtypes`.

    Raises:
      ValueError: If `table` does not exist on server at `server_address`.
      ValueError: If `table` does not have a signature.
      errors.DeadlineExceededError: If `get_signature_timeout_secs` provided and
        exceeded.
      ValueError: See __init__.
    """
    client = reverb_client.Client(server_address)
    info = client.server_info(get_signature_timeout_secs)
    if table not in info:
      raise ValueError(
          f'Server at {server_address} does not contain any table named '
          f'{table}. Found: {", ".join(sorted(info.keys()))}.')

    if not info[table].signature:
      raise ValueError(
          f'Table {table} at {server_address} does not have a signature.')

    shapes = tree.map_structure(lambda x: x.shape, info[table].signature)
    dtypes = tree.map_structure(lambda x: x.dtype, info[table].signature)

    return cls(
        server_address=server_address,
        table=table,
        shapes=shapes,
        dtypes=dtypes,
        max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
        num_workers_per_iterator=num_workers_per_iterator,
        max_samples_per_stream=max_samples_per_stream,
        rate_limiter_timeout_ms=rate_limiter_timeout_ms,
        flexible_batch_size=flexible_batch_size)
Beispiel #16
0
 def setUpClass(cls):
     super().setUpClass()
     cls._server = make_server()
     cls._client = client.Client(f'localhost:{cls._server.port}')
Beispiel #17
0
 def setUpClass(cls):
     super().setUpClass()
     cls._tables, cls._server = make_tables_and_server()
     cls._client = reverb_client.Client(f'localhost:{cls._server.port}')