Esempio n. 1
0
def main(cmd=None):
    r"""Speaker diarization training.

    Example:
        % python diar_train.py diar --print_config --optim adadelta \
                > conf/train_diar.yaml
        % python diar_train.py --config conf/train_diar.yaml
    """
    DiarizationTask.main(cmd=cmd)
Esempio n. 2
0
def config_file(tmp_path: Path):
    # Write default configuration file
    DiarizationTask.main(cmd=[
        "--dry_run",
        "true",
        "--output_dir",
        str(tmp_path),
        "--num_spk",
        "2",
    ])
    return tmp_path / "config.yaml"
Esempio n. 3
0
def diar_config_file2(tmp_path: Path):
    # Write default configuration file
    DiarizationTask.main(cmd=[
        "--dry_run",
        "true",
        "--output_dir",
        str(tmp_path),
        "--attractor",
        "rnn",
        "--attractor_conf",
        "unit=256",
        "--num_spk",
        "2",
    ])
    return tmp_path / "config.yaml"
Esempio n. 4
0
    def __init__(
        self,
        train_config: Union[Path, str] = None,
        model_file: Union[Path, str] = None,
        segment_size: Optional[float] = None,
        normalize_segment_scale: bool = False,
        show_progressbar: bool = False,
        num_spk: Optional[int] = None,
        device: str = "cpu",
        dtype: str = "float32",
    ):
        assert check_argument_types()

        # 1. Build Diar model
        diar_model, diar_train_args = DiarizationTask.build_model_from_file(
            train_config, model_file, device)
        diar_model.to(dtype=getattr(torch, dtype)).eval()

        self.device = device
        self.dtype = dtype
        self.diar_train_args = diar_train_args
        self.diar_model = diar_model

        # only used when processing long speech, i.e.
        # segment_size is not None and hop_size is not None
        self.segment_size = segment_size
        self.normalize_segment_scale = normalize_segment_scale
        self.show_progressbar = show_progressbar
        # not specifying "num_spk" in inference config file
        # will enable speaker number prediction during inference
        self.num_spk = num_spk

        self.segmenting = segment_size is not None
        if self.segmenting:
            logging.info("Perform segment-wise speaker diarization")
            logging.info("Segment length = {} sec".format(segment_size))
        else:
            logging.info("Perform direct speaker diarization on the input")
Esempio n. 5
0
def inference(
    output_dir: str,
    batch_size: int,
    dtype: str,
    fs: int,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    model_tag: Optional[str],
    allow_variable_data_keys: bool,
    segment_size: Optional[float],
    show_progressbar: bool,
):
    assert check_argument_types()
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )

    if ngpu >= 1:
        device = "cuda"
    else:
        device = "cpu"

    # 1. Set random-seed
    set_all_random_seed(seed)

    # 2. Build separate_speech
    diarize_speech_kwargs = dict(
        train_config=train_config,
        model_file=model_file,
        segment_size=segment_size,
        show_progressbar=show_progressbar,
        device=device,
        dtype=dtype,
    )
    diarize_speech = DiarizeSpeech.from_pretrained(
        model_tag=model_tag,
        **diarize_speech_kwargs,
    )

    # 3. Build data-iterator
    loader = DiarizationTask.build_streaming_iterator(
        data_path_and_name_and_type,
        dtype=dtype,
        batch_size=batch_size,
        key_file=key_file,
        num_workers=num_workers,
        preprocess_fn=DiarizationTask.build_preprocess_fn(
            diarize_speech.diar_train_args, False),
        collate_fn=DiarizationTask.build_collate_fn(
            diarize_speech.diar_train_args, False),
        allow_variable_data_keys=allow_variable_data_keys,
        inference=True,
    )

    # 4. Start for-loop
    writer = NpyScpWriter(f"{output_dir}/predictions",
                          f"{output_dir}/diarize.scp")

    for keys, batch in loader:
        assert isinstance(batch, dict), type(batch)
        assert all(isinstance(s, str) for s in keys), keys
        _bs = len(next(iter(batch.values())))
        assert len(keys) == _bs, f"{len(keys)} != {_bs}"
        batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}

        spk_predictions = diarize_speech(**batch)
        for b in range(batch_size):
            writer[keys[b]] = spk_predictions[b]

    writer.close()
Esempio n. 6
0
def get_parser():
    parser = DiarizationTask.get_parser()
    return parser