Ejemplo n.º 1
0
    def testAttrsMapStructure(self, values):
        if attr is None:
            self.skipTest("attr module is unavailable.")

        structure = NestTest.UnsortedSampleAttr(*values)
        new_structure = nest.map_structure(lambda x: x, structure)
        self.assertEqual(structure, new_structure)
Ejemplo n.º 2
0
    def testMapStructureOverPlaceholders(self):
        # We must drop into a graph context to avoid eager mode. Placeholders
        # and feed_dicts do not work with eager mode.
        with tf.Graph().as_default() as g:
            inp_a = (tf.compat.v1.placeholder(tf.float32, shape=[3, 4]),
                     tf.compat.v1.placeholder(tf.float32, shape=[3, 7]))
            inp_b = (tf.compat.v1.placeholder(tf.float32, shape=[3, 4]),
                     tf.compat.v1.placeholder(tf.float32, shape=[3, 7]))

            output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)

            nest.assert_same_structure(output, inp_a)
            self.assertShapeEqual(np.zeros((3, 4)), output[0])
            self.assertShapeEqual(np.zeros((3, 7)), output[1])

            feed_dict = {
                inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
                inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
            }

            with self.session(graph=g) as sess:
                output_np = sess.run(output, feed_dict=feed_dict)
            self.assertAllClose(output_np[0],
                                feed_dict[inp_a][0] + feed_dict[inp_b][0])
            self.assertAllClose(output_np[1],
                                feed_dict[inp_a][1] + feed_dict[inp_b][1])
Ejemplo n.º 3
0
    def testMapStructureWithStrings(self):
        inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
        inp_b = NestTest.ABTuple(a=2, b=(1, 3))
        out = nest.map_structure(lambda string, repeats: string * repeats,
                                 inp_a, inp_b)
        self.assertEqual("foofoo", out.a)
        self.assertEqual("bar", out.b[0])
        self.assertEqual("bazbazbaz", out.b[1])

        nt = NestTest.ABTuple(a=("something", "something_else"),
                              b="yet another thing")
        rev_nt = nest.map_structure(lambda x: x[::-1], nt)
        # Check the output is the correct structure, and all strings are reversed.
        nest.assert_same_structure(nt, rev_nt)
        self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
        self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
        self.assertEqual(nt.b[::-1], rev_nt.b)
Ejemplo n.º 4
0
  def testMapStructureOverPlaceholders(self):
    inp_a = (tf.placeholder(tf.float32, shape=[3, 4]),
             tf.placeholder(tf.float32, shape=[3, 7]))
    inp_b = (tf.placeholder(tf.float32, shape=[3, 4]),
             tf.placeholder(tf.float32, shape=[3, 7]))

    output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)

    nest.assert_same_structure(output, inp_a)
    self.assertShapeEqual(np.zeros((3, 4)), output[0])
    self.assertShapeEqual(np.zeros((3, 7)), output[1])

    feed_dict = {
        inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
        inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
    }

    with self.cached_session() as sess:
      output_np = sess.run(output, feed_dict=feed_dict)
    self.assertAllClose(output_np[0],
                        feed_dict[inp_a][0] + feed_dict[inp_b][0])
    self.assertAllClose(output_np[1],
                        feed_dict[inp_a][1] + feed_dict[inp_b][1])