def test_basic_bad_shape(self): keys = dataset_ops.Dataset.range(100) values = dataset_ops.Dataset.range(100).map( lambda x: string_ops.as_string(x * 2)) values = values.batch(4) ds = dataset_ops.Dataset.zip((keys, values)) with self.assertRaises(ValueError): lookup_ops.DatasetInitializer(ds)
def test_from_file(self): vocabulary_file = self._createVocabFile("test.txt", ("one", "two", "three")) ds = reader_ops.TextLineDataset(vocabulary_file) ds = ds.enumerate(start=1) init = lookup_ops.DatasetInitializer(ds) table = self.getHashTable()(init, default_value="") self.initialize_table(table) output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64)) result = self.evaluate(output) self.assertAllEqual(["two", "three", ""], result)
def test_basic(self): keys = dataset_ops.Dataset.range(100) values = dataset_ops.Dataset.range(100).map( lambda x: string_ops.as_string(x * 2)) ds = dataset_ops.Dataset.zip((keys, values)) init = lookup_ops.DatasetInitializer(ds) table = self.getHashTable()(init, default_value="") self.initialize_table(table) output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64)) result = self.evaluate(output) self.assertAllEqual(["0", "4", "10"], result)
def test_compatibility(self): with ops.Graph().as_default(): keys = dataset_ops.Dataset.range(100) values = dataset_ops.Dataset.range(100).map(string_ops.as_string) ds = dataset_ops.Dataset.zip((keys, values)) init = lookup_ops.DatasetInitializer(ds) table = self.getHashTable()(init, default_value="") output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64)) self.evaluate(core_lookup_ops.tables_initializer()) result = self.evaluate(output) self.assertAllEqual(["0", "2", "5"], result)
def test_map_variable(self): ds = dataset_ops.Dataset.range(100) captured_var = variables.Variable(0) def func(_): return captured_var.assign_add(1) ds = ds.map(func) ds = ds.enumerate(start=1) init = lookup_ops.DatasetInitializer(ds) table = self.getHashTable()(init, default_value=-1) self.evaluate(captured_var.initializer) self.initialize_table(table) output = table.lookup(constant_op.constant([1, 2, 101], dtypes.int64)) result = self.evaluate(output) self.assertAllEqual([1, 2, -1], result)
def datasetInitializer(self, vals): keys = dataset_ops.Dataset.range(len(vals)) values = dataset_ops.Dataset.from_tensor_slices(vals) ds = dataset_ops.Dataset.zip((keys, values)) return data_lookup_ops.DatasetInitializer(ds)