def testNestFlattenWithTuplePaths(self):
    structure = [[TestCompositeTensor(1, 2, 3)], 100, {
        'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
    }]
    result1 = nest.flatten_with_tuple_paths(structure, expand_composites=True)
    expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3), ((1,), 100),
                 ((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5), ((2, 'y', 1), 6)]
    self.assertEqual(result1, expected1)

    result2 = nest.flatten_with_tuple_paths(structure, expand_composites=False)
    expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1,), 100),
                 ((2, 'y'), TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
    self.assertEqual(result2, expected2)
Exemplo n.º 2
0
def convert_structure_to_signature(structure, arg_names=None):
  """Convert a potentially nested structure to a signature.

  Args:
    structure: Structure to convert, where top level collection is a list or a
      tuple.
    arg_names: Optional list of arguments that has equal number of elements as
      `structure` and is used for naming corresponding TensorSpecs.

  Returns:
    Identical structure that has TensorSpec objects instead of Tensors and
    UknownArgument instead of any unsupported types.
  """
  def encode_arg(arg, path):
    """A representation for this argument, for converting into signatures."""
    if isinstance(arg, ops.Tensor):
      user_specified_name = None
      try:
        user_specified_name = compat.as_str(
            arg.op.get_attr("_user_specified_name"))
      except ValueError:
        pass

      if path and user_specified_name and user_specified_name != path[0]:
        # The user has explicitly named the argument differently than the name
        # of the function argument.
        name = user_specified_name
      else:
        name = "/".join([str(p) for p in path])
      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
    if isinstance(arg, composite_tensor.CompositeTensor):
      # TODO(b/133606651) Do we need to inject arg_name?
      return arg._type_spec  # pylint: disable=protected-access
    if isinstance(arg, (
        int,
        float,
        bool,
        type(None),
        dtypes.DType,
        tensor_spec.TensorSpec,
        type_spec.TypeSpec,
    )):
      return arg
    return UnknownArgument()

  # We are using the flattened paths to name the TensorSpecs. We need an
  # explicit name for them downstream.
  flattened = nest.flatten_with_tuple_paths(structure)
  if arg_names:
    if len(arg_names) != len(structure):
      raise ValueError(
          "Passed in arg_names don't match actual signature (%s)." % arg_names)
    # Replace all top-level names with their actual arg_names. If a path before
    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
    flattened = [
        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
    ]

  mapped = [encode_arg(arg, path) for path, arg in flattened]
  return nest.pack_sequence_as(structure, mapped)
    def testNestFlattenWithTuplePaths(self):
        structure = [[TestCompositeTensor(1, 2, 3)], 100, {
            'y': TestCompositeTensor(TestCompositeTensor(4, 5), 6)
        }]
        result1 = nest.flatten_with_tuple_paths(structure,
                                                expand_composites=True)
        expected1 = [((0, 0, 0), 1), ((0, 0, 1), 2), ((0, 0, 2), 3),
                     ((1, ), 100), ((2, 'y', 0, 0), 4), ((2, 'y', 0, 1), 5),
                     ((2, 'y', 1), 6)]
        self.assertEqual(result1, expected1)

        result2 = nest.flatten_with_tuple_paths(structure,
                                                expand_composites=False)
        expected2 = [((0, 0), TestCompositeTensor(1, 2, 3)), ((1, ), 100),
                     ((2, 'y'),
                      TestCompositeTensor(TestCompositeTensor(4, 5), 6))]
        self.assertEqual(result2, expected2)
Exemplo n.º 4
0
 def testFlatten(self):
     t = data_structures._TupleWrapper(
         (1, data_structures._TupleWrapper((2, ))))
     self.assertEqual([1, 2], nest.flatten(t))
     self.assertEqual(nest.flatten_with_tuple_paths((1, (2, ))),
                      nest.flatten_with_tuple_paths(t))
     self.assertEqual((3, (4, )), nest.pack_sequence_as(t, [3, 4]))
     nt_type = collections.namedtuple("nt", ["x", "y"])
     nt = nt_type(1., 2.)
     wrapped_nt = data_structures._TupleWrapper(nt)
     self.assertEqual(nest.flatten_with_tuple_paths(nt),
                      nest.flatten_with_tuple_paths(wrapped_nt))
     self.assertEqual((
         3,
         4,
     ), nest.pack_sequence_as(wrapped_nt, [3, 4]))
     self.assertEqual(3, nest.pack_sequence_as(wrapped_nt, [3, 4]).x)
Exemplo n.º 5
0
def convert_structure_to_signature(structure, arg_names=None):
  """Convert a potentially nested structure to a signature.

  Args:
    structure: Structure to convert, where top level collection is a list or a
      tuple.
    arg_names: Optional list of arguments that has equal number of elements as
      `structure` and is used for naming corresponding TensorSpecs.

  Returns:
    Identical structure that has TensorSpec objects instead of Tensors and
    UknownArgument instead of any unsupported types.
  """
  structure = composite_tensor.replace_composites_with_components(structure)
  def encode_arg(arg, path):
    """A representation for this argument, for converting into signatures."""
    if isinstance(arg, ops.Tensor):
      user_specified_name = None
      try:
        user_specified_name = compat.as_str(
            arg.op.get_attr("_user_specified_name"))
      except ValueError:
        pass

      if path and user_specified_name and user_specified_name != path[0]:
        # The user has explicitly named the argument differently than the name
        # of the function argument.
        name = user_specified_name
      else:
        name = "/".join([str(p) for p in path])
      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
    if isinstance(arg, (
        int,
        float,
        bool,
        type(None),
        dtypes.DType,
        tensor_spec.TensorSpec,
    )):
      return arg
    return UnknownArgument()

  # We are using the flattened paths to name the TensorSpecs. We need an
  # explicit name for them downstream.
  flattened = nest.flatten_with_tuple_paths(structure, expand_composites=True)
  if arg_names:
    if len(arg_names) != len(structure):
      raise ValueError(
          "Passed in arg_names don't match actual signature (%s)." % arg_names)
    # Replace all top-level names with their actual arg_names. If a path before
    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
    flattened = [
        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
    ]

  mapped = [encode_arg(arg, path) for path, arg in flattened]
  return nest.pack_sequence_as(structure, mapped, expand_composites=True)
Exemplo n.º 6
0
def _maybe_populate_ground_truth(sample_transformations, module_name):
    """Populates the ground truth values.

  This tries to import a module from the `ground_truth` module with the
  `module_name` name.

  Args:
    sample_transformations: A dictionary of Python strings to
      `SampleTransformation`s.
    module_name: Python string. Module name from which to load the ground truth.

  Returns:
    sample_transformations: Same as the input `sample_transformations`, but with
      the ground truth values populated.
  """
    sample_transformations = sample_transformations.copy()
    try:
        # This assumes `model` is a 2 levels deep inside of `inference_gym`. If we
        # move it, we'd change the `-2` to equal the (negative) nesting level.
        root_name_comps = __name__.split('.')[:-2]
        module = importlib.import_module(
            '.'.join(root_name_comps +
                     ['targets', 'ground_truth', module_name]))
    except ImportError:
        return sample_transformations
    for name, sample_transformation in sample_transformations.items():
        flat_mean = []
        flat_sem = []
        flat_std = []
        flat_sestd = []
        for tuple_path, _ in nest.flatten_with_tuple_paths(
                sample_transformation.dtype):
            mean, sem, std, sestd = ground_truth_encoding.load_ground_truth_part(
                module, name, tuple_path)
            flat_mean.append(mean)
            flat_sem.append(sem)
            flat_std.append(std)
            flat_sestd.append(sestd)

        def _pack_or_none(flat_parts):
            if any(part is None for part in flat_parts):
                return None
            else:
                return tf.nest.pack_sequence_as(sample_transformation.dtype,
                                                flat_parts)  # pylint: disable=cell-var-from-loop

        new_transformation = sample_transformation._replace(
            ground_truth_mean=_pack_or_none(flat_mean),
            ground_truth_mean_standard_error=_pack_or_none(flat_sem),
            ground_truth_standard_deviation=_pack_or_none(flat_std),
            ground_truth_standard_deviation_standard_error=_pack_or_none(
                flat_sestd),
        )
        sample_transformations[name] = new_transformation
    return sample_transformations
Exemplo n.º 7
0
def _flatten_module(module,
                    recursive,
                    predicate,
                    attribute_traversal_key,
                    attributes_to_ignore,
                    with_path,
                    module_path=(),
                    seen=None):
  """Implementation of `flatten`."""
  if seen is None:
    seen = set([id(module)])

  module_dict = vars(module)
  submodules = []

  for key in sorted(module_dict, key=attribute_traversal_key):
    if key in attributes_to_ignore:
      continue

    for leaf_path, leaf in nest.flatten_with_tuple_paths(module_dict[key]):
      leaf_path = (key,) + leaf_path

      # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
      if not with_path:
        leaf_id = id(leaf)
        if leaf_id in seen:
          continue
        seen.add(leaf_id)

      if predicate(leaf):
        if with_path:
          yield module_path + leaf_path, leaf
        else:
          yield leaf

      if recursive and _is_module(leaf):
        # Walk direct properties first then recurse.
        submodules.append((module_path + leaf_path, leaf))

  for submodule_path, submodule in submodules:
    subvalues = _flatten_module(
        submodule,
        recursive=recursive,
        predicate=predicate,
        attribute_traversal_key=attribute_traversal_key,
        attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,
        with_path=with_path,
        module_path=submodule_path,
        seen=seen)

    for subvalue in subvalues:
      # Predicate is already tested for these values.
      yield subvalue
Exemplo n.º 8
0
def _flatten_module(module,
                    recursive,
                    predicate,
                    attribute_traversal_key,
                    attributes_to_ignore,
                    with_path,
                    module_path=(),
                    seen=None):
    """Implementation of `flatten`."""
    if seen is None:
        seen = set([id(module)])

    module_dict = vars(module)
    submodules = []

    for key in sorted(module_dict, key=attribute_traversal_key):
        if key in attributes_to_ignore:
            continue

        for leaf_path, leaf in nest.flatten_with_tuple_paths(module_dict[key]):
            leaf_path = (key, ) + leaf_path

            # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
            if not with_path:
                leaf_id = id(leaf)
                if leaf_id in seen:
                    continue
                seen.add(leaf_id)

            if predicate(leaf):
                if with_path:
                    yield module_path + leaf_path, leaf
                else:
                    yield leaf

            if recursive and _is_module(leaf):
                # Walk direct properties first then recurse.
                submodules.append((module_path + leaf_path, leaf))

    for submodule_path, submodule in submodules:
        subvalues = _flatten_module(
            submodule,
            recursive=recursive,
            predicate=predicate,
            attribute_traversal_key=attribute_traversal_key,
            attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,
            with_path=with_path,
            module_path=submodule_path,
            seen=seen)

        for subvalue in subvalues:
            # Predicate is already tested for these values.
            yield subvalue
Exemplo n.º 9
0
def _flatten_non_variable_composites_with_tuple_path(structure, path_prefix=()):
  """Flattens composite tensors with tuple path expect variables."""
  for path, child in nest.flatten_with_tuple_paths(structure):
    if (isinstance(child, composite_tensor.CompositeTensor) and
        not _is_variable(child)):
      # pylint: disable=protected-access
      spec = child._type_spec
      yield from _flatten_non_variable_composites_with_tuple_path(
          spec._to_components(child),
          path_prefix + path + (spec.value_type.__name__,))
      # pylint: enable=protected-access
    else:
      yield path_prefix + path, child
Exemplo n.º 10
0
def _maybe_populate_ground_truth(sample_transformations, module_name):
    """Populates the ground truth values.

  This tries to import a module from the `ground_truth` module with the
  `module_name` name.

  Args:
    sample_transformations: A dictionary of Python strings to
      `SampleTransformation`s.
    module_name: Python string. Module name from which to load the ground truth.

  Returns:
    sample_transformations: Same as the input `sample_transformations`, but with
      the ground truth values populated.
  """
    sample_transformations = sample_transformations.copy()
    try:
        module = importlib.import_module(
            _GROUND_TRUTH_MODULE_PATTERN.format(module_name))
    except ImportError:
        return sample_transformations
    for name, sample_transformation in sample_transformations.items():
        flat_mean = []
        flat_sem = []
        flat_std = []
        flat_sestd = []
        for tuple_path, _ in nest.flatten_with_tuple_paths(
                sample_transformation.dtype):
            mean, sem, std, sestd = ground_truth_encoding.load_ground_truth_part(
                module, name, tuple_path)
            flat_mean.append(mean)
            flat_sem.append(sem)
            flat_std.append(std)
            flat_sestd.append(sestd)

        def _pack_or_none(flat_parts):
            if any(part is None for part in flat_parts):
                return None
            else:
                return tf.nest.pack_sequence_as(sample_transformation.dtype,
                                                flat_parts)  # pylint: disable=cell-var-from-loop

        new_transformation = sample_transformation._replace(
            ground_truth_mean=_pack_or_none(flat_mean),
            ground_truth_mean_standard_error=_pack_or_none(flat_sem),
            ground_truth_standard_deviation=_pack_or_none(flat_std),
            ground_truth_standard_deviation_standard_error=_pack_or_none(
                flat_sestd),
        )
        sample_transformations[name] = new_transformation
    return sample_transformations
Exemplo n.º 11
0
def hash_structure(struct):
    """Hashes a possibly mutable structure of tensors."""
    def make_hashable(obj):
        if isinstance(obj, (HashableWeakRef, WeakStructRef)):
            return obj
        elif isinstance(obj, np.ndarray):
            return str(obj.__array_interface__) + str(id(obj))
        else:
            return _IdentityHash(obj)

    # Flatten structs into a tuple of tuples to make mutable containers hashable.
    flat_pairs = nest.flatten_with_tuple_paths(struct)
    hashable = ((k, make_hashable(v)) for k, v in flat_pairs)
    return hash(tuple(hashable))
Exemplo n.º 12
0
def convert_structure_to_signature(structure, arg_names=None):
  """Convert a potentially nested structure to a signature.

  Args:
    structure: Structure to convert, where top level collection is a list or a
      tuple.
    arg_names: Optional list of arguments that has equal number of elements as
      `structure` and is used for naming corresponding TensorSpecs.

  Returns:
    Identical structure that has TensorSpec objects instead of Tensors and
    UknownArgument instead of any unsupported types.
  """

  def encode_arg(arg, name=None):
    """A representation for this argument, for converting into signatures."""
    if isinstance(arg, ops.Tensor):
      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
    if isinstance(arg, (
        int,
        float,
        bool,
        type(None),
        dtypes.DType,
        tensor_spec.TensorSpec,
    )):
      return arg
    return UnknownArgument()

  # We are using the flattened paths to name the TensorSpecs. We need an
  # explicit name for them downstream.
  flattened = nest.flatten_with_tuple_paths(structure)
  if arg_names:
    if len(arg_names) != len(structure):
      raise ValueError(
          "Passed in arg_names don't match actual signature (%s)." % arg_names)
    # Replace all top-level names with their actual arg_names. If a path before
    # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
    flattened = [
        ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
    ]

  mapped = [
      encode_arg(arg, "/".join([str(p) for p in path]))
      for path, arg in flattened
  ]
  return nest.pack_sequence_as(structure, mapped)
Exemplo n.º 13
0
  def testNestFlatten(self, structure, expected, paths, expand_composites=True):
    result = nest.flatten(structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)

    result_with_paths = nest.flatten_with_tuple_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))

    string_paths = ['/'.join(str(p) for p in path) for path in paths]  # pylint: disable=g-complex-comprehension
    result_with_string_paths = nest.flatten_with_joined_string_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_string_paths,
                     list(zip(string_paths, expected)))

    flat_paths_result = list(
        nest.yield_flat_paths(structure, expand_composites=expand_composites))
    self.assertEqual(flat_paths_result, paths)
Exemplo n.º 14
0
  def testNestFlatten(self, structure, expected, paths, expand_composites=True):
    result = nest.flatten(structure, expand_composites=expand_composites)
    self.assertEqual(result, expected)

    result_with_paths = nest.flatten_with_tuple_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_paths, list(zip(paths, expected)))

    string_paths = ['/'.join(str(p) for p in path) for path in paths]  # pylint: disable=g-complex-comprehension
    result_with_string_paths = nest.flatten_with_joined_string_paths(
        structure, expand_composites=expand_composites)
    self.assertEqual(result_with_string_paths,
                     list(zip(string_paths, expected)))

    flat_paths_result = list(
        nest.yield_flat_paths(structure, expand_composites=expand_composites))
    self.assertEqual(flat_paths_result, paths)
Exemplo n.º 15
0
def hashable_structure(struct):
    """Hashes a possibly mutable structure of `Tensor`s."""
    def make_hashable(obj):
        if isinstance(obj,
                      (HashableWeakRef, WeakStructRef, ObjectIdentityWrapper)):
            return obj
        elif isinstance(obj, np.ndarray):
            obj_hash = hash(str(obj.__array_interface__) + str(id(obj)))
            return ObjectIdentityWrapper(obj, object_hash=obj_hash)
        elif tf.is_tensor(obj):
            return ObjectIdentityWrapper(obj)
        try:
            hash(obj)
            return obj
        except TypeError:
            return ObjectIdentityWrapper(obj)

    # Flatten structs into a tuple of tuples to make mutable containers hashable.
    return tuple((k, make_hashable(v))
                 for k, v in nest.flatten_with_tuple_paths(struct))
  def batch_shape(self):
    """Shape of a single sample from a single event index as a `TensorShape`.

    May be partially defined or unknown.

    The batch dimensions are indexes into independent, non-identical
    parameterizations of this distribution.

    Returns:
      batch_shape: `TensorShape`, possibly unknown.
    """
    batch_shape = self._batch_shape()
    # See comment in `batch_shape_tensor()` on structured batch shapes. If
    # `_batch_shape()` is a `tf.TensorShape` instance or a flat list/tuple that
    # does not contain `tf.TensorShape`s, we infer that it is not structured.
    if (isinstance(batch_shape, tf.TensorShape)
        or all(len(path) == 1 and not isinstance(s, tf.TensorShape)
               for path, s in nest.flatten_with_tuple_paths(batch_shape))):
      return tf.TensorShape(batch_shape)
    return nest.map_structure_up_to(
        self.dtype, tf.TensorShape, batch_shape, check_types=False)
Exemplo n.º 17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    tf.enable_v2_behavior()

    stan_model = getattr(targets, FLAGS.target)()

    with stan_model.sample_fn(sampling_iters=FLAGS.stan_samples,
                              chains=FLAGS.stan_chains,
                              show_progress=True) as mcmc_output:
        summary = mcmc_output.summary()
        if FLAGS.print_summary:
            pd.set_option('display.max_rows', sys.maxsize)
            pd.set_option('display.max_columns', sys.maxsize)
            print(mcmc_output.diagnose())
            print(summary)

        array_strs = []
        for name, fn in sorted(stan_model.extract_fns.items()):
            transformed_samples = []

            # We handle one chain at a time to reduce memory usage.
            chain_means = []
            chain_stds = []
            chain_esss = []
            for chain_id in range(FLAGS.stan_chains):
                # TODO(https://github.com/stan-dev/cmdstanpy/issues/218): This step is
                # very slow and wastes memory. Consider reading the CSV files ourselves.

                # sample shape is [num_samples, num_chains, num_columns]
                chain = mcmc_output.sample[:, chain_id, :]
                dataframe = pd.DataFrame(chain,
                                         columns=mcmc_output.column_names)

                transformed_samples = fn(dataframe)

                # We reduce over the samples dimension. Transformations can return
                # nested outputs.
                mean = tf.nest.map_structure(lambda s: s.mean(0),
                                             transformed_samples)
                std = tf.nest.map_structure(lambda s: s.std(0),
                                            transformed_samples)
                ess = tf.nest.map_structure(get_ess, transformed_samples)

                chain_means.append(mean)
                chain_stds.append(std)
                chain_esss.append(ess)

            # Now we reduce across chains.
            ess = tf.nest.map_structure(lambda *s: np.sum(s, 0), *chain_esss)
            mean = tf.nest.map_structure(lambda *s: np.mean(s, 0),
                                         *chain_means)
            sem = tf.nest.map_structure(lambda std, ess: std / np.sqrt(ess),
                                        std, ess)
            std = tf.nest.map_structure(lambda *s: np.mean(s, 0), *chain_stds)

            for (tuple_path, mean_part), sem_part, std_part in zip(
                    nest.flatten_with_tuple_paths(mean), tf.nest.flatten(sem),
                    tf.nest.flatten(std)):
                array_strs.extend(
                    ground_truth_encoding.save_ground_truth_part(
                        name=name,
                        tuple_path=tuple_path,
                        mean=mean_part,
                        sem=sem_part,
                        std=std_part,
                        sestd=None,
                    ))

    argv_str = '\n'.join(['  {} \\'.format(arg) for arg in sys.argv[1:]])
    command_str = (
        """bazel run //tools/inference_gym_ground_truth:get_ground_truth -- \
{argv_str}""".format(argv_str=argv_str))

    file_str = ground_truth_encoding.get_ground_truth_module_source(
        target_name=FLAGS.target,
        command_str=command_str,
        array_strs=array_strs)

    if FLAGS.output_directory is None:
        file_basedir = os.path.dirname(os.path.realpath(__file__))
        output_directory = os.path.join(
            file_basedir, '../../tensorflow_probability/python/experimental/'
            'inference_gym/targets/ground_truth')
    else:
        output_directory = FLAGS.output_directory
    file_path = os.path.join(output_directory, '{}.py'.format(FLAGS.target))
    print('Writing ground truth values to: {}'.format(file_path))
    with open(file_path, 'w') as f:
        f.write(file_str)
Exemplo n.º 18
0
def _flatten_module(module,
                    recursive,
                    predicate,
                    attribute_traversal_key,
                    attributes_to_ignore,
                    with_path,
                    expand_composites,
                    module_path=(),
                    seen=None):
    """Implementation of `flatten`."""
    if seen is None:
        seen = set([id(module)])

    module_dict = vars(module)
    submodules = []

    for key in sorted(module_dict, key=attribute_traversal_key):
        if key in attributes_to_ignore:
            continue

        prop = module_dict[key]
        try:
            leaves = nest.flatten_with_tuple_paths(
                prop, expand_composites=expand_composites)
        except Exception as cause:  # pylint: disable=broad-except
            six.raise_from(
                ValueError("Error processing property {!r} of {!r}".format(
                    key, prop)), cause)

        for leaf_path, leaf in leaves:
            leaf_path = (key, ) + leaf_path

            # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
            if not with_path:
                leaf_id = id(leaf)
                if leaf_id in seen:
                    continue
                seen.add(leaf_id)

            if predicate(leaf):
                if with_path:
                    yield module_path + leaf_path, leaf
                else:
                    yield leaf

            if recursive and _is_module(leaf):
                # Walk direct properties first then recurse.
                submodules.append((module_path + leaf_path, leaf))

    for submodule_path, submodule in submodules:
        subvalues = _flatten_module(
            submodule,
            recursive=recursive,
            predicate=predicate,
            attribute_traversal_key=attribute_traversal_key,
            attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,  # pylint: disable=protected-access
            with_path=with_path,
            expand_composites=expand_composites,
            module_path=submodule_path,
            seen=seen)

        for subvalue in subvalues:
            # Predicate is already tested for these values.
            yield subvalue
Exemplo n.º 19
0
 def testFlattenWithTuplePaths(self, inputs, expected):
   self.assertEqual(nest.flatten_with_tuple_paths(inputs), expected)
Exemplo n.º 20
0
def _flatten_module(module,
                    recursive,
                    predicate,
                    attribute_traversal_key,
                    attributes_to_ignore,
                    with_path,
                    expand_composites,
                    module_path=(),
                    seen=None,
                    recursion_stack=None):
  """Implementation of `flatten`.

  Args:
    module: Current module to process.
    recursive: Whether to recurse into child modules or not.
    predicate: (Optional) If set then only values matching predicate are
      yielded. A value of `None` (the default) means no items will be
      filtered.
    attribute_traversal_key: (Optional) Method to rekey object attributes
      before they are sorted. Contract is the same as `key` argument to
      builtin `sorted` and only applies to object properties.
    attributes_to_ignore: object attributes to ignored.
    with_path: (Optional) Whether to include the path to the object as well
      as the object itself. If `with_path` is `True` then leaves will not be
      de-duplicated (e.g. if the same leaf instance is reachable via multiple
      modules then it will be yielded multiple times with different paths).
    expand_composites: If true, then composite tensors are expanded into their
      component tensors.
    module_path: The path to the current module as a tuple.
    seen: A set containing all leaf IDs seen so far.
    recursion_stack: A list containing all module IDs associated with the
      current call stack.

  Yields:
    Matched leaves with the optional corresponding paths of the current module
    and optionally all its submodules.
  """
  module_id = id(module)
  if seen is None:
    seen = set([module_id])

  module_dict = vars(module)
  submodules = []

  if recursion_stack is None:
    recursion_stack = []

  # When calling `_flatten_module` with `with_path=False`, the global lookup
  # table `seen` guarantees the uniqueness of the matched objects.
  # In the case of `with_path=True`, there might be multiple paths associated
  # with the same predicate, so we don't stop traversing according to `seen`
  # to make sure all these paths are returned.
  # When there are cycles connecting submodules, we break cycles by avoiding
  # following back edges (links pointing to a node in `recursion_stack`).
  if module_id in recursion_stack:
    recursive = False

  for key in sorted(module_dict, key=attribute_traversal_key):
    if key in attributes_to_ignore:
      continue

    prop = module_dict[key]
    try:
      if expand_composites:
        leaves = list(_flatten_non_variable_composites_with_tuple_path(prop))
      else:
        leaves = nest.flatten_with_tuple_paths(prop)
    except Exception as cause:  # pylint: disable=broad-except
      raise ValueError("Error processing property {!r} of {!r}".format(
          key, prop)) from cause

    for leaf_path, leaf in leaves:
      leaf_path = (key,) + leaf_path

      if not with_path:
        leaf_id = id(leaf)
        if leaf_id in seen:
          continue
        seen.add(leaf_id)

      if predicate(leaf):
        if with_path:
          yield module_path + leaf_path, leaf
        else:
          yield leaf

      if recursive and _is_module(leaf):
        # Walk direct properties first then recurse.
        submodules.append((module_path + leaf_path, leaf))

  recursion_stack.append(module_id)

  for submodule_path, submodule in submodules:
    subvalues = _flatten_module(
        submodule,
        recursive=recursive,
        predicate=predicate,
        attribute_traversal_key=attribute_traversal_key,
        attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,  # pylint: disable=protected-access
        with_path=with_path,
        expand_composites=expand_composites,
        module_path=submodule_path,
        seen=seen,
        recursion_stack=recursion_stack)

    for subvalue in subvalues:
      # Predicate is already tested for these values.
      yield subvalue

  recursion_stack.pop()
Exemplo n.º 21
0
 def testFlattenWithTuplePaths(self, inputs, expected):
     self.assertEqual(nest.flatten_with_tuple_paths(inputs), expected)