def testNestFlattenWithJoinedStringPaths(self):
    structure = [[TestCompositeTensor(1, 2, 3)], 100, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    }]
    result1 = nest.flatten_with_joined_string_paths(
        structure, expand_composites=True)
    expected1 = [('0/0/0', 1), ('0/0/1', 2), ('0/0/2', 3), ('1', 100),
                 ('2/y/0/0', 4), ('2/y/0/1', 5), ('2/y/1', 6)]
    self.assertEqual(result1, expected1)

    result2 = nest.flatten_with_joined_string_paths(
        structure, expand_composites=False)
    expected2 = [('0/0', TestCompositeTensor(1, 2, 3)), ('1', 100),
                 ('2/y', TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
    self.assertEqual(result2, expected2)
示例#2
0
    def encode_fn(x, flat_state):
      state = tf.nest.pack_sequence_as(state_py_structure['state'], flat_state)
      encode_params, decode_params = encoder.get_params(state)
      encoded_x, state_update_tensors, input_shapes = encoder.encode(
          x, encode_params)
      updated_flat_state = tuple(
          tf.nest.flatten(encoder.update_state(state, state_update_tensors)))

      # The following code converts the nested structres necessary for the
      # underlying encoder, to a single flat dictionary, which is simpler to
      # manipulate by the users of SimpleEncoder.
      full_encoded_structure = {
          _TENSORS: encoded_x,
          _PARAMS: decode_params,
          _SHAPES: input_shapes
      }
      flat_encoded_structure = dict(
          core_nest.flatten_with_joined_string_paths(
              full_encoded_structure, separator='/'))
      flat_encoded_py_structure, flat_encoded_tf_structure = (
          py_utils.split_dict_py_tf(flat_encoded_structure))

      if not encoded_py_structure:
        encoded_py_structure['full'] = tf.nest.map_structure(
            lambda _: None, full_encoded_structure)
        encoded_py_structure['flat_py'] = flat_encoded_py_structure
      return flat_encoded_tf_structure, updated_flat_state
    def testNestFlattenWithJoinedStringPaths(self):
        structure = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        }]
        result1 = nest.flatten_with_joined_string_paths(structure,
                                                        expand_composites=True)
        expected1 = [('0/0/0', 1), ('0/0/1', 2), ('0/0/2', 3), ('1', 100),
                     ('2/y/0/0', 4), ('2/y/0/1', 5), ('2/y/1', 6)]
        self.assertEqual(result1, expected1)

        result2 = nest.flatten_with_joined_string_paths(
            structure, expand_composites=False)
        expected2 = [('0/0', TestCompositeTensor(1, 2, 3)), ('1', 100),
                     ('2/y', TestCompositeTensor(TestCompositeTensor(4, 5),
                                                 6))]
        self.assertEqual(result2, expected2)
示例#4
0
    def extract_loss_metric_dict_from_history(
            self,
            history: tf.keras.callbacks.History,
            structure: dict,
            prefix='val_') -> dict:
        history: Dict[str, float] = history.history

        # metrics from validation set starts with val
        if prefix:
            if prefix != 'val_':
                raise ValueError('prefix should either be "val_" or None')
            history = {
                k.replace(prefix, ''): v
                for k, v in history.items() if k.startswith(prefix)
            }

        # get structure path
        structure_path = [
            p for p, _ in flatten_with_joined_string_paths(structure)
        ]
        # make flat history and pack
        flat_history = [history[p] for p in structure_path]
        history = tf.nest.pack_sequence_as(structure=structure,
                                           flat_sequence=flat_history)

        return history
        def decode_before_sum_fn(encoded_structure, params):
            """See the `decode_before_sum` method of this class."""
            py_utils.assert_compatible(encoded_structure_spec,
                                       encoded_structure)
            py_utils.assert_compatible(decode_before_sum_params_spec, params)

            encoded_structure = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(
                    internal_structure['encoded_structure'],
                    tf.nest.flatten(encoded_structure)),
                internal_py_values['encoded_structure'])
            params = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(
                    internal_structure['decode_before_sum_params'], params),
                internal_py_values['decode_before_sum_params'])

            encoded_tensors = encoded_structure[_TENSORS]
            input_shapes = encoded_structure[_SHAPES]
            part_decoded_structure = encoder.decode_before_sum(
                encoded_tensors, params, input_shapes)

            _add_to_structure('part_decoded_structure', part_decoded_structure)
            if isinstance(part_decoded_structure, dict):
                return dict(
                    core_nest.flatten_with_joined_string_paths(
                        part_decoded_structure, separator='/'))
            else:
                return part_decoded_structure
        def encode_fn(x, params):
            """See the `encode` method of this class."""
            if not tensorspec.is_compatible_with(x):
                raise ValueError(
                    'The provided x is not compatible with the expected tensorspec.'
                )
            py_utils.assert_compatible(encode_params_spec, params)

            params = py_utils.merge_dicts(
                tf.nest.pack_sequence_as(internal_structure['encode_params'],
                                         params),
                internal_py_values['encode_params'])
            encoded_x, state_update_tensors, input_shapes = encoder.encode(
                x, params)
            input_shapes_before_sum, _ = (
                core_encoder.split_shapes_by_commuting_structure(
                    input_shapes, commuting_structure))

            encoded_structure = {
                _TENSORS: encoded_x,
                _SHAPES: input_shapes_before_sum
            }
            encoded_structure_py, encoded_structure_tf = py_utils.split_dict_py_tf(
                encoded_structure)

            _add_to_structure('encoded_structure', encoded_structure_tf)
            _add_to_structure('state_update_tensors', state_update_tensors)
            _add_to_py_values('encoded_structure', encoded_structure_py)

            return (dict(
                core_nest.flatten_with_joined_string_paths(
                    encoded_structure_tf, separator='/')),
                    tuple(tf.nest.flatten(state_update_tensors)))
示例#7
0
 def testFlattenWithStringPaths(self):
   for inputs_expected in (
       {"inputs": [], "expected": []},
       {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
       {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
     inputs = inputs_expected["inputs"]
     expected = inputs_expected["expected"]
     self.assertEqual(
         nest.flatten_with_joined_string_paths(inputs, separator="/"),
         expected)
示例#8
0
    def test_identity(self):
        encoder = common_encoders.identity()
        self.assertIsInstance(encoder, core_encoder.Encoder)

        params, _ = encoder.get_params(encoder.initial_state())
        encoded_x, _, _ = encoder.encode(tf.constant(1.0), params)
        keys = [
            k for k, _ in core_nest.flatten_with_joined_string_paths(encoded_x)
        ]
        self.assertSameElements(['identity_values'], keys)
示例#9
0
 def testFlattenWithStringPaths(self):
   for inputs_expected in (
       {"inputs": [], "expected": []},
       {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
       {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
     inputs = inputs_expected["inputs"]
     expected = inputs_expected["expected"]
     self.assertEqual(
         nest.flatten_with_joined_string_paths(inputs, separator="/"),
         expected)
示例#10
0
    def test_hadamard_quantization(self):
        encoder = common_encoders.hadamard_quantization(8)
        self.assertIsInstance(encoder, core_encoder.Encoder)

        params, _ = encoder.get_params(encoder.initial_state())
        encoded_x, _, _ = encoder.encode(tf.constant(1.0), params)
        keys = [
            k for k, _ in core_nest.flatten_with_joined_string_paths(encoded_x)
        ]
        self.assertSameElements([
            'flattened_values/hadamard_values/min_max',
            'flattened_values/hadamard_values/quantized_values/bitpacked_values'
        ], keys)
示例#11
0
 def _check_spec(self, element_spec):
     if isinstance(element_spec, values.PerReplicaSpec):
         element_spec = element_spec._component_specs  # pylint: disable=protected-access
     specs = nest.flatten_with_joined_string_paths(element_spec)
     for path, spec in specs:
         if isinstance(spec, (sparse_tensor.SparseTensorSpec,
                              ragged_tensor.RaggedTensorSpec)):
             raise ValueError(
                 "Found tensor {} with spec {}. TPUStrategy does not support "
                 "distributed datasets with device prefetch when using sparse or "
                 "ragged tensors. If you indend to use sparse or ragged tensors, "
                 "please pass a tf.distribute.InputOptions object with "
                 "experimental_prefetch_to_device set to False to your dataset "
                 "distribution function.".format(path, type(spec)))
示例#12
0
  def testNestFlatten(self, structure, expected, paths, expand_composites=True):
    result = nest.flatten(structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)

    result_with_paths = nest.flatten_with_tuple_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))

    string_paths = ['/'.join(str(p) for p in path) for path in paths]  # pylint: disable=g-complex-comprehension
    result_with_string_paths = nest.flatten_with_joined_string_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_string_paths,
                     list(zip(string_paths, expected)))

    flat_paths_result = list(
        nest.yield_flat_paths(structure, expand_composites=expand_composites))
    self.assertEqual(flat_paths_result, paths)
  def testNestFlatten(self, structure, expected, paths, expand_composites=True):
    result = nest.flatten(structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)

    result_with_paths = nest.flatten_with_tuple_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))

    string_paths = ['/'.join(str(p) for p in path) for path in paths]  # pylint: disable=g-complex-comprehension
    result_with_string_paths = nest.flatten_with_joined_string_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_string_paths,
                     list(zip(string_paths, expected)))

    flat_paths_result = list(
        nest.yield_flat_paths(structure, expand_composites=expand_composites))
    self.assertEqual(flat_paths_result, paths)
示例#14
0
 def testFlattenNamedTuple(self):
   # pylint: disable=invalid-name
   Foo = collections.namedtuple("Foo", ["a", "b"])
   Bar = collections.namedtuple("Bar", ["c", "d"])
   # pylint: enable=invalid-name
   test_cases = [
       (Foo(a=3, b=Bar(c=23, d=42)),
        [("a", 3), ("b/c", 23), ("b/d", 42)]),
       (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
        [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
       (Bar(c=42, d=43),
        [("c", 42), ("d", 43)]),
       (Bar(c=[42], d=43),
        [("c/0", 42), ("d", 43)]),
   ]
   for inputs, expected in test_cases:
     self.assertEqual(
         list(nest.flatten_with_joined_string_paths(inputs)), expected)
示例#15
0
 def testFlattenNamedTuple(self):
     # pylint: disable=invalid-name
     Foo = collections.namedtuple("Foo", ["a", "b"])
     Bar = collections.namedtuple("Bar", ["c", "d"])
     # pylint: enable=invalid-name
     test_cases = [
         (Foo(a=3, b=Bar(c=23, d=42)), [("a", 3), ("b/c", 23),
                                        ("b/d", 42)]),
         (Foo(a=Bar(c=23, d=42),
              b=Bar(c=0, d="something")), [("a/c", 23), ("a/d", 42),
                                           ("b/c", 0),
                                           ("b/d", "something")]),
         (Bar(c=42, d=43), [("c", 42), ("d", 43)]),
         (Bar(c=[42], d=43), [("c/0", 42), ("d", 43)]),
     ]
     for inputs, expected in test_cases:
         self.assertEqual(
             list(nest.flatten_with_joined_string_paths(inputs)), expected)
示例#16
0
def convert_structure_to_signature(structure):
    """Convert a potentially nested structure to a signature.

  Args:
    structure: Structure to convert.

  Returns:
    Identical structure that has TensorSpec objects instead of Tensors and
    UknownArgument instead of any unsupported types.
  """
    def encode_arg(arg, name=None):
        """A representation for this argument, for converting into signatures."""
        if isinstance(arg, ops.Tensor):
            return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
        if isinstance(arg, (int, float, bool, tensor_spec.TensorSpec)):
            return arg
        return UnknownArgument()

    # We are using the flattened paths to name the TensorSpecs. We need an
    # explicit name for them downstream.
    flattened_with_paths = nest.flatten_with_joined_string_paths(structure)
    mapped = [encode_arg(arg, path) for path, arg in flattened_with_paths]
    return nest.pack_sequence_as(structure, mapped)
示例#17
0
 def add_flatten_losses_metrics(self, return_dict: dict):
     current_eval_loss_dict = create_dict_from_nested_model(self, ele_name='losses')
     flatten_losses = flatten_with_joined_string_paths(current_eval_loss_dict)
     flatten_losses = {p: v for p, v in flatten_losses}
     return_dict.update(flatten_losses)
     return return_dict
示例#18
0
    def assertAllAssertsNested(self, assert_fn, *structure, **kwargs):
        """Run `assert_fn` on `structure` and report which elements errored.

    This function will run `assert_fn` on each element of `structure` as
    `assert_fn(structure[0], structure[1], ...)`, collecting any exceptions
    raised in the process. Afterward, it will report which elements of
    `structure` triggered an assertion, as well as the assertions themselves.

    Args:
      assert_fn: A callable that accepts as many arguments as there are
        structures.
      *structure: A list of nested structures.
      **kwargs: Valid keyword args are:

        * `shallow`: If not None, uses this as the shared tree prefix of
          `structure` for the purpose of being able to use `structure` which
          only share that tree prefix (e.g. `[1, 2]` and `[[1], 2]` share the
          `[., .]` tree prefix).
        * `msg`: Used as the message when a failure happened. Default:
          `"AllAssertsNested failed"`.
        * `check_types`: If `True`, types of sequences are checked as well,
          including the keys of dictionaries. If `False`, for example a list and
          a tuple of objects may be equivalent. Default: `False`.

    Raises:
      AssertionError: If the structures are mismatched, or at `assert_fn` raised
        an exception at least once.
    """
        shallow = kwargs.pop('shallow', None)
        if shallow is None:
            shallow = structure[0]
        msg = kwargs.pop('msg', 'AllAssertsNested failed')

        def _one_part(*structure):
            try:
                assert_fn(*structure)
            except Exception as part_e:  # pylint: disable=broad-except
                return part_e

        try:
            maybe_exceptions = nest.map_structure_up_to(
                shallow, _one_part, *structure, **kwargs)
            overall_exception = None
            exceptions_with_paths = [
                (p, e) for p, e in nest.flatten_with_joined_string_paths(
                    maybe_exceptions) if e is not None
            ]
        except Exception as e:  # pylint: disable=broad-except
            overall_exception = e
            exceptions_with_paths = []

        final_msg = '{}:\n\n'.format(msg)
        if overall_exception:
            final_msg += str(overall_exception)
            raise AssertionError(final_msg)
        if exceptions_with_paths:
            for i, one_structure in enumerate(structure):
                final_msg += 'Structure {}:\n{}\n\n'.format(i, one_structure)
            final_msg += 'Exceptions:\n\n'
            for p, exception in exceptions_with_paths:
                final_msg += 'Path: {}\nException: {}\n{}\n\n'.format(
                    p,
                    type(exception).__name__, exception)
            # Drop the final two newlines.
            raise AssertionError(final_msg[:-2])
示例#19
0
 def testFlattenWithStringPaths(self, inputs, expected):
     self.assertEqual(
         nest.flatten_with_joined_string_paths(inputs, separator="/"),
         expected)
示例#20
0
 def testFlattenWithStringPaths(self, inputs, expected):
   self.assertEqual(
       nest.flatten_with_joined_string_paths(inputs, separator="/"),
       expected)
示例#21
0
    def embedding_lookup(self,
                         features: Any,
                         weights: Optional[Any] = None) -> Any:
        """Apply embedding lookup on TPUs using Tensorcore.

    Note that all the sparse and ragged tensors will be converted to dense
    tensors on CPU and then passed to the TPU to do embedding look up. Large
    embedding lookup is not supported by this API, use the TPUEmbedding mid
    level api instead.

    Args:
      features: a nested structure of Tensors, SparseTensors or RaggedTensors.
      weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
        None for no weights. If not None, structure must match that of inputs,
        but entries are allowed to be None.

    Returns:
      A nested structure of Tensors with the same structure as inputs.
    """
        if not self._built:
            self.build()
        nest.assert_same_structure(features, self._feature_config)

        flat_inputs = nest.flatten(features)
        flat_weights = [None] * len(flat_inputs)
        if weights is not None:
            nest.assert_same_structure(features, weights)
            flat_weights = nest.flatten(weights)
        flat_features = nest.flatten_with_joined_string_paths(
            self._feature_config)

        outputs = []
        for inp, weight, (path, feature) in zip(flat_inputs, flat_weights,
                                                flat_features):
            table = self.embedding_tables[feature.table]

            if weight is not None:
                if isinstance(inp, ops.Tensor):
                    raise ValueError(
                        "Weight specified for {}, but input is dense.".format(
                            path))
                elif type(weight) is not type(inp):
                    raise ValueError(
                        "Weight for {} is of type {} but it does not match type of the "
                        "input which is {}.".format(path, type(weight),
                                                    type(inp)))
                elif feature.max_sequence_length > 0:
                    raise ValueError(
                        "Weight specified for {}, but this is a sequence "
                        "feature.".format(path))

            if isinstance(inp, ops.Tensor):
                if feature.max_sequence_length > 0:
                    raise ValueError(
                        "Feature {} is a sequence feature but a dense tensor "
                        "was passed.".format(path))
                outputs.append(embedding_ops.embedding_lookup_v2(table, inp))

            elif isinstance(inp, sparse_tensor.SparseTensor):
                outputs.append(
                    self._embedding_lookup_for_sparse_tensor(
                        inp, weight, table, feature))
            elif isinstance(inp, ragged_tensor.RaggedTensor):
                outputs.append(
                    self._embedding_lookup_for_ragged_tensor(
                        inp, weight, table, feature))
            else:
                raise ValueError(
                    "Input {} is type {}. Tensor, SparseTensor or "
                    "RaggedTensor expected.".format(path, type(inp)))
        return nest.pack_sequence_as(self._feature_config, outputs)
示例#22
0
 def get_problem_loss(self, current_loss_dict: dict,
                      problem: str) -> List[tf.Tensor]:
     flatten_loss_with_path = flatten_with_joined_string_paths(
         current_loss_dict)
     return [v for p, v in flatten_loss_with_path if problem in p]
示例#23
0
def cpu_embedding_lookup(
    inputs: Any,
    weights: Optional[Any],
    tables: Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable],
    feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable]  # pylint:disable=g-bare-generic
) -> Any:
    """Apply standard lookup ops with `tf.tpu.experimental.embedding` configs.

  This function is a utility which allows using the
  `tf.tpu.experimental.embedding` config objects with standard lookup functions.
  This can be used when exporting a model which uses
  `tf.tpu.experimental.embedding.TPUEmbedding` for serving on CPU. In particular
  `tf.tpu.experimental.embedding.TPUEmbedding` only supports lookups on TPUs and
  should not be part of your serving graph.

  Note that TPU specific options (such as `max_sequence_length`) in the
  configuration objects will be ignored.

  In the following example we take a trained model (see the documentation for
  `tf.tpu.experimental.embedding.TPUEmbedding` for the context) and create a
  saved model with a serving function that will perform the embedding lookup and
  pass the results to your model:

  ```python
  model = model_fn(...)
  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
      feature_config=feature_config,
      batch_size=1024,
      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
  checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
  checkpoint.restore(...)

  @tf.function(input_signature=[{'feature_one': tf.TensorSpec(...),
                                 'feature_two': tf.TensorSpec(...),
                                 'feature_three': tf.TensorSpec(...)}])
  def serve_tensors(embedding_features):
    embedded_features = tf.tpu.experimental.embedding.serving_embedding_lookup(
        embedding_features, None, embedding.embedding_tables,
        feature_config)
    return model(embedded_features)

  model.embedding_api = embedding
  tf.saved_model.save(model,
                      export_dir=...,
                      signatures={'serving_default': serve_tensors})

  ```

  NOTE: It's important to assign the embedding API object to a member of your
  model as `tf.saved_model.save` only supports saving variables as one
  `Trackable` object. Since the model's weights are in `model` and the
  embedding table are managed by `embedding`, we assign `embedding` to an
  attribute of `model` so that tf.saved_model.save can find the embedding
  variables.

  NOTE: The same `serve_tensors` function and `tf.saved_model.save` call will
  work directly from training.

  Args:
    inputs: a nested structure of Tensors, SparseTensors or RaggedTensors.
    weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
      None for no weights. If not None, structure must match that of inputs, but
      entries are allowed to be None.
    tables: a dict of mapping TableConfig objects to Variables.
    feature_config: a nested structure of FeatureConfig objects with the same
      structure as inputs.

  Returns:
    A nested structure of Tensors with the same structure as inputs.
  """

    nest.assert_same_structure(inputs, feature_config)

    flat_inputs = nest.flatten(inputs)
    flat_weights = [None] * len(flat_inputs)
    if weights is not None:
        nest.assert_same_structure(inputs, weights)
        flat_weights = nest.flatten(weights)
    flat_features = nest.flatten_with_joined_string_paths(feature_config)

    outputs = []
    for inp, weight, (path, feature) in zip(flat_inputs, flat_weights,
                                            flat_features):
        table = tables[feature.table]

        if weight is not None:
            if isinstance(inp, ops.Tensor):
                raise ValueError(
                    "Weight specified for {}, but input is dense.".format(
                        path))
            elif type(weight) is not type(inp):
                raise ValueError(
                    "Weight for {} is of type {} but it does not match type of the "
                    "input which is {}.".format(path, type(weight), type(inp)))
            elif feature.max_sequence_length > 0:
                raise ValueError(
                    "Weight specified for {}, but this is a sequence "
                    "feature.".format(path))

        if isinstance(inp, ops.Tensor):
            if feature.max_sequence_length > 0:
                raise ValueError(
                    "Feature {} is a sequence feature but a dense tensor "
                    "was passed.".format(path))
            outputs.append(embedding_ops.embedding_lookup_v2(table, inp))

        elif isinstance(inp, sparse_tensor.SparseTensor):
            outputs.append(
                _embedding_lookup_for_sparse_tensor(inp, weight, table,
                                                    feature))
        elif isinstance(inp, ragged_tensor.RaggedTensor):
            outputs.append(
                _embedding_lookup_for_ragged_tensor(inp, weight, table,
                                                    feature))
        else:
            raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
                             "RaggedTensor expected.".format(path, type(inp)))
    return nest.pack_sequence_as(feature_config, outputs)