예제 #1
0
    def _build(self, *args):
        """Connects the BatchApply module into the graph.

    Args:
      *args: a Tensor or a nested list of Tensors. The input tensors will
          have their first dimensions merged, then an op or a module will be
          called on the input. The first dimension of the output will be
          split again based on the leading dimensions of the first input
          tensor.

    Returns:
      A Tensor resulting of applying the process above.
    """
        # Merge leading dimensions for each input Tensor, then apply inner module.
        merged = nest.map(lambda inp: merge_leading_dims(inp, self._n_dims),
                          args)
        results = self._module(*merged)

        # Unmerging takes the sizes of the leading dimensions from an input example
        # with equal shape for the leading `n_dims` dimensions. Typically this is
        # the first input.
        example_input = tf.convert_to_tensor(
            nest.flatten(args)[self._input_example_index])

        def _split_to_original_leading_dims(result):
            return split_leading_dim(result, example_input, self._n_dims)

        return nest.map(_split_to_original_leading_dims, results)
예제 #2
0
  def _build(self, inputs):
    """Connects the MergeDims module into the graph.

    Args:
      inputs: Tensor or a nested list of Tensors to merge. Its rank must be
          greater than or equal to `start` + `size`.

    Returns:
      The merged Tensor or a nested list of merged Tensors.

    Raises:
      ValueError: If any of the `inputs` tensors has insufficient rank.
    """
    if nest.is_sequence(inputs):
      merged_tensors = [self._merge(tensor) for tensor in nest.flatten(inputs)]
      return nest.pack_sequence_as(inputs, merged_tensors)

    # inputs is a single tf.Tensor
    return self._merge(inputs)
예제 #3
0
  def _build(self, inputs):
    """Connects the MergeDims module into the graph.

    Args:
      inputs: Tensor or a nested list of Tensors to merge. Its rank must be
          greater than or equal to `start` + `size`.

    Returns:
      The merged Tensor or a nested list of merged Tensors.

    Raises:
      ValueError: If any of the `inputs` tensors has insufficient rank.
    """
    if nest.is_sequence(inputs):
      merged_tensors = [self._merge(tensor) for tensor in nest.flatten(inputs)]
      return nest.pack_sequence_as(inputs, merged_tensors)

    # inputs is a single tf.Tensor
    return self._merge(inputs)
예제 #4
0
    def testFlattenUpTo(self):
        # Normal application (Example 1).
        input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
        shallow_tree = [[True, True], [False, True]]
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree,
                         [[2, 2], [3, 3], [4, 9], [5, 5]])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        # Normal application (Example 2).
        input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
        shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        input_tree_flattened = nest.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ["input_tree_0", "input_tree_1"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = [0]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = [0, 1]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ["shallow_tree"]
        with self.assertRaises(TypeError) as cm:
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(
            str(cm.exception),
            "If shallow structure is a sequence, input must also be "
            "a sequence. Input has type: <{} 'str'>.".format(typekw))
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = "input_tree"
        shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
        with self.assertRaises(TypeError) as cm:
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(
            str(cm.exception),
            "If shallow structure is a sequence, input must also be "
            "a sequence. Input has type: <{} 'str'>.".format(typekw))
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = [9]
        with self.assertRaises(TypeError) as cm:
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(
            str(cm.exception),
            "If shallow structure is a sequence, input must also be "
            "a sequence. Input has type: <{} 'int'>.".format(typekw))
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = 0
        shallow_tree = [9, 8]
        with self.assertRaises(TypeError) as cm:
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(
            str(cm.exception),
            "If shallow structure is a sequence, input must also be "
            "a sequence. Input has type: <{} 'int'>.".format(typekw))
        self.assertEqual(flattened_shallow_tree, shallow_tree)