def __init__( self, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40, input_tensor=True ): try: from ctc_decoders import Scorer from ctc_decoders import ctc_beam_search_decoder_batch except ModuleNotFoundError: raise ModuleNotFoundError( "BeamSearchDecoderWithLM requires the " "installation of ctc_decoders " "from nemo/scripts/install_decoders.py" ) super().__init__() # Override the default placement from neural factory and set placement/device to be CPU. self._placement = DeviceType.CPU self._device = get_cuda_device(self._placement) if self._factory.world_size > 1: raise ValueError("BeamSearchDecoderWithLM does not run in distributed mode") self.scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab) self.beam_search_func = ctc_beam_search_decoder_batch self.vocab = vocab self.beam_width = beam_width self.num_cpus = num_cpus self.cutoff_prob = cutoff_prob self.cutoff_top_n = cutoff_top_n self.input_tensor = input_tensor
def __init__(self, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40, input_tensor=False): try: from ctc_decoders import Scorer, ctc_beam_search_decoder_batch except ModuleNotFoundError: raise ModuleNotFoundError("BeamSearchDecoderWithLM requires the " "installation of ctc_decoders " "from scripts/install_ctc_decoders.sh") super().__init__() if lm_path is not None: self.scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab) else: self.scorer = None self.beam_search_func = ctc_beam_search_decoder_batch self.vocab = vocab self.beam_width = beam_width self.num_cpus = num_cpus self.cutoff_prob = cutoff_prob self.cutoff_top_n = cutoff_top_n self.input_tensor = input_tensor
def _decode_with_lm( filepath_to_logprobs, kenlm_model_path, vocab, beam_width, beam_search_alpha, beam_search_beta ): filepaths = list(filepath_to_logprobs.keys()) asr_logprobs = [filepath_to_logprobs[p] for p in filepaths] from ctc_decoders import Scorer, ctc_beam_search_decoder_batch scorer = Scorer( beam_search_alpha, beam_search_beta, model_path=kenlm_model_path, vocabulary=vocab ) asr_probs = [softmax(logits) for logits in asr_logprobs] transcriptions = ctc_beam_search_decoder_batch( probs_split=asr_probs, vocabulary=vocab, beam_size=beam_width, ext_scoring_func=scorer, num_processes=max(os.cpu_count(), 1) ) transcriptions = [t[0][1] for t in transcriptions] filepath_to_transc = {} for filepath, transc in zip(filepaths, transcriptions): filepath_to_transc[filepath] = transc return filepath_to_transc
def __init__(self, *, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40, **kwargs): try: from ctc_decoders import Scorer from ctc_decoders import ctc_beam_search_decoder_batch except ModuleNotFoundError: raise ModuleNotFoundError("BeamSearchDecoderWithLM requires the " "installation of ctc_decoders " "from nemo/scripts/install_decoders.py") self.scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab) self.beam_search_func = ctc_beam_search_decoder_batch super().__init__( # Override default placement from neural factory placement=DeviceType.CPU, **kwargs) self.vocab = vocab self.beam_width = beam_width self.num_cpus = num_cpus self.cutoff_prob = cutoff_prob self.cutoff_top_n = cutoff_top_n
def language_model(model: str = 'malaya-speech', alpha: float = 2.5, beta: float = 0.3, **kwargs): """ Load KenLM language model. Parameters ---------- model : str, optional (default='malaya-speech') Model architecture supported. Allowed values: * ``'malaya-speech'`` - Gathered from malaya-speech ASR transcript. * ``'malaya-speech-wikipedia'`` - Gathered from malaya-speech ASR transcript + Wikipedia (Random sample 300k sentences). * ``'local'`` - Gathered from IIUM Confession. * ``'wikipedia'`` - Gathered from malay Wikipedia. alpha: float, optional (default=2.5) score = alpha * np.log(lm) + beta * np.log(word_cnt), increase will put more bias on lm score computed by kenlm. beta: float, optional (beta=0.3) score = alpha * np.log(lm) + beta * np.log(word_cnt), increase will put more bias on word count. Returns ------- result : Tuple[ctc_decoders.Scorer, List[str]] Tuple of ctc_decoders.Scorer and vocab. """ try: from ctc_decoders import Scorer except: raise ModuleNotFoundError( 'ctc_decoders not installed. Please install it by `pip install ctc-decoders` and try again.' ) from malaya_speech.utils import check_file check_file(PATH_LM[model], S3_PATH_LM[model], **kwargs) with open(PATH_LM[model]['vocab']) as fopen: vocab_list = json.load(fopen) + ['{', '}', '['] scorer = Scorer(alpha, beta, PATH_LM[model]['model'], vocab_list) return scorer
def __init__(self, config, vocab, resource_path=None): """ Create a new CTCBeamSearchDecoder. See CTCDecoder.from_config() to automatically create the correct type of instance dependening on config. """ super().__init__(config, vocab) self.config.setdefault('word_threshold', -1000.0) self.reset() self.scorer = None #self.num_cores = max(os.cpu_count(), 1) # set default config # https://github.com/NVIDIA/NeMo/blob/855ce265b80c0dc40f4f06ece76d2c9d6ca1be8d/nemo/collections/asr/modules/beam_search_decoder.py#L21 self.config.setdefault('language_model', None) self.config.setdefault('beam_width', 32) #128) self.config.setdefault('alpha', 0.7 if self.language_model else 0.0) self.config.setdefault('beta', 0.0) self.config.setdefault('cutoff_prob', 1.0) self.config.setdefault('cutoff_top_n', 40) self.config.setdefault('top_k', 3) # check for language model file if self.language_model: if not os.path.isfile(self.language_model): self.config['language_model'] = os.path.join( resource_path, self.language_model) if not os.path.isfile(self.language_model): raise IOError( f"language model file '{self.language_model}' does not exist" ) logging.info('creating CTCBeamSearchDecoder') logging.info(str(self.config)) # create scorer if self.language_model: self.scorer = Scorer(self.config['alpha'], self.config['beta'], model_path=self.language_model, vocabulary=self.vocab)
def test_decoders(self): ''' Test all CTC decoders on a sample transcript ('ten seconds'). Standard TF decoders should output 'then seconds'. Custom CTC decoder with LM rescoring should yield 'ten seconds'. ''' logits = tf.constant(self.seq) seq_len = tf.constant([self.seq.shape[0]]) greedy_decoded = tf.nn.ctc_greedy_decoder(logits, seq_len, merge_repeated=True) beam_search_decoded = tf.nn.ctc_beam_search_decoder(logits, seq_len, beam_width=self.beam_width, top_paths=1, merge_repeated=False) with tf.Session() as sess: res_greedy, res_beam = sess.run([greedy_decoded, beam_search_decoded]) decoded_greedy, prob_greedy = res_greedy decoded_text = ''.join([self.vocab[c] for c in decoded_greedy[0].values]) self.assertTrue( abs(7079.117 + prob_greedy[0][0]) < self.tol ) self.assertTrue( decoded_text == 'then seconds' ) decoded_beam, prob_beam = res_beam decoded_text = ''.join([self.vocab[c] for c in decoded_beam[0].values]) if tf.__version__ >= '1.11': # works for newer versions only (with CTC decoder fix) self.assertTrue( abs(1.1842 + prob_beam[0][0]) < self.tol ) self.assertTrue( decoded_text == 'then seconds' ) scorer = Scorer(alpha=2.0, beta=0.5, model_path='ctc_decoder_with_lm/ctc-test-lm.binary', vocabulary=self.vocab[:-1]) res = ctc_beam_search_decoder(softmax(self.seq.squeeze()), self.vocab[:-1], beam_size=self.beam_width, ext_scoring_func=scorer) res_prob, decoded_text = res[0] self.assertTrue( abs(4.0845 + res_prob) < self.tol ) self.assertTrue( decoded_text == self.label )
def __init__(self, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40): if lm_path is not None: self.scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab) else: self.scorer = None self.vocab = vocab self.beam_width = beam_width self.num_cpus = num_cpus self.cutoff_prob = cutoff_prob self.cutoff_top_n = cutoff_top_n
def __init__(self, index_to_token, num_classes, beam_width=1024, lm_path=None, alpha=None, beta=None, vocab_array=None): super().__init__(index_to_token, num_classes, vocab_array) self.beam_width = beam_width self.lm_path = lm_path self.alpha = alpha self.beta = beta if self.lm_path: assert self.alpha and self.beta and self.vocab_array, \ "alpha, beta and vocab_array must be specified" self.scorer = Scorer(self.alpha, self.beta, model_path=self.lm_path, vocabulary=self.vocab_array) else: self.scorer = None
probs_batch = [] for line in labels: audio_filename = line[0] probs_batch.append(logits[audio_filename]) batch_prob_end = time.time() print("Batch logit loading took %s seconds" % (batch_prob_end - data_load_end)) if args.mode == 'eval': eval_start = time.time() wer, _ = evaluate_wer(logits, labels, vocab, greedy_decoder) print('Greedy WER = {:.4f}'.format(wer)) best_result = {'wer': 1e6, 'alpha': 0.0, 'beta': 0.0, 'beams': None} for alpha in np.arange(args.alpha, args.alpha_max, args.alpha_step): for beta in np.arange(args.beta, args.beta_max, args.beta_step): scorer = Scorer(alpha, beta, model_path=args.lm, vocabulary=vocab[:-1]) print("scorer complete") probs_batch_list = list(divide_chunks(probs_batch, 500)) res = [] for probs_batch in probs_batch_list: f = time.time() result = ctc_beam_search_decoder_batch( probs_batch, vocab[:-1], beam_size=args.beam_width, num_processes=num_cpus, ext_scoring_func=scorer) e = time.time() for j in result: res.append(j)
from tiramisu_asr.utils.utils import bytes_to_string, merge_two_last_dims decoder_config = { "vocabulary": "/mnt/Projects/asrk16/TiramisuASR/examples/deepspeech2/vocabularies/vietnamese.txt", "beam_width": 100, "blank_at_zero": True, "lm_config": { "model_path": "/mnt/Data/ML/NLP/vntc_asrtrain_5gram_trie.binary", "alpha": 2.0, "beta": 2.0 } } text_featurizer = TextFeaturizer(decoder_config) text_featurizer.add_scorer( Scorer(**decoder_config["lm_config"], vocabulary=text_featurizer.vocab_array)) speech_featurizer = TFSpeechFeaturizer({ "sample_rate": 16000, "frame_ms": 25, "stride_ms": 10, "num_feature_bins": 80, "feature_type": "spectrogram", "preemphasis": 0.97, # "delta": True, # "delta_delta": True, "normalize_signal": True, "normalize_feature": True, "normalize_per_feature": False, # "pitch": False, })
def main(): parser = argparse.ArgumentParser(prog="SelfAttentionDS2 Histogram") parser.add_argument("--config", type=str, default=None, help="Config file") parser.add_argument("--audio", type=str, default=None, help="Audio file") parser.add_argument("--saved_model", type=str, default=None, help="Saved model") parser.add_argument("--from_weights", type=bool, default=False, help="Load from weights") parser.add_argument("--output", type=str, default=None, help="Output dir storing histograms") args = parser.parse_args() config = UserConfig(args.config, args.config, learning=False) speech_featurizer = SpeechFeaturizer(config["speech_config"]) text_featurizer = CharFeaturizer(config["decoder_config"]) text_featurizer.add_scorer(Scorer(**text_featurizer.decoder_config["lm_config"], vocabulary=text_featurizer.vocab_array)) f, c = speech_featurizer.compute_feature_dim() satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c], arch_config=config["model_config"], num_classes=text_featurizer.num_classes) satt_ds2_model._build([1, 50, f, c]) if args.from_weights: satt_ds2_model.load_weights(args.saved_model) else: saved_model = tf.keras.models.load_model(args.saved_model) satt_ds2_model.set_weights(saved_model.get_weights()) satt_ds2_model.summary(line_length=100) satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer) signal = read_raw_audio(args.audio, speech_featurizer.sample_rate) features = speech_featurizer.extract(signal) decoded = satt_ds2_model.recognize_beam(tf.expand_dims(features, 0), lm=True) print(bytes_to_string(decoded.numpy())) for i in range(1, len(satt_ds2_model.base_model.layers)): func = tf.keras.backend.function([satt_ds2_model.base_model.input], [satt_ds2_model.base_model.layers[i].output]) data = func([np.expand_dims(features, 0), 1])[0][0] print(data.shape) data = data.flatten() plt.hist(data, 200, color='green', histtype="stepfilled") plt.title(f"Output of {satt_ds2_model.base_model.layers[i].name}", fontweight="bold") plt.savefig(os.path.join( args.output, f"{i}_{satt_ds2_model.base_model.layers[i].name}.png")) plt.clf() plt.cla() plt.close() fc = satt_ds2_model(tf.expand_dims(features, 0), training=False) plt.hist(fc[0].numpy().flatten(), 200, color="green", histtype="stepfilled") plt.title(f"Output of {satt_ds2_model.layers[-1].name}", fontweight="bold") plt.savefig(os.path.join(args.output, f"{satt_ds2_model.layers[-1].name}.png")) plt.clf() plt.cla() plt.close() fc = tf.nn.softmax(fc) plt.hist(fc[0].numpy().flatten(), 10, color="green", histtype="stepfilled") plt.title("Output of softmax", fontweight="bold") plt.savefig(os.path.join(args.output, "softmax_hist.png")) plt.clf() plt.cla() plt.close() plt.hist(features.flatten(), 200, color="green", histtype="stepfilled") plt.title("Log Mel Spectrogram", fontweight="bold") plt.savefig(os.path.join(args.output, "log_mel_spectrogram.png")) plt.clf() plt.cla() plt.close()
def run(ref_corpus_path, logits_path, config_path, output_path, lm_path, num_workers, beam_width, alpha_start, alpha_end, alpha_step, beta_start, beta_end, beta_step): print('Load refs') refs = [] lengths = [] with open(ref_corpus_path, 'r') as f: for x in json.load(f): refs.append((x['utt_idx'], x['transcript'])) lengths.append(x['files'][0]['num_samples']) logits_raw = get_logits(logits_path) print('N Logits: {}'.format(len(logits_raw))) print('Shape Logits 0: {}'.format(logits_raw[0].shape)) logits = [] for i, l in enumerate(logits_raw): num_samples = lengths[i] num_frames = int(num_samples / 320) + 1 logits.append(l[:num_frames]) print('Shape Logits 0 (after trim): {}'.format(logits[0].shape)) logits = [softmax(l) for l in logits] vocab = get_vocab(config_path) print('N Vocab: {}'.format(len(vocab))) refs_dict = {x[0]: x[1] for x in refs} print(len(refs)) for alpha in np.arange(alpha_start, alpha_end, alpha_step): for beta in np.arange(beta_start, beta_end, beta_step): print('alpha: {}, beta: {}'.format(alpha, beta)) target_folder = os.path.join(output_path, 'lm_{}_{}'.format(alpha, beta)) os.makedirs(target_folder, exist_ok=True) # decoder = ctcdecode.BestPathDecoder(vocab) # scorer = ctcdecode.WordKenLMScorer(lm_path, alpha, beta) # decoder = ctcdecode.BeamSearchDecoder( # vocab, # num_workers=num_workers, # beam_width=beam_width, # scorers=[scorer], # cutoff_prob=np.log(0.000001), # cutoff_top_n=40 # ) # result = decoder.decode_batch(logits) start = time.time() scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab[:-1]) print('Scorer loaded') res = ctc_beam_search_decoder_batch(logits, vocab[:-1], beam_size=beam_width, num_processes=num_workers, ext_scoring_func=scorer) result = [[v for v in zip(*x)][1][0] for x in res] print('Took {}'.format(time.time() - start)) print(len(result)) predictions = {} for i, pred in enumerate(result): predictions[refs[i][0]] = pred pred_path = os.path.join(target_folder, 'predictions.txt') with open(pred_path, 'w') as f: outs = ['{} {}'.format(k, v) for k, v in predictions.items()] f.write('\n'.join(outs)) report = evaluate(refs_dict, predictions) report_path = os.path.join(target_folder, 'result.txt') report.write_report(report_path, template='asr_detail')
def __init__(self, model_params=MODEL_PARAMS, scope_name='S2T', sr=16000, frame_len=0.2, frame_overlap=2.4, timestep_duration=0.02, ext_model_infer_func=None, merge=True, beam_width=1, language_model=None, alpha=2.8, beta=1.0): ''' Args: model_params: list of OpenSeq2Seq arguments (same as for run.py) scope_name: model's scope name sr: sample rate, Hz frame_len: frame's duration, seconds frame_overlap: duration of overlaps before and after current frame, seconds frame_overlap should be multiple of frame_len timestep_duration: time per step at model's output, seconds ext_model_infer_func: callback for external inference engine, if it is not None, then we don't build TF inference graph merge: whether to do merge in greedy decoder beam_width: beam width for beam search decoder if larger than 1 language_model: path to LM (to use with beam search decoder) alpha: LM weight (trade-off between acoustic and LM scores) beta: word weight (added per every transcribed word in prediction) ''' if ext_model_infer_func is None: # Build TF inference graph self.model_S2T, checkpoint_S2T = self._get_model(model_params, scope_name) # Create the session and load the checkpoints sess_config = tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.allow_growth = True self.sess = tf.InteractiveSession(config=sess_config) vars_S2T = {} for v in tf.get_collection(tf.GraphKeys.VARIABLES): if scope_name in v.name: vars_S2T['/'.join(v.op.name.split('/')[1:])] = v saver_S2T = tf.train.Saver(vars_S2T) saver_S2T.restore(self.sess, checkpoint_S2T) self.params = self.model_S2T.params else: # No TF, load pre-, post-processing parameters from config, # use external inference engine _, base_config, _, _ = get_base_config(model_params) self.params = base_config self.ext_model_infer_func = ext_model_infer_func self.vocab = self._load_vocab( self.model_S2T.params['data_layer_params']['vocab_file'] ) self.sr = sr self.frame_len = frame_len self.n_frame_len = int(frame_len * sr) self.frame_overlap = frame_overlap self.n_frame_overlap = int(frame_overlap * sr) if self.n_frame_overlap % self.n_frame_len: raise ValueError( "'frame_overlap' should be multiple of 'frame_len'" ) self.n_timesteps_overlap = int(frame_overlap / timestep_duration) - 2 self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len, dtype=np.float32) self.merge = merge self._beam_decoder = None # greedy decoder's state (unmerged transcription) self.text = '' # forerunner greedy decoder's state (unmerged transcription) self.forerunner_text = '' self.offset = 5 # self._calibrate_offset() if beam_width > 1: if language_model is None: self._beam_decoder = BeamDecoder(self.vocab, beam_width) else: self._scorer = Scorer(alpha, beta, language_model, self.vocab) self._beam_decoder = BeamDecoder(self.vocab, beam_width, ext_scorer=self._scorer) self.reset()
def main(): parser = argparse.ArgumentParser(prog="SelfAttentionDS2 Histogram") parser.add_argument("--config", type=str, default=None, help="Config file") parser.add_argument("--audio", type=str, default=None, help="Audio file") parser.add_argument("--saved_model", type=str, default=None, help="Saved model") parser.add_argument("--from_weights", type=bool, default=False, help="Load from weights") parser.add_argument("--output", type=str, default=None, help="Output dir storing histograms") args = parser.parse_args() config = UserConfig(args.config, args.config, learning=False) speech_featurizer = SpeechFeaturizer(config["speech_config"]) text_featurizer = CharFeaturizer(config["decoder_config"]) text_featurizer.add_scorer( Scorer(**text_featurizer.decoder_config["lm_config"], vocabulary=text_featurizer.vocab_array)) f, c = speech_featurizer.compute_feature_dim() satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c], arch_config=config["model_config"], num_classes=text_featurizer.num_classes) satt_ds2_model._build([1, 50, f, c]) if args.from_weights: satt_ds2_model.load_weights(args.saved_model) else: saved_model = tf.keras.models.load_model(args.saved_model) satt_ds2_model.set_weights(saved_model.get_weights()) satt_ds2_model.summary(line_length=100) satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer) signal = read_raw_audio(args.audio, speech_featurizer.sample_rate) features = speech_featurizer.extract(signal) decoded = satt_ds2_model.recognize_beam(tf.expand_dims(features, 0), lm=True) print(bytes_to_string(decoded.numpy())) # for i in range(1, len(satt_ds2_model.base_model.layers)): # func = tf.keras.backend.function([satt_ds2_model.base_model.input], # [satt_ds2_model.base_model.layers[i].output]) # data = func([np.expand_dims(features, 0), 1])[0][0] # print(data.shape) # plt.figure(figsize=(16, 5)) # ax = plt.gca() # im = ax.imshow(data.T, origin="lower", aspect="auto") # ax.set_title(f"{satt_ds2_model.base_model.layers[i].name}", fontweight="bold") # divider = make_axes_locatable(ax) # cax = divider.append_axes("right", size="5%", pad=0.05) # plt.colorbar(im, cax=cax) # plt.savefig(os.path.join( # args.output, f"{i}_{satt_ds2_model.base_model.layers[i].name}.png")) # plt.clf() # plt.cla() # plt.close() fc = satt_ds2_model(tf.expand_dims(features, 0), training=False) plt.figure(figsize=(16, 5)) ax = plt.gca() ax.set_title(f"{satt_ds2_model.layers[-1].name}", fontweight="bold") im = ax.imshow(fc[0].numpy().T, origin="lower", aspect="auto") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) plt.savefig( os.path.join(args.output, f"{satt_ds2_model.layers[-1].name}.png")) plt.clf() plt.cla() plt.close() fc = tf.nn.softmax(fc) plt.figure(figsize=(16, 5)) ax = plt.gca() ax.set_title("Softmax", fontweight="bold") im = ax.imshow(fc[0].numpy().T, origin="lower", aspect="auto") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) plt.savefig(os.path.join(args.output, "softmax.png")) plt.clf() plt.cla() plt.close() plt.figure(figsize=(16, 5)) ax = plt.gca() ax.set_title("Log Mel Spectrogram", fontweight="bold") im = ax.imshow(features[:, :, 0].T, origin="lower", aspect="auto") divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) plt.savefig(os.path.join(args.output, "features.png")) plt.clf() plt.cla() plt.close()
def run(args): assert args.mode in modes, f"Mode must in {modes}" config = UserConfig(DEFAULT_YAML, args.config, learning=True) speech_featurizer = SpeechFeaturizer(config["speech_config"]) text_featurizer = TextFeaturizer(config["decoder_config"]) if args.mode == "train": tf.random.set_seed(2020) if args.mixed_precision: policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16") tf.keras.mixed_precision.experimental.set_policy(policy) print("Enabled mixed precision training") ctc_trainer = CTCTrainer(speech_featurizer, text_featurizer, config["learning_config"]["running_config"], args.mixed_precision) if args.tfrecords: train_dataset = ASRTFRecordDataset( config["learning_config"]["dataset_config"]["train_paths"], config["learning_config"]["dataset_config"]["tfrecords_dir"], speech_featurizer, text_featurizer, "train", augmentations=config["learning_config"]["augmentations"], shuffle=True, ) eval_dataset = ASRTFRecordDataset( config["learning_config"]["dataset_config"]["eval_paths"], config["learning_config"]["dataset_config"]["tfrecords_dir"], speech_featurizer, text_featurizer, "eval", shuffle=False ) else: train_dataset = ASRSliceDataset( stage="train", speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=config["learning_config"]["dataset_config"]["train_paths"], augmentations=config["learning_config"]["augmentations"], shuffle=True, ) eval_dataset = ASRSliceDataset( stage="eval", speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=config["learning_config"]["dataset_config"]["eval_paths"], shuffle=False ) # Build DS2 model f, c = speech_featurizer.compute_feature_dim() with ctc_trainer.strategy.scope(): satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c], arch_config=config["model_config"], num_classes=text_featurizer.num_classes) satt_ds2_model._build([1, 50, f, c]) optimizer = create_optimizer( name=config["learning_config"]["optimizer_config"]["name"], d_model=config["model_config"]["att"]["head_size"], **config["learning_config"]["optimizer_config"]["config"] ) # Compile ctc_trainer.compile(satt_ds2_model, optimizer, max_to_keep=args.max_ckpts) ctc_trainer.fit(train_dataset, eval_dataset, args.eval_train_ratio) if args.export: if args.from_weights: ctc_trainer.model.save_weights(args.export) else: ctc_trainer.model.save(args.export) elif args.mode == "test": tf.random.set_seed(0) assert args.export text_featurizer.add_scorer( Scorer(**text_featurizer.decoder_config["lm_config"], vocabulary=text_featurizer.vocab_array)) # Build DS2 model f, c = speech_featurizer.compute_feature_dim() satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c], arch_config=config["model_config"], num_classes=text_featurizer.num_classes) satt_ds2_model._build([1, 50, f, c]) satt_ds2_model.summary(line_length=100) optimizer = create_optimizer( name=config["learning_config"]["optimizer_config"]["name"], d_model=config["model_config"]["att"]["head_size"], **config["learning_config"]["optimizer_config"]["config"] ) batch_size = config["learning_config"]["running_config"]["batch_size"] if args.tfrecords: test_dataset = ASRTFRecordDataset( config["learning_config"]["dataset_config"]["test_paths"], config["learning_config"]["dataset_config"]["tfrecords_dir"], speech_featurizer, text_featurizer, "test", augmentations=config["learning_config"]["augmentations"], shuffle=False ).create(batch_size * args.eval_train_ratio) else: test_dataset = ASRSliceDataset( stage="test", speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, data_paths=config["learning_config"]["dataset_config"]["test_paths"], augmentations=config["learning_config"]["augmentations"], shuffle=False ).create(batch_size * args.eval_train_ratio) ctc_tester = BaseTester( config=config["learning_config"]["running_config"], saved_path=args.export, from_weights=args.from_weights ) ctc_tester.compile(satt_ds2_model, speech_featurizer, text_featurizer) ctc_tester.run(test_dataset) else: assert args.export # Build DS2 model f, c = speech_featurizer.compute_feature_dim() satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c], arch_config=config["model_config"], num_classes=text_featurizer.num_classes) satt_ds2_model._build([1, 50, f, c]) optimizer = create_optimizer( name=config["learning_config"]["optimizer_config"]["name"], d_model=config["model_config"]["att"]["head_size"], **config["learning_config"]["optimizer_config"]["config"] ) def save_func(**kwargs): if args.from_weights: kwargs["model"].save_weights(args.export) else: kwargs["model"].save(args.export) save_from_checkpoint(func=save_func, outdir=config["learning_config"]["running_config"]["outdir"], model=satt_ds2_model, optimizer=optimizer)