예제 #1
0
    def test_read_wav_bad_arg(self):
        """Call where the argument is not a file-like object."""
        class Nonsense(object):
            pass

        with self.assertRaisesRegex(TypeError, 'Nonsense found'):
            dsp.read_wav_file(Nonsense())
예제 #2
0
    def test_read_wav_reader_raises_exception(self):
        """Test where the file object read method raises an exception."""
        def _failing_read(unused_size):
            raise OSError('read method failed')

        reader = MockReader(_failing_read)
        with self.assertRaisesRegex(OSError, 'read method failed'):
            dsp.read_wav_file(reader)
예제 #3
0
    def test_read_wav_given_memory_stream(self):
        """Read WAV from an in-memory stream."""
        samples, sample_rate_hz = dsp.read_wav_file(io.BytesIO(self.wav_bytes))

        self.assertEqual(samples.dtype, np.int16)
        np.testing.assert_array_equal(samples, self.wav_samples)
        self.assertEqual(sample_rate_hz, 48000)
예제 #4
0
    def test_read_wav_given_filename(self):
        """Read WAV given a filename with read_wav_file()."""
        samples, sample_rate_hz = dsp.read_wav_file(self.read_filename)

        self.assertEqual(samples.dtype, np.int16)
        np.testing.assert_array_equal(samples, self.wav_samples)
        self.assertEqual(sample_rate_hz, 48000)
예제 #5
0
def main(_):
    # Read WAV file.
    samples, sample_rate_hz = dsp.read_wav_file(FLAGS.input, dtype=np.float32)
    samples = samples.mean(axis=1)

    # Run frontend to get CARL frames. The classifier expects input sample rate
    # CLASSIFIER_INPUT_HZ, block_size=128, pcen_cross_channel_diffusivity=60, and
    # otherwise the default frontend settings.
    carl = frontend.CarlFrontend(input_sample_rate_hz=CLASSIFIER_INPUT_HZ,
                                 block_size=128,
                                 pcen_cross_channel_diffusivity=60.0)
    if sample_rate_hz != CLASSIFIER_INPUT_HZ:
        resampler = dsp.Resampler(sample_rate_hz, CLASSIFIER_INPUT_HZ)
        samples = resampler.process_samples(samples)
    frames = phone_util.run_frontend(carl, samples)
    # The frame rate is 125Hz (hop size of 8ms).
    frame_rate = CLASSIFIER_INPUT_HZ / carl.block_size

    timeseries = {}
    for window in sliding_window(frames, classify_phoneme.NUM_FRAMES):
        # Run classifier inference on the current window.
        scores = classify_phoneme.classify_phoneme_scores(window)
        append_to_dict(timeseries, scores)

    fig_combined, fig_phoneme = plot_output(frames, frame_rate, timeseries,
                                            os.path.basename(FLAGS.input))

    if FLAGS.output:  # Save plot as an image file.
        stem, ext = os.path.splitext(FLAGS.output)
        plot.save_figure(stem + '-combined' + ext, fig_combined)
        plot.save_figure(stem + '-phoneme' + ext, fig_phoneme)
    else:  # Show plot interactively.
        plt.show()
    return 0
예제 #6
0
    def test_read_wav_given_local_file_object(self):
        """Read WAV given a local file object."""
        with open(self.read_filename, 'rb') as f:
            samples, sample_rate_hz = dsp.read_wav_file(f)

        self.assertEqual(samples.dtype, np.int16)
        np.testing.assert_array_equal(samples, self.wav_samples)
        self.assertEqual(sample_rate_hz, 48000)
예제 #7
0
    def test_write_wav_local_file(self):
        """Write WAV to a given filename with write_wav_file()."""
        try:
            write_filename = os.path.join(self.temp_dir, 'write.wav')
            dsp.write_wav_file(write_filename, self.wav_samples, 44100)

            samples, sample_rate_hz = dsp.read_wav_file(write_filename)
            np.testing.assert_array_equal(samples, self.wav_samples)
            self.assertEqual(sample_rate_hz, 44100)
        finally:
            if os.path.isfile(write_filename):
                os.remove(write_filename)
예제 #8
0
    def _run_phone(self, phone):
        self.assertEqual(len(embed_vowel.TARGET_NAMES),
                         embed_vowel.NUM_TARGETS)
        self.assertEqual(len(embed_vowel.TARGET_COORDS),
                         embed_vowel.NUM_TARGETS)
        phone_index = embed_vowel.TARGET_NAMES.index(phone)

        wav_file = (f'extras/test/testdata/phone_{phone}.wav')
        samples, _ = dsp.read_wav_file(wav_file, dtype=np.float32)
        samples = samples.mean(axis=1)

        carl = frontend.CarlFrontend()
        self.assertEqual(carl.num_channels, embed_vowel.NUM_CHANNELS)
        samples = samples[:len(samples) - len(samples) % carl.block_size]
        frames = carl.process_samples(samples)
        coords = embed_vowel.embed_vowel(frames)

        distance_from_targets = np.linalg.norm(
            coords[:, np.newaxis, :] -
            embed_vowel.TARGET_COORDS[np.newaxis, :, :],
            axis=-1)

        # Compare `coords` with embed_vowel_scores.
        scores = embed_vowel.embed_vowel_scores(frames)
        self.assertEqual(scores.shape,
                         (frames.shape[0], embed_vowel.NUM_TARGETS))
        np.testing.assert_array_equal(scores.argmax(axis=1),
                                      distance_from_targets.argmin(axis=1))
        np.testing.assert_allclose(scores,
                                   np.exp(-4.0 * distance_from_targets),
                                   atol=1e-6)

        # Compute L2 distance from intended_target.
        distance_from_intended = distance_from_targets[:, phone_index]
        # Mean distance from the intended target is not too large.
        self.assertLessEqual(distance_from_intended.mean(), 0.35)

        other = [j for j in range(embed_vowel.NUM_TARGETS) if j != phone_index]
        min_distance_from_other = distance_from_targets[:, other].min(axis=1)
        accuracy = float(
            np.count_nonzero(distance_from_intended < min_distance_from_other)
        ) / len(frames)
        # Coordinate is closest to the indended target >= 70% of the time.
        self.assertGreaterEqual(accuracy, 0.7)
    def _run_phoneme(self, phoneme: str) -> None:
        """A forgiving test that the classifier is basically working.

    Runs CarlFrontend + ClassifyPhoneme on a short WAV recording of a pure
    phone, and checks that a moderately confident score is sometimes given to
    the correct label.

    Args:
      phoneme: String, name of the phoneme to test.
    """
        wav_file = (f'extras/test/testdata/phone_{phoneme}.wav')
        samples, sample_rate_hz = dsp.read_wav_file(wav_file, dtype=np.float32)
        samples = samples.mean(axis=1)
        self.assertEqual(sample_rate_hz, CLASSIFIER_INPUT_HZ)

        # Run frontend to get CARL frames. The classifier expects input sample rate
        # CLASSIFIER_INPUT_HZ, block_size=128, pcen_cross_channel_diffusivity=60,
        # and otherwise the default frontend settings.
        carl = frontend.CarlFrontend(input_sample_rate_hz=CLASSIFIER_INPUT_HZ,
                                     block_size=128,
                                     pcen_cross_channel_diffusivity=60.0)
        self.assertEqual(carl.num_channels, classify_phoneme.NUM_CHANNELS)
        samples = samples[:len(samples) - len(samples) % carl.block_size]
        frames = carl.process_samples(samples)

        count_correct = 0
        count_total = 0
        for window in sliding_window(frames, classify_phoneme.NUM_FRAMES):
            scores = classify_phoneme.classify_phoneme_scores(window)
            # Count as "correct" if correct label's score is moderately confident.
            count_correct += (scores['phoneme'][phoneme] > 0.1)
            count_total += 1

        self.assertCountEqual(scores['phoneme'].keys(),
                              classify_phoneme.PHONEMES)
        self.assertCountEqual(scores['manner'].keys(),
                              classify_phoneme.MANNERS)
        self.assertCountEqual(scores['place'].keys(), classify_phoneme.PLACES)

        accuracy = float(count_correct) / count_total
        self.assertGreaterEqual(accuracy, 0.6)
예제 #10
0
def main(_):
    # Read WAV file.
    samples, sample_rate_hz = dsp.read_wav_file(FLAGS.input, dtype=np.float32)
    samples = samples.mean(axis=1)

    # Make the frontend and network.
    carl = frontend.CarlFrontend(input_sample_rate_hz=sample_rate_hz)
    phone_net, target_names = get_phone_net(FLAGS.model)

    # Run audio-to-phone inference.
    frames = phone_util.run_frontend(carl, samples)
    frame_rate = sample_rate_hz / carl.block_size
    scores = phone_net(frames)

    fig = plot_output(frames, frame_rate, scores, target_names,
                      os.path.basename(FLAGS.input) + '\n' + FLAGS.model)

    if FLAGS.output:  # Save plot as an image file.
        plot.save_figure(FLAGS.output, fig)
    else:  # Show plot interactively.
        plt.show()
    return 0
    def test_phone(self):
        for phone, intended_tactor in [('aa', 1), ('eh', 5), ('uw', 2)]:
            input_samples, input_sample_rate_hz = dsp.read_wav_file(
                f'extras/test/testdata/phone_{phone}.wav')
            input_samples = input_samples[:, 0]

            for decimation_factor in [1, 2, 4, 8]:
                processor = tactile_processor.TactileProcessor(
                    input_sample_rate_hz=input_sample_rate_hz,
                    decimation_factor=decimation_factor)
                block_size = processor.block_size
                output_block_size = block_size // decimation_factor

                energy = np.zeros(tactile_processor.NUM_TACTORS)

                start = 0
                while start + block_size < len(input_samples):
                    block_end = start + block_size
                    input_block = input_samples[start:block_end]
                    # Convert to floats in [-1, 1].
                    input_block = input_block.astype(np.float32) / 2.0**15

                    # Run audio-to-tactile processing in a streaming manner.
                    output_block = processor.process_samples(input_block)

                    self.assertEqual(output_block.shape[0], output_block_size)
                    self.assertEqual(output_block.shape[1],
                                     tactile_processor.NUM_TACTORS)

                    # Accumulate energy for each channel.
                    energy += (output_block**2).sum(axis=0)
                    start = block_end

                # The intended tactor has the largest energy in the vowel cluster.
                for c in range(1, 8):
                    if c != intended_tactor:
                        self.assertGreaterEqual(energy[intended_tactor],
                                                1.65 * energy[c])
예제 #12
0
def read_wav(s):
    wav_samples, wav_sample_rate_hz = dsp.read_wav_file(wav_dir + s + '.wav',
                                                        dtype=np.float32)
    assert wav_sample_rate_hz == 16000
    return wav_samples.mean(axis=1)
예제 #13
0
 def test_read_wav_reader_result_too_large(self):
     """Test where the read method returns more than requested."""
     reader = MockReader(lambda size: b'\000' * (size + 1))
     with self.assertRaisesRegex(ValueError, 'exceeds requested size'):
         dsp.read_wav_file(reader)
예제 #14
0
 def test_read_wav_reader_returns_wrong_type(self):
     """Test where the read method returns the wrong type."""
     reader = MockReader(lambda size: [0] * size)
     with self.assertRaisesRegex(TypeError, 'list found'):
         dsp.read_wav_file(reader)
예제 #15
0
def process_one_wav_file(wav_file: str) -> Dict[str, List[np.ndarray]]:
  """Processes one WAV file to create observed frames.

  Processes one TIMIT WAV file with the frontend, and uses the associated label
  file to group observed frames by phone. Segments shorter than
  FLAGS.min_phone_length_s or with labels in PHONES_TO_EXCLUDE_FROM_DATASET are
  skipped.

  Audio channels are averaged (if there are multiple channels) to reduce to mono
  before processing.

  Args:
    wav_file: String, WAV file path.
  Returns:
    Examples dict with values of shape (num_examples, num_frames, num_channels).
    `examples[phone][i]` is the input for the ith example with label `phone`.
  """
  samples_orig, sample_rate_hz = dsp.read_wav_file(wav_file, dtype=np.float32)
  samples_orig = samples_orig.mean(axis=1)

  phone_times = phone_util.get_phone_times(
      phone_util.get_phone_label_filename(wav_file))
  frontend = carl_frontend.CarlFrontend(**get_frontend_params_from_flags())
  examples = collections.defaultdict(list)
  translation = 0

  for draw_index in range(FLAGS.num_draws):
    samples = np.copy(samples_orig)

    # Resample from sample_rate_hz to AUDIO_SAMPLE_RATE_HZ, perturbed up to
    # +/-max_resample_percent to change pitch and compress/dilate time.
    # TODO: For more data augmentation, consider changing pitch and
    # time stretching independently.
    dilation_factor = AUDIO_SAMPLE_RATE_HZ / sample_rate_hz
    if draw_index > 0:
      max_log_dilation = np.log(1.0 + FLAGS.max_resample_percent / 100.0)
      dilation_factor *= np.exp(
          np.random.uniform(-max_log_dilation, max_log_dilation))

    if abs(dilation_factor - 1.0) >= 1e-4:
      resampler = dsp.Resampler(1.0, dilation_factor, max_denominator=2000)
      samples = resampler.process_samples(samples)

    if draw_index > 0:
      # Prepend a random fraction of a block of silence. This randomizes the
      # input phase with respect to the frontend's decimation by block_size.
      translation = np.random.randint(FLAGS.block_size)
      samples = np.append(np.zeros(translation), samples)
      # Add white Gaussian noise.
      samples = np.random.normal(
          samples, FLAGS.noise_stddev).astype(np.float32)
      # Scale the samples to simulate the recording at a different distance.
      samples /= np.exp(np.random.uniform(
          np.log(FLAGS.min_simulated_distance),
          np.log(FLAGS.max_simulated_distance)))

    observed = phone_util.run_frontend(frontend, samples)

    for start, end, phone in phone_times:
      start = int(round(dilation_factor * start)) + translation
      end = min(int(round(dilation_factor * end)), len(samples)) + translation
      phone_length_s = float(end - start) / sample_rate_hz

      # Skip short (quickly-spoken) phone segments. They are likely influenced
      # by preceding/following phones, making classification is less clear.
      if phone_length_s < FLAGS.min_phone_length_s:
        continue  # Skip short phone.

      phone = COALESCE_SIMILAR_PHONES.get(phone, phone)

      if phone in PHONES_TO_EXCLUDE_FROM_DATASET:
        continue

      # There may be confusing transitions (or possible labeling inaccuracy)
      # near the segment endpoints, so trim a fraction from each end.
      length = end - start
      start += int(round(length * FLAGS.phone_trim_left))
      end -= int(round(length * FLAGS.phone_trim_right))

      # Convert sample indices from audio sample rate to frame rate.
      start //= FLAGS.block_size
      end //= FLAGS.block_size

      left_context = FLAGS.num_frames_left_context
      # Extract a window every `hop` frames and append to examples.
      examples[phone].append(sliding_window(
          observed[max(0, start - left_context):end],
          window_size=left_context + 1,
          hop=FLAGS.downsample_factor // frontend.block_size))

  return examples
def main(argv):
    parser = argparse.ArgumentParser(
        description='TactileProcessor Python demo')
    parser.add_argument('--input', type=str, help='Input WAV or device')
    parser.add_argument('--output', type=str, help='Output device')
    parser.add_argument('--sample_rate_hz', default=16000, type=int)
    parser.add_argument('--channels', type=str)
    parser.add_argument('--channel_gains_db', type=str, default='')
    parser.add_argument('--chunk_size', default=256, type=int)
    parser.add_argument('--cutoff_hz', default=500.0, type=float)
    parser.add_argument('--global_gain_db', default=0.0, type=float)
    parser.add_argument('--use_equalizer',
                        dest='use_equalizer',
                        action='store_true')
    parser.add_argument('--nouse_equalizer',
                        dest='use_equalizer',
                        action='store_false')
    parser.set_defaults(use_equalizer=True)
    parser.add_argument('--mid_gain_db', default=-10.0, type=float)
    parser.add_argument('--high_gain_db', default=-5.5, type=float)
    args = parser.parse_args(argv[1:])

    for arg_name in ('input', 'output', 'channels'):
        if not getattr(args, arg_name):
            print('Must specify --' + arg_name)
            return

    play_wav_file = args.input.endswith('.wav')

    if play_wav_file:
        wav_samples, sample_rate_hz = dsp.read_wav_file(args.input,
                                                        dtype=np.float32)
        wav_samples = wav_samples.mean(axis=1)
        input_device = None
    else:
        input_device = args.input
        sample_rate_hz = args.sample_rate_hz

    worker = tactile_worker.TactileWorker(
        input_device=input_device,
        output_device=args.output,
        sample_rate_hz=sample_rate_hz,
        channels=args.channels,
        channel_gains_db=args.channel_gains_db,
        global_gain_db=args.global_gain_db,
        chunk_size=args.chunk_size,
        cutoff_hz=args.cutoff_hz,
        use_equalizer=args.use_equalizer,
        mid_gain_db=args.mid_gain_db,
        high_gain_db=args.high_gain_db)

    if play_wav_file:
        worker.set_playback_input()
        worker.play(wav_samples)
    else:
        worker.set_mic_input()

    print('\nPress Ctrl+C to stop program.\n')
    print(' '.join(f'{s:<7}' for s in ('base', 'aa', 'uw', 'ih', 'iy', 'eh',
                                       'ae', 'uh', 'sh', 's')))
    rms_min, rms_max = 0.003, 0.05

    try:
        while True:
            # When playback is almost done, add wav_samples to the queue again. This
            # makes the WAV play in a loop.
            if (play_wav_file and
                    worker.remaining_playback_samples < 0.1 * sample_rate_hz):
                worker.reset()
                worker.play(wav_samples)

            # Get the volume meters for each tactor and make a simple visualization.
            volume = worker.volume_meters
            activation = np.log2(1e-9 + volume / rms_min) / np.log2(
                rms_max / rms_min)
            activation = activation.clip(0.0, 1.0)
            print('\r' + ' '.join(volume_meter(a) for a in activation), end='')

            time.sleep(0.025)
    except KeyboardInterrupt:  # Stop gracefully on Ctrl+C.
        print('\n')
예제 #17
0
 def test_read_wav_read_not_callable(self):
     """Test where the read attribute is not callable."""
     reader = MockReader(None)
     with self.assertRaisesRegex(TypeError, 'not callable'):
         dsp.read_wav_file(reader)