def _build(self, *args, **kwargs): """Connects the BatchApply module into the graph. Args: *args: a Tensor or a nested list or dictionary of Tensors. The input tensors will have their first dimensions merged, then an op or a module will be called on the input. The first dimension of the output tensor(s) will be split again based on the leading dimensions of the first input tensor. **kwargs: Dictionary of named arguments; used in the same way as `*args`. Returns: A Tensor or nested list or dictionary of Tensors as a result of applying the process above. ("None" return values are also supported.) """ flattened = nest.flatten_iterable([args, kwargs]) merged_flattened = [ merge_leading_dims(inp, self._n_dims) if inp is not None else None for inp in flattened ] merged_args, merged_kwargs = nest.pack_iterable_as([args, kwargs], merged_flattened) results = self._module(*merged_args, **merged_kwargs) # Unmerging takes the sizes of the leading dimensions from an input example # with equal shape for the leading `n_dims` dimensions. Typically this is # the first input. example_input = tf.convert_to_tensor( flattened[self._input_example_index]) def _split_to_original_leading_dims(result): if result is None: return None else: return split_leading_dim(result, example_input, self._n_dims) flat_results = nest.flatten_iterable(results) flat_unmerged_results = [ _split_to_original_leading_dims(result) for result in flat_results ] return nest.pack_iterable_as(results, flat_unmerged_results)
def _build(self, *args, **kwargs): """Connects the BatchApply module into the graph. Args: *args: a Tensor or a nested list or dictionary of Tensors. The input tensors will have their first dimensions merged, then an op or a module will be called on the input. The first dimension of the output tensor(s) will be split again based on the leading dimensions of the first input tensor. **kwargs: Dictionary of named arguments; used in the same way as `*args`. Returns: A Tensor or nested list or dictionary of Tensors as a result of applying the process above. ("None" return values are also supported.) """ flattened = nest.flatten_iterable([args, kwargs]) merged_flattened = [ merge_leading_dims(inp, self._n_dims) if inp is not None else None for inp in flattened] merged_args, merged_kwargs = nest.pack_iterable_as([args, kwargs], merged_flattened) results = self._module(*merged_args, **merged_kwargs) # Unmerging takes the sizes of the leading dimensions from an input example # with equal shape for the leading `n_dims` dimensions. Typically this is # the first input. example_input = tf.convert_to_tensor(flattened[self._input_example_index]) def _split_to_original_leading_dims(result): if result is None: return None else: return split_leading_dim(result, example_input, self._n_dims) flat_results = nest.flatten_iterable(results) flat_unmerged_results = [_split_to_original_leading_dims(result) for result in flat_results] return nest.pack_iterable_as(results, flat_unmerged_results)
def testFlattenAndPackIterable(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. named_tuple = collections.namedtuple("A", ("b", "c")) mess = [ "z", named_tuple(3, 4), { "c": [ 1, collections.OrderedDict([ ("b", 3), ("a", 2), ]), ], "b": 5 }, 17 ] flattened = nest.flatten_iterable(mess) self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) structure_of_mess = [ 14, named_tuple("a", True), { "c": [ 0, collections.OrderedDict([ ("b", 9), ("a", 8), ]), ], "b": 3 }, "hi everybody", ] unflattened = nest.pack_iterable_as(structure_of_mess, flattened) self.assertEqual(unflattened, mess)
def testFlattenAndPackIterable(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. named_tuple = collections.namedtuple("A", ("b", "c")) mess = [ "z", named_tuple(3, 4), { "c": [ 1, collections.OrderedDict([ ("b", 3), ("a", 2), ]), ], "b": 5 }, 17 ] flattened = nest.flatten_iterable(mess) self.assertEqual(flattened, ["z", 3, 4, 5, 1, 3, 2, 17]) structure_of_mess = [ 14, named_tuple("a", True), { "c": [ 0, collections.OrderedDict([ ("b", 9), ("a", 8), ]), ], "b": 3 }, "hi everybody", ] unflattened = nest.pack_iterable_as(structure_of_mess, flattened) self.assertEqual(unflattened, mess)
def testFlatternIterable_scalarStructure(self): # Tests can call flatten_iterable with single "scalar" object. structure = "hello" flattened = nest.flatten_iterable(structure) unflattened = nest.pack_iterable_as("goodbye", flattened) self.assertEqual(structure, unflattened)
def testFlattenIterable_stringIsNotFlattened(self): structure = "lots of letters" flattened = nest.flatten_iterable(structure) self.assertEqual(len(flattened), 1)
def testFlattenIterable_numpyIsNotFlattened(self): structure = np.array([1, 2, 3]) flattened = nest.flatten_iterable(structure) self.assertEqual(len(flattened), 1)