コード例 #1
0
 def test_num_spectrogram_bins_dynamic(self):
     num_spectrogram_bins = array_ops.placeholder_with_default(
         ops.convert_to_tensor(129, dtype=dtypes.int32), shape=())
     mel_matrix_np = spectrogram_to_mel_matrix(20, 129, 8000.0, 125.0,
                                               3800.0)
     mel_matrix = mel_ops.linear_to_mel_weight_matrix(
         20, num_spectrogram_bins, 8000.0, 125.0, 3800.0)
     self.assertAllClose(mel_matrix_np, mel_matrix, atol=3e-6)
コード例 #2
0
 def test_constant_folding(self):
   """Mel functions should be constant foldable."""
   # TODO(rjryan): tf.bloat16 cannot be constant folded by Grappler.
   for dtype in (dtypes.float32, dtypes.float64):
     g = ops.Graph()
     with g.as_default():
       mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype)
       rewritten_graph = test_util.grappler_optimize(g, [mel_matrix])
       self.assertEqual(1, len(rewritten_graph.node))
コード例 #3
0
 def test_num_spectrogram_bins_dynamic(self):
   with self.session(use_gpu=True):
     num_spectrogram_bins = array_ops.placeholder(shape=(),
                                                  dtype=dtypes.int32)
     mel_matrix_np = spectrogram_to_mel_matrix(
         20, 129, 8000.0, 125.0, 3800.0)
     mel_matrix = mel_ops.linear_to_mel_weight_matrix(
         20, num_spectrogram_bins, 8000.0, 125.0, 3800.0)
     self.assertAllClose(
         mel_matrix_np,
         mel_matrix.eval(feed_dict={num_spectrogram_bins: 129}), atol=3e-6)
コード例 #4
0
 def test_constant_folding(self, dtype):
     """Mel functions should be constant foldable."""
     if context.executing_eagerly():
         return
     # TODO(rjryan): tf.bfloat16 cannot be constant folded by Grappler.
     g = ops.Graph()
     with g.as_default():
         mel_matrix = mel_ops.linear_to_mel_weight_matrix(
             sample_rate=constant_op.constant(8000.0, dtype=dtypes.float32),
             dtype=dtype)
         rewritten_graph = test_util.grappler_optimize(g, [mel_matrix])
         self.assertLen(rewritten_graph.node, 1)
コード例 #5
0
 def test_matches_reference_implementation(
     self, num_mel_bins, num_spectrogram_bins, sample_rate,
     use_tensor_sample_rate, lower_edge_hertz, upper_edge_hertz, dtype):
   if use_tensor_sample_rate:
     sample_rate = constant_op.constant(sample_rate)
   mel_matrix_np = spectrogram_to_mel_matrix(
       num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
       upper_edge_hertz, dtype)
   mel_matrix = mel_ops.linear_to_mel_weight_matrix(
       num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
       upper_edge_hertz, dtype)
   self.assertAllClose(mel_matrix_np, mel_matrix, atol=3e-6)
コード例 #6
0
 def test_matches_reference_implementation(self):
   # Tuples of (num_mel_bins, num_spectrogram_bins, sample_rate,
   # lower_edge_hertz, upper_edge_hertz) to test.
   configs = [
       # Defaults.
       (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64),
       # Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
       (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64)
   ]
   with self.session(use_gpu=True):
     for config in configs:
       mel_matrix_np = spectrogram_to_mel_matrix(*config)
       mel_matrix = mel_ops.linear_to_mel_weight_matrix(*config)
       self.assertAllClose(mel_matrix_np, mel_matrix.eval(), atol=3e-6)
コード例 #7
0
 def test_matches_reference_implementation(self):
     # Tuples of (num_mel_bins, num_spectrogram_bins, sample_rate,
     # lower_edge_hertz, upper_edge_hertz) to test.
     configs = [
         # Defaults.
         (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64),
         # Same as above, but with a constant Tensor sample rate.
         (20, 129, constant_op.constant(8000.0), 125.0, 3800.0,
          dtypes.float64),
         # Settings used by Tacotron (https://arxiv.org/abs/1703.10135).
         (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64),
     ]
     for config in configs:
         mel_matrix_np = spectrogram_to_mel_matrix(*config)
         mel_matrix = mel_ops.linear_to_mel_weight_matrix(*config)
         self.assertAllClose(mel_matrix_np, mel_matrix, atol=3e-6)
コード例 #8
0
 def test_error(self):
   with self.assertRaises(ValueError):
     mel_ops.linear_to_mel_weight_matrix(num_mel_bins=0)
   with self.assertRaises(ValueError):
     mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0)
   with self.assertRaises(ValueError):
     mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
   with self.assertRaises(ValueError):
     mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
   with self.assertRaises(ValueError):
     mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=100,
                                         upper_edge_hertz=10)
   with self.assertRaises(ValueError):
     mel_ops.linear_to_mel_weight_matrix(upper_edge_hertz=1000,
                                         sample_rate=800)
   with self.assertRaises(ValueError):
     mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32)
コード例 #9
0
 def test_dtypes(self):
   # LinSpace is not supported for tf.float16.
   for dtype in (dtypes.bfloat16, dtypes.float32, dtypes.float64):
     self.assertEqual(dtype,
                      mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype)
コード例 #10
0
 def test_error(self):
     # TODO(rjryan): Error types are different under eager.
     if context.executing_eagerly():
         return
     with self.assertRaises(ValueError):
         mel_ops.linear_to_mel_weight_matrix(num_mel_bins=0)
     with self.assertRaises(ValueError):
         mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
     with self.assertRaises(ValueError):
         mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=-1)
     with self.assertRaises(ValueError):
         mel_ops.linear_to_mel_weight_matrix(lower_edge_hertz=100,
                                             upper_edge_hertz=10)
     with self.assertRaises(ValueError):
         mel_ops.linear_to_mel_weight_matrix(upper_edge_hertz=1000,
                                             sample_rate=800)
     with self.assertRaises(ValueError):
         mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32)
コード例 #11
0
 def test_dtypes(self, dtype):
     # LinSpace is not supported for tf.float16.
     self.assertEqual(
         dtype,
         mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype)