Ejemplo n.º 1
0
    def _build(self, *args):
        """Connects the BatchApply module into the graph.

    Args:
      *args: a Tensor or a nested list of Tensors. The input tensors will
          have their first dimensions merged, then an op or a module will be
          called on the input. The first dimension of the output will be
          split again based on the leading dimensions of the first input
          tensor.

    Returns:
      A Tensor resulting of applying the process above.
    """
        # Merge leading dimensions for each input Tensor, then apply inner module.
        merged = nest.map(lambda inp: merge_leading_dims(inp, self._n_dims),
                          args)
        results = self._module(*merged)

        # Unmerging takes the sizes of the leading dimensions from an input example
        # with equal shape for the leading `n_dims` dimensions. Typically this is
        # the first input.
        example_input = tf.convert_to_tensor(
            nest.flatten(args)[self._input_example_index])

        def _split_to_original_leading_dims(result):
            return split_leading_dim(result, example_input, self._n_dims)

        return nest.map(_split_to_original_leading_dims, results)
Ejemplo n.º 2
0
  def testEmptySequences(self):
    f = lambda x: x + 1
    empty_nt = collections.namedtuple("empty_nt", "")

    self.assertEqual((), nest.map(f, ()))
    self.assertEqual([], nest.map(f, []))
    self.assertEqual(empty_nt(), nest.map(f, empty_nt()))

    # This is checking actual equality of types, empty list != empty tuple
    self.assertNotEqual((), nest.map(f, []))
Ejemplo n.º 3
0
    def testEmptySequences(self):
        f = lambda x: x + 1
        empty_nt = collections.namedtuple("empty_nt", "")

        self.assertEqual((), nest.map(f, ()))
        self.assertEqual([], nest.map(f, []))
        self.assertEqual(empty_nt(), nest.map(f, empty_nt()))

        # This is checking actual equality of types, empty list != empty tuple
        self.assertNotEqual((), nest.map(f, []))
Ejemplo n.º 4
0
 def testStringRepeat(self):
     ab_tuple = collections.namedtuple("ab_tuple", "a, b")
     inp_a = ab_tuple(a="foo", b=("bar", "baz"))
     inp_b = ab_tuple(a=2, b=(1, 3))
     out = nest.map(lambda string, repeats: string * repeats, inp_a, inp_b)
     self.assertEqual(out.a, "foofoo")
     self.assertEqual(out.b[0], "bar")
     self.assertEqual(out.b[1], "bazbazbaz")
Ejemplo n.º 5
0
 def testStringRepeat(self):
   ab_tuple = collections.namedtuple("ab_tuple", "a, b")
   inp_a = ab_tuple(a="foo", b=("bar", "baz"))
   inp_b = ab_tuple(a=2, b=(1, 3))
   out = nest.map(lambda string, repeats: string * repeats, inp_a, inp_b)
   self.assertEqual(out.a, "foofoo")
   self.assertEqual(out.b[0], "bar")
   self.assertEqual(out.b[1], "bazbazbaz")
Ejemplo n.º 6
0
def _nested_unary_mul(nested_a, p):
  """Multiply `Tensors` in arbitrarily nested `Tensor` `nested_a` with `p`."""
  def mul_with_broadcast(tensor):
    ndims = tensor.shape.ndims
    if ndims != 2:
      p_reshaped = tf.reshape(p, [-1] + [1] * (ndims - 1))
      return p_reshaped * tensor
    else:
      return p * tensor
  return nest.map(mul_with_broadcast, nested_a)
Ejemplo n.º 7
0
    def testMapSingleCollection(self):
        ab_tuple = collections.namedtuple("ab_tuple", "a, b")
        nt = ab_tuple(a=("something", "something_else"), b="yet another thing")
        rev_nt = nest.map(lambda x: x[::-1], nt)

        # Check the output is the correct structure, and all strings are reversed.
        nest.assert_same_structure(nt, rev_nt)
        self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
        self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
        self.assertEqual(nt.b[::-1], rev_nt.b)
Ejemplo n.º 8
0
  def testMapSingleCollection(self):
    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    nt = ab_tuple(a=("something", "something_else"),
                  b="yet another thing")
    rev_nt = nest.map(lambda x: x[::-1], nt)

    # Check the output is the correct structure, and all strings are reversed.
    nest.assert_same_structure(nt, rev_nt)
    self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
    self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
    self.assertEqual(nt.b[::-1], rev_nt.b)
Ejemplo n.º 9
0
  def testMapOverTwoTuples(self):
    inp_a = (tf.placeholder(tf.float32, shape=[3, 4]),
             tf.placeholder(tf.float32, shape=[3, 7]))
    inp_b = (tf.placeholder(tf.float32, shape=[3, 4]),
             tf.placeholder(tf.float32, shape=[3, 7]))

    output = nest.map(lambda x1, x2: x1 + x2, inp_a, inp_b)

    nest.assert_same_structure(output, inp_a)
    self.assertShapeEqual(np.zeros((3, 4)), output[0])
    self.assertShapeEqual(np.zeros((3, 7)), output[1])

    feed_dict = {
        inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
        inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
    }

    with self.test_session() as sess:
      output_np = sess.run(output, feed_dict=feed_dict)
    self.assertAllClose(output_np[0],
                        feed_dict[inp_a][0] + feed_dict[inp_b][0])
    self.assertAllClose(output_np[1],
                        feed_dict[inp_a][1] + feed_dict[inp_b][1])
Ejemplo n.º 10
0
    def testMapOverTwoTuples(self):
        inp_a = (tf.placeholder(tf.float32, shape=[3, 4]),
                 tf.placeholder(tf.float32, shape=[3, 7]))
        inp_b = (tf.placeholder(tf.float32, shape=[3, 4]),
                 tf.placeholder(tf.float32, shape=[3, 7]))

        output = nest.map(lambda x1, x2: x1 + x2, inp_a, inp_b)

        nest.assert_same_structure(output, inp_a)
        self.assertShapeEqual(np.zeros((3, 4)), output[0])
        self.assertShapeEqual(np.zeros((3, 7)), output[1])

        feed_dict = {
            inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
            inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
        }

        with self.test_session() as sess:
            output_np = sess.run(output, feed_dict=feed_dict)
        self.assertAllClose(output_np[0],
                            feed_dict[inp_a][0] + feed_dict[inp_b][0])
        self.assertAllClose(output_np[1],
                            feed_dict[inp_a][1] + feed_dict[inp_b][1])
Ejemplo n.º 11
0
 def testStructureMustBeSame(self):
   inp_a = (3, 4)
   inp_b = (42, 42, 44)
   err = "The two structures don't have the same number of elements."
   with self.assertRaisesRegexp(ValueError, err):
     nest.map(lambda a, b: a + b, inp_a, inp_b)
Ejemplo n.º 12
0
def _nested_add(nested_a, nested_b):
  """Add two arbitrarily nested `Tensors`."""
  return nest.map(lambda a, b: a + b, nested_a, nested_b)
Ejemplo n.º 13
0
def _nested_zeros_like(nested_a):
  return nest.map(tf.zeros_like, nested_a)
Ejemplo n.º 14
0
 def testMultiNest(self):
   inp_a = (3, (4, 5))
   inp_b = (42, (42, 44))
   output = nest.map(lambda a, b: a + b, inp_a, inp_b)
   self.assertEqual((45, (46, 49)), output)
Ejemplo n.º 15
0
def _nested_add(nested_a, nested_b):
  """Add two arbitrarily nested `Tensors`."""
  return nest.map(lambda a, b: a + b, nested_a, nested_b)
Ejemplo n.º 16
0
def _nested_unary_mul(nested_a, p):
  """Multiply `Tensors` in arbitrarily nested `Tensor` `nested_a` with `p`."""
  return nest.map(lambda a: p * a, nested_a)
Ejemplo n.º 17
0
 def testNoSequences(self):
   with self.assertRaisesRegexp(ValueError,
                                "Must provide at least one structure"):
     nest.map(lambda x: x)
Ejemplo n.º 18
0
 def testNoSequences(self):
     with self.assertRaisesRegexp(ValueError,
                                  "Must provide at least one structure"):
         nest.map(lambda x: x)
Ejemplo n.º 19
0
 def testMultiNest(self):
     inp_a = (3, (4, 5))
     inp_b = (42, (42, 44))
     output = nest.map(lambda a, b: a + b, inp_a, inp_b)
     self.assertEqual((45, (46, 49)), output)
Ejemplo n.º 20
0
 def testNoSequences(self):
     with self.assertRaisesRegexp(ValueError,
                                  "Cannot map over no sequences"):
         nest.map(lambda x: x)
Ejemplo n.º 21
0
def _nested_unary_mul(nested_a, p):
  """Multiply `Tensors` in arbitrarily nested `Tensor` `nested_a` with `p`."""
  return nest.map(lambda a: p * a, nested_a)
Ejemplo n.º 22
0
def _nested_zeros_like(nested_a):
  return nest.map(tf.zeros_like, nested_a)
Ejemplo n.º 23
0
 def testStructureMustBeSame(self):
     inp_a = (3, 4)
     inp_b = (42, 42, 44)
     err = "The two structures don't have the same number of elements."
     with self.assertRaisesRegexp(ValueError, err):
         nest.map(lambda a, b: a + b, inp_a, inp_b)