def test_nonfinite(self): """Test that Dataset.write_npz raises if examples have nonfinite values.""" dataset = phone_util.Dataset({}, self.metadata) dataset.examples['ae'] = np.full( (1, self.num_frames, self.num_channels), np.nan) with self.assertRaisesRegex(ValueError, 'nonfinite value'): dataset.write_npz(io.BytesIO())
def test_split(self): """Test Dataset.split() method.""" dataset = phone_util.Dataset(self.examples, self.metadata) dataset_a, dataset_b = dataset.split(0.1) self.assertEqual(dataset_a.example_counts, {'ae': 7, 'sil': 20}) self.assertEqual(dataset_b.example_counts, {'ae': 63, 'sil': 180}) self.assertEqual(dataset_a.metadata, dataset.metadata) self.assertEqual(dataset_b.metadata, dataset.metadata)
def test_subsample(self): """Test Dataset.subsample() method.""" self.examples['x'] = np.random.rand( 20, self.num_frames, self.num_channels).astype(np.float32) dataset = phone_util.Dataset(self.examples, self.metadata) dataset.subsample({'ae': 0.6, 'sil': 0.1}) self.assertNotIn('x', dataset.examples) self.assertEqual(dataset.example_counts, {'ae': round(0.6 * 70), 'sil': round(0.1 * 200)})
def test_get_xy_arrays(self): """Test Dataset.get_xy_arrays() method.""" dataset = phone_util.Dataset(self.examples, self.metadata) x, y = dataset.get_xy_arrays(['sil', 'z', 'ae']) expected_x = np.concatenate((self.examples['sil'], self.examples['ae'])) expected_y = np.hstack(([0] * 200, [2] * 70)) np.testing.assert_array_equal(x, expected_x) np.testing.assert_array_equal(y, expected_y) x, y = dataset.get_xy_arrays(['sil', 'z', 'ae'], shuffle=True) self.assertEqual(x.shape, (270, self.num_frames, self.num_channels))
def random_dataset(): np.random.seed(0) num_frames_left_context = 2 num_channels = 16 num_frames = num_frames_left_context + 1 examples = { 'aa': np.random.rand( 10, num_frames, num_channels).astype(np.float32), 'eh': np.random.rand( 14, num_frames, num_channels).astype(np.float32), 'iy': np.random.rand( 11, num_frames, num_channels).astype(np.float32), } metadata = { 'num_frames_left_context': num_frames_left_context, 'num_channels': num_channels, } return phone_util.Dataset(examples, metadata)
def test_basic(self): """Test creating, writing, and reading the Dataset class.""" # Create Dataset. dataset = phone_util.Dataset(self.examples, self.metadata) self.assertNotIn('z', dataset.examples) self.assertEqual(dataset.num_frames, self.num_frames) self.assertEqual(dataset.num_channels, self.num_channels) self.assertDictEqual(dataset.example_counts, {'ae': 70, 'sil': 200}) # Write Dataset to in-memory .npz file and then read it. npz = io.BytesIO() dataset.write_npz(npz) npz.seek(0) # Rewind to beginning of file for reading. recovered = phone_util.read_dataset_npz(npz) np.testing.assert_array_equal(recovered.examples['ae'], self.examples['ae']) np.testing.assert_array_equal(recovered.examples['sil'], self.examples['sil']) self.assertDictEqual(recovered.metadata, self.metadata) self.assertEqual(recovered.num_frames, self.num_frames) self.assertEqual(recovered.num_channels, self.num_channels)
def main(argv) -> int: if len(argv) > 1: print(f'WARNING: Non-flag arguments: {argv}') assert FLAGS.downsample_factor % FLAGS.block_size == 0 wav_files = set() for glob_pattern in FLAGS.examples: glob_pattern = os.path.expanduser(glob_pattern) wav_files.update(wav_file for wav_file in glob.glob(glob_pattern) if wav_file.lower().endswith('.wav')) if not wav_files: print(f'Error: No .wav files found matching {FLAGS.examples}') return 1 for wav_file in wav_files: phn_file = phone_util.get_phone_label_filename(wav_file) if not os.path.isfile(phn_file): print(f'Error: .phn file not found: {phn_file}') return 1 frontend = carl_frontend.CarlFrontend(**get_frontend_params_from_flags()) examples = process_wav_files(wav_files) metadata = { 'frontend_params': get_frontend_params_from_flags(), 'num_channels': frontend.num_channels, 'num_frames_left_context': FLAGS.num_frames_left_context, 'flags': phone_util.get_main_module_flags_dict(), } dataset = phone_util.Dataset(examples, metadata) balance_examples(dataset) print(f'\nWriting dataset to {FLAGS.output}') dataset.write_npz(FLAGS.output) return 0