コード例 #1
0
ファイル: extensions_test.py プロジェクト: yliu45/trax
 def f(c, a):
   assert a.shape == (3,)
   assert c.shape == (4,)
   b = (tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.sin(c)) +
        tf_np.sum(tf_np.sin(d)))
   c = tf_np.sin(c * b)
   assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
   return c, b
コード例 #2
0
 def f(c_g, a_e):
   c, g = c_g
   a, e = a_e
   assert a.shape == (3,)
   assert e.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert c.shape == (4,)
   assert g.shape == (2,)
   b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
                 tf_np.sum(tf_np.tan(d)))
   f = tf_np.cos(a)
   c = tf_np.sin(c * b)
   g = tf_np.sin(g * b)
   assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert f.shape == (3,)
   return [c, g], (b, f)
コード例 #3
0
ファイル: extensions_test.py プロジェクト: yliu45/trax
 def f(c_g_i, a_e_h):
   c_g, i = c_g_i
   c, g = c_g
   a, e_h = a_e_h
   e, h = e_h
   assert a.shape == (3,)
   assert e.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert c.shape == (4,)
   assert g.shape == (2,)
   assert i is None
   assert h is None
   b = tf_np.cos(tf_np.sum(tf_np.sin(a)) + tf_np.sum(tf_np.cos(c)) +
                 tf_np.sum(tf_np.tan(d)))
   f = tf_np.cos(a)
   c = tf_np.sin(c * b)
   g = tf_np.sin(g * b)
   assert b.shape == ()  # pylint: disable=g-explicit-bool-comparison
   assert f.shape == (3,)
   return [(c, g), i], (b, [f, h])