Example #1
0
 def test_tensor(self):
     with self.test_session():
         a = Normal(mu=0.0, sigma=1.0)
         b = tf.constant(2.0)
         c = a + b
         d = Normal(mu=c, sigma=1.0)
         self.assertEqual(get_children(a), [d])
         self.assertEqual(get_children(b), [d])
         self.assertEqual(get_children(c), [d])
         self.assertEqual(get_children(d), [])
Example #2
0
 def test_tensor(self):
   with self.test_session():
     a = Normal(0.0, 1.0)
     b = tf.constant(2.0)
     c = a + b
     d = Normal(c, 1.0)
     self.assertEqual(get_children(a), [d])
     self.assertEqual(get_children(b), [d])
     self.assertEqual(get_children(c), [d])
     self.assertEqual(get_children(d), [])
Example #3
0
 def test_control_flow(self):
     with self.test_session():
         a = Bernoulli(p=0.5)
         b = Normal(mu=0.0, sigma=1.0)
         c = tf.constant(0.0)
         d = tf.cond(tf.cast(a, tf.bool), lambda: b, lambda: c)
         e = Normal(mu=d, sigma=1.0)
         self.assertEqual(get_children(a), [e])
         self.assertEqual(get_children(b), [e])
         self.assertEqual(get_children(c), [e])
         self.assertEqual(get_children(d), [e])
         self.assertEqual(get_children(e), [])
Example #4
0
 def test_chain_structure(self):
     with self.test_session():
         a = Normal(mu=0.0, sigma=1.0)
         b = Normal(mu=a, sigma=1.0)
         c = Normal(mu=b, sigma=1.0)
         d = Normal(mu=c, sigma=1.0)
         e = Normal(mu=d, sigma=1.0)
         self.assertEqual(get_children(a), [b])
         self.assertEqual(get_children(b), [c])
         self.assertEqual(get_children(c), [d])
         self.assertEqual(get_children(d), [e])
         self.assertEqual(get_children(e), [])
Example #5
0
 def test_control_flow(self):
   with self.test_session():
     a = Bernoulli(0.5)
     b = Normal(0.0, 1.0)
     c = tf.constant(0.0)
     d = tf.cond(tf.cast(a, tf.bool), lambda: b, lambda: c)
     e = Normal(d, 1.0)
     self.assertEqual(get_children(a), [e])
     self.assertEqual(get_children(b), [e])
     self.assertEqual(get_children(c), [e])
     self.assertEqual(get_children(d), [e])
     self.assertEqual(get_children(e), [])
Example #6
0
 def test_chain_structure(self):
   with self.test_session():
     a = Normal(mu=0.0, sigma=1.0)
     b = Normal(mu=a, sigma=1.0)
     c = Normal(mu=b, sigma=1.0)
     d = Normal(mu=c, sigma=1.0)
     e = Normal(mu=d, sigma=1.0)
     self.assertEqual(get_children(a), [b])
     self.assertEqual(get_children(b), [c])
     self.assertEqual(get_children(c), [d])
     self.assertEqual(get_children(d), [e])
     self.assertEqual(get_children(e), [])
 def test_chain_structure(self):
   """a -> b -> c -> d -> e"""
   with self.test_session():
     a = Normal(0.0, 1.0)
     b = Normal(a, 1.0)
     c = Normal(b, 1.0)
     d = Normal(c, 1.0)
     e = Normal(d, 1.0)
     self.assertEqual(get_children(a), [b])
     self.assertEqual(get_children(b), [c])
     self.assertEqual(get_children(c), [d])
     self.assertEqual(get_children(d), [e])
     self.assertEqual(get_children(e), [])
 def test_a_structure(self):
   """e <- d <- a -> b -> c"""
   with self.test_session():
     a = Normal(0.0, 1.0)
     b = Normal(a, 1.0)
     c = Normal(b, 1.0)
     d = Normal(a, 1.0)
     e = Normal(d, 1.0)
     self.assertEqual(set(get_children(a)), set([b, d]))
     self.assertEqual(get_children(b), [c])
     self.assertEqual(get_children(c), [])
     self.assertEqual(get_children(d), [e])
     self.assertEqual(get_children(e), [])
Example #9
0
    def test_scan(self):
        """copied form test_chain_structure"""
        def cumsum(x):
            return tf.scan(lambda a, x: a + x, x)

        with self.test_session():
            a = Normal(mu=tf.ones([3]), sigma=tf.ones([3]))
            b = Normal(mu=cumsum(a), sigma=tf.ones([3]))
            c = Normal(mu=cumsum(b), sigma=tf.ones([3]))
            d = Normal(mu=cumsum(c), sigma=tf.ones([3]))
            e = Normal(mu=cumsum(d), sigma=tf.ones([3]))
            self.assertEqual(get_children(a), [b])
            self.assertEqual(get_children(b), [c])
            self.assertEqual(get_children(c), [d])
            self.assertEqual(get_children(d), [e])
            self.assertEqual(get_children(e), [])
Example #10
0
  def test_scan(self):
    """copied from test_chain_structure"""
    def cumsum(x):
      return tf.scan(lambda a, x: a + x, x)

    with self.test_session():
      a = Normal(tf.ones([3]), tf.ones([3]))
      b = Normal(cumsum(a), tf.ones([3]))
      c = Normal(cumsum(b), tf.ones([3]))
      d = Normal(cumsum(c), tf.ones([3]))
      e = Normal(cumsum(d), tf.ones([3]))
      self.assertEqual(get_children(a), [b])
      self.assertEqual(get_children(b), [c])
      self.assertEqual(get_children(c), [d])
      self.assertEqual(get_children(d), [e])
      self.assertEqual(get_children(e), [])