Exemple #1
0
    def testShardedMutableHashTable(self):
        for num_shards in [1, 3, 10]:
            with self._single_threaded_test_session():
                default_val = -1
                empty_key = 0
                keys = tf.constant([11, 12, 13], tf.int64)
                values = tf.constant([0, 1, 2], tf.int64)
                table = _ShardedMutableDenseHashTable(tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
                self.assertAllEqual(0, table.size().eval())

                table.insert(keys, values).run()
                self.assertAllEqual(3, table.size().eval())

                input_string = tf.constant([11, 12, 14], tf.int64)
                output = table.lookup(input_string)
                self.assertAllEqual([3], output.get_shape())
                self.assertAllEqual([0, 1, -1], output.eval())
Exemple #2
0
  def testShardedMutableHashTable(self):
    for num_shards in [1, 3, 10]:
      with self._single_threaded_test_session():
        default_val = -1
        empty_key = 0
        keys = tf.constant([11, 12, 13], tf.int64)
        values = tf.constant([0, 1, 2], tf.int64)
        table = _ShardedMutableDenseHashTable(
            tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
        self.assertAllEqual(0, table.size().eval())

        table.insert(keys, values).run()
        self.assertAllEqual(3, table.size().eval())

        input_string = tf.constant([11, 12, 14], tf.int64)
        output = table.lookup(input_string)
        self.assertAllEqual([3], output.get_shape())
        self.assertAllEqual([0, 1, -1], output.eval())
Exemple #3
0
  def testShardedMutableHashTableVectors(self):
    for num_shards in [1, 3, 10]:
      with self._single_threaded_test_session():
        default_val = [-0.1, 0.2]
        empty_key = [0, 1]
        keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64)
        values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32)
        table = _ShardedMutableDenseHashTable(
            tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards)
        self.assertAllEqual(0, table.size().eval())

        table.insert(keys, values).run()
        self.assertAllEqual(3, table.size().eval())

        input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64)
        output = table.lookup(input_string)
        self.assertAllEqual([3, 2], output.get_shape())
        self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]],
                            output.eval())
Exemple #4
0
    def testExportSharded(self):
        with self._single_threaded_test_session():
            empty_key = -2
            default_val = -1
            num_shards = 2
            keys = tf.constant([10, 11, 12], tf.int64)
            values = tf.constant([2, 3, 4], tf.int64)
            table = _ShardedMutableDenseHashTable(tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
            self.assertAllEqual(0, table.size().eval())

            table.insert(keys, values).run()
            self.assertAllEqual(3, table.size().eval())

            keys_list, values_list = table.export_sharded()
            self.assertAllEqual(num_shards, len(keys_list))
            self.assertAllEqual(num_shards, len(values_list))

            # Exported keys include empty key buckets set to the empty_key
            self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten()))
            self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten()))
            # Exported values include empty value buckets set to 0
            self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten()))
            self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))
Exemple #5
0
  def testExportSharded(self):
    with self._single_threaded_test_session():
      empty_key = -2
      default_val = -1
      num_shards = 2
      keys = tf.constant([10, 11, 12], tf.int64)
      values = tf.constant([2, 3, 4], tf.int64)
      table = _ShardedMutableDenseHashTable(
          tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
      self.assertAllEqual(0, table.size().eval())

      table.insert(keys, values).run()
      self.assertAllEqual(3, table.size().eval())

      keys_list, values_list = table.export_sharded()
      self.assertAllEqual(num_shards, len(keys_list))
      self.assertAllEqual(num_shards, len(values_list))

      # Exported keys include empty key buckets set to the empty_key
      self.assertAllEqual(set([-2, 10, 12]), set(keys_list[0].eval().flatten()))
      self.assertAllEqual(set([-2, 11]), set(keys_list[1].eval().flatten()))
      # Exported values include empty value buckets set to 0
      self.assertAllEqual(set([0, 2, 4]), set(values_list[0].eval().flatten()))
      self.assertAllEqual(set([0, 3]), set(values_list[1].eval().flatten()))