def get_static_table(tmpdir, vocab_list, mask_token=None, dtype=dtypes.string, oov_tokens=None): vocabulary_file = os.path.join(tmpdir, "tmp_vocab.txt") if dtype == dtypes.string: with open(vocabulary_file, "w") as f: f.write("\n".join(vocab_list) + "\n") else: with open(vocabulary_file, "w") as f: f.write("\n".join([str(v) for v in vocab_list]) + "\n") offset = ((0 if mask_token is None else 1) + (len(oov_tokens) if oov_tokens is not None else 0)) init = lookup_ops.TextFileInitializer(vocabulary_file, dtype, lookup_ops.TextFileIndex.WHOLE_LINE, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER, value_index_offset=offset) if context.executing_eagerly(): table = lookup_ops.StaticHashTable(init, default_value=-7) else: table = lookup_ops.StaticHashTableV1(init, default_value=-7) return table_utils.TableHandler( table, oov_tokens, mask_token=mask_token, use_v1_apis=(not context.executing_eagerly()))
def __init__(self): self.asset = asset.Asset( test.test_src_dir_path( "cc/saved_model/testdata/static_hashtable_asset.txt")) self.table = lookup_ops.StaticHashTable( lookup_ops.TextFileInitializer( self.asset, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), -1)
def testDistributeLookupTable(self, init_source): cluster = data_service_test_base.TestCluster(num_workers=1) initializer = self.lookupTableInitializer(init_source, [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(lookup_ops.tables_initializer()) self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True)
def createStaticHashTable(self, init_source=None, vals=None, default_value=None, initializer=None): if not initializer: initializer = self.make_initializer(init_source, vals) return lookup_ops.StaticHashTable(initializer=initializer, default_value=default_value)
def get_graph_def(): with ops.Graph().as_default() as g: x = constant_op.constant([2, 9], name="x") keys = constant_op.constant([1, 2], name="keys") values = constant_op.constant([3, 4], name="values") default = constant_op.constant(-1, name="default") table = lookup_ops.StaticHashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default) _ = table.lookup(x) return g.as_graph_def()
def testDistributeLookupTable(self): cluster = data_service_test_base.TestCluster(num_workers=1) keys_tensor = constant_op.constant([1, 2]) vals_tensor = constant_op.constant([11, 12]) table = lookup_ops.StaticHashTable( lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) ds = dataset_ops.Dataset.range(3, output_type=dtypes.int32) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.assertDatasetProduces(ds, [-1, 11, 12])
def testLookupTableGraphSerialization(self, init_source): vals = [10, 11] initializer = self.lookupTableInitializer(init_source, vals) table = lookup_ops.StaticHashTable(initializer, -1) dataset = dataset_ops.Dataset.range(3) dataset = dataset.map(table.lookup) self.evaluate(lookup_ops.tables_initializer()) round_tripped = self.graphRoundTrip(dataset) del table del dataset self.assertDatasetProduces(round_tripped, [10, 11, -1], requires_initialization=True)
def testResourceOnWrongDevice(self): cluster = data_service_test_base.TestCluster(num_workers=1) with ops.device(self._devices[0]): initializer = self.lookupTableInitializer("keyvaluetensor", [10, 11]) table = lookup_ops.StaticHashTable(initializer, -1) self.evaluate(lookup_ops.tables_initializer()) with ops.device(self._devices[1]): ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) with self.assertRaisesRegex( errors.FailedPreconditionError, "Serialization error while trying to register a dataset"): ds = self.make_distributed_dataset(ds, cluster) self.getDatasetOutput(ds, requires_initialization=True)
def testStaticHashTableDatasetFnHostTrainingLoop(self, enable_packed_var): self._dataset_fn_tracing_count = 0 strategy = get_tpu_strategy(enable_packed_var) with strategy.scope(): vals = [0, 1, 2] keys_tensor = constant_op.constant( list(range(len(vals))), dtype=dtypes.int64) vals_tensor = constant_op.constant(vals) initializer = lookup_ops.KeyValueTensorInitializer( keys_tensor, vals_tensor) per_worker_table = lookup_ops.StaticHashTable( initializer, default_value=-1) @def_function.function def dataset_fn(input_context): tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64) global_batch_size = 2 batch_size = input_context.get_per_replica_batch_size(global_batch_size) dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch( batch_size, drop_remainder=True) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) dataset = dataset.prefetch(2) # This prefetches 2 batches per device. dataset = dataset.map(per_worker_table.lookup) self._dataset_fn_tracing_count += 1 return dataset dist_iterator = iter( strategy.experimental_distribute_datasets_from_function(dataset_fn)) @def_function.function def step_fn(inputs): # inputs should be [0, 1, -1] return math_ops.reduce_sum(inputs) def train_steps(iterator, steps): for _ in math_ops.range(steps): strategy.run(step_fn, args=(next(iterator),)) train_steps(dist_iterator, steps=5) self.assertEqual(self._dataset_fn_tracing_count, 1)
def __init__(self, init_source, filepath): vals = [0, 1, 2] if init_source == "textfile": with open(filepath, "w") as f: f.write("\n".join(str(v) for v in vals) + "\n") self.initializer = lookup_ops.TextFileInitializer( filepath, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE) else: keys_tensor = constant_op.constant( list(range(len(vals))), dtype=dtypes.int64) vals_tensor = constant_op.constant(vals) self.initializer = lookup_ops.KeyValueTensorInitializer( keys_tensor, vals_tensor) self.table = lookup_ops.StaticHashTable( self.initializer, default_value=-2)
def testDistributeLookupTable(self, init_from_file): cluster = data_service_test_base.TestCluster(num_workers=1) if init_from_file: file = os.path.join(self.get_temp_dir(), "distribute_lookup_table") with open(file, "w") as f: f.write("10\n11\n") initializer = lookup_ops.TextFileInitializer( file, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE) else: keys_tensor = constant_op.constant([0, 1], dtype=dtypes.int64) vals_tensor = constant_op.constant([10, 11]) initializer = lookup_ops.KeyValueTensorInitializer( keys_tensor, vals_tensor) table = lookup_ops.StaticHashTable(initializer, -1) ds = dataset_ops.Dataset.range(3) ds = ds.map(table.lookup) ds = self.make_distributed_dataset(ds, cluster) self.evaluate(lookup_ops.tables_initializer()) self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True)
def testLookupTableGraphSerialization(self, init_from_file): if init_from_file: file = os.path.join(self.get_temp_dir(), "lookup_table_graph_serialize") with open(file, "w") as f: f.write("10\n11\n") initializer = lookup_ops.TextFileInitializer( file, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE) else: keys_tensor = constant_op.constant([0, 1], dtype=dtypes.int64) vals_tensor = constant_op.constant([10, 11]) initializer = lookup_ops.KeyValueTensorInitializer( keys_tensor, vals_tensor) table = lookup_ops.StaticHashTable(initializer, -1) dataset = dataset_ops.Dataset.range(3) dataset = dataset.map(table.lookup) self.evaluate(lookup_ops.tables_initializer()) round_tripped = self.graphRoundTrip(dataset) del table del dataset self.assertDatasetProduces(round_tripped, [10, 11, -1], requires_initialization=True)
def __init__(self, max_tokens, num_oov_indices, mask_token, oov_token, vocabulary=None, invert=False, output_mode=INT, sparse=False, pad_to_max_tokens=False, **kwargs): # If max_tokens is set, the value must be greater than 1 - otherwise we # are creating a 0-element vocab, which doesn't make sense. if max_tokens is not None and max_tokens <= 1: raise ValueError("If set, `max_tokens` must be greater than 1. " "You passed {}".format(max_tokens)) if num_oov_indices < 0: raise ValueError("`num_oov_indices` must be greater than or equal to 0. " "You passed {}".format(num_oov_indices)) # Support deprecated names for output_modes. if output_mode == "binary": output_mode = MULTI_HOT if output_mode == "tf-idf": output_mode = TF_IDF # 'output_mode' must be one of (INT, MULTI_HOT, COUNT, TF_IDF) layer_utils.validate_string_arg( output_mode, allowable_strings=(INT, MULTI_HOT, COUNT, TF_IDF), layer_name=self.__class__.__name__, arg_name="output_mode") if invert and output_mode != INT: raise ValueError("`output_mode` must be {} when `invert` is true. You " "passed {}".format(INT, output_mode)) self.invert = invert self.max_tokens = max_tokens self.num_oov_indices = num_oov_indices self.output_mode = output_mode self.sparse = sparse self.pad_to_max_tokens = pad_to_max_tokens self._called = False # A note on vocab_size: we need to always keep a non-Tensor representation # of vocab_size around to use in graph building. Because we might be # in a tf.function, we can't rely on evaluating the actual tables to # find the value either. self._vocab_size = None # We need to keep track our current vocab size outside of our layer weights # to support a static output shape when `output_mode != INT`. The bincount # ops do not set shape on their outputs, which means we have to set it # ourselves. We persist the current vocab size as a hidden part of the # config when serializing our model. if "vocabulary_size" in kwargs: self._vocab_size = kwargs["vocabulary_size"] del kwargs["vocabulary_size"] restore_from_static_table = kwargs.pop("has_static_table", False) # Make sure the mask token and oov token are truly of the dtype we want. We # can ignore strings here, because they have only one dtype. dtype = kwargs["dtype"] if dtype == dtypes.int32: mask_token = None if mask_token is None else np.int32(mask_token) oov_token = None if oov_token is None else np.int32(oov_token) elif dtype == dtypes.int64: mask_token = None if mask_token is None else np.int64(mask_token) oov_token = None if oov_token is None else np.int64(oov_token) self.mask_token = mask_token self.oov_token = oov_token if max_tokens is not None: available_vocab_size = max_tokens - self._token_start_index() else: available_vocab_size = None super(IndexLookup, self).__init__( combiner=_IndexLookupCombiner( vocab_size=available_vocab_size, mask_value=mask_token, oov_value=oov_token, compute_idf=(output_mode == TF_IDF)), **kwargs) # We need to save the key dtype so that we know if we're expecting int64 # keys. If we are, we will cast int32 inputs to int64 as well. if invert: self._key_dtype = dtypes.int64 self._value_dtype = self.dtype self._mask_key = 0 self._mask_value = mask_token key_index = lookup_ops.TextFileIndex.LINE_NUMBER value_index = lookup_ops.TextFileIndex.WHOLE_LINE default_value = self.oov_token oov_indices = None else: self._key_dtype = self.dtype self._value_dtype = dtypes.int64 self._mask_key = mask_token key_index = lookup_ops.TextFileIndex.WHOLE_LINE value_index = lookup_ops.TextFileIndex.LINE_NUMBER # Masks should map to 0 for int output and be dropped otherwise. Max ints # will be dropped from the bincount op. self._mask_value = 0 if self.output_mode == INT else dtypes.int64.max oov_start = self._oov_start_index() token_start = self._token_start_index() if self.num_oov_indices == 0: # If there are no OOV indices, we map OOV tokens to -1 and error out # during call if we find a negative index. default_value = -1 oov_indices = None elif self.num_oov_indices == 1: # If there is only one OOV index, we can set that index as the default # value of the index_lookup table. default_value = oov_start oov_indices = None else: # If we hav multiple OOV values, we need to do a further hashing step; # to make this easier, we set the OOV value to -1. (This lets us do a # vectorized add and cast to boolean to determine locations where we # need to do extra hashing.) default_value = -1 oov_indices = list(range(oov_start, token_start)) self._static_vocabulary_path = None has_vocab_path = (vocabulary is not None and isinstance(vocabulary, str)) if has_vocab_path or restore_from_static_table: self._has_static_table = True if vocabulary is None: # If we're restoring a layer that was saved with a static table # initializer, we create a fake initializer object to let the code # progress. The savedmodel restoration code will handle restoring # the actual data. initializer = _NullInitializer(self._key_dtype, self._value_dtype) else: if not gfile.Exists(vocabulary): raise ValueError("Vocabulary file %s does not exist." % (vocabulary,)) self._static_vocabulary_path = vocabulary num_tokens = table_utils.num_tokens_in_file(vocabulary) self._vocab_size = self._token_start_index() + num_tokens initializer = lookup_ops.TextFileInitializer( filename=vocabulary, key_dtype=self._key_dtype, key_index=key_index, value_dtype=self._value_dtype, value_index=value_index, value_index_offset=self._token_start_index()) self._table = lookup_ops.StaticHashTable( initializer, default_value=default_value) self._table_handler = table_utils.TableHandler( table=self._table, mask_token=self._mask_key if self.mask_token is not None else None, mask_value=self._mask_value, oov_tokens=oov_indices) tracked_table = self._add_trackable(self._table, trainable=False) else: self._has_static_table = False self._table = lookup_ops.MutableHashTable( key_dtype=self._key_dtype, value_dtype=self._value_dtype, default_value=default_value, name=(self._name + "_index_table")) self._table_handler = table_utils.TableHandler( table=self._table, oov_tokens=oov_indices) if vocabulary is not None: self.set_vocabulary(vocabulary) tracked_table = self._add_trackable(self._table, trainable=False) if self.output_mode == TF_IDF: # The TF-IDF weight may have a (None,) tensorshape. This creates # a 1D variable with arbitrary shape, which we can assign any weight to # so long as it has 1 dimension. In order to properly initialize this # weight in Keras, we need to provide a custom callable initializer which # does not depend on the shape of the weight (as all other initializers # do) since the weight is not known. Hence the lambda shape, dtype: [0]. if not self.pad_to_max_tokens or max_tokens is None: initializer = lambda shape, dtype: [0] else: initializer = init_ops.zeros_initializer # We are adding these here instead of in build() since they do not depend # on the input shape at all. idf_shape = (max_tokens,) if self.pad_to_max_tokens else (None,) self.tf_idf_weights = self._add_state_variable( name="idf", shape=tensor_shape.TensorShape(idf_shape), dtype=backend.floatx(), initializer=initializer) # This is a workaround for summary() on this layer. Because the table is # not mutable during training, the effective number of parameters (and so # the weight shape) is 0; we add this as an attr so that the parameter # counting code in the Model object doesn't throw an attribute error. tracked_table.shape = tensor_shape.TensorShape((0,))