def testGetFeaturesForWav(self): tmp_dir = self.get_temp_dir() wav_dir = os.path.join(tmp_dir, "wavs") os.mkdir(wav_dir) self._SaveWavFolders(wav_dir, ["a", "b", "c"], 1) flags = self._GetDefaultFlags() flags.preprocess = "average" flags.desired_samples = 1600 flags.train_dir = tmp_dir flags.summaries_dir = tmp_dir flags.data_dir = wav_dir with self.cached_session() as sess: audio_processor = input_data.AudioProcessor(flags) sample_data = np.zeros([flags.desired_samples, 1]) for i in range(flags.desired_samples): phase = i % 4 if phase == 0: sample_data[i, 0] = 0 elif phase == 1: sample_data[i, 0] = -1 elif phase == 2: sample_data[i, 0] = 0 elif phase == 3: sample_data[i, 0] = 1 test_wav_path = os.path.join(tmp_dir, "test_wav.wav") input_data.save_wav_file(test_wav_path, sample_data, 16000) results = audio_processor.get_features_for_wav( test_wav_path, flags, sess) spectrogram = results[0] self.assertEqual(1, spectrogram.shape[0]) self.assertEqual(16, spectrogram.shape[1]) self.assertEqual(11, spectrogram.shape[2]) self.assertNear(0, spectrogram[0, 0, 0], 0.1) self.assertNear(200, spectrogram[0, 0, 5], 0.1)
def testSaveWavFile(self): tmp_dir = self.get_temp_dir() file_path = os.path.join(tmp_dir, "load_test.wav") save_data = np.zeros([16000, 1]) input_data.save_wav_file(file_path, save_data, 16000) loaded_data = input_data.load_wav_file(file_path) self.assertIsNotNone(loaded_data) self.assertLen(loaded_data, 16000)