def predict(args):
    assets = AssetManager(args.base_dir)
    storage = PredictionStorage(
        assets.create_prediction_storage(args.model, args.data_name))

    # Change path of args.model manually
    network = SpeechEnhancementNetwork.load(
        assets.get_model_cache_path(args.model))

    with open(assets.get_normalization_cache_path(args.model),
              'rb') as normalization_fd:
        video_normalizer = pickle.load(normalization_fd)

    try:

        video_normalizer.normalize(extract_frames("data"))

        predicted_speech_spectrograms = network.predict(
            extract_audio()[1], extract_frames("data"))

        predicted_speech_signal = data_processor.reconstruct_speech_signal(
            extract_audio()[0], predicted_speech_spectrograms, 30)

        predicted_speech_signal.save_to_wav_file("enhanced.wav")

    except Exception:
        logging.exception("failed to predict %s. skipping" % "test")
def start(args):

    # Initialize Network
    assets = AssetManager(args.prediction_dir)
    storage = PredictionStorage(args.prediction_dir)
    network = SpeechEnhancementNetwork.load(
        assets.get_model_cache_path(args.model_dir))
    network.start_prediction_mode()
    network.predict(np.zeros((2, 80, 24)), np.zeros((2, 128, 128, 6)))


    predicted_speech_signal = reconstruct_speech_signal\
        (AudioSignal.from_wav_file("/cs/engproj/322/real_time/raw_data/mixture.wav"), np.zeros((2, 80, 24)), 30)

    with open(assets.get_normalization_cache_path(args.model_dir),
              'rb') as normalization_fd:
        video_normalizer = pickle.load(normalization_fd)

    lock = Lock()
    video_dir = assets.get_video_cache_path(args.video_audio_dir)
    predict_object = RunPredict(network, video_dir, storage.storage_dir)

    # Run video, audio, preprocess and play threads
    video_queue = Queue()
    audio_queue = Queue()
    predict_queue = Queue()
    play_queue = Queue()
    video_object = VideoProcess(video_dir)
    video_thread = Process(target=video_object.capture_frames,
                           args=(video_queue, lock))
    audio_object = AudioProcess(
        assets.get_audio_cache_path(args.video_audio_dir))
    audio_thread = Process(target=audio_object.capture_frames,
                           args=(audio_queue, lock))
    preprocess_thread = Process(target=predict_object.run_pre_process,
                                args=(video_queue, audio_queue, predict_queue,
                                      video_normalizer, lock))
    play_thread = Process(target=predict_object.play, args=(play_queue, lock))

    video_thread.start()
    audio_thread.start()
    preprocess_thread.start()
    play_thread.start()

    # Run predict
    predict_object.predict(predict_queue, play_queue, lock)

    video_thread.join()
    audio_thread.join()
    preprocess_thread.join()
    play_thread.join()

    # Save files
    predict_object.save_files(storage)

    print("*Finish All*")
Ejemplo n.º 3
0
def train(args):
    assets = AssetManager(args.base_dir)
    assets.create_model(args.model)

    train_preprocessed_blob_paths = [
        assets.get_preprocessed_blob_path(d) for d in args.train_data_names
    ]
    validation_preprocessed_blob_paths = [
        assets.get_preprocessed_blob_path(d) for d in args.validation_data_names
    ]

    train_samples = load_preprocessed_blobs(train_preprocessed_blob_paths)
    (
        train_video_samples,
        train_mixed_spectrograms,
        train_speech_spectrograms,
    ) = make_sample_set(train_samples)

    validation_samples = load_preprocessed_blobs(validation_preprocessed_blob_paths)
    (
        validation_video_samples,
        validation_mixed_spectrograms,
        validation_speech_spectrograms,
    ) = make_sample_set(validation_samples)

    video_normalizer = data_processor.VideoNormalizer(train_video_samples)
    video_normalizer.normalize(train_video_samples)
    video_normalizer.normalize(validation_video_samples)

    with open(
        assets.get_normalization_cache_path(args.model), "wb"
    ) as normalization_fd:
        pickle.dump(video_normalizer, normalization_fd)

    network = SpeechEnhancementNetwork.build(
        train_mixed_spectrograms.shape[1:], train_video_samples.shape[1:]
    )
    network.train(
        train_mixed_spectrograms,
        train_video_samples,
        train_speech_spectrograms,
        validation_mixed_spectrograms,
        validation_video_samples,
        validation_speech_spectrograms,
        assets.get_model_cache_path(args.model),
        assets.get_tensorboard_dir(args.model),
    )

    network.save(assets.get_model_cache_path(args.model))
Ejemplo n.º 4
0
def predict(args):
    assets = AssetManager(args.base_dir)
    storage = PredictionStorage(
        assets.create_prediction_storage(args.model, args.data_name)
    )
    network = SpeechEnhancementNetwork.load(assets.get_model_cache_path(args.model))

    with open(
        assets.get_normalization_cache_path(args.model), "rb"
    ) as normalization_fd:
        video_normalizer = pickle.load(normalization_fd)

    samples = load_preprocessed_blob(assets.get_preprocessed_blob_path(args.data_name))
    for sample in samples:
        try:
            print(
                "predicting (%s, %s)..."
                % (sample.video_file_path, sample.noise_file_path)
            )

            video_normalizer.normalize(sample.video_samples)

            loss = network.evaluate(
                sample.mixed_spectrograms,
                sample.video_samples,
                sample.speech_spectrograms,
            )
            print("loss: %f" % loss)

            predicted_speech_spectrograms = network.predict(
                sample.mixed_spectrograms, sample.video_samples
            )

            predicted_speech_signal = data_processor.reconstruct_speech_signal(
                sample.mixed_signal,
                predicted_speech_spectrograms,
                sample.video_frame_rate,
            )

            storage.save_prediction(sample, predicted_speech_signal)

        except Exception:
            logging.exception("failed to predict %s. skipping" % sample.video_file_path)