コード例 #1
0
    def testGetFeaturesForWav(self):
        tmp_dir = self.get_temp_dir()
        wav_dir = os.path.join(tmp_dir, "wavs")
        os.mkdir(wav_dir)
        self._SaveWavFolders(wav_dir, ["a", "b", "c"], 1)
        flags = self._GetDefaultFlags()
        flags.preprocess = "average"
        flags.desired_samples = 1600
        flags.train_dir = tmp_dir
        flags.summaries_dir = tmp_dir
        flags.data_dir = wav_dir
        with self.cached_session() as sess:
            audio_processor = input_data.AudioProcessor(flags)
            sample_data = np.zeros([flags.desired_samples, 1])
            for i in range(flags.desired_samples):
                phase = i % 4
                if phase == 0:
                    sample_data[i, 0] = 0
                elif phase == 1:
                    sample_data[i, 0] = -1
                elif phase == 2:
                    sample_data[i, 0] = 0
                elif phase == 3:
                    sample_data[i, 0] = 1
            test_wav_path = os.path.join(tmp_dir, "test_wav.wav")
            input_data.save_wav_file(test_wav_path, sample_data, 16000)

            results = audio_processor.get_features_for_wav(
                test_wav_path, flags, sess)
            spectrogram = results[0]
            self.assertEqual(1, spectrogram.shape[0])
            self.assertEqual(16, spectrogram.shape[1])
            self.assertEqual(11, spectrogram.shape[2])
            self.assertNear(0, spectrogram[0, 0, 0], 0.1)
            self.assertNear(200, spectrogram[0, 0, 5], 0.1)
コード例 #2
0
 def testSaveWavFile(self):
     tmp_dir = self.get_temp_dir()
     file_path = os.path.join(tmp_dir, "load_test.wav")
     save_data = np.zeros([16000, 1])
     input_data.save_wav_file(file_path, save_data, 16000)
     loaded_data = input_data.load_wav_file(file_path)
     self.assertIsNotNone(loaded_data)
     self.assertLen(loaded_data, 16000)