Beispiel #1
0
    def testScanStruct(self):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c_g, a_e_h):
            c, g = c_g
            a, e, h = a_e_h
            assert a.shape == (3, )
            assert e.shape == ()  # pylint: disable=g-explicit-bool-comparison
            assert c.shape == (4, )
            assert g.shape == (2, )
            b = tf_np.cos(
                tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
                tf_np.sum(tf_np.tan(d)))
            f = tf_np.cos(a)
            c = tf_np.sin(c * b)
            g = tf_np.sin(g * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            assert f.shape == (3, )
            return [c, g], (b, f, h)

        xs = (rng.randn(5, 3), rng.randn(5), None)
        init = [rng.randn(4), rng.randn(2)]

        c_g, b_f_h = extensions.scan(f, init, xs)
        self.assertIsInstance(c_g, list)
        self.assertIsInstance(b_f_h, tuple)
        c, g = c_g
        b, f, h = b_f_h
        self.assertEqual((4, ), c.shape)
        self.assertEqual((2, ), g.shape)
        self.assertEqual((5, ), b.shape)
        self.assertEqual((5, 3), f.shape)
        self.assertIsNone(h)
Beispiel #2
0
    def testScanStruct(self):
        rng = np.random.RandomState(0)

        d = rng.randn(2)

        def f(c_g_i, a_e_h):
            c_g, i = c_g_i
            c, g = c_g
            a, e_h = a_e_h
            e, h = e_h
            assert a.shape == (3, )
            assert e.shape == ()  # pylint: disable=g-explicit-bool-comparison
            assert c.shape == (4, )
            assert g.shape == (2, )
            assert i is None
            assert h is None
            b = tf_np.cos(
                tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
                tf_np.sum(tf_np.tan(d)))
            f = tf_np.cos(a)
            c = tf_np.sin(c * b)
            g = tf_np.sin(g * b)
            assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
            assert f.shape == (3, )
            return [(c, g), i], (b, [f, h])

        xs = (rng.randn(5, 3), [rng.randn(5), None])
        init = [(rng.randn(4), rng.randn(2)), None]

        c_g_i, b_f_h = extensions.scan(f, init, xs)
        self.assertIsInstance(c_g_i, list)
        self.assertIsInstance(b_f_h, tuple)
        c_g, i = c_g_i
        c, g = c_g
        self.assertIsInstance(c_g, tuple)
        self.assertEqual((4, ), c.shape)
        self.assertEqual((2, ), g.shape)
        self.assertIsNone(i)
        b, f_h = b_f_h
        f, h = f_h
        self.assertIsInstance(f_h, list)
        self.assertEqual((5, ), b.shape)
        self.assertEqual((5, 3), f.shape)
        self.assertIsNone(h)