Ejemplo n.º 1
0
 def test_basic_bad_shape(self):
     keys = dataset_ops.Dataset.range(100)
     values = dataset_ops.Dataset.range(100).map(
         lambda x: string_ops.as_string(x * 2))
     values = values.batch(4)
     ds = dataset_ops.Dataset.zip((keys, values))
     with self.assertRaises(ValueError):
         lookup_ops.DatasetInitializer(ds)
Ejemplo n.º 2
0
    def test_from_file(self):
        vocabulary_file = self._createVocabFile("test.txt",
                                                ("one", "two", "three"))
        ds = reader_ops.TextLineDataset(vocabulary_file)
        ds = ds.enumerate(start=1)
        init = lookup_ops.DatasetInitializer(ds)
        table = self.getHashTable()(init, default_value="")
        self.initialize_table(table)

        output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
        result = self.evaluate(output)
        self.assertAllEqual(["two", "three", ""], result)
Ejemplo n.º 3
0
    def test_basic(self):
        keys = dataset_ops.Dataset.range(100)
        values = dataset_ops.Dataset.range(100).map(
            lambda x: string_ops.as_string(x * 2))
        ds = dataset_ops.Dataset.zip((keys, values))
        init = lookup_ops.DatasetInitializer(ds)
        table = self.getHashTable()(init, default_value="")
        self.initialize_table(table)

        output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64))
        result = self.evaluate(output)
        self.assertAllEqual(["0", "4", "10"], result)
Ejemplo n.º 4
0
 def test_compatibility(self):
     with ops.Graph().as_default():
         keys = dataset_ops.Dataset.range(100)
         values = dataset_ops.Dataset.range(100).map(string_ops.as_string)
         ds = dataset_ops.Dataset.zip((keys, values))
         init = lookup_ops.DatasetInitializer(ds)
         table = self.getHashTable()(init, default_value="")
         output = table.lookup(constant_op.constant([0, 2, 5],
                                                    dtypes.int64))
         self.evaluate(core_lookup_ops.tables_initializer())
         result = self.evaluate(output)
     self.assertAllEqual(["0", "2", "5"], result)
Ejemplo n.º 5
0
    def test_map_variable(self):
        ds = dataset_ops.Dataset.range(100)
        captured_var = variables.Variable(0)

        def func(_):
            return captured_var.assign_add(1)

        ds = ds.map(func)
        ds = ds.enumerate(start=1)
        init = lookup_ops.DatasetInitializer(ds)
        table = self.getHashTable()(init, default_value=-1)
        self.evaluate(captured_var.initializer)
        self.initialize_table(table)

        output = table.lookup(constant_op.constant([1, 2, 101], dtypes.int64))
        result = self.evaluate(output)
        self.assertAllEqual([1, 2, -1], result)
Ejemplo n.º 6
0
 def datasetInitializer(self, vals):
   keys = dataset_ops.Dataset.range(len(vals))
   values = dataset_ops.Dataset.from_tensor_slices(vals)
   ds = dataset_ops.Dataset.zip((keys, values))
   return data_lookup_ops.DatasetInitializer(ds)