def _build(self, *args): """Connects the BatchApply module into the graph. Args: *args: a Tensor or a nested list of Tensors. The input tensors will have their first dimensions merged, then an op or a module will be called on the input. The first dimension of the output will be split again based on the leading dimensions of the first input tensor. Returns: A Tensor resulting of applying the process above. """ # Merge leading dimensions for each input Tensor, then apply inner module. merged = nest.map(lambda inp: merge_leading_dims(inp, self._n_dims), args) results = self._module(*merged) # Unmerging takes the sizes of the leading dimensions from an input example # with equal shape for the leading `n_dims` dimensions. Typically this is # the first input. example_input = tf.convert_to_tensor( nest.flatten(args)[self._input_example_index]) def _split_to_original_leading_dims(result): return split_leading_dim(result, example_input, self._n_dims) return nest.map(_split_to_original_leading_dims, results)
def testEmptySequences(self): f = lambda x: x + 1 empty_nt = collections.namedtuple("empty_nt", "") self.assertEqual((), nest.map(f, ())) self.assertEqual([], nest.map(f, [])) self.assertEqual(empty_nt(), nest.map(f, empty_nt())) # This is checking actual equality of types, empty list != empty tuple self.assertNotEqual((), nest.map(f, []))
def testStringRepeat(self): ab_tuple = collections.namedtuple("ab_tuple", "a, b") inp_a = ab_tuple(a="foo", b=("bar", "baz")) inp_b = ab_tuple(a=2, b=(1, 3)) out = nest.map(lambda string, repeats: string * repeats, inp_a, inp_b) self.assertEqual(out.a, "foofoo") self.assertEqual(out.b[0], "bar") self.assertEqual(out.b[1], "bazbazbaz")
def _nested_unary_mul(nested_a, p): """Multiply `Tensors` in arbitrarily nested `Tensor` `nested_a` with `p`.""" def mul_with_broadcast(tensor): ndims = tensor.shape.ndims if ndims != 2: p_reshaped = tf.reshape(p, [-1] + [1] * (ndims - 1)) return p_reshaped * tensor else: return p * tensor return nest.map(mul_with_broadcast, nested_a)
def testMapSingleCollection(self): ab_tuple = collections.namedtuple("ab_tuple", "a, b") nt = ab_tuple(a=("something", "something_else"), b="yet another thing") rev_nt = nest.map(lambda x: x[::-1], nt) # Check the output is the correct structure, and all strings are reversed. nest.assert_same_structure(nt, rev_nt) self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) self.assertEqual(nt.b[::-1], rev_nt.b)
def testMapOverTwoTuples(self): inp_a = (tf.placeholder(tf.float32, shape=[3, 4]), tf.placeholder(tf.float32, shape=[3, 7])) inp_b = (tf.placeholder(tf.float32, shape=[3, 4]), tf.placeholder(tf.float32, shape=[3, 7])) output = nest.map(lambda x1, x2: x1 + x2, inp_a, inp_b) nest.assert_same_structure(output, inp_a) self.assertShapeEqual(np.zeros((3, 4)), output[0]) self.assertShapeEqual(np.zeros((3, 7)), output[1]) feed_dict = { inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) } with self.test_session() as sess: output_np = sess.run(output, feed_dict=feed_dict) self.assertAllClose(output_np[0], feed_dict[inp_a][0] + feed_dict[inp_b][0]) self.assertAllClose(output_np[1], feed_dict[inp_a][1] + feed_dict[inp_b][1])
def testStructureMustBeSame(self): inp_a = (3, 4) inp_b = (42, 42, 44) err = "The two structures don't have the same number of elements." with self.assertRaisesRegexp(ValueError, err): nest.map(lambda a, b: a + b, inp_a, inp_b)
def _nested_add(nested_a, nested_b): """Add two arbitrarily nested `Tensors`.""" return nest.map(lambda a, b: a + b, nested_a, nested_b)
def _nested_zeros_like(nested_a): return nest.map(tf.zeros_like, nested_a)
def testMultiNest(self): inp_a = (3, (4, 5)) inp_b = (42, (42, 44)) output = nest.map(lambda a, b: a + b, inp_a, inp_b) self.assertEqual((45, (46, 49)), output)
def _nested_unary_mul(nested_a, p): """Multiply `Tensors` in arbitrarily nested `Tensor` `nested_a` with `p`.""" return nest.map(lambda a: p * a, nested_a)
def testNoSequences(self): with self.assertRaisesRegexp(ValueError, "Must provide at least one structure"): nest.map(lambda x: x)
def testNoSequences(self): with self.assertRaisesRegexp(ValueError, "Cannot map over no sequences"): nest.map(lambda x: x)