def testScanStruct(self): rng = np.random.RandomState(0) d = rng.randn(2) def f(c_g, a_e_h): c, g = c_g a, e, h = a_e_h assert a.shape == (3, ) assert e.shape == () # pylint: disable=g-explicit-bool-comparison assert c.shape == (4, ) assert g.shape == (2, ) b = tf_np.cos( tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + tf_np.sum(tf_np.tan(d))) f = tf_np.cos(a) c = tf_np.sin(c * b) g = tf_np.sin(g * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison assert f.shape == (3, ) return [c, g], (b, f, h) xs = (rng.randn(5, 3), rng.randn(5), None) init = [rng.randn(4), rng.randn(2)] c_g, b_f_h = extensions.scan(f, init, xs) self.assertIsInstance(c_g, list) self.assertIsInstance(b_f_h, tuple) c, g = c_g b, f, h = b_f_h self.assertEqual((4, ), c.shape) self.assertEqual((2, ), g.shape) self.assertEqual((5, ), b.shape) self.assertEqual((5, 3), f.shape) self.assertIsNone(h)
def testScanStruct(self): rng = np.random.RandomState(0) d = rng.randn(2) def f(c_g_i, a_e_h): c_g, i = c_g_i c, g = c_g a, e_h = a_e_h e, h = e_h assert a.shape == (3, ) assert e.shape == () # pylint: disable=g-explicit-bool-comparison assert c.shape == (4, ) assert g.shape == (2, ) assert i is None assert h is None b = tf_np.cos( tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) + tf_np.sum(tf_np.tan(d))) f = tf_np.cos(a) c = tf_np.sin(c * b) g = tf_np.sin(g * b) assert b.shape == () # pylint: disable=g-explicit-bool-comparison assert f.shape == (3, ) return [(c, g), i], (b, [f, h]) xs = (rng.randn(5, 3), [rng.randn(5), None]) init = [(rng.randn(4), rng.randn(2)), None] c_g_i, b_f_h = extensions.scan(f, init, xs) self.assertIsInstance(c_g_i, list) self.assertIsInstance(b_f_h, tuple) c_g, i = c_g_i c, g = c_g self.assertIsInstance(c_g, tuple) self.assertEqual((4, ), c.shape) self.assertEqual((2, ), g.shape) self.assertIsNone(i) b, f_h = b_f_h f, h = f_h self.assertIsInstance(f_h, list) self.assertEqual((5, ), b.shape) self.assertEqual((5, 3), f.shape) self.assertIsNone(h)