Esempio n. 1
0
  def _build(self, *args, **kwargs):
    """Connects the BatchApply module into the graph.

    Args:
      *args: a Tensor or a nested list or dictionary of Tensors. The input
          tensors will have their first dimensions merged, then an op or a
          module will be called on the input. The first dimension of the output
          tensor(s) will be split again based on the leading dimensions of the
          first input tensor.
      **kwargs: Dictionary of named arguments; used in the same way as `*args`.

    Returns:
      A Tensor or nested list or dictionary of Tensors as a result of applying
      the process above. ("None" return values are also supported.)
    """
    flattened = nest.flatten_iterable([args, kwargs])
    merged_flattened = [
        merge_leading_dims(inp, self._n_dims) if inp is not None else None
        for inp in flattened]
    merged_args, merged_kwargs = nest.pack_iterable_as([args, kwargs],
                                                       merged_flattened)

    results = self._module(*merged_args, **merged_kwargs)

    # Unmerging takes the sizes of the leading dimensions from an input example
    # with equal shape for the leading `n_dims` dimensions. Typically this is
    # the first input.
    example_input = tf.convert_to_tensor(flattened[self._input_example_index])
    def _split_to_original_leading_dims(result):
      if result is None:
        return None
      else:
        return split_leading_dim(result, example_input, self._n_dims)

    flat_results = nest.flatten_iterable(results)
    flat_unmerged_results = [_split_to_original_leading_dims(result)
                             for result in flat_results]
    return nest.pack_iterable_as(results, flat_unmerged_results)
Esempio n. 2
0
  def testFlattenAndPackIterable(self):
    # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
    named_tuple = collections.namedtuple("A", ("b", "c"))
    mess = [
        "z",
        named_tuple(3, 4),
        {
            "c": [
                1,
                collections.OrderedDict([
                    ("b", 3),
                    ("a", 2),
                ]),
            ],
            "b": 5
        },
        17
    ]

    flattened = nest.flatten_iterable(mess)
    self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])

    structure_of_mess = [
        14,
        named_tuple("a", True),
        {
            "c": [
                0,
                collections.OrderedDict([
                    ("b", 9),
                    ("a", 8),
                ]),
            ],
            "b": 3
        },
        "hi everybody",
    ]

    unflattened = nest.pack_iterable_as(structure_of_mess, flattened)
    self.assertEqual(unflattened, mess)
Esempio n. 3
0
  def testFlattenAndPackIterable(self):
    # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
    named_tuple = collections.namedtuple("A", ("b", "c"))
    mess = [
        "z",
        named_tuple(3, 4),
        {
            "c": [
                1,
                collections.OrderedDict([
                    ("b", 3),
                    ("a", 2),
                ]),
            ],
            "b": 5
        },
        17
    ]

    flattened = nest.flatten_iterable(mess)
    self.assertEqual(flattened, ["z", 3, 4, 5, 1, 3, 2, 17])

    structure_of_mess = [
        14,
        named_tuple("a", True),
        {
            "c": [
                0,
                collections.OrderedDict([
                    ("b", 9),
                    ("a", 8),
                ]),
            ],
            "b": 3
        },
        "hi everybody",
    ]

    unflattened = nest.pack_iterable_as(structure_of_mess, flattened)
    self.assertEqual(unflattened, mess)
Esempio n. 4
0
 def testPackIterableAs_wrongLengthsError(self):
   with self.assertRaisesRegexp(
       ValueError,
       "Structure had 2 elements, but flat_sequence had 3 elements."):
     nest.pack_iterable_as(["hello", "world"],
                           ["and", "goodbye", "again"])
Esempio n. 5
0
 def testPackIterableAs_scalarStructureError(self):
   with self.assertRaisesRegexp(
       ValueError, r"Structure is a scalar but len\(flat_sequence\) == 2 > 1"):
     nest.pack_iterable_as("hi", ["bye", "twice"])
Esempio n. 6
0
 def testPackIterableAs_notIterableError(self):
   with self.assertRaisesRegexp(TypeError,
                                "flat_sequence must be a sequence"):
     nest.pack_iterable_as("hi", "bye")
Esempio n. 7
0
 def testFlatternIterable_scalarStructure(self):
   # Tests can call flatten_iterable with single "scalar" object.
   structure = "hello"
   flattened = nest.flatten_iterable(structure)
   unflattened = nest.pack_iterable_as("goodbye", flattened)
   self.assertEqual(structure, unflattened)