コード例 #1
0
def get_static_table(tmpdir,
                     vocab_list,
                     mask_token=None,
                     dtype=dtypes.string,
                     oov_tokens=None):
    vocabulary_file = os.path.join(tmpdir, "tmp_vocab.txt")

    if dtype == dtypes.string:
        with open(vocabulary_file, "w") as f:
            f.write("\n".join(vocab_list) + "\n")
    else:
        with open(vocabulary_file, "w") as f:
            f.write("\n".join([str(v) for v in vocab_list]) + "\n")

    offset = ((0 if mask_token is None else 1) +
              (len(oov_tokens) if oov_tokens is not None else 0))
    init = lookup_ops.TextFileInitializer(vocabulary_file,
                                          dtype,
                                          lookup_ops.TextFileIndex.WHOLE_LINE,
                                          dtypes.int64,
                                          lookup_ops.TextFileIndex.LINE_NUMBER,
                                          value_index_offset=offset)
    if context.executing_eagerly():
        table = lookup_ops.StaticHashTable(init, default_value=-7)
    else:
        table = lookup_ops.StaticHashTableV1(init, default_value=-7)

    return table_utils.TableHandler(
        table,
        oov_tokens,
        mask_token=mask_token,
        use_v1_apis=(not context.executing_eagerly()))
コード例 #2
0
 def __init__(self):
     self.asset = asset.Asset(
         test.test_src_dir_path(
             "cc/saved_model/testdata/static_hashtable_asset.txt"))
     self.table = lookup_ops.StaticHashTable(
         lookup_ops.TextFileInitializer(
             self.asset, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
             dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), -1)
コード例 #3
0
 def testDistributeLookupTable(self, init_source):
   cluster = data_service_test_base.TestCluster(num_workers=1)
   initializer = self.lookupTableInitializer(init_source, [10, 11])
   table = lookup_ops.StaticHashTable(initializer, -1)
   ds = dataset_ops.Dataset.range(3)
   ds = ds.map(table.lookup)
   ds = self.make_distributed_dataset(ds, cluster)
   self.evaluate(lookup_ops.tables_initializer())
   self.assertDatasetProduces(ds, [10, 11, -1], requires_initialization=True)
コード例 #4
0
 def createStaticHashTable(self,
                           init_source=None,
                           vals=None,
                           default_value=None,
                           initializer=None):
     if not initializer:
         initializer = self.make_initializer(init_source, vals)
     return lookup_ops.StaticHashTable(initializer=initializer,
                                       default_value=default_value)
コード例 #5
0
 def get_graph_def():
   with ops.Graph().as_default() as g:
     x = constant_op.constant([2, 9], name="x")
     keys = constant_op.constant([1, 2], name="keys")
     values = constant_op.constant([3, 4], name="values")
     default = constant_op.constant(-1, name="default")
     table = lookup_ops.StaticHashTable(
         lookup_ops.KeyValueTensorInitializer(keys, values), default)
     _ = table.lookup(x)
   return g.as_graph_def()
コード例 #6
0
 def testDistributeLookupTable(self):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     keys_tensor = constant_op.constant([1, 2])
     vals_tensor = constant_op.constant([11, 12])
     table = lookup_ops.StaticHashTable(
         lookup_ops.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
     ds = dataset_ops.Dataset.range(3, output_type=dtypes.int32)
     ds = ds.map(table.lookup)
     ds = self.make_distributed_dataset(ds, cluster)
     self.assertDatasetProduces(ds, [-1, 11, 12])
コード例 #7
0
 def testLookupTableGraphSerialization(self, init_source):
     vals = [10, 11]
     initializer = self.lookupTableInitializer(init_source, vals)
     table = lookup_ops.StaticHashTable(initializer, -1)
     dataset = dataset_ops.Dataset.range(3)
     dataset = dataset.map(table.lookup)
     self.evaluate(lookup_ops.tables_initializer())
     round_tripped = self.graphRoundTrip(dataset)
     del table
     del dataset
     self.assertDatasetProduces(round_tripped, [10, 11, -1],
                                requires_initialization=True)
コード例 #8
0
  def testResourceOnWrongDevice(self):
    cluster = data_service_test_base.TestCluster(num_workers=1)
    with ops.device(self._devices[0]):
      initializer = self.lookupTableInitializer("keyvaluetensor", [10, 11])
      table = lookup_ops.StaticHashTable(initializer, -1)
      self.evaluate(lookup_ops.tables_initializer())

    with ops.device(self._devices[1]):
      ds = dataset_ops.Dataset.range(3)
      ds = ds.map(table.lookup)
      with self.assertRaisesRegex(
          errors.FailedPreconditionError,
          "Serialization error while trying to register a dataset"):
        ds = self.make_distributed_dataset(ds, cluster)
        self.getDatasetOutput(ds, requires_initialization=True)
コード例 #9
0
  def testStaticHashTableDatasetFnHostTrainingLoop(self, enable_packed_var):
    self._dataset_fn_tracing_count = 0
    strategy = get_tpu_strategy(enable_packed_var)

    with strategy.scope():
      vals = [0, 1, 2]
      keys_tensor = constant_op.constant(
          list(range(len(vals))), dtype=dtypes.int64)
      vals_tensor = constant_op.constant(vals)
      initializer = lookup_ops.KeyValueTensorInitializer(
          keys_tensor, vals_tensor)
      per_worker_table = lookup_ops.StaticHashTable(
          initializer, default_value=-1)

    @def_function.function
    def dataset_fn(input_context):
      tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64)
      global_batch_size = 2
      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
      dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
          batch_size, drop_remainder=True)
      dataset = dataset.shard(input_context.num_input_pipelines,
                              input_context.input_pipeline_id)
      dataset = dataset.prefetch(2)  # This prefetches 2 batches per device.
      dataset = dataset.map(per_worker_table.lookup)
      self._dataset_fn_tracing_count += 1
      return dataset

    dist_iterator = iter(
        strategy.experimental_distribute_datasets_from_function(dataset_fn))

    @def_function.function
    def step_fn(inputs):
      # inputs should be [0, 1, -1]
      return math_ops.reduce_sum(inputs)

    def train_steps(iterator, steps):

      for _ in math_ops.range(steps):
        strategy.run(step_fn, args=(next(iterator),))

    train_steps(dist_iterator, steps=5)
    self.assertEqual(self._dataset_fn_tracing_count, 1)
コード例 #10
0
    def __init__(self, init_source, filepath):
      vals = [0, 1, 2]
      if init_source == "textfile":

        with open(filepath, "w") as f:
          f.write("\n".join(str(v) for v in vals) + "\n")

        self.initializer = lookup_ops.TextFileInitializer(
            filepath, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
            dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
      else:
        keys_tensor = constant_op.constant(
            list(range(len(vals))), dtype=dtypes.int64)
        vals_tensor = constant_op.constant(vals)
        self.initializer = lookup_ops.KeyValueTensorInitializer(
            keys_tensor, vals_tensor)

      self.table = lookup_ops.StaticHashTable(
          self.initializer, default_value=-2)
コード例 #11
0
 def testDistributeLookupTable(self, init_from_file):
     cluster = data_service_test_base.TestCluster(num_workers=1)
     if init_from_file:
         file = os.path.join(self.get_temp_dir(), "distribute_lookup_table")
         with open(file, "w") as f:
             f.write("10\n11\n")
         initializer = lookup_ops.TextFileInitializer(
             file, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
             dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
     else:
         keys_tensor = constant_op.constant([0, 1], dtype=dtypes.int64)
         vals_tensor = constant_op.constant([10, 11])
         initializer = lookup_ops.KeyValueTensorInitializer(
             keys_tensor, vals_tensor)
     table = lookup_ops.StaticHashTable(initializer, -1)
     ds = dataset_ops.Dataset.range(3)
     ds = ds.map(table.lookup)
     ds = self.make_distributed_dataset(ds, cluster)
     self.evaluate(lookup_ops.tables_initializer())
     self.assertDatasetProduces(ds, [10, 11, -1],
                                requires_initialization=True)
コード例 #12
0
    def testLookupTableGraphSerialization(self, init_from_file):
        if init_from_file:
            file = os.path.join(self.get_temp_dir(),
                                "lookup_table_graph_serialize")
            with open(file, "w") as f:
                f.write("10\n11\n")
            initializer = lookup_ops.TextFileInitializer(
                file, dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER,
                dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE)
        else:
            keys_tensor = constant_op.constant([0, 1], dtype=dtypes.int64)
            vals_tensor = constant_op.constant([10, 11])
            initializer = lookup_ops.KeyValueTensorInitializer(
                keys_tensor, vals_tensor)

        table = lookup_ops.StaticHashTable(initializer, -1)
        dataset = dataset_ops.Dataset.range(3)
        dataset = dataset.map(table.lookup)
        self.evaluate(lookup_ops.tables_initializer())
        round_tripped = self.graphRoundTrip(dataset)
        del table
        del dataset
        self.assertDatasetProduces(round_tripped, [10, 11, -1],
                                   requires_initialization=True)
コード例 #13
0
  def __init__(self,
               max_tokens,
               num_oov_indices,
               mask_token,
               oov_token,
               vocabulary=None,
               invert=False,
               output_mode=INT,
               sparse=False,
               pad_to_max_tokens=False,
               **kwargs):
    # If max_tokens is set, the value must be greater than 1 - otherwise we
    # are creating a 0-element vocab, which doesn't make sense.
    if max_tokens is not None and max_tokens <= 1:
      raise ValueError("If set, `max_tokens` must be greater than 1. "
                       "You passed {}".format(max_tokens))

    if num_oov_indices < 0:
      raise ValueError("`num_oov_indices` must be greater than or equal to 0. "
                       "You passed {}".format(num_oov_indices))

    # Support deprecated names for output_modes.
    if output_mode == "binary":
      output_mode = MULTI_HOT
    if output_mode == "tf-idf":
      output_mode = TF_IDF
    # 'output_mode' must be one of (INT, MULTI_HOT, COUNT, TF_IDF)
    layer_utils.validate_string_arg(
        output_mode,
        allowable_strings=(INT, MULTI_HOT, COUNT, TF_IDF),
        layer_name=self.__class__.__name__,
        arg_name="output_mode")

    if invert and output_mode != INT:
      raise ValueError("`output_mode` must be {} when `invert` is true. You "
                       "passed {}".format(INT, output_mode))

    self.invert = invert
    self.max_tokens = max_tokens
    self.num_oov_indices = num_oov_indices
    self.output_mode = output_mode
    self.sparse = sparse
    self.pad_to_max_tokens = pad_to_max_tokens
    self._called = False

    # A note on vocab_size: we need to always keep a non-Tensor representation
    # of vocab_size around to use in graph building. Because we might be
    # in a tf.function, we can't rely on evaluating the actual tables to
    # find the value either.
    self._vocab_size = None
    # We need to keep track our current vocab size outside of our layer weights
    # to support a static output shape when `output_mode != INT`. The bincount
    # ops do not set shape on their outputs, which means we have to set it
    # ourselves. We persist the current vocab size as a hidden part of the
    # config when serializing our model.
    if "vocabulary_size" in kwargs:
      self._vocab_size = kwargs["vocabulary_size"]
      del kwargs["vocabulary_size"]

    restore_from_static_table = kwargs.pop("has_static_table", False)

    # Make sure the mask token and oov token are truly of the dtype we want. We
    # can ignore strings here, because they have only one dtype.
    dtype = kwargs["dtype"]
    if dtype == dtypes.int32:
      mask_token = None if mask_token is None else np.int32(mask_token)
      oov_token = None if oov_token is None else np.int32(oov_token)
    elif dtype == dtypes.int64:
      mask_token = None if mask_token is None else np.int64(mask_token)
      oov_token = None if oov_token is None else np.int64(oov_token)
    self.mask_token = mask_token
    self.oov_token = oov_token

    if max_tokens is not None:
      available_vocab_size = max_tokens - self._token_start_index()
    else:
      available_vocab_size = None

    super(IndexLookup, self).__init__(
        combiner=_IndexLookupCombiner(
            vocab_size=available_vocab_size,
            mask_value=mask_token,
            oov_value=oov_token,
            compute_idf=(output_mode == TF_IDF)),
        **kwargs)

    # We need to save the key dtype so that we know if we're expecting int64
    # keys. If we are, we will cast int32 inputs to int64 as well.
    if invert:
      self._key_dtype = dtypes.int64
      self._value_dtype = self.dtype
      self._mask_key = 0
      self._mask_value = mask_token
      key_index = lookup_ops.TextFileIndex.LINE_NUMBER
      value_index = lookup_ops.TextFileIndex.WHOLE_LINE
      default_value = self.oov_token
      oov_indices = None
    else:
      self._key_dtype = self.dtype
      self._value_dtype = dtypes.int64
      self._mask_key = mask_token
      key_index = lookup_ops.TextFileIndex.WHOLE_LINE
      value_index = lookup_ops.TextFileIndex.LINE_NUMBER
      # Masks should map to 0 for int output and be dropped otherwise. Max ints
      # will be dropped from the bincount op.
      self._mask_value = 0 if self.output_mode == INT else dtypes.int64.max
      oov_start = self._oov_start_index()
      token_start = self._token_start_index()
      if self.num_oov_indices == 0:
        # If there are no OOV indices, we map OOV tokens to -1 and error out
        # during call if we find a negative index.
        default_value = -1
        oov_indices = None
      elif self.num_oov_indices == 1:
        # If there is only one OOV index, we can set that index as the default
        # value of the index_lookup table.
        default_value = oov_start
        oov_indices = None
      else:
        # If we hav multiple OOV values, we need to do a further hashing step;
        # to make this easier, we set the OOV value to -1. (This lets us do a
        # vectorized add and cast to boolean to determine locations where we
        # need to do extra hashing.)
        default_value = -1
        oov_indices = list(range(oov_start, token_start))

    self._static_vocabulary_path = None
    has_vocab_path = (vocabulary is not None and isinstance(vocabulary, str))
    if has_vocab_path or restore_from_static_table:
      self._has_static_table = True
      if vocabulary is None:
        # If we're restoring a layer that was saved with a static table
        # initializer, we create a fake initializer object to let the code
        # progress. The savedmodel restoration code will handle restoring
        # the actual data.
        initializer = _NullInitializer(self._key_dtype, self._value_dtype)
      else:
        if not gfile.Exists(vocabulary):
          raise ValueError("Vocabulary file %s does not exist." % (vocabulary,))
        self._static_vocabulary_path = vocabulary
        num_tokens = table_utils.num_tokens_in_file(vocabulary)
        self._vocab_size = self._token_start_index() + num_tokens

        initializer = lookup_ops.TextFileInitializer(
            filename=vocabulary,
            key_dtype=self._key_dtype,
            key_index=key_index,
            value_dtype=self._value_dtype,
            value_index=value_index,
            value_index_offset=self._token_start_index())

      self._table = lookup_ops.StaticHashTable(
          initializer, default_value=default_value)
      self._table_handler = table_utils.TableHandler(
          table=self._table,
          mask_token=self._mask_key if self.mask_token is not None else None,
          mask_value=self._mask_value,
          oov_tokens=oov_indices)

      tracked_table = self._add_trackable(self._table, trainable=False)

    else:
      self._has_static_table = False
      self._table = lookup_ops.MutableHashTable(
          key_dtype=self._key_dtype,
          value_dtype=self._value_dtype,
          default_value=default_value,
          name=(self._name + "_index_table"))
      self._table_handler = table_utils.TableHandler(
          table=self._table,
          oov_tokens=oov_indices)
      if vocabulary is not None:
        self.set_vocabulary(vocabulary)
      tracked_table = self._add_trackable(self._table, trainable=False)

    if self.output_mode == TF_IDF:
      # The TF-IDF weight may have a (None,) tensorshape. This creates
      # a 1D variable with arbitrary shape, which we can assign any weight to
      # so long as it has 1 dimension. In order to properly initialize this
      # weight in Keras, we need to provide a custom callable initializer which
      # does not depend on the shape of the weight (as all other initializers
      # do) since the weight is not known. Hence the lambda shape, dtype: [0].
      if not self.pad_to_max_tokens or max_tokens is None:
        initializer = lambda shape, dtype: [0]
      else:
        initializer = init_ops.zeros_initializer

      # We are adding these here instead of in build() since they do not depend
      # on the input shape at all.
      idf_shape = (max_tokens,) if self.pad_to_max_tokens else (None,)
      self.tf_idf_weights = self._add_state_variable(
          name="idf",
          shape=tensor_shape.TensorShape(idf_shape),
          dtype=backend.floatx(),
          initializer=initializer)

    # This is a workaround for summary() on this layer. Because the table is
    # not mutable during training, the effective number of parameters (and so
    # the weight shape) is 0; we add this as an attr so that the parameter
    # counting code in the Model object doesn't throw an attribute error.
    tracked_table.shape = tensor_shape.TensorShape((0,))