Пример #1
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    window_length = 2048  # Model specific constant.

    input_details, output_details = tflite_model.get_model_detail(
        FLAGS.model_path)
    input_wav_length = input_details[0]['shape'][0]
    output_roll_length = output_details[0]['shape'][1]
    assert (input_wav_length - window_length) % (output_roll_length - 1) == 0
    hop_size = (input_wav_length - window_length) // (output_roll_length - 1)

    overlap_timesteps = 4
    overlap_wav = hop_size * overlap_timesteps + window_length

    results = multiprocessing.Queue()

    results_thread = threading.Thread(target=result_collector,
                                      args=(results, ))
    results_thread.start()

    if FLAGS.wav_file:
        model = tflite_model.Model(model_path=FLAGS.model_path)

        wav_data = tf.gfile.Open(FLAGS.wav_file, 'rb').read()
        samples = audio_io.wav_data_to_samples(wav_data, MODEL_SAMPLE_RATE)
        samples = samples[:MODEL_SAMPLE_RATE * 10]  # Only the first 10 seconds
        samples = samples.reshape((-1, 1))
        samples_length = samples.shape[0]
        # Extend samples with zeros
        samples = np.pad(samples, (0, input_wav_length), mode='constant')
        for i, pos in enumerate(
                range(0, samples_length - input_wav_length + overlap_wav,
                      input_wav_length - overlap_wav)):
            chunk = samples[pos:pos + input_wav_length]
            task = OnsetsTask(AudioChunk(i, chunk))
            task(model)
            results.put(task)
    else:
        tasks = multiprocessing.JoinableQueue()

        ## Make and start the workers
        num_workers = 4
        workers = [
            TfLiteWorker(FLAGS.model_path, tasks, results)
            for i in range(num_workers)
        ]
        for w in workers:
            w.start()

        audio_feeder = AudioQueue(
            callback=lambda audio_chunk: tasks.put(OnsetsTask(audio_chunk)),
            audio_device_index=FLAGS.mic
            if FLAGS.mic is None else int(FLAGS.mic),
            sample_rate_hz=int(FLAGS.sample_rate_hz),
            frame_length=input_wav_length,
            overlap=overlap_wav)

        audio_feeder.start()
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    results = multiprocessing.Queue()
    results_thread = threading.Thread(target=result_collector,
                                      args=(results, ))
    results_thread.start()

    model = tflite_model.Model(model_path=FLAGS.model_path)
    overlap_timesteps = 4
    overlap_wav = model.get_hop_size(
    ) * overlap_timesteps + model.get_window_length()

    if FLAGS.wav_file:
        wav_data = open(FLAGS.wav_file, 'rb').read()
        samples = audio_recorder.wav_data_to_samples(wav_data,
                                                     model.get_sample_rate())
        samples = samples[:model.get_sample_rate() *
                          10]  # Only the first 10 seconds
        samples = samples.reshape((-1, 1))
        samples_length = samples.shape[0]
        # Extend samples with zeros
        samples = np.pad(samples, (0, model.get_input_wav_length()),
                         mode='constant')
        for i, pos in enumerate(
                range(
                    0, samples_length - model.get_input_wav_length() +
                    overlap_wav,
                    model.get_input_wav_length() - overlap_wav)):
            chunk = samples[pos:pos + model.get_input_wav_length()]
            task = OnsetsTask(AudioChunk(i, chunk))
            task(model)
            results.put(task)
    else:
        tasks = multiprocessing.JoinableQueue()

        ## Make and start the workers
        num_workers = 4
        workers = [
            TfLiteWorker(FLAGS.model_path, tasks, results)
            for i in range(num_workers)
        ]
        for w in workers:
            w.start()

        audio_feeder = AudioQueue(
            callback=lambda audio_chunk: tasks.put(OnsetsTask(audio_chunk)),
            audio_device_index=FLAGS.mic
            if FLAGS.mic is None else int(FLAGS.mic),
            sample_rate_hz=int(FLAGS.sample_rate_hz),
            model_sample_rate=model.get_sample_rate(),
            frame_length=model.get_input_wav_length(),
            overlap=overlap_wav)

        audio_feeder.start()
    def setup(self):
        if self._model is not None:
            return

        self._model = tflite_model.Model(model_path=self._model_path)