def task_configs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[task_config_pb2.TaskConfig], st.SearchStrategy[Tuple[task_config_pb2.TaskConfig, Dict]], ]: """Returns a SearchStrategy for a TaskConfig plus maybe the kwargs.""" kwargs: Dict = {} descript = task_config_pb2.TaskConfig.DESCRIPTOR # model model_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["supported_models"].fields ])) if model_str == "speech_to_text": kwargs[model_str] = draw(speech_to_texts()) else: raise ValueError(f"unknown model type {model_str}") # train config kwargs["train_config"] = draw(train_configs()) # eval config kwargs["eval_config"] = draw(eval_configs()) # initialise and return all_fields_set(task_config_pb2.TaskConfig, kwargs) task_config = task_config_pb2.TaskConfig(**kwargs) if not return_kwargs: return task_config return task_config, kwargs
def export_ds1( ds1_cfg_fp: str, weights_fp: str, onnx_fp: str, opset_version: int = 11 ): """Exports :py:class:`DeepSpeech1` model. Args: ds1_cfg_fp: filepath to config for DeepSpeech1 task config. weights_fp: filpath of weights. onnx_fp: filepath to save onnx file to. opset_version: onnx opset_version. """ # Define attributes for ONNX export input_names = ["input", "in_lens", "h_n_in", "c_n_in"] output_names = ["output", "out_lens", "h_n_out", "c_n_out"] dynamic_axes = { "input": {0: "batch", 3: "seq_len"}, "in_lens": {0: "batch"}, "h_n_in": {1: "batch"}, "c_n_in": {1: "batch"}, "output": {0: "seq_len", 1: "batch"}, "out_lens": {0: "batch"}, "h_n_out": {1: "batch"}, "c_n_out": {1: "batch"}, } # build model and load weights with open(ds1_cfg_fp) as f: pb = task_config_pb2.TaskConfig() task_config = text_format.Merge(f.read(), pb) stt = build_stt(task_config.speech_to_text) ds1 = stt.model state_dict = torch.load(weights_fp, map_location=torch.device("cpu")) ds1.load_state_dict(state_dict=state_dict, strict=True) # gen ds1 input args args = gen_ds1_args(ds1) model = CollapseArgsDS1(ds1) model.eval() example_outputs = model(*args) torch.onnx.export( model, args, onnx_fp, export_params=True, verbose=False, example_outputs=example_outputs, dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, opset_version=opset_version, )
def test_model_in_configs_can_be_built(config_path): """Ensures :py:class:`task_config` in .config file can be built. This attempts to build the task config **minus the dataset** which is replaced with fake_speech_to_text for speed. """ with open(config_path, "r") as config_file: config = config_file.read() compiled = text_format.Merge(config, task_config_pb2.TaskConfig()) replace_dataset_w_fake_dataset(compiled.train_config.dataset) replace_dataset_w_fake_dataset(compiled.eval_config.dataset) build(compiled)
def test_all_configs_build(config_path): """Ensures all `myrtlespeech/config/*.config` files parse.""" with open(config_path, "r") as config_file: config = config_file.read() text_format.Merge(config, task_config_pb2.TaskConfig())
def parse(config_path: str) -> task_config_pb2.TaskConfig: """TODO""" with open(config_path) as f: task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig()) return task_config
from myrtlespeech.post_process.ctc_greedy_decoder import CTCGreedyDecoder from myrtlespeech.post_process.ctc_beam_decoder import CTCBeamDecoder from myrtlespeech.builders.task_config import build from myrtlespeech.run.train import fit from myrtlespeech.protos import task_config_pb2 from myrtlespeech.run.stage import Stage os.environ["CUDA_VISIBLE_DEVICES"] = "1" torch.backends.cudnn.benchmark = False from myrtlespeech.model.cnn import MaskConv1d, MaskConv2d, PaddingMode # parse example config file with open("../src/myrtlespeech/configs/deep_speech_2_en.config") as f: task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig()) # create all components for config seq_to_seq, epochs, train_loader, eval_loader = build(task_config) class Profiler(Callback): """ nvprof -f --profile-from-start off -o trace.nvvp -- python3 script.py Read using NVIDIA Visual Profiler (nvvp) """ def on_batch_begin(self, *args, **kwargs):