def _supported_lms( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[empty_pb2.Empty], st.SearchStrategy[Tuple[ empty_pb2.Empty, Dict]], ]: """Returns a SearchStrategy for supported_lms plus maybe the kwargs.""" kwargs: Dict = {} descript = language_model_pb2.LanguageModel.DESCRIPTOR lm_type_str = draw( st.sampled_from( [f.name for f in descript.oneofs_by_name["supported_lms"].fields])) # get kwargs for chosen lm_type if lm_type_str == "no_lm": lm_type = empty_pb2.Empty else: raise ValueError(f"test does not support generation of {lm_type}") # initialise lm_type and return all_fields_set(lm_type, kwargs) lm = lm_type(**kwargs) # type: ignore if not return_kwargs: return lm return lm, kwargs
def sgds( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[optimizer_pb2.SGD], st.SearchStrategy[Tuple[ optimizer_pb2.SGD, Dict]], ]: """Returns a SearchStrategy for an SGD plus maybe the kwargs.""" kwargs: Dict = {} kwargs["learning_rate"] = draw( st.floats(min_value=1e-4, max_value=1.0, allow_nan=False)) kwargs["nesterov_momentum"] = draw(st.booleans()) # nesterov momentum requires momentum > 0 min_momentum = 0.0 if not kwargs["nesterov_momentum"] else 0.1 kwargs["momentum"] = FloatValue(value=draw( st.floats(min_value=min_momentum, max_value=10.0, allow_nan=False))) kwargs["l2_weight_decay"] = FloatValue( value=draw(st.floats(min_value=0.0, max_value=10.0, allow_nan=False))) # initialise and return all_fields_set(optimizer_pb2.SGD, kwargs) sgd = optimizer_pb2.SGD(**kwargs) if not return_kwargs: return sgd return sgd, kwargs
def task_configs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[task_config_pb2.TaskConfig], st.SearchStrategy[Tuple[task_config_pb2.TaskConfig, Dict]], ]: """Returns a SearchStrategy for a TaskConfig plus maybe the kwargs.""" kwargs: Dict = {} descript = task_config_pb2.TaskConfig.DESCRIPTOR # model model_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["supported_models"].fields ])) if model_str == "speech_to_text": kwargs[model_str] = draw(speech_to_texts()) else: raise ValueError(f"unknown model type {model_str}") # train config kwargs["train_config"] = draw(train_configs()) # eval config kwargs["eval_config"] = draw(eval_configs()) # initialise and return all_fields_set(task_config_pb2.TaskConfig, kwargs) task_config = task_config_pb2.TaskConfig(**kwargs) if not return_kwargs: return task_config return task_config, kwargs
def activations( draw, return_kwargs: bool = False ) -> Union[ st.SearchStrategy[activation_pb2.Activation], st.SearchStrategy[Tuple[activation_pb2.Activation, Dict]], ]: """Returns a SearchStrategy for activation fns plus maybe the kwargs.""" kwargs = {} descript = activation_pb2.Activation.DESCRIPTOR activation_str = draw( st.sampled_from( [f.name for f in descript.oneofs_by_name["activation"].fields] ) ) if activation_str == "identity": kwargs["identity"] = empty_pb2.Empty() elif activation_str == "hardtanh": kwargs["hardtanh"] = activation_pb2.Activation.Hardtanh( min_val=draw(st.floats(-20.0, -0.1, allow_nan=False)), max_val=draw(st.floats(0.1, 20.0, allow_nan=False)), ) elif activation_str == "relu": kwargs["relu"] = activation_pb2.Activation.ReLU() else: raise ValueError(f"test does not support activation={activation_str}") all_fields_set(activation_pb2.Activation, kwargs) act = activation_pb2.Activation(**kwargs) if not return_kwargs: return act return act, kwargs
def adams( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[optimizer_pb2.Adam], st.SearchStrategy[Tuple[ optimizer_pb2.Adam, Dict]], ]: """Returns a SearchStrategy for an Adam optimizer plus maybe the kwargs.""" kwargs: Dict = {} kwargs["learning_rate"] = draw( st.floats(min_value=1e-4, max_value=1.0, allow_nan=False)) kwargs["beta_1"] = FloatValue( value=draw(st.floats(min_value=0.0, max_value=1.0, allow_nan=False))) kwargs["beta_2"] = FloatValue( value=draw(st.floats(min_value=0.0, max_value=1.0, allow_nan=False))) kwargs["eps"] = FloatValue( value=draw(st.floats(min_value=0.0, max_value=1e-3, allow_nan=False))) kwargs["l2_weight_decay"] = FloatValue( value=draw(st.floats(min_value=0.0, max_value=10.0, allow_nan=False))) kwargs["amsgrad"] = draw(st.booleans()) # initialise and return all_fields_set(optimizer_pb2.Adam, kwargs) adam = optimizer_pb2.Adam(**kwargs) if not return_kwargs: return adam return adam, kwargs
def pre_process_steps( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[pre_process_step_pb2.PreProcessStep], st. SearchStrategy[Tuple[pre_process_step_pb2.PreProcessStep, Dict]], ]: """Returns a SearchStrategy for a pre_process_step + maybe the kwargs.""" kwargs: Dict = {} kwargs["stage"] = draw(stages()) descript = pre_process_step_pb2.PreProcessStep.DESCRIPTOR step_type_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["pre_process_step"].fields ])) if step_type_str == "mfcc": kwargs["mfcc"] = draw(_mfccs()) elif step_type_str == "spec_augment": kwargs["spec_augment"] = draw(_spec_augments()) elif step_type_str == "standardize": kwargs["standardize"] = draw(_standardizes()) elif step_type_str == "context_frames": kwargs["context_frames"] = draw(_context_frames()) else: raise ValueError(f"unknown pre_process_step type {step_type_str}") # initialise return all_fields_set(pre_process_step_pb2.PreProcessStep, kwargs) step = pre_process_step_pb2.PreProcessStep(**kwargs) if not return_kwargs: return step return step, kwargs
def rnns( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[rnn_pb2.RNN], st.SearchStrategy[Tuple[rnn_pb2.RNN, Dict]]]: """Returns a SearchStrategy for RNN plus maybe the kwargs.""" kwargs = {} to_ignore: List[str] = [] kwargs["rnn_type"] = draw(st.sampled_from(rnn_pb2.RNN.RNN_TYPE.values())) kwargs["hidden_size"] = draw(st.integers(1, 32)) kwargs["num_layers"] = draw(st.integers(1, 4)) kwargs["bias"] = draw(st.booleans()) kwargs["bidirectional"] = draw(st.booleans()) if kwargs["rnn_type"] == rnn_pb2.RNN.RNN_TYPE.LSTM and kwargs["bias"]: kwargs["forget_gate_bias"] = FloatValue(value=draw( st.floats(min_value=-10.0, max_value=10.0, allow_nan=False))) else: to_ignore = ["forget_gate_bias"] all_fields_set(rnn_pb2.RNN, kwargs, to_ignore) rnn = rnn_pb2.RNN(**kwargs) if not return_kwargs: return rnn return rnn, kwargs
def fully_connecteds( draw, return_kwargs: bool = False, valid_only: bool = False ) -> Union[ st.SearchStrategy[fully_connected_pb2.FullyConnected], st.SearchStrategy[Tuple[fully_connected_pb2.FullyConnected, Dict]], ]: """Returns a SearchStrategy for a FC layer plus maybe the kwargs.""" kwargs = {} kwargs["num_hidden_layers"] = draw(st.integers(0, 3)) if valid_only and kwargs["num_hidden_layers"] == 0: kwargs["hidden_size"] = None kwargs["activation"] = activation_pb2.Activation( identity=empty_pb2.Empty() ) kwargs["dropout"] = None else: kwargs["hidden_size"] = draw(st.integers(1, 32)) kwargs["activation"] = draw(activations()) kwargs["dropout"] = FloatValue( value=draw( st.one_of(st.none(), st.floats(min_value=0.0, max_value=1.0)) ) ) all_fields_set(fully_connected_pb2.FullyConnected, kwargs) fc = fully_connected_pb2.FullyConnected(**kwargs) if not return_kwargs: return fc return fc, kwargs
def datasets( draw, return_kwargs: bool = False ) -> Union[ st.SearchStrategy[dataset_pb2.Dataset], st.SearchStrategy[Tuple[dataset_pb2.Dataset, Dict]], ]: """Returns a SearchStrategy for a Dataset plus maybe the kwargs.""" kwargs: Dict = {} desc = dataset_pb2.Dataset.DESCRIPTOR dataset_type_str = draw( st.sampled_from( [f.name for f in desc.oneofs_by_name["supported_datasets"].fields] ) ) # get kwargs for chosen dataset_type_str if dataset_type_str == "fake_speech_to_text": audio_ms_lower = draw(st.integers(1, 1000)) audio_ms_upper = draw(st.integers(audio_ms_lower, 4 * audio_ms_lower)) audio_ms = range_pb2.Range(lower=audio_ms_lower, upper=audio_ms_upper) label_symbols = "".join(draw(random_alphabet(min_size=2)).symbols) label_len_lower = draw(st.integers(1, 1000)) label_len_upper = draw( st.integers(label_len_lower, 4 * label_len_lower) ) label_len = range_pb2.Range( lower=label_len_lower, upper=label_len_upper ) kwargs["fake_speech_to_text"] = dataset_pb2.Dataset.FakeSpeechToText( dataset_len=draw(st.integers(1, 100)), audio_ms=audio_ms, label_symbols=label_symbols, label_len=label_len, ) elif dataset_type_str == "librispeech": warnings.warn("librispeech dataset not supported") assume(False) else: raise ValueError( f"test does not support generation of {dataset_type_str}" ) # initialise dataset and return all_fields_set(dataset_pb2.Dataset, kwargs) dataset = dataset_pb2.Dataset(**kwargs) # type: ignore if not return_kwargs: return dataset return dataset, kwargs
def constant_lrs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[lr_scheduler_pb2.ConstantLR], st.SearchStrategy[Tuple[lr_scheduler_pb2.ConstantLR, Dict]], ]: """Returns a SearchStrategy for an ConstantLR plus maybe the kwargs.""" kwargs: Dict = {} # initialise and return all_fields_set(lr_scheduler_pb2.ConstantLR, kwargs) constant_lr = lr_scheduler_pb2.ConstantLR(**kwargs) if not return_kwargs: return constant_lr return constant_lr, kwargs
def _standardizes( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[pre_process_step_pb2.Standardize], st.SearchStrategy[Tuple[pre_process_step_pb2.Standardize, Dict]], ]: """Returns a SearchStrategy for Standardizes plus maybe the kwargs.""" kwargs: Dict = {} # initialise and return all_fields_set(pre_process_step_pb2.Standardize, kwargs) std = pre_process_step_pb2.Standardize(**kwargs) # type: ignore if not return_kwargs: return std return std, kwargs
def lookaheads( draw, return_kwargs: bool = False ) -> Union[ st.SearchStrategy[lookahead_pb2.Lookahead], st.SearchStrategy[Tuple[lookahead_pb2.Lookahead, Dict]], ]: """Returns a SearchStrategy for a lookahead layer plus maybe the kwargs.""" kwargs = {} kwargs["context"] = draw(st.integers(1, 32)) all_fields_set(lookahead_pb2.Lookahead, kwargs) lookahead = lookahead_pb2.Lookahead(**kwargs) if not return_kwargs: return lookahead return lookahead, kwargs
def _context_frames( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[pre_process_step_pb2.ContextFrames], st. SearchStrategy[Tuple[pre_process_step_pb2.ContextFrames, Dict]], ]: """Returns a SearchStrategy for ContextFrames plus maybe the kwargs.""" kwargs: Dict = {} kwargs["n_context"] = draw(st.integers(1, 18)) # initialise and return all_fields_set(pre_process_step_pb2.ContextFrames, kwargs) cf = pre_process_step_pb2.ContextFrames(**kwargs) # type: ignore if not return_kwargs: return cf return cf, kwargs
def eval_configs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[eval_config_pb2.EvalConfig], st.SearchStrategy[Tuple[eval_config_pb2.EvalConfig, Dict]], ]: """Returns a SearchStrategy for a EvalConfig plus maybe the kwargs.""" kwargs: Dict = {} kwargs["batch_size"] = draw(st.integers(min_value=1, max_value=128)) kwargs["dataset"] = draw(datasets()) # initialise and return all_fields_set(eval_config_pb2.EvalConfig, kwargs) eval_config = eval_config_pb2.EvalConfig(**kwargs) if not return_kwargs: return eval_config return eval_config, kwargs
def exponential_lrs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[lr_scheduler_pb2.ExponentialLR], st.SearchStrategy[Tuple[lr_scheduler_pb2.ExponentialLR, Dict]], ]: """Returns a SearchStrategy for an ExponentialLR plus maybe the kwargs.""" kwargs: Dict = {} kwargs["gamma"] = draw( st.floats(min_value=1e-4, max_value=1.0, allow_nan=False)) # initialise and return all_fields_set(lr_scheduler_pb2.ExponentialLR, kwargs) exponential_lr = lr_scheduler_pb2.ExponentialLR(**kwargs) if not return_kwargs: return exponential_lr return exponential_lr, kwargs
def ctc_beam_decoders( draw, return_kwargs: bool = False, alphabet_len: Optional[int] = None, blank_index: Optional[int] = None, ) -> Union[ st.SearchStrategy[ctc_beam_decoder_pb2.CTCBeamDecoder], st.SearchStrategy[Tuple[ctc_beam_decoder_pb2.CTCBeamDecoder, Dict]], ]: """Returns a SearchStrategy for CTCBeamDecoder plus maybe the kwargs.""" kwargs: Dict = {} end = 100 if alphabet_len is not None: end = max(0, alphabet_len - 1) if blank_index is not None: kwargs["blank_index"] = blank_index else: kwargs["blank_index"] = draw(st.integers(0, end)) kwargs["beam_width"] = draw(st.integers(1, 2048)) kwargs["prune_threshold"] = draw(st.floats(0.0, 1.0, allow_nan=False)) kwargs["language_model"] = draw(language_models()) if not isinstance(kwargs["language_model"], empty_pb2.Empty): kwargs["lm_weight"] = FloatValue( value=draw(st.floats(allow_nan=False, allow_infinity=False)) ) kwargs["separator_index"] = UInt32Value( value=draw( st.integers(0, end).filter(lambda v: v != kwargs["blank_index"]) ) ) kwargs["word_weight"] = draw( st.floats(allow_nan=False, allow_infinity=False) ) # initialise and return all_fields_set(ctc_beam_decoder_pb2.CTCBeamDecoder, kwargs) ctc_beam_decoder = ctc_beam_decoder_pb2.CTCBeamDecoder(**kwargs) if not return_kwargs: return ctc_beam_decoder return ctc_beam_decoder, kwargs
def train_configs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[train_config_pb2.TrainConfig], st.SearchStrategy[Tuple[train_config_pb2.TrainConfig, Dict]], ]: """Returns a SearchStrategy for a TrainConfig + maybe the kwargs.""" kwargs: Dict = {} kwargs["batch_size"] = draw(st.integers(min_value=1, max_value=128)) kwargs["epochs"] = draw(st.integers(min_value=1, max_value=128)) # optimizer descript = train_config_pb2.TrainConfig.DESCRIPTOR optim_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["supported_optimizers"].fields ])) if optim_str == "sgd": kwargs[optim_str] = draw(sgds()) elif optim_str == "adam": kwargs[optim_str] = draw(adams()) else: raise ValueError(f"unknown optim type {optim_str}") kwargs["dataset"] = draw(datasets()) # shuffle shuffle_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["supported_shuffles"].fields ])) if shuffle_str == "shuffle_batches_before_every_epoch": kwargs[shuffle_str] = draw(st.booleans()) else: raise ValueError(f"unknown shuffle type {shuffle_str}") # initialise and return all_fields_set(train_config_pb2.TrainConfig, kwargs) train_config = train_config_pb2.TrainConfig(**kwargs) if not return_kwargs: return train_config return train_config, kwargs
def conv1ds( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[conv_layer_pb2.Conv1d], st.SearchStrategy[Tuple[ conv_layer_pb2.Conv1d, Dict]], ]: """Returns a SearchStrategy for a Conv1d layer plus maybe the kwargs.""" kwargs = {} kwargs["output_channels"] = draw(st.integers(1, 32)) kwargs["kernel_time"] = draw(st.integers(1, 7)) kwargs["stride_time"] = draw(st.integers(1, 7)) kwargs["padding_mode"] = draw(padding_modes()) kwargs["bias"] = draw(st.booleans()) all_fields_set(conv_layer_pb2.Conv1d, kwargs) conv1d = conv_layer_pb2.Conv1d(**kwargs) if not return_kwargs: return conv1d return conv1d, kwargs
def _mfccs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[pre_process_step_pb2.MFCC], st.SearchStrategy[Tuple[pre_process_step_pb2.MFCC, Dict]], ]: """Returns a SearchStrategy for MFCCs plus maybe the kwargs.""" kwargs: Dict = {} kwargs["n_mfcc"] = draw(st.integers(1, 128)) kwargs["win_length"] = draw(st.integers(100, 400)) kwargs["hop_length"] = draw(st.integers(50, kwargs["win_length"])) kwargs["legacy"] = draw(st.booleans()) # initialise and return all_fields_set(pre_process_step_pb2.MFCC, kwargs) mfcc = pre_process_step_pb2.MFCC(**kwargs) # type: ignore if not return_kwargs: return mfcc return mfcc, kwargs
def _spec_augments( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[pre_process_step_pb2.SpecAugment], st.SearchStrategy[Tuple[pre_process_step_pb2.SpecAugment, Dict]], ]: """Returns a SearchStrategy for SpecAugments plus maybe the kwargs.""" kwargs: Dict = {} kwargs["feature_mask"] = draw(st.integers(0, 80)) kwargs["time_mask"] = draw(st.integers(0, 100)) kwargs["n_feature_masks"] = draw(st.integers(0, 3)) kwargs["n_time_masks"] = draw(st.integers(0, 3)) # initialise and return all_fields_set(pre_process_step_pb2.SpecAugment, kwargs) spec_augment = pre_process_step_pb2.SpecAugment(**kwargs) # type: ignore if not return_kwargs: return spec_augment return spec_augment, kwargs
def step_lrs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[lr_scheduler_pb2.StepLR], st.SearchStrategy[Tuple[ lr_scheduler_pb2.StepLR, Dict]], ]: """Returns a SearchStrategy for an StepLR plus maybe the kwargs.""" kwargs: Dict = {} kwargs["step_size"] = draw(st.integers(1, 30)) kwargs["gamma"] = FloatValue( value=draw(st.floats(min_value=0.1, max_value=1.0, allow_nan=False))) # initialise and return all_fields_set(lr_scheduler_pb2.StepLR, kwargs) step_lr = lr_scheduler_pb2.StepLR(**kwargs) if not return_kwargs: return step_lr return step_lr, kwargs
def deep_speech_2s( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[deep_speech_2_pb2.DeepSpeech2], st.SearchStrategy[Tuple[deep_speech_2_pb2.DeepSpeech2, Dict]], ]: """Returns a SearchStrategy for DeepSpeech2 plus maybe the kwargs.""" kwargs: Dict = {} kwargs["conv_block"] = draw(_conv_blocks()) kwargs["rnn"] = draw(rnns()) kwargs["lookahead_block"] = draw(_lookahead_blocks()) kwargs["fully_connected"] = draw(fully_connecteds(valid_only=True)) # initialise and return all_fields_set(deep_speech_2_pb2.DeepSpeech2, kwargs) ds2 = deep_speech_2_pb2.DeepSpeech2(**kwargs) # type: ignore if not return_kwargs: return ds2 return ds2, kwargs
def cosine_annealing_lrs( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[lr_scheduler_pb2.CosineAnnealingLR], st. SearchStrategy[Tuple[lr_scheduler_pb2.CosineAnnealingLR, Dict]], ]: """Returns a SearchStrategy for an CosineAnnealingLR plus maybe the kwargs. """ kwargs: Dict = {} kwargs["t_max"] = draw(st.integers(1, 30)) kwargs["eta_min"] = FloatValue( value=draw(st.floats(min_value=1e-9, max_value=1e-3, allow_nan=False))) # initialise and return all_fields_set(lr_scheduler_pb2.CosineAnnealingLR, kwargs) cosine_annealing_lr = lr_scheduler_pb2.CosineAnnealingLR(**kwargs) if not return_kwargs: return cosine_annealing_lr return cosine_annealing_lr, kwargs
def ctc_losses( draw, return_kwargs: bool = False, alphabet_len: Optional[int] = None ) -> Union[st.SearchStrategy[ctc_loss_pb2.CTCLoss], st.SearchStrategy[Tuple[ ctc_loss_pb2.CTCLoss, Dict]], ]: """Returns a SearchStrategy for CTCLoss plus maybe the kwargs.""" kwargs = {} end = 1000 if alphabet_len is not None: end = max(0, alphabet_len - 1) kwargs["blank_index"] = draw(st.integers(0, end)) kwargs["reduction"] = draw( st.sampled_from(ctc_loss_pb2.CTCLoss.REDUCTION.values())) all_fields_set(ctc_loss_pb2.CTCLoss, kwargs) ctc_loss = ctc_loss_pb2.CTCLoss(**kwargs) if not return_kwargs: return ctc_loss return ctc_loss, kwargs
def deep_speech_1s( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[deep_speech_1_pb2.DeepSpeech1], st.SearchStrategy[Tuple[deep_speech_1_pb2.DeepSpeech1, Dict]], ]: """Returns a SearchStrategy for DeepSpeech1 plus maybe the kwargs.""" kwargs: Dict = {} # draw zero or more kwargs["n_hidden"] = draw(st.integers(1, 128)) kwargs["drop_prob"] = draw(st.floats(0.0, 1.0, allow_nan=False)) kwargs["relu_clip"] = draw(st.floats(1.0, 20.0)) kwargs["forget_gate_bias"] = draw(st.floats(0.0, 1.0)) kwargs["hard_lstm"] = draw(st.booleans()) # initialise and return all_fields_set(deep_speech_1_pb2.DeepSpeech1, kwargs) ds1 = deep_speech_1_pb2.DeepSpeech1(**kwargs) # type: ignore if not return_kwargs: return ds1 return ds1, kwargs
def language_models( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[Union[empty_pb2.Empty, Callable[ [List[int]], float]]], st.SearchStrategy[Tuple[Union[ empty_pb2.Empty, Callable[[List[int]], float]], Dict]], ]: """Returns a SearchStrategy for language models plus maybe the kwargs.""" kwargs: Dict = {} # initialise oneof supported_lms supported_lm = draw(_supported_lms()) if isinstance(supported_lm, empty_pb2.Empty): kwargs["no_lm"] = supported_lm else: raise ValueError(f"unknown lm type {type(supported_lm)}") # initialise language model and return all_fields_set(language_model_pb2.LanguageModel, kwargs) lm = language_model_pb2.LanguageModel(**kwargs) if not return_kwargs: return lm return lm, kwargs
def speech_to_texts( draw, return_kwargs: bool = False ) -> Union[st.SearchStrategy[speech_to_text_pb2.SpeechToText], st.SearchStrategy[Tuple[speech_to_text_pb2.SpeechToText, Dict]], ]: """Returns a SearchStrategy for a SpeechToText model + maybe the kwargs.""" kwargs: Dict = {} kwargs["alphabet"] = "".join(draw(random_alphabet(min_size=2)).symbols) descript = speech_to_text_pb2.SpeechToText.DESCRIPTOR # preprocess step kwargs["pre_process_step"] = [] if draw(st.booleans()): kwargs["pre_process_step"].append(draw(pre_process_steps())) # record input_features and input_channels to ensure built model is valid _, input_features, input_channels = _build_pre_process_steps( kwargs["pre_process_step"]) # model model_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["supported_models"].fields ])) if model_str == "deep_speech_1": kwargs[model_str] = draw(deep_speech_1s()) elif model_str == "deep_speech_2": kwargs[model_str] = draw(deep_speech_2s()) warnings.warn( "TODO: fix hack that assumes input_features > 200 for deep_speech_2" ) assume(input_features > 200) else: raise ValueError(f"unknown model type {model_str}") # record CTC blank index to share between CTC components ctc_blank_index: Optional[int] = None # loss loss_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["supported_losses"].fields ])) if loss_str == "ctc_loss": kwargs["ctc_loss"] = draw( ctc_losses(alphabet_len=len(kwargs["alphabet"]))) ctc_blank_index = kwargs["ctc_loss"].blank_index else: raise ValueError(f"unknown loss type {loss_str}") # post process post_str = draw( st.sampled_from([ f.name for f in descript.oneofs_by_name["supported_post_processes"].fields ])) if post_str == "ctc_greedy_decoder": if ctc_blank_index is None: ctc_blank_index = draw( st.integers(0, max(0, len(kwargs["alphabet"]) - 1))) kwargs["ctc_greedy_decoder"] = ctc_greedy_decoder_pb2.CTCGreedyDecoder( blank_index=ctc_blank_index) elif post_str == "ctc_beam_decoder": beam_kwargs = {"alphabet_len": len(kwargs["alphabet"])} if ctc_blank_index is not None: beam_kwargs["blank_index"] = ctc_blank_index kwargs["ctc_beam_decoder"] = draw(ctc_beam_decoders(**beam_kwargs)) else: raise ValueError(f"unknown post_process type {post_str}") # initialise and return all_fields_set(speech_to_text_pb2.SpeechToText, kwargs) speech_to_text = speech_to_text_pb2.SpeechToText( # type: ignore **kwargs) if not return_kwargs: return speech_to_text return speech_to_text, kwargs