Beispiel #1
0
def make_server():
    return server.Server(
        tables=[
            server.Table(
                'dist',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1)),
            server.Table(
                'dist2',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1)),
            server.Table(
                'signatured',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1),
                signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))),
        ],
        port=None,
    )
Beispiel #2
0
  def test_max_samples(self, num_workers_per_iterator,
                       max_in_flight_samples_per_worker, max_samples):
    s = reverb_server.Server([reverb_server.Table.queue('q', 10)])
    c = s.localhost_client()

    for i in range(10):
      c.insert(i, {'q': 1})

    ds = timestep_dataset.TimestepDataset(
        server_address=c.server_address,
        table='q',
        dtypes=tf.int64,
        shapes=tf.TensorShape([]),
        max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
        num_workers_per_iterator=num_workers_per_iterator,
        max_samples=max_samples)

    iterator = ds.make_one_shot_iterator()
    item = iterator.get_next()
    # Check that it fetches exactly 5 samples.
    with self.session() as sess:
      for _ in range(max_samples):
        sess.run(item)
      with self.assertRaises(tf.errors.OutOfRangeError):
        sess.run(item)

    # Check that no prefetching happened; Check that there are 5 items left in
    # the queue.
    self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples)
    np.testing.assert_array_equal(
        next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
    def test_max_samples(self, num_workers_per_iterator,
                         max_in_flight_samples_per_worker, max_samples):
        s = reverb_server.Server([reverb_server.Table.queue('q', 10)])
        c = s.localhost_client()

        for i in range(10):
            c.insert(i, {'q': 1})

        ds = timestep_dataset.TimestepDataset(
            server_address=c.server_address,
            table='q',
            dtypes=tf.int64,
            shapes=tf.TensorShape([]),
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            num_workers_per_iterator=num_workers_per_iterator,
            max_samples=max_samples)

        # Check that it fetches exactly `max_samples` samples.
        it = ds.as_numpy_iterator()
        self.assertLen(list(it), max_samples)

        # Check that no prefetching happened in the queue.
        self.assertEqual(c.server_info()['q'].current_size, 10 - max_samples)
        np.testing.assert_array_equal(
            next(c.sample('q', 1))[0].data[0], np.asarray(max_samples))
Beispiel #4
0
def make_server():
    return reverb_server.Server(
        tables=[
            reverb_server.Table(
                'dist',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1)),
            reverb_server.Table(
                'signatured',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1),
                signature=tf.TensorSpec(dtype=tf.float32, shape=(None, None))),
            reverb_server.Table(
                'bounded_spec_signatured',
                sampler=item_selectors.Prioritized(priority_exponent=1),
                remover=item_selectors.Fifo(),
                max_size=1000000,
                rate_limiter=rate_limiters.MinSize(1),
                # Currently only the `shape` and `dtype` of the bounded spec
                # is considered during signature check.
                # TODO(b/158033101): Check the boundaries as well.
                signature=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                        shape=(None, None),
                                                        minimum=(0.0, 0.0),
                                                        maximum=(10.0, 10.)),
            ),
        ],
        port=None,
    )
Beispiel #5
0
    def test_combines_sequence_length_with_signature_if_not_emit_timestamps(
            self):
        server = reverb_server.Server([
            reverb_server.Table.queue('queue',
                                      10,
                                      signature={
                                          'a': {
                                              'b':
                                              tf.TensorSpec([3, 3],
                                                            tf.float32),
                                              'c':
                                              tf.TensorSpec([], tf.int64),
                                          },
                                      })
        ])

        dataset = reverb_dataset.ReplayDataset.from_table_signature(
            f'localhost:{server.port}',
            'queue',
            100,
            emit_timesteps=False,
            sequence_length=5)
        self.assertDictEqual(
            dataset.element_spec.data, {
                'a': {
                    'b': tf.TensorSpec([5, 3, 3], tf.float32),
                    'c': tf.TensorSpec([5], tf.int64),
                },
            })
def make_server():
    return reverb_server.Server(tables=[
        reverb_server.Table(name=TABLE,
                            sampler=item_selectors.Prioritized(
                                priority_exponent=1),
                            remover=item_selectors.Fifo(),
                            max_size=1000,
                            rate_limiter=rate_limiters.MinSize(1)),
    ])
Beispiel #7
0
 def test_in_process_client(self):
     my_server = server.Server(tables=[
         server.Table(name=TABLE_NAME,
                      sampler=item_selectors.Prioritized(1),
                      remover=item_selectors.Fifo(),
                      max_size=100,
                      rate_limiter=rate_limiters.MinSize(2)),
     ])
     my_client = my_server.localhost_client()
     my_client.reset(TABLE_NAME)
     del my_client
     my_server.stop()
Beispiel #8
0
 def setUpClass(cls):
     super().setUpClass()
     cls.server = server.Server(tables=[
         server.Table(
             name=TABLE_NAME,
             sampler=item_selectors.Prioritized(1),
             remover=item_selectors.Fifo(),
             max_size=1000,
             rate_limiter=rate_limiters.MinSize(3),
             signature=tf.TensorSpec(dtype=tf.int64, shape=()),
         ),
     ],
                                port=None)
     cls.client = client.Client(f'localhost:{cls.server.port}')
Beispiel #9
0
def make_tables_and_server():
    tables = [
        server.Table('dist',
                     sampler=item_selectors.Prioritized(priority_exponent=1),
                     remover=item_selectors.Fifo(),
                     max_size=1000000,
                     rate_limiter=rate_limiters.MinSize(1)),
        server.Table('dist2',
                     sampler=item_selectors.Prioritized(priority_exponent=1),
                     remover=item_selectors.Fifo(),
                     max_size=1000000,
                     rate_limiter=rate_limiters.MinSize(1)),
    ]
    return tables, server.Server(tables=tables)
Beispiel #10
0
 def test_duplicate_priority_table_name(self):
     with self.assertRaises(ValueError):
         server.Server(tables=[
             server.Table(name='test',
                          sampler=item_selectors.Prioritized(1),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          rate_limiter=rate_limiters.MinSize(2)),
             server.Table(name='test',
                          sampler=item_selectors.Prioritized(2),
                          remover=item_selectors.Fifo(),
                          max_size=200,
                          rate_limiter=rate_limiters.MinSize(1))
         ])
Beispiel #11
0
    def test_table_not_found(self):
        server = reverb_server.Server([
            reverb_server.Table.queue('table_a', 10),
            reverb_server.Table.queue('table_c', 10),
            reverb_server.Table.queue('table_b', 10),
        ])
        address = f'localhost:{server.port}'

        with self.assertRaisesWithPredicateMatch(
                ValueError,
                f'Server at {address} does not contain any table named not_found. '
                f'Found: table_a, table_b, table_c.'):
            reverb_dataset.ReplayDataset.from_table_signature(
                address, 'not_found', 100)
Beispiel #12
0
    def test_sets_dtypes_from_signature(self):
        signature = {
            'a': {
                'b': tf.TensorSpec([3, 3], tf.float32),
                'c': tf.TensorSpec([], tf.int64),
            },
            'x': tf.TensorSpec([None], tf.uint64),
        }

        server = reverb_server.Server(
            [reverb_server.Table.queue('queue', 10, signature=signature)])

        dataset = reverb_dataset.ReplayDataset.from_table_signature(
            f'localhost:{server.port}', 'queue', 100)
        self.assertDictEqual(dataset.element_spec.data, signature)
  def test_episode_steps_reset_on_end_episode(self, clear_buffers: bool):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')

    # Create a writer and check that the counter starts at 0.
    writer = client.trajectory_writer(num_keep_alive_refs=1)
    self.assertEqual(writer.episode_steps, 0)

    # Append a step and check that the counter is incremented.
    writer.append([1])
    self.assertEqual(writer.episode_steps, 1)

    # End the episode and check the counter is reset.
    writer.end_episode(clear_buffers=clear_buffers)
    self.assertEqual(writer.episode_steps, 0)
  def test_episode_steps(self):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')
    writer = client.trajectory_writer(num_keep_alive_refs=1)

    for _ in range(10):
      # Every episode, including the first, should start at zero.
      self.assertEqual(writer.episode_steps, 0)

      for i in range(1, 21):
        writer.append({'x': 3, 'y': 2})

        # Step count should increment with each append call.
        self.assertEqual(writer.episode_steps, i)

      # Ending the episode should reset the step count to zero.
      writer.end_episode()
  def test_timeout_on_end_episode(self):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')

    writer = client.trajectory_writer(num_keep_alive_refs=1)
    writer.append([1])

    # Table has space for one item, up to 2 more items can be queued in
    # table worker queues.
    # Since there isn't space for all 4 items end_episode should time out.
    with self.assertRaises(errors.DeadlineExceededError):
      for _ in range(4):
        writer.create_item('queue', 1.0, writer.history[0][:])
        writer.end_episode(clear_buffers=False, timeout_ms=1)

    writer.close()
    server.stop()
Beispiel #16
0
 def test_can_sample(self):
     table = server.Table(name=TABLE_NAME,
                          sampler=item_selectors.Prioritized(1),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          max_times_sampled=1,
                          rate_limiter=rate_limiters.MinSize(2))
     my_server = server.Server(tables=[table], port=None)
     my_client = my_server.in_process_client()
     self.assertFalse(table.can_sample(1))
     self.assertTrue(table.can_insert(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     self.assertFalse(table.can_sample(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     self.assertTrue(table.can_sample(2))
     # TODO(b/153258711): This should return False since max_times_sampled=1.
     self.assertTrue(table.can_sample(3))
     del my_client
     my_server.stop()
Beispiel #17
0
 def setUpClass(cls):
   super().setUpClass()
   cls.tables = [
       server.Table(
           name=TABLE_NAME,
           sampler=item_selectors.Prioritized(1),
           remover=item_selectors.Fifo(),
           max_size=1000,
           rate_limiter=rate_limiters.MinSize(3),
           signature=tf.TensorSpec(dtype=tf.int64, shape=[]),
       ),
       server.Table.queue(
           name=NESTED_SIGNATURE_TABLE_NAME,
           max_size=10,
           signature=QUEUE_SIGNATURE,
       ),
       server.Table.queue(SIMPLE_QUEUE_NAME, 10),
   ]
   cls.server = server.Server(tables=cls.tables)
   cls.client = cls.server.localhost_client()
  def test_sets_dtypes_from_bounded_spec_signature(self):
    bounded_spec_signature = {
        'a': {
            'b': tensor_spec.BoundedTensorSpec([3, 3], tf.float32, 0, 3),
            'c': tensor_spec.BoundedTensorSpec([], tf.int64, 0, 5),
        },
    }

    server = reverb_server.Server([
        reverb_server.Table.queue(
            'queue', 10, signature=bounded_spec_signature)
    ])

    dataset = reverb_dataset.ReplayDataset.from_table_signature(
        f'localhost:{server.port}', 'queue', 100)
    self.assertDictEqual(
        dataset.element_spec.data, {
            'a': {
                'b': tf.TensorSpec([3, 3], tf.float32),
                'c': tf.TensorSpec([], tf.int64),
            },
        })
  def test_episode_steps_partial_step(self):
    server = server_lib.Server([server_lib.Table.queue('queue', 1)])
    client = client_lib.Client(f'localhost:{server.port}')
    writer = client.trajectory_writer(num_keep_alive_refs=1)

    for _ in range(3):
      # Every episode, including the first, should start at zero.
      self.assertEqual(writer.episode_steps, 0)

      for i in range(1, 4):
        writer.append({'x': 3}, partial_step=True)

        # Step count should not increment on partial append calls.
        self.assertEqual(writer.episode_steps, i - 1)

        writer.append({'y': 2})

        # Step count should increment after the unqualified append call.
        self.assertEqual(writer.episode_steps, i)

      # Ending the episode should reset the step count to zero.
      writer.end_episode()
    def test_max_samples(self, num_workers_per_iterator,
                         max_in_flight_samples_per_worker, max_samples):

        server = reverb_server.Server(
            tables=[reverb_server.Table.queue(_TABLE, 10)])
        client = server.localhost_client()

        with client.trajectory_writer(10) as writer:
            for i in range(10):
                writer.append([np.ones([3, 3], np.int32) * i])
                writer.create_item(_TABLE, 1.0, {
                    'last': writer.history[0][-1],
                    'all': writer.history[0][:],
                })

        dataset = trajectory_dataset.TrajectoryDataset(
            tf.constant(client.server_address),
            table=tf.constant(_TABLE),
            dtypes={
                'last': tf.int32,
                'all': tf.int32,
            },
            shapes={
                'last': tf.TensorShape([3, 3]),
                'all': tf.TensorShape([None, 3, 3]),
            },
            num_workers_per_iterator=num_workers_per_iterator,
            max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
            max_samples=max_samples)

        # Check that it fetches exactly `max_samples` samples.
        it = dataset.as_numpy_iterator()
        self.assertLen(list(it), max_samples)
        # Check that no prefetching happen on the queue.
        self.assertEqual(client.server_info()[_TABLE].current_size,
                         10 - max_samples)
Beispiel #21
0
 def test_can_sample(self):
     table = server.Table(name=TABLE_NAME,
                          sampler=item_selectors.Prioritized(1),
                          remover=item_selectors.Fifo(),
                          max_size=100,
                          max_times_sampled=1,
                          rate_limiter=rate_limiters.MinSize(2))
     my_server = server.Server(tables=[table])
     my_client = my_server.localhost_client()
     self.assertFalse(table.can_sample(1))
     self.assertTrue(table.can_insert(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     self.assertFalse(table.can_sample(1))
     my_client.insert(1, {TABLE_NAME: 1.0})
     for _ in range(100):
         if table.info.current_size == 2:
             break
         time.sleep(0.01)
     self.assertEqual(table.info.current_size, 2)
     self.assertTrue(table.can_sample(2))
     # TODO(b/153258711): This should return False since max_times_sampled=1.
     self.assertTrue(table.can_sample(3))
     del my_client
     my_server.stop()
Beispiel #22
0
 def test_no_priority_table_provided(self):
     with self.assertRaises(ValueError):
         server.Server(tables=[], port=None)
 def setUpClass(cls):
     super().setUpClass()
     cls._server = server_lib.Server([server_lib.Table.queue('queue', 100)])
 def setUpClass(cls):
     super().setUpClass()
     cls._server = server_lib.Server(
         [server_lib.Table.queue(table, 100) for table in TABLES])