def task_configs(
    draw,
    return_kwargs: bool = False
) -> Union[st.SearchStrategy[task_config_pb2.TaskConfig],
           st.SearchStrategy[Tuple[task_config_pb2.TaskConfig, Dict]], ]:
    """Returns a SearchStrategy for a TaskConfig plus maybe the kwargs."""
    kwargs: Dict = {}

    descript = task_config_pb2.TaskConfig.DESCRIPTOR

    # model
    model_str = draw(
        st.sampled_from([
            f.name for f in descript.oneofs_by_name["supported_models"].fields
        ]))
    if model_str == "speech_to_text":
        kwargs[model_str] = draw(speech_to_texts())
    else:
        raise ValueError(f"unknown model type {model_str}")

    # train config
    kwargs["train_config"] = draw(train_configs())

    # eval config
    kwargs["eval_config"] = draw(eval_configs())

    # initialise and return
    all_fields_set(task_config_pb2.TaskConfig, kwargs)
    task_config = task_config_pb2.TaskConfig(**kwargs)
    if not return_kwargs:
        return task_config
    return task_config, kwargs
def export_ds1(
    ds1_cfg_fp: str, weights_fp: str, onnx_fp: str, opset_version: int = 11
):
    """Exports :py:class:`DeepSpeech1` model.

    Args:
        ds1_cfg_fp: filepath to config for DeepSpeech1 task config.

        weights_fp: filpath of weights.

        onnx_fp: filepath to save onnx file to.

        opset_version: onnx opset_version.
    """

    # Define attributes for ONNX export
    input_names = ["input", "in_lens", "h_n_in", "c_n_in"]
    output_names = ["output", "out_lens", "h_n_out", "c_n_out"]
    dynamic_axes = {
        "input": {0: "batch", 3: "seq_len"},
        "in_lens": {0: "batch"},
        "h_n_in": {1: "batch"},
        "c_n_in": {1: "batch"},
        "output": {0: "seq_len", 1: "batch"},
        "out_lens": {0: "batch"},
        "h_n_out": {1: "batch"},
        "c_n_out": {1: "batch"},
    }

    # build model and load weights
    with open(ds1_cfg_fp) as f:
        pb = task_config_pb2.TaskConfig()
        task_config = text_format.Merge(f.read(), pb)
    stt = build_stt(task_config.speech_to_text)
    ds1 = stt.model
    state_dict = torch.load(weights_fp, map_location=torch.device("cpu"))
    ds1.load_state_dict(state_dict=state_dict, strict=True)

    # gen ds1 input args
    args = gen_ds1_args(ds1)

    model = CollapseArgsDS1(ds1)

    model.eval()
    example_outputs = model(*args)
    torch.onnx.export(
        model,
        args,
        onnx_fp,
        export_params=True,
        verbose=False,
        example_outputs=example_outputs,
        dynamic_axes=dynamic_axes,
        input_names=input_names,
        output_names=output_names,
        opset_version=opset_version,
    )
def test_model_in_configs_can_be_built(config_path):
    """Ensures :py:class:`task_config` in .config file can be built.

    This attempts to build the task config **minus the dataset** which is
    replaced with fake_speech_to_text for speed.
    """
    with open(config_path, "r") as config_file:
        config = config_file.read()

    compiled = text_format.Merge(config, task_config_pb2.TaskConfig())
    replace_dataset_w_fake_dataset(compiled.train_config.dataset)
    replace_dataset_w_fake_dataset(compiled.eval_config.dataset)
    build(compiled)
def test_all_configs_build(config_path):
    """Ensures all `myrtlespeech/config/*.config` files parse."""
    with open(config_path, "r") as config_file:
        config = config_file.read()
    text_format.Merge(config, task_config_pb2.TaskConfig())
Beispiel #5
0
def parse(config_path: str) -> task_config_pb2.TaskConfig:
    """TODO"""
    with open(config_path) as f:
        task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())
    return task_config
Beispiel #6
0
from myrtlespeech.post_process.ctc_greedy_decoder import CTCGreedyDecoder
from myrtlespeech.post_process.ctc_beam_decoder import CTCBeamDecoder
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2
from myrtlespeech.run.stage import Stage

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

torch.backends.cudnn.benchmark = False

from myrtlespeech.model.cnn import MaskConv1d, MaskConv2d, PaddingMode

# parse example config file
with open("../src/myrtlespeech/configs/deep_speech_2_en.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

# create all components for config
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)


class Profiler(Callback):
    """
    
    nvprof -f --profile-from-start off -o trace.nvvp -- python3 script.py
    
    
    Read using NVIDIA Visual Profiler (nvvp)
    
    """
    def on_batch_begin(self, *args, **kwargs):