Пример #1
0
  def testStaticRegexFullMatchDelegation(self):
    with self.cached_session():
      input_tensor = constant_op.constant("foo", dtypes.string)
      pattern = "[a-z]*"
      op = string_ops.regex_full_match(input_tensor, pattern)
      self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)

      pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
      op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
      self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
Пример #2
0
  def testStaticRegexFullMatchDelegation(self):
    with compat.forward_compatibility_horizon(2018, 11, 20):
      with self.cached_session():
        input_tensor = constant_op.constant("foo", dtypes.string)
        pattern = "[a-z]*"
        op = string_ops.regex_full_match(input_tensor, pattern)
        self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)

        pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
        op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
        self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
  def testStaticRegexFullMatchDelegation(self):
    with compat.forward_compatibility_horizon(2018, 11, 20):
      with self.test_session():
        input_tensor = constant_op.constant("foo", dtypes.string)
        pattern = "[a-z]*"
        op = string_ops.regex_full_match(input_tensor, pattern)
        self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)

        pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
        op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
        self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
Пример #4
0
def wordshape(input_tensor, pattern, name=None):
    """Determine wordshape features for each input string.

  Args:
    input_tensor: string `Tensor` with any shape.
    pattern: A `tftext.WordShape` or a list of WordShapes.
    name: A name for the operation (optional).

  Returns:
    `<bool>[input_tensor.shape + pattern.shape]`: A tensor where
      `result[i1...iN, j]` is true if `input_tensor[i1...iN]` has the wordshape
      specified by `pattern[j]`.

  Raises:
    ValueError: If `pattern` contains an unknown identifier.
  """
    if isinstance(pattern, WordShape):
        return string_ops.regex_full_match(input_tensor, pattern.value, name)
    elif (isinstance(pattern, (list, tuple))
          and all(isinstance(s, WordShape) for s in pattern)):
        with ops.name_scope(name, "Wordshape", input_tensor):
            return array_ops.stack(
                [wordshape(input_tensor, s) for s in pattern], axis=-1)
    else:
        raise TypeError(
            "Expected 'pattern' to be a single WordShape or a list of WordShapes."
        )
 def testInvalidPattern(self):
     values = ["abc", "1"]
     with self.test_session():
         input_vector = constant_op.constant(values, dtypes.string)
         invalid_pattern = "A["
         matched = string_ops.regex_full_match(input_vector,
                                               invalid_pattern)
         with self.assertRaisesOpError("Invalid pattern"):
             matched.eval()
Пример #6
0
  def save(self, file_prefix, options=None):
    """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
    options = options or checkpoint_options.CheckpointOptions()

    # IMPLEMENTATION DETAILS: most clients should skip.
    #
    # Suffix for any well-formed "checkpoint_prefix", when sharded.
    # Transformations:
    # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
    #
    # Example:
    #   During runtime, a temporary directory is first created, which contains
    #   files
    #
    #     <train dir>/myckpt_temp/
    #        part-?????-of-?????{.index, .data-00000-of-00001}
    #
    #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    #
    #     <train dir>/
    #        myckpt{.index, .data-?????-of-?????}
    #
    #   Filesystems with eventual consistency (such as S3), don't need a
    #   temporary location. Using a temporary directory in those cases might
    #   cause situations where files are not available during copy.
    #
    # Users only need to interact with the user-specified prefix, which is
    # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    # prefix directly, instead of any physical pathname.  (On failure and
    # subsequent restore, an outdated and orphaned temporary directory can be
    # safely removed.)
    with ops.device("CPU"):
      sharded_suffix = array_ops.where(
          string_ops.regex_full_match(file_prefix, "^s3://.*"),
          constant_op.constant(".part"),
          constant_op.constant("_temp/part"))
      tmp_checkpoint_prefix = string_ops.string_join(
          [file_prefix, sharded_suffix])
      registered_paths = {
          saver_name: registered_saver_filename(file_prefix, saver_name)
          for saver_name in self._registered_savers
      }

    def save_fn():
      saved_prefixes = []
      # Save with the registered savers. These run before default savers due to
      # the API contract.
      for saver_name, (save_fn, _) in self._registered_savers.items():
        maybe_saved_prefixes = save_fn(registered_paths[saver_name])
        if maybe_saved_prefixes is not None:
          flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
          if not all(
              tensor_util.is_tf_type(x) and x.dtype == dtypes.string
              for x in flattened_saved_prefixes):
            raise ValueError(
                "Registered saver must return a (maybe empty) list of "
                f"string type tensors. Got {maybe_saved_prefixes}.")
          saved_prefixes.extend(flattened_saved_prefixes)

      # (Default saver) Save with single device savers.
      num_shards = len(self._single_device_savers)
      sharded_saves = []
      num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
      last_device = None
      for shard, (device, saver) in enumerate(
          sorted(self._single_device_savers.items())):
        last_device = device
        with ops.device(saveable_object_util.set_cpu0(device)):
          shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                          num_shards_tensor)
        saved_prefixes.append(shard_prefix)
        with ops.device(device):
          # _SingleDeviceSaver will use the CPU device when necessary, but
          # initial read operations should be placed on the SaveableObject's
          # device.
          sharded_saves.append(saver.save(shard_prefix, options))

      with ops.control_dependencies(sharded_saves):
        # Merge on the io_device if specified, otherwise co-locates the merge op
        # with the last device used.
        merge_device = (
            options.experimental_io_device or
            saveable_object_util.set_cpu0(last_device))
        with ops.device(merge_device):
          # V2 format write path consists of a metadata merge step.  Once
          # merged, attempts to delete the temporary directory,
          # "<user-fed prefix>_temp".
          return gen_io_ops.merge_v2_checkpoints(
              saved_prefixes, file_prefix, delete_old_dirs=True)

    # Since this will causes a function re-trace on each save, limit this to the
    # cases where it is needed: eager and when there are multiple tasks/single
    # device savers. Note that the retrace is needed to ensure we pickup the
    # latest values of options like experimental_io_device.
    if context.executing_eagerly() and len(self._single_device_savers) > 1:
      # Explicitly place the identity op on the first device.
      @def_function.function(jit_compile=False)
      def tf_function_save():
        save_fn()
      tf_function_save()
    else:
      return save_fn()
Пример #7
0
    def save(self, file_prefix, options=None):
        """Save the saveable objects to a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix to
        save under.
      options: Optional `CheckpointOptions` object.
    Returns:
      An `Operation`, or None when executing eagerly.
    """
        options = options or checkpoint_options.CheckpointOptions()
        for callback in self._before_save_callbacks:
            callback()

        # IMPLEMENTATION DETAILS: most clients should skip.
        #
        # Suffix for any well-formed "checkpoint_prefix", when sharded.
        # Transformations:
        # * Users pass in "save_path" in save() and restore().  Say "myckpt".
        # * checkpoint_prefix gets fed <save_path><sharded_suffix>.
        #
        # Example:
        #   During runtime, a temporary directory is first created, which contains
        #   files
        #
        #     <train dir>/myckpt_temp/
        #        part-?????-of-?????{.index, .data-00000-of-00001}
        #
        #   Before .save() finishes, they will be (hopefully, atomically) renamed to
        #
        #     <train dir>/
        #        myckpt{.index, .data-?????-of-?????}
        #
        #   Filesystems with eventual consistency (such as S3), don't need a
        #   temporary location. Using a temporary directory in those cases might
        #   cause situations where files are not available during copy.
        #
        # Users only need to interact with the user-specified prefix, which is
        # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
        # prefix directly, instead of any physical pathname.  (On failure and
        # subsequent restore, an outdated and orphaned temporary directory can be
        # safely removed.)
        with ops.device("CPU"):
            sharded_suffix = array_ops.where(
                string_ops.regex_full_match(file_prefix, "^s3://.*"),
                constant_op.constant(".part"),
                constant_op.constant("_temp_%s/part" % uuid.uuid4().hex))
            tmp_checkpoint_prefix = string_ops.string_join(
                [file_prefix, sharded_suffix])

        num_shards = len(self._single_device_savers)
        sharded_saves = []
        sharded_prefixes = []
        num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
        last_device = None
        for shard, (device, saver) in enumerate(
                sorted(self._single_device_savers.items())):
            last_device = device
            with ops.device(saveable_object_util.set_cpu0(device)):
                shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
                                                num_shards_tensor)
            sharded_prefixes.append(shard_prefix)
            with ops.device(device):
                # _SingleDeviceSaver will use the CPU device when necessary, but initial
                # read operations should be placed on the SaveableObject's device.
                sharded_saves.append(saver.save(shard_prefix, options))

        with ops.control_dependencies(sharded_saves):
            # Merge on the io_device if specified, otherwise co-locates the merge op
            # with the last device used.
            merge_device = (options.experimental_io_device
                            or saveable_object_util.set_cpu0(last_device))
            with ops.device(merge_device):
                # V2 format write path consists of a metadata merge step.  Once merged,
                # attempts to delete the temporary directory, "<user-fed prefix>_temp".
                return gen_io_ops.merge_v2_checkpoints(sharded_prefixes,
                                                       file_prefix,
                                                       delete_old_dirs=True)
 def testEmptyMatch(self):
     values = ["abc", "1"]
     with self.test_session():
         input_vector = constant_op.constant(values, dtypes.string)
         matched = string_ops.regex_full_match(input_vector, "").eval()
         self.assertAllEqual([False, False], matched)
 def testRegexFullMatch(self):
     values = ["abaaba", "abcdabcde"]
     with self.test_session():
         input_vector = constant_op.constant(values, dtypes.string)
         matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
         self.assertAllEqual([True, False], matched)