コード例 #1
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    fs: int,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    model_tag: Optional[str],
    inference_config: Optional[str],
    allow_variable_data_keys: bool,
    segment_size: Optional[float],
    hop_size: Optional[float],
    normalize_segment_scale: bool,
    show_progressbar: bool,
    ref_channel: Optional[int],
    normalize_output_wav: bool,
    enh_s2t_task: bool,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build separate_speech
    separate_speech_kwargs = dict(
        train_config=train_config,
        model_file=model_file,
        inference_config=inference_config,
        segment_size=segment_size,
        hop_size=hop_size,
        normalize_segment_scale=normalize_segment_scale,
        show_progressbar=show_progressbar,
        ref_channel=ref_channel,
        normalize_output_wav=normalize_output_wav,
        device=device,
        dtype=dtype,
        enh_s2t_task=enh_s2t_task,
    )
    separate_speech = SeparateSpeech.from_pretrained(
        model_tag=model_tag,
        **separate_speech_kwargs,
    )

    # 3. Build data-iterator
    loader = EnhancementTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=EnhancementTask.build_preprocess_fn(
            separate_speech.enh_train_args, False),
        collate_fn=EnhancementTask.build_collate_fn(
            separate_speech.enh_train_args, False),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4. Start for-loop
    output_dir = Path(output_dir).expanduser().resolve()
    writers = []
    for i in range(separate_speech.num_spk):
        writers.append(
            SoundScpWriter(f"{output_dir}/wavs/{i + 1}",
                           f"{output_dir}/spk{i + 1}.scp"))

    for i, (keys, batch) in enumerate(loader):
        logging.info(f"[{i}] Enhancing {keys}")
        assert isinstance(batch, dict), type(batch)
        assert all(isinstance(s, str) for s in keys), keys
        _bs = len(next(iter(batch.values())))
        assert len(keys) == _bs, f"{len(keys)} != {_bs}"
        batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}

        waves = separate_speech(**batch)
        for (spk, w) in enumerate(waves):
            for b in range(batch_size):
                writers[spk][keys[b]] = fs, w[b]

    for writer in writers:
        writer.close()
コード例 #2
0
ファイル: enh_inference.py プロジェクト: jxncyym/espnet
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    fs: int,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    enh_train_config: str,
    enh_model_file: str,
    allow_variable_data_keys: bool,
    normalize_output_wav: bool,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build Enh model
    enh_model, enh_train_args = EnhancementTask.build_model_from_file(
        enh_train_config, enh_model_file, device)
    enh_model.eval()

    num_spk = enh_model.num_spk

    # 3. Build data-iterator
    loader = EnhancementTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=EnhancementTask.build_preprocess_fn(
            enh_train_args, False),
        collate_fn=EnhancementTask.build_collate_fn(enh_train_args),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    writers = []
    for i in range(num_spk):
        writers.append(
            SoundScpWriter(f"{output_dir}/wavs/{i + 1}",
                           f"{output_dir}/spk{i + 1}.scp"))

    for keys, batch in loader:
        assert isinstance(batch, dict), type(batch)
        assert all(isinstance(s, str) for s in keys), keys
        _bs = len(next(iter(batch.values())))
        assert len(keys) == _bs, f"{len(keys)} != {_bs}"

        with torch.no_grad():
            # a. To device
            batch = to_device(batch, device)
            # b. Forward Enhancement Frontend
            waves, _, _ = enh_model.enh_model.forward_rawwav(
                batch["speech_mix"], batch["speech_mix_lengths"])
            assert len(waves[0]) == batch_size, len(waves[0])

        # FIXME(Chenda): will be incorrect when
        #  batch size is not 1 or multi-channel case
        if normalize_output_wav:
            waves = [
                (w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).T.cpu().numpy()
                for w in waves
            ]  # list[(sample,batch)]
        else:
            waves = [w.T.cpu().numpy() for w in waves]
        for (i, w) in enumerate(waves):
            writers[i][keys[0]] = fs, w

    for writer in writers:
        writer.close()