Ejemplo n.º 1
0
  def test_nonfinite(self):
    """Test that Dataset.write_npz raises if examples have nonfinite values."""
    dataset = phone_util.Dataset({}, self.metadata)
    dataset.examples['ae'] = np.full(
        (1, self.num_frames, self.num_channels), np.nan)

    with self.assertRaisesRegex(ValueError, 'nonfinite value'):
      dataset.write_npz(io.BytesIO())
Ejemplo n.º 2
0
  def test_split(self):
    """Test Dataset.split() method."""
    dataset = phone_util.Dataset(self.examples, self.metadata)

    dataset_a, dataset_b = dataset.split(0.1)

    self.assertEqual(dataset_a.example_counts, {'ae': 7, 'sil': 20})
    self.assertEqual(dataset_b.example_counts, {'ae': 63, 'sil': 180})
    self.assertEqual(dataset_a.metadata, dataset.metadata)
    self.assertEqual(dataset_b.metadata, dataset.metadata)
Ejemplo n.º 3
0
  def test_subsample(self):
    """Test Dataset.subsample() method."""
    self.examples['x'] = np.random.rand(
        20, self.num_frames, self.num_channels).astype(np.float32)
    dataset = phone_util.Dataset(self.examples, self.metadata)

    dataset.subsample({'ae': 0.6, 'sil': 0.1})

    self.assertNotIn('x', dataset.examples)
    self.assertEqual(dataset.example_counts, {'ae': round(0.6 * 70),
                                              'sil': round(0.1 * 200)})
Ejemplo n.º 4
0
  def test_get_xy_arrays(self):
    """Test Dataset.get_xy_arrays() method."""
    dataset = phone_util.Dataset(self.examples, self.metadata)

    x, y = dataset.get_xy_arrays(['sil', 'z', 'ae'])

    expected_x = np.concatenate((self.examples['sil'], self.examples['ae']))
    expected_y = np.hstack(([0] * 200, [2] * 70))

    np.testing.assert_array_equal(x, expected_x)
    np.testing.assert_array_equal(y, expected_y)

    x, y = dataset.get_xy_arrays(['sil', 'z', 'ae'], shuffle=True)

    self.assertEqual(x.shape, (270, self.num_frames, self.num_channels))
Ejemplo n.º 5
0
def random_dataset():
  np.random.seed(0)
  num_frames_left_context = 2
  num_channels = 16
  num_frames = num_frames_left_context + 1
  examples = {
      'aa': np.random.rand(
          10, num_frames, num_channels).astype(np.float32),
      'eh': np.random.rand(
          14, num_frames, num_channels).astype(np.float32),
      'iy': np.random.rand(
          11, num_frames, num_channels).astype(np.float32),
  }
  metadata = {
      'num_frames_left_context': num_frames_left_context,
      'num_channels': num_channels,
  }
  return phone_util.Dataset(examples, metadata)
Ejemplo n.º 6
0
  def test_basic(self):
    """Test creating, writing, and reading the Dataset class."""
    # Create Dataset.
    dataset = phone_util.Dataset(self.examples, self.metadata)

    self.assertNotIn('z', dataset.examples)
    self.assertEqual(dataset.num_frames, self.num_frames)
    self.assertEqual(dataset.num_channels, self.num_channels)
    self.assertDictEqual(dataset.example_counts, {'ae': 70, 'sil': 200})

    # Write Dataset to in-memory .npz file and then read it.
    npz = io.BytesIO()
    dataset.write_npz(npz)
    npz.seek(0)  # Rewind to beginning of file for reading.
    recovered = phone_util.read_dataset_npz(npz)

    np.testing.assert_array_equal(recovered.examples['ae'],
                                  self.examples['ae'])
    np.testing.assert_array_equal(recovered.examples['sil'],
                                  self.examples['sil'])
    self.assertDictEqual(recovered.metadata, self.metadata)
    self.assertEqual(recovered.num_frames, self.num_frames)
    self.assertEqual(recovered.num_channels, self.num_channels)
Ejemplo n.º 7
0
def main(argv) -> int:
  if len(argv) > 1:
    print(f'WARNING: Non-flag arguments: {argv}')
  assert FLAGS.downsample_factor % FLAGS.block_size == 0

  wav_files = set()
  for glob_pattern in FLAGS.examples:
    glob_pattern = os.path.expanduser(glob_pattern)
    wav_files.update(wav_file for wav_file in glob.glob(glob_pattern)
                     if wav_file.lower().endswith('.wav'))

  if not wav_files:
    print(f'Error: No .wav files found matching {FLAGS.examples}')
    return 1

  for wav_file in wav_files:
    phn_file = phone_util.get_phone_label_filename(wav_file)
    if not os.path.isfile(phn_file):
      print(f'Error: .phn file not found: {phn_file}')
      return 1

  frontend = carl_frontend.CarlFrontend(**get_frontend_params_from_flags())
  examples = process_wav_files(wav_files)

  metadata = {
      'frontend_params': get_frontend_params_from_flags(),
      'num_channels': frontend.num_channels,
      'num_frames_left_context': FLAGS.num_frames_left_context,
      'flags': phone_util.get_main_module_flags_dict(),
  }
  dataset = phone_util.Dataset(examples, metadata)
  balance_examples(dataset)

  print(f'\nWriting dataset to {FLAGS.output}')
  dataset.write_npz(FLAGS.output)
  return 0