コード例 #1
0
 def test_constant_folding(self):
   """Mel functions should be constant foldable."""
   # TODO(rjryan): tf.bloat16 cannot be constant folded by Grappler.
   for dtype in (dtypes.float32, dtypes.float64):
     g = ops.Graph()
     with g.as_default():
       mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype)
       rewritten_graph = test_util.grappler_optimize(g, [mel_matrix])
       self.assertEqual(1, len(rewritten_graph.node))
コード例 #2
0
 def test_constant_folding(self, window_fn, periodic, tf_dtype_tol):
     """Window functions should be constant foldable for constant inputs."""
     if context.executing_eagerly():
         return
     g = ops.Graph()
     with g.as_default():
         window = window_fn(100, periodic=periodic, dtype=tf_dtype_tol[0])
         rewritten_graph = test_util.grappler_optimize(g, [window])
         self.assertLen(rewritten_graph.node, 1)
コード例 #3
0
ファイル: window_ops_test.py プロジェクト: Wajih-O/tensorflow
 def test_constant_folding(self):
   """Window functions should be constant foldable for constant inputs."""
   for window_fn in (window_ops.hann_window, window_ops.hamming_window):
     for dtype, _ in self._dtypes:
       for periodic in [False, True]:
         g = ops.Graph()
         with g.as_default():
           window = window_fn(100, periodic=periodic, dtype=dtype)
           rewritten_graph = test_util.grappler_optimize(g, [window])
           self.assertEqual(1, len(rewritten_graph.node))
コード例 #4
0
 def test_constant_folding(self):
     """Window functions should be constant foldable for constant inputs."""
     for window_fn in (window_ops.hann_window, window_ops.hamming_window):
         for dtype, _ in self._dtypes:
             for periodic in [False, True]:
                 g = ops.Graph()
                 with g.as_default():
                     window = window_fn(100, periodic=periodic, dtype=dtype)
                     rewritten_graph = test_util.grappler_optimize(
                         g, [window])
                     self.assertEqual(1, len(rewritten_graph.node))
コード例 #5
0
 def test_constant_folding(self, dtype):
     """Mel functions should be constant foldable."""
     if context.executing_eagerly():
         return
     # TODO(rjryan): tf.bfloat16 cannot be constant folded by Grappler.
     g = ops.Graph()
     with g.as_default():
         mel_matrix = mel_ops.linear_to_mel_weight_matrix(
             sample_rate=constant_op.constant(8000.0, dtype=dtypes.float32),
             dtype=dtype)
         rewritten_graph = test_util.grappler_optimize(g, [mel_matrix])
         self.assertLen(rewritten_graph.node, 1)
コード例 #6
0
 def test_constant_folding(self):
   """frame should be constant foldable for constant inputs."""
   for pad_end in [True, False]:
     g = ops.Graph()
     with g.as_default():
       frame_length, frame_step = 32, 16
       signal_shape = (2, 128)
       signal = array_ops.ones(signal_shape)
       frames = shape_ops.frame(signal, frame_length, frame_step,
                                pad_end=pad_end)
       rewritten_graph = test_util.grappler_optimize(g, [frames])
       self.assertEqual(1, len(rewritten_graph.node))
コード例 #7
0
 def test_constant_folding(self):
   """frame should be constant foldable for constant inputs."""
   for pad_end in [True, False]:
     g = ops.Graph()
     with g.as_default():
       frame_length, frame_step = 32, 16
       signal_shape = (2, 128)
       signal = array_ops.ones(signal_shape)
       frames = shape_ops.frame(signal, frame_length, frame_step,
                                pad_end=pad_end)
       rewritten_graph = test_util.grappler_optimize(g, [frames])
       self.assertEqual(1, len(rewritten_graph.node))