예제 #1
0
  def test_iterate_with_unknown_sequence_length(self, table_name,
                                                sequence_length):
    self._populate_replay(sequence_length)

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table=table_name,
        dtypes=(tf.float32,),
        shapes=(tf.TensorShape([None, 3, 3]),),
        emit_timesteps=False,
        sequence_length=None)

    # Check the shape of the items.
    iterator = dataset.make_initializable_iterator()
    dataset_item = iterator.get_next()
    self.assertIsNone(dataset_item.info.key.shape.as_list()[0], None)
    self.assertIsNone(dataset_item.data[0].shape.as_list()[0], None)

    # Verify that once evaluated, the samples has the expected length.
    got = self._sample_from(dataset, 10)
    for sample in got:
      self.assertIsInstance(sample, replay_sample.ReplaySample)

      # The keys and data should be batched up by the sequence length.
      self.assertEqual(sample.info.key.shape, (sequence_length,))
      np.testing.assert_array_equal(
          sample.data[0], np.zeros((sequence_length, 3, 3), dtype=np.float32))
예제 #2
0
  def test_priority_update_is_applied(self):
    # Start with uniform distribution
    for i in range(4):
      self._client.insert([np.array([i], dtype=np.uint32)], {'dist': 1})

    # Until we have recieved all 4 items.
    items = {}
    while len(items) < 4:
      item = next(self._client.sample('dist'))[0]
      items[item.info.key] = item.info.probability
      self.assertEqual(item.info.probability, 0.25)

    # Update the priority of one of the items.
    update_key = next(iter(items.keys()))
    with self.session() as session:
      client = tf_client.TFClient(self._client.server_address)
      update_op = client.update_priorities(
          table=tf.constant('dist'),
          keys=tf.constant([update_key], dtype=tf.uint64),
          priorities=tf.constant([3], dtype=tf.float64))
      self.assertIsNone(session.run(update_op))

    # The updated item now has priority 3 and the other 3 items have priority 1
    # each. The probability of sampling the new item should thus be 50%. We
    # sample until the updated item is seen and check that the probability (and
    # thus the priority) has been updated.
    for _ in range(1000):
      item = next(self._client.sample('dist'))[0]
      if item.info.key == update_key:
        self.assertEqual(item.info.probability, 0.5)
        break
    else:
      self.fail('Updated item was not found')
예제 #3
0
  def test_multiple_iterators(self):
    with self._client.writer(100) as writer:
      for i in range(10):
        writer.append([np.ones((81, 81), dtype=np.float32) * i])
      writer.create_item(table='dist', num_timesteps=10, priority=1)

    trajectory_length = 5
    batch_size = 3

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table='dist', dtypes=(tf.float32,), shapes=(tf.TensorShape([81, 81]),))
    dataset = dataset.batch(trajectory_length)

    iterators = [
        dataset.make_initializable_iterator() for _ in range(batch_size)
    ]
    items = tf.stack(
        [tf.squeeze(iterator.get_next().data) for iterator in iterators])

    with self.session() as session:
      session.run([iterator.initializer for iterator in iterators])
      got = session.run(items)
      self.assertEqual(got.shape, (batch_size, trajectory_length, 81, 81))

      want = np.array(
          [[np.ones([81, 81]) * i for i in range(trajectory_length)]] *
          batch_size)
      np.testing.assert_array_equal(got, want)
예제 #4
0
 def test_dtype_mismatch_result_in_error_raised(self):
   data = [np.zeros((81, 81))]
   self._client.insert(data, {'dist': 1})
   with self.session() as session:
     client = tf_client.TFClient(self._client.server_address)
     with self.assertRaises(tf.errors.InternalError):
       session.run(client.sample('dist', [tf.float32]))
예제 #5
0
 def test_shape_result_in_error_raised(self):
   with self.session() as session:
     client = tf_client.TFClient(self._client.server_address)
     update_op = client.update_priorities(
         tf.constant('dist'), tf.constant([1, 2], dtype=tf.uint64),
         tf.constant([1], dtype=tf.float64))
     with self.assertRaises(tf.errors.InvalidArgumentError):
       session.run(update_op)
예제 #6
0
    def test_timeout_invalid_arguments(self):
        client = tf_client.TFClient(self._client.server_address)

        with self.assertRaisesRegex(ValueError, r'must be an integer >= -1'):
            client.dataset(table='dist',
                           dtypes=(tf.float32, ),
                           shapes=(tf.TensorShape([3, 3]), ),
                           rate_limiter_timeout_ms=-2)
예제 #7
0
    def test_iterate_nested_and_batched(self):
        with self._client.writer(100) as writer:
            for i in range(1000):
                writer.append({
                    'observation': {
                        'data': np.zeros((3, 3), dtype=np.float32),
                        'extras': [
                            np.int64(10),
                            np.ones([1], dtype=np.int32),
                        ],
                    },
                    'reward': np.zeros((10, 10), dtype=np.float32),
                })
                if i % 5 == 0 and i >= 100:
                    writer.create_item(table='dist',
                                       num_timesteps=100,
                                       priority=1)

        client = tf_client.TFClient(self._client.server_address)
        dataset = client.dataset(
            table='dist',
            dtypes=(((tf.float32), (tf.int64, tf.int32)), tf.float32),
            shapes=((tf.TensorShape([3, 3]), (tf.TensorShape(None),
                                              tf.TensorShape([1]))),
                    tf.TensorShape([10, 10])),
        )
        dataset = dataset.batch(3)

        structure = {
            'observation': {
                'data':
                tf.TensorSpec([3, 3], tf.float32),
                'extras': [
                    tf.TensorSpec([], tf.int64),
                    tf.TensorSpec([1], tf.int32),
                ],
            },
            'reward': tf.TensorSpec([], tf.int64),
        }

        got = self._sample_from(dataset, 10)
        self.assertLen(got, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)

            transition = tree.unflatten_as(structure,
                                           tree.flatten(sample.data))
            np.testing.assert_array_equal(
                transition['observation']['data'],
                np.zeros([3, 3, 3], dtype=np.float32))
            np.testing.assert_array_equal(
                transition['observation']['extras'][0],
                np.ones([3], dtype=np.int64) * 10)
            np.testing.assert_array_equal(
                transition['observation']['extras'][1],
                np.ones([3, 1], dtype=np.int32))
            np.testing.assert_array_equal(
                transition['reward'], np.zeros([3, 10, 10], dtype=np.float32))
예제 #8
0
 def test_incompatible_shape_when_using_sequence_length(self, sequence_length):
   client = tf_client.TFClient(self._client.server_address)
   with self.assertRaises(ValueError):
     client.dataset(
         table='dist',
         dtypes=(tf.float32,),
         shapes=(tf.TensorShape([sequence_length + 1, 3, 3]),),
         emit_timesteps=False,
         sequence_length=sequence_length)
예제 #9
0
 def test_retries_until_success_or_fatal_error(self):
   with self.session() as session:
     client = tf_client.TFClient(self._client.server_address)
     with futures.ThreadPoolExecutor(max_workers=1) as executor:
       sample = executor.submit(session.run,
                                client.sample('dist', [tf.float64]))
       input_data = [np.zeros((81, 81))]
       self._client.insert(input_data, {'dist': 1})
       np.testing.assert_equal(input_data, sample.result().data)
예제 #10
0
 def test_sets_meta_data_fields(self):
   input_data = [np.ones((81, 81), dtype=np.float64)]
   self._client.insert(input_data, {'dist': 1})
   with self.session() as session:
     client = tf_client.TFClient(self._client.server_address)
     sample = session.run(client.sample('dist', [tf.float64]))
     np.testing.assert_equal(input_data, sample.data)
     self.assertNotEqual(sample.info.key, 0)
     self.assertEqual(sample.info.probability, 1)
     self.assertEqual(sample.info.table_size, 1)
예제 #11
0
  def test_sampler_parameter_validation(self, **kwargs):
    client = tf_client.TFClient(self._client.server_address)
    dtypes = (tf.float32,)
    shapes = (tf.TensorShape([3, 3]),)

    if 'want_error' in kwargs:
      error = kwargs.pop('want_error')
      with self.assertRaises(error):
        client.dataset('dist', dtypes, shapes, **kwargs)
    else:
      client.dataset('dist', dtypes, shapes, **kwargs)
예제 #12
0
  def test_incompatible_signature_shape(self, table_name):
    self._populate_replay()

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table=table_name, dtypes=(tf.float32,), shapes=(tf.TensorShape([3]),))
    with self.assertRaisesWithPredicateMatch(
        tf.errors.InvalidArgumentError,
        r'Requested incompatible tensor at flattened index 3 from table '
        r'\'{}\'.  Requested \(dtype, shape\): \(float, \[3\]\).  '
        r'Signature \(dtype, shape\): \(float, \[\?,\?\]\)'.format(table_name)):
      self._sample_from(dataset, 10)
예제 #13
0
  def test_iterate(self):
    self._populate_replay()

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table='dist', dtypes=(tf.float32,), shapes=(tf.TensorShape([3, 3]),))
    got = self._sample_from(dataset, 10)
    for sample in got:
      self.assertIsInstance(sample, replay_sample.ReplaySample)
      # A single sample is returned so the key should be a scalar int64.
      self.assertIsInstance(sample.info.key, np.uint64)
      np.testing.assert_array_equal(sample.data[0],
                                    np.zeros((3, 3), dtype=np.float32))
예제 #14
0
    def test_inconsistent_signature_size(self):
        self._populate_replay()

        client = tf_client.TFClient(self._client.server_address)
        dataset = client.dataset(table='signatured',
                                 dtypes=(tf.float32, tf.float64),
                                 shapes=(tf.TensorShape([3, 3]),
                                         tf.TensorShape([])))
        with self.assertRaisesWithPredicateMatch(
                tf.errors.InvalidArgumentError,
                r'Inconsistent number of tensors requested from table \'signatured\'.  '
                r'Requested 5 tensors, but table signature shows 4 tensors.'):
            self._sample_from(dataset, 10)
예제 #15
0
  def test_checks_that_table_has_rank_1(self):
    client = tf_client.TFClient(self._client.server_address)
    priorities = tf.constant([1.0], dtype=tf.float64)

    # Works for rank 1.
    client.insert(self.data, tf.constant(['dist']), priorities)

    # Does not work for rank > 1.
    with self.assertRaises(ValueError):
      client.insert(self.data, tf.constant([['dist']]), priorities)

    # Does not work for rank < 1.
    with self.assertRaises(ValueError):
      client.insert(self.data, tf.constant('dist'), priorities)
예제 #16
0
  def test_checks_sequence_length_when_timesteps_emitted(
      self, table_name, actual_sequence_length, provided_sequence_length):
    self._populate_replay(actual_sequence_length)

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table=table_name,
        dtypes=(tf.float32,),
        shapes=(tf.TensorShape([provided_sequence_length, 3, 3]),),
        emit_timesteps=True,
        sequence_length=provided_sequence_length)

    with self.assertRaises(tf.errors.InvalidArgumentError):
      self._sample_from(dataset, 10)
예제 #17
0
  def test_checks_that_priorities_argument_has_rank_1(self):
    client = tf_client.TFClient(self._client.server_address)
    data = [tf.constant([1, 2])]
    tables = tf.constant(['dist'])

    # Works for rank 1.
    client.insert(data, tables, tf.constant([1.0], dtype=tf.float64))

    # Does not work for rank > 1.
    with self.assertRaises(ValueError):
      client.insert(data, tables, tf.constant([[1.0]], dtype=tf.float64))

    # Does not work for rank < 1.
    with self.assertRaises(ValueError):
      client.insert(data, tables, tf.constant(1.0, dtype=tf.float64))
예제 #18
0
  def test_iterate_over_blobs(self):
    for _ in range(10):
      self._client.insert((np.ones([3, 3], dtype=np.int32)), {'dist': 1})

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table='dist', dtypes=(tf.int32,), shapes=(tf.TensorShape([3, 3]),))

    got = self._sample_from(dataset, 20)
    self.assertLen(got, 20)
    for sample in got:
      self.assertIsInstance(sample, replay_sample.ReplaySample)
      self.assertIsInstance(sample.info.key, np.uint64)
      self.assertIsInstance(sample.info.probability, np.float64)
      np.testing.assert_array_equal(sample.data[0],
                                    np.ones((3, 3), dtype=np.int32))
예제 #19
0
    def test_timeout(self):
        client = tf_client.TFClient(self._client.server_address)

        dataset_0s = client.dataset(table='dist',
                                    dtypes=(tf.float32, ),
                                    shapes=(tf.TensorShape([3, 3]), ),
                                    rate_limiter_timeout_ms=0)

        dataset_1s = client.dataset(table='dist',
                                    dtypes=(tf.float32, ),
                                    shapes=(tf.TensorShape([3, 3]), ),
                                    rate_limiter_timeout_ms=1000)

        dataset_2s = client.dataset(table='dist',
                                    dtypes=(tf.float32, ),
                                    shapes=(tf.TensorShape([3, 3]), ),
                                    rate_limiter_timeout_ms=2000)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_0s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 0)
        self.assertLess(duration, 5)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_1s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 1)
        self.assertLess(duration, 10)

        start_time = time.time()
        with self.assertRaisesWithPredicateMatch(tf.errors.OutOfRangeError,
                                                 r'End of sequence'):
            self._sample_from(dataset_2s, 1)
        duration = time.time() - start_time
        self.assertGreaterEqual(duration, 2)
        self.assertLess(duration, 10)

        # If we insert some data, and the rate limiter doesn't force any waiting,
        # then we can ask for a timeout of 0s and still get data back.
        self._populate_replay()
        got = self._sample_from(dataset_0s, 2)
        self.assertLen(got, 2)
예제 #20
0
  def test_single_table_insert(self):
    with self.session() as session:
      client = tf_client.TFClient(self._client.server_address)
      insert_op = client.insert(
          data=[tf.constant([1, 2, 3], dtype=tf.int8)],
          tables=tf.constant(['dist']),
          priorities=tf.constant([1.0], dtype=tf.float64))
      sample_op = client.sample('dist', [tf.int8])

      # Check that insert op succeeds.
      self.assertIsNone(session.run(insert_op))

      # Check that the sampled data matches the inserted.
      sample = session.run(sample_op)
      self.assertLen(sample.data, 1)
      np.testing.assert_equal(
          np.array([1, 2, 3], dtype=np.int8), sample.data[0])
예제 #21
0
    def test_iterate_batched(self, table_name):
        self._populate_replay()

        client = tf_client.TFClient(self._client.server_address)
        dataset = client.dataset(table=table_name,
                                 dtypes=(tf.float32, ),
                                 shapes=(tf.TensorShape([3, 3]), ))
        dataset = dataset.batch(2, True)

        got = self._sample_from(dataset, 10)
        for sample in got:
            self.assertIsInstance(sample, replay_sample.ReplaySample)

            # The keys should be batched up like the data.
            self.assertEqual(sample.info.key.shape, (2, ))

            np.testing.assert_array_equal(
                sample.data[0], np.zeros((2, 3, 3), dtype=np.float32))
예제 #22
0
  def test_converts_spec_lists_into_tuples(self):
    for _ in range(10):
      data = [
          (np.ones([1, 1], dtype=np.int32),),
          [
              np.ones([3, 3], dtype=np.int8),
              (np.ones([2, 2], dtype=np.float64),)
          ],
      ]
      self._client.insert(data, {'dist': 1})

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table='dist',
        dtypes=[
            (tf.int32,),
            [
                tf.int8,
                (tf.float64,),
            ],
        ],
        shapes=[
            (tf.TensorShape([1, 1]),),
            [
                tf.TensorShape([3, 3]),
                (tf.TensorShape([2, 2]),),
            ],
        ])

    got = self._sample_from(dataset, 10)

    for sample in got:
      self.assertIsInstance(sample, replay_sample.ReplaySample)
      self.assertIsInstance(sample.info.key, np.uint64)
      tree.assert_same_structure(sample.data, (
          (None,),
          (
              None,
              (None,),
          ),
      ))
예제 #23
0
  def test_session_is_closed_while_op_pending(self):
    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table='dist', dtypes=tf.float32, shapes=tf.TensorShape([]))

    iterator = dataset.make_initializable_iterator()
    item = iterator.get_next()

    def _session_closer(sess, wait_time_secs):
      def _fn():
        time.sleep(wait_time_secs)
        sess.close()

      return _fn

    with self.session() as sess:
      sess.run(iterator.initializer)
      thread = threading.Thread(target=_session_closer(sess, 3))
      thread.start()
      with self.assertRaises(tf.errors.CancelledError):
        sess.run(item)
예제 #24
0
  def test_iterate_with_sequence_length(self, table_name, sequence_length,
                                        max_time_steps):
    # Also ensure we get sequence_length-shaped outputs when
    # writers' max_time_steps != sequence_length.
    self._populate_replay(sequence_length, max_time_steps=max_time_steps)

    client = tf_client.TFClient(self._client.server_address)
    dataset = client.dataset(
        table=table_name,
        dtypes=(tf.float32,),
        shapes=(tf.TensorShape([sequence_length, 3, 3]),),
        emit_timesteps=False,
        sequence_length=sequence_length)

    got = self._sample_from(dataset, 10)
    for sample in got:
      self.assertIsInstance(sample, replay_sample.ReplaySample)

      # The keys and data should be batched up by the sequence length.
      self.assertEqual(sample.info.key.shape, (sequence_length,))
      np.testing.assert_array_equal(
          sample.data[0], np.zeros((sequence_length, 3, 3), dtype=np.float32))
예제 #25
0
 def test_forwards_server_error(self):
   with self.session() as session:
     client = tf_client.TFClient(self._client.server_address)
     with self.assertRaises(tf.errors.NotFoundError):
       session.run(client.sample('invalid', [tf.float64]))
예제 #26
0
 def test_checks_dtype_of_table_argument(self):
   client = tf_client.TFClient(self._client.server_address)
   with self.assertRaises(ValueError):
     client.insert(self.data, tf.constant([1]),
                   tf.constant([1.0], dtype=tf.float64))
예제 #27
0
 def test_checks_that_tables_and_priorities_arguments_have_same_shape(self):
   client = tf_client.TFClient(self._client.server_address)
   with self.assertRaises(ValueError):
     client.insert(self.data, tf.constant(['dist', 'dist2']),
                   tf.constant([1.0], dtype=tf.float64))
예제 #28
0
 def test_checks_that_priorities_argument_has_dtype_float64(self):
   client = tf_client.TFClient(self._client.server_address)
   with self.assertRaises(ValueError):
     client.insert(self.data, tf.constant(['dist']),
                   tf.constant([1.0], dtype=tf.float32))