예제 #1
0
    def testShardedMutableHashTable(self):
        for num_shards in [1, 3, 10]:
            with self._single_threaded_test_session():
                default_val = -1
                keys = tf.constant(["brain", "salad", "surgery"])
                values = tf.constant([0, 1, 2], tf.int64)
                table = _ShardedMutableHashTable(tf.string, tf.int64, default_val, num_shards=num_shards)
                self.assertAllEqual(0, table.size().eval())

                table.insert(keys, values).run()
                self.assertAllEqual(3, table.size().eval())

                input_string = tf.constant(["brain", "salad", "tank"])
                output = table.lookup(input_string)
                self.assertAllEqual([3], output.get_shape())

                result = output.eval()
                self.assertAllEqual([0, 1, -1], result)
예제 #2
0
    def testExportSharded(self):
        with self._single_threaded_test_session():
            default_val = -1
            num_shards = 2
            keys = tf.constant(["a1", "b1", "c2"])
            values = tf.constant([0, 1, 2], tf.int64)
            table = _ShardedMutableHashTable(tf.string, tf.int64, default_val, num_shards=num_shards)
            self.assertAllEqual(0, table.size().eval())

            table.insert(keys, values).run()
            self.assertAllEqual(3, table.size().eval())

            keys_list, values_list = table.export_sharded()
            self.assertAllEqual(num_shards, len(keys_list))
            self.assertAllEqual(num_shards, len(values_list))

            self.assertAllEqual(set([b"b1", b"c2"]), set(keys_list[0].eval()))
            self.assertAllEqual([b"a1"], keys_list[1].eval())
            self.assertAllEqual(set([1, 2]), set(values_list[0].eval()))
            self.assertAllEqual([0], values_list[1].eval())
예제 #3
0
  def testExportSharded(self):
    with self._single_threaded_test_session():
      default_val = -1
      num_shards = 2
      keys = tf.constant(['a1', 'b1', 'c2'])
      values = tf.constant([0, 1, 2], tf.int64)
      table = _ShardedMutableHashTable(
          tf.string, tf.int64, default_val, num_shards=num_shards)
      self.assertAllEqual(0, table.size().eval())

      table.insert(keys, values).run()
      self.assertAllEqual(3, table.size().eval())

      keys_list, values_list = table.export_sharded()
      self.assertAllEqual(num_shards, len(keys_list))
      self.assertAllEqual(num_shards, len(values_list))

      self.assertAllEqual(set([b'b1', b'c2']), set(keys_list[0].eval()))
      self.assertAllEqual([b'a1'], keys_list[1].eval())
      self.assertAllEqual(set([1, 2]), set(values_list[0].eval()))
      self.assertAllEqual([0], values_list[1].eval())
예제 #4
0
  def testShardedMutableHashTable(self):
    for num_shards in [1, 3, 10]:
      with self._single_threaded_test_session():
        default_val = -1
        keys = tf.constant(['brain', 'salad', 'surgery'])
        values = tf.constant([0, 1, 2], tf.int64)
        table = _ShardedMutableHashTable(tf.string,
                                         tf.int64,
                                         default_val,
                                         num_shards=num_shards)
        self.assertAllEqual(0, table.size().eval())

        table.insert(keys, values).run()
        self.assertAllEqual(3, table.size().eval())

        input_string = tf.constant(['brain', 'salad', 'tank'])
        output = table.lookup(input_string)
        self.assertAllEqual([3], output.get_shape())

        result = output.eval()
        self.assertAllEqual([0, 1, -1], result)
예제 #5
0
  def testShardedMutableHashTable(self):
    for num_shards in [1, 3, 10]:
      with self.test_session():
        default_val = -1
        keys = tf.constant(['brain', 'salad', 'surgery'])
        values = tf.constant([0, 1, 2], tf.int64)
        table = _ShardedMutableHashTable(tf.string,
                                         tf.int64,
                                         default_val,
                                         num_shards=num_shards)
        self.assertAllEqual(0, table.size().eval())

        table.insert(keys, values).run()
        self.assertAllEqual(3, table.size().eval())

        input_string = tf.constant(['brain', 'salad', 'tank'])
        output = table.lookup(input_string)
        self.assertAllEqual([3], output.get_shape())

        result = output.eval()
        self.assertAllEqual([0, 1, -1], result)

        self.assertAllEqual(3, table.values_reduce_sum().eval())