def __init__(
        self, vocab, beam_width, alpha, beta, lm_path, num_cpus, cutoff_prob=1.0, cutoff_top_n=40, input_tensor=True
    ):

        try:
            from ctc_decoders import Scorer
            from ctc_decoders import ctc_beam_search_decoder_batch
        except ModuleNotFoundError:
            raise ModuleNotFoundError(
                "BeamSearchDecoderWithLM requires the "
                "installation of ctc_decoders "
                "from nemo/scripts/install_decoders.py"
            )

        super().__init__()
        # Override the default placement from neural factory and set placement/device to be CPU.
        self._placement = DeviceType.CPU
        self._device = get_cuda_device(self._placement)

        if self._factory.world_size > 1:
            raise ValueError("BeamSearchDecoderWithLM does not run in distributed mode")

        self.scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab)
        self.beam_search_func = ctc_beam_search_decoder_batch
        self.vocab = vocab
        self.beam_width = beam_width
        self.num_cpus = num_cpus
        self.cutoff_prob = cutoff_prob
        self.cutoff_top_n = cutoff_top_n
        self.input_tensor = input_tensor
Esempio n. 2
0
    def __init__(self,
                 vocab,
                 beam_width,
                 alpha,
                 beta,
                 lm_path,
                 num_cpus,
                 cutoff_prob=1.0,
                 cutoff_top_n=40,
                 input_tensor=False):

        try:
            from ctc_decoders import Scorer, ctc_beam_search_decoder_batch
        except ModuleNotFoundError:
            raise ModuleNotFoundError("BeamSearchDecoderWithLM requires the "
                                      "installation of ctc_decoders "
                                      "from scripts/install_ctc_decoders.sh")

        super().__init__()

        if lm_path is not None:
            self.scorer = Scorer(alpha,
                                 beta,
                                 model_path=lm_path,
                                 vocabulary=vocab)
        else:
            self.scorer = None
        self.beam_search_func = ctc_beam_search_decoder_batch
        self.vocab = vocab
        self.beam_width = beam_width
        self.num_cpus = num_cpus
        self.cutoff_prob = cutoff_prob
        self.cutoff_top_n = cutoff_top_n
        self.input_tensor = input_tensor
Esempio n. 3
0
def _decode_with_lm(
    filepath_to_logprobs, kenlm_model_path, vocab, beam_width, beam_search_alpha,
    beam_search_beta
    ):
    filepaths = list(filepath_to_logprobs.keys())
    asr_logprobs = [filepath_to_logprobs[p] for p in filepaths]
    from ctc_decoders import Scorer, ctc_beam_search_decoder_batch
    scorer = Scorer(
        beam_search_alpha,
        beam_search_beta,
        model_path=kenlm_model_path,
        vocabulary=vocab
    )
    asr_probs = [softmax(logits) for logits in asr_logprobs]
    transcriptions = ctc_beam_search_decoder_batch(
        probs_split=asr_probs,
        vocabulary=vocab,
        beam_size=beam_width,
        ext_scoring_func=scorer,
        num_processes=max(os.cpu_count(), 1)
    )
    transcriptions = [t[0][1] for t in transcriptions]
    filepath_to_transc = {}
    for filepath, transc in zip(filepaths, transcriptions):
        filepath_to_transc[filepath] = transc
    return filepath_to_transc
Esempio n. 4
0
    def __init__(self,
                 *,
                 vocab,
                 beam_width,
                 alpha,
                 beta,
                 lm_path,
                 num_cpus,
                 cutoff_prob=1.0,
                 cutoff_top_n=40,
                 **kwargs):

        try:
            from ctc_decoders import Scorer
            from ctc_decoders import ctc_beam_search_decoder_batch
        except ModuleNotFoundError:
            raise ModuleNotFoundError("BeamSearchDecoderWithLM requires the "
                                      "installation of ctc_decoders "
                                      "from nemo/scripts/install_decoders.py")

        self.scorer = Scorer(alpha, beta, model_path=lm_path, vocabulary=vocab)
        self.beam_search_func = ctc_beam_search_decoder_batch

        super().__init__(
            # Override default placement from neural factory
            placement=DeviceType.CPU,
            **kwargs)

        self.vocab = vocab
        self.beam_width = beam_width
        self.num_cpus = num_cpus
        self.cutoff_prob = cutoff_prob
        self.cutoff_top_n = cutoff_top_n
Esempio n. 5
0
def language_model(model: str = 'malaya-speech',
                   alpha: float = 2.5,
                   beta: float = 0.3,
                   **kwargs):
    """
    Load KenLM language model.

    Parameters
    ----------
    model : str, optional (default='malaya-speech')
        Model architecture supported. Allowed values:

        * ``'malaya-speech'`` - Gathered from malaya-speech ASR transcript.
        * ``'malaya-speech-wikipedia'`` - Gathered from malaya-speech ASR transcript + Wikipedia (Random sample 300k sentences).
        * ``'local'`` - Gathered from IIUM Confession.
        * ``'wikipedia'`` - Gathered from malay Wikipedia.
        
    alpha: float, optional (default=2.5)
        score = alpha * np.log(lm) + beta * np.log(word_cnt), 
        increase will put more bias on lm score computed by kenlm.
    beta: float, optional (beta=0.3)
        score = alpha * np.log(lm) + beta * np.log(word_cnt), 
        increase will put more bias on word count.

    Returns
    -------
    result : Tuple[ctc_decoders.Scorer, List[str]]
        Tuple of ctc_decoders.Scorer and vocab.
    """
    try:
        from ctc_decoders import Scorer
    except:
        raise ModuleNotFoundError(
            'ctc_decoders not installed. Please install it by `pip install ctc-decoders` and try again.'
        )
    from malaya_speech.utils import check_file

    check_file(PATH_LM[model], S3_PATH_LM[model], **kwargs)

    with open(PATH_LM[model]['vocab']) as fopen:
        vocab_list = json.load(fopen) + ['{', '}', '[']

    scorer = Scorer(alpha, beta, PATH_LM[model]['model'], vocab_list)
    return scorer
Esempio n. 6
0
    def __init__(self, config, vocab, resource_path=None):
        """
        Create a new CTCBeamSearchDecoder.
        
        See CTCDecoder.from_config() to automatically create
        the correct type of instance dependening on config.
        """
        super().__init__(config, vocab)
        self.config.setdefault('word_threshold', -1000.0)
        self.reset()

        self.scorer = None
        #self.num_cores = max(os.cpu_count(), 1)

        # set default config
        # https://github.com/NVIDIA/NeMo/blob/855ce265b80c0dc40f4f06ece76d2c9d6ca1be8d/nemo/collections/asr/modules/beam_search_decoder.py#L21
        self.config.setdefault('language_model', None)
        self.config.setdefault('beam_width', 32)  #128)
        self.config.setdefault('alpha', 0.7 if self.language_model else 0.0)
        self.config.setdefault('beta', 0.0)
        self.config.setdefault('cutoff_prob', 1.0)
        self.config.setdefault('cutoff_top_n', 40)
        self.config.setdefault('top_k', 3)

        # check for language model file
        if self.language_model:
            if not os.path.isfile(self.language_model):
                self.config['language_model'] = os.path.join(
                    resource_path, self.language_model)
                if not os.path.isfile(self.language_model):
                    raise IOError(
                        f"language model file '{self.language_model}' does not exist"
                    )

        logging.info('creating CTCBeamSearchDecoder')
        logging.info(str(self.config))

        # create scorer
        if self.language_model:
            self.scorer = Scorer(self.config['alpha'],
                                 self.config['beta'],
                                 model_path=self.language_model,
                                 vocabulary=self.vocab)
Esempio n. 7
0
  def test_decoders(self):
    '''
    Test all CTC decoders on a sample transcript ('ten seconds').
    Standard TF decoders should output 'then seconds'.
    Custom CTC decoder with LM rescoring should yield 'ten seconds'.
    '''
    logits = tf.constant(self.seq)
    seq_len = tf.constant([self.seq.shape[0]])

    greedy_decoded = tf.nn.ctc_greedy_decoder(logits, seq_len, 
        merge_repeated=True)

    beam_search_decoded = tf.nn.ctc_beam_search_decoder(logits, seq_len, 
        beam_width=self.beam_width, 
        top_paths=1, 
        merge_repeated=False)

    with tf.Session() as sess:
      res_greedy, res_beam = sess.run([greedy_decoded, 
          beam_search_decoded])

    decoded_greedy, prob_greedy = res_greedy
    decoded_text = ''.join([self.vocab[c] for c in decoded_greedy[0].values])
    self.assertTrue( abs(7079.117 + prob_greedy[0][0]) < self.tol )
    self.assertTrue( decoded_text == 'then seconds' )

    decoded_beam, prob_beam = res_beam
    decoded_text = ''.join([self.vocab[c] for c in decoded_beam[0].values])
    if tf.__version__ >= '1.11':
      # works for newer versions only (with CTC decoder fix)
      self.assertTrue( abs(1.1842 + prob_beam[0][0]) < self.tol )
    self.assertTrue( decoded_text == 'then seconds' )

    scorer = Scorer(alpha=2.0, beta=0.5,
        model_path='ctc_decoder_with_lm/ctc-test-lm.binary', 
        vocabulary=self.vocab[:-1])
    res = ctc_beam_search_decoder(softmax(self.seq.squeeze()), self.vocab[:-1],
                                  beam_size=self.beam_width,
                                  ext_scoring_func=scorer)
    res_prob, decoded_text = res[0]
    self.assertTrue( abs(4.0845 + res_prob) < self.tol )
    self.assertTrue( decoded_text == self.label )
Esempio n. 8
0
    def __init__(self,
                 vocab,
                 beam_width,
                 alpha,
                 beta,
                 lm_path,
                 num_cpus,
                 cutoff_prob=1.0,
                 cutoff_top_n=40):

        if lm_path is not None:
            self.scorer = Scorer(alpha,
                                 beta,
                                 model_path=lm_path,
                                 vocabulary=vocab)
        else:
            self.scorer = None
        self.vocab = vocab
        self.beam_width = beam_width
        self.num_cpus = num_cpus
        self.cutoff_prob = cutoff_prob
        self.cutoff_top_n = cutoff_top_n
Esempio n. 9
0
 def __init__(self,
              index_to_token,
              num_classes,
              beam_width=1024,
              lm_path=None,
              alpha=None,
              beta=None,
              vocab_array=None):
     super().__init__(index_to_token, num_classes, vocab_array)
     self.beam_width = beam_width
     self.lm_path = lm_path
     self.alpha = alpha
     self.beta = beta
     if self.lm_path:
         assert self.alpha and self.beta and self.vocab_array, \
           "alpha, beta and vocab_array must be specified"
         self.scorer = Scorer(self.alpha,
                              self.beta,
                              model_path=self.lm_path,
                              vocabulary=self.vocab_array)
     else:
         self.scorer = None
Esempio n. 10
0
probs_batch = []
for line in labels:
    audio_filename = line[0]
    probs_batch.append(logits[audio_filename])
batch_prob_end = time.time()
print("Batch logit loading took %s seconds" % (batch_prob_end - data_load_end))

if args.mode == 'eval':
    eval_start = time.time()
    wer, _ = evaluate_wer(logits, labels, vocab, greedy_decoder)
    print('Greedy WER = {:.4f}'.format(wer))
    best_result = {'wer': 1e6, 'alpha': 0.0, 'beta': 0.0, 'beams': None}
    for alpha in np.arange(args.alpha, args.alpha_max, args.alpha_step):
        for beta in np.arange(args.beta, args.beta_max, args.beta_step):
            scorer = Scorer(alpha,
                            beta,
                            model_path=args.lm,
                            vocabulary=vocab[:-1])
            print("scorer complete")
            probs_batch_list = list(divide_chunks(probs_batch, 500))
            res = []
            for probs_batch in probs_batch_list:
                f = time.time()
                result = ctc_beam_search_decoder_batch(
                    probs_batch,
                    vocab[:-1],
                    beam_size=args.beam_width,
                    num_processes=num_cpus,
                    ext_scoring_func=scorer)
                e = time.time()
                for j in result:
                    res.append(j)
Esempio n. 11
0
from tiramisu_asr.utils.utils import bytes_to_string, merge_two_last_dims

decoder_config = {
    "vocabulary":
    "/mnt/Projects/asrk16/TiramisuASR/examples/deepspeech2/vocabularies/vietnamese.txt",
    "beam_width": 100,
    "blank_at_zero": True,
    "lm_config": {
        "model_path": "/mnt/Data/ML/NLP/vntc_asrtrain_5gram_trie.binary",
        "alpha": 2.0,
        "beta": 2.0
    }
}
text_featurizer = TextFeaturizer(decoder_config)
text_featurizer.add_scorer(
    Scorer(**decoder_config["lm_config"],
           vocabulary=text_featurizer.vocab_array))
speech_featurizer = TFSpeechFeaturizer({
    "sample_rate": 16000,
    "frame_ms": 25,
    "stride_ms": 10,
    "num_feature_bins": 80,
    "feature_type": "spectrogram",
    "preemphasis": 0.97,
    # "delta": True,
    # "delta_delta": True,
    "normalize_signal": True,
    "normalize_feature": True,
    "normalize_per_feature": False,
    # "pitch": False,
})
def main():
    parser = argparse.ArgumentParser(prog="SelfAttentionDS2 Histogram")

    parser.add_argument("--config", type=str, default=None,
                        help="Config file")

    parser.add_argument("--audio", type=str, default=None,
                        help="Audio file")

    parser.add_argument("--saved_model", type=str, default=None,
                        help="Saved model")

    parser.add_argument("--from_weights", type=bool, default=False,
                        help="Load from weights")

    parser.add_argument("--output", type=str, default=None,
                        help="Output dir storing histograms")

    args = parser.parse_args()

    config = UserConfig(args.config, args.config, learning=False)
    speech_featurizer = SpeechFeaturizer(config["speech_config"])
    text_featurizer = CharFeaturizer(config["decoder_config"])
    text_featurizer.add_scorer(Scorer(**text_featurizer.decoder_config["lm_config"],
                                      vocabulary=text_featurizer.vocab_array))

    f, c = speech_featurizer.compute_feature_dim()
    satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                      arch_config=config["model_config"],
                                      num_classes=text_featurizer.num_classes)
    satt_ds2_model._build([1, 50, f, c])

    if args.from_weights:
        satt_ds2_model.load_weights(args.saved_model)
    else:
        saved_model = tf.keras.models.load_model(args.saved_model)
        satt_ds2_model.set_weights(saved_model.get_weights())

    satt_ds2_model.summary(line_length=100)

    satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer)

    signal = read_raw_audio(args.audio, speech_featurizer.sample_rate)
    features = speech_featurizer.extract(signal)
    decoded = satt_ds2_model.recognize_beam(tf.expand_dims(features, 0), lm=True)
    print(bytes_to_string(decoded.numpy()))

    for i in range(1, len(satt_ds2_model.base_model.layers)):
        func = tf.keras.backend.function([satt_ds2_model.base_model.input],
                                         [satt_ds2_model.base_model.layers[i].output])
        data = func([np.expand_dims(features, 0), 1])[0][0]
        print(data.shape)
        data = data.flatten()
        plt.hist(data, 200, color='green', histtype="stepfilled")
        plt.title(f"Output of {satt_ds2_model.base_model.layers[i].name}", fontweight="bold")
        plt.savefig(os.path.join(
            args.output, f"{i}_{satt_ds2_model.base_model.layers[i].name}.png"))
        plt.clf()
        plt.cla()
        plt.close()

    fc = satt_ds2_model(tf.expand_dims(features, 0), training=False)
    plt.hist(fc[0].numpy().flatten(), 200, color="green", histtype="stepfilled")
    plt.title(f"Output of {satt_ds2_model.layers[-1].name}", fontweight="bold")
    plt.savefig(os.path.join(args.output, f"{satt_ds2_model.layers[-1].name}.png"))
    plt.clf()
    plt.cla()
    plt.close()
    fc = tf.nn.softmax(fc)
    plt.hist(fc[0].numpy().flatten(), 10, color="green", histtype="stepfilled")
    plt.title("Output of softmax", fontweight="bold")
    plt.savefig(os.path.join(args.output, "softmax_hist.png"))
    plt.clf()
    plt.cla()
    plt.close()
    plt.hist(features.flatten(), 200, color="green", histtype="stepfilled")
    plt.title("Log Mel Spectrogram", fontweight="bold")
    plt.savefig(os.path.join(args.output, "log_mel_spectrogram.png"))
    plt.clf()
    plt.cla()
    plt.close()
Esempio n. 13
0
def run(ref_corpus_path, logits_path, config_path, output_path, lm_path,
        num_workers, beam_width, alpha_start, alpha_end, alpha_step,
        beta_start, beta_end, beta_step):

    print('Load refs')
    refs = []
    lengths = []
    with open(ref_corpus_path, 'r') as f:
        for x in json.load(f):
            refs.append((x['utt_idx'], x['transcript']))
            lengths.append(x['files'][0]['num_samples'])

    logits_raw = get_logits(logits_path)
    print('N Logits: {}'.format(len(logits_raw)))
    print('Shape Logits 0: {}'.format(logits_raw[0].shape))
    logits = []

    for i, l in enumerate(logits_raw):
        num_samples = lengths[i]
        num_frames = int(num_samples / 320) + 1
        logits.append(l[:num_frames])

    print('Shape Logits 0 (after trim): {}'.format(logits[0].shape))
    logits = [softmax(l) for l in logits]

    vocab = get_vocab(config_path)
    print('N Vocab: {}'.format(len(vocab)))

    refs_dict = {x[0]: x[1] for x in refs}
    print(len(refs))

    for alpha in np.arange(alpha_start, alpha_end, alpha_step):
        for beta in np.arange(beta_start, beta_end, beta_step):
            print('alpha: {}, beta: {}'.format(alpha, beta))
            target_folder = os.path.join(output_path,
                                         'lm_{}_{}'.format(alpha, beta))
            os.makedirs(target_folder, exist_ok=True)

            # decoder = ctcdecode.BestPathDecoder(vocab)
            # scorer = ctcdecode.WordKenLMScorer(lm_path, alpha, beta)
            # decoder = ctcdecode.BeamSearchDecoder(
            #     vocab,
            #     num_workers=num_workers,
            #     beam_width=beam_width,
            #     scorers=[scorer],
            #     cutoff_prob=np.log(0.000001),
            #     cutoff_top_n=40
            # )
            # result = decoder.decode_batch(logits)

            start = time.time()
            scorer = Scorer(alpha,
                            beta,
                            model_path=lm_path,
                            vocabulary=vocab[:-1])
            print('Scorer loaded')
            res = ctc_beam_search_decoder_batch(logits,
                                                vocab[:-1],
                                                beam_size=beam_width,
                                                num_processes=num_workers,
                                                ext_scoring_func=scorer)
            result = [[v for v in zip(*x)][1][0] for x in res]
            print('Took {}'.format(time.time() - start))

            print(len(result))
            predictions = {}

            for i, pred in enumerate(result):
                predictions[refs[i][0]] = pred

            pred_path = os.path.join(target_folder, 'predictions.txt')
            with open(pred_path, 'w') as f:
                outs = ['{} {}'.format(k, v) for k, v in predictions.items()]
                f.write('\n'.join(outs))

            report = evaluate(refs_dict, predictions)
            report_path = os.path.join(target_folder, 'result.txt')
            report.write_report(report_path, template='asr_detail')
Esempio n. 14
0
    def __init__(self, model_params=MODEL_PARAMS, scope_name='S2T', 
                 sr=16000, frame_len=0.2, frame_overlap=2.4, 
                 timestep_duration=0.02, 
                 ext_model_infer_func=None, merge=True,
                 beam_width=1, language_model=None, 
                 alpha=2.8, beta=1.0):
        '''
        Args:
          model_params: list of OpenSeq2Seq arguments (same as for run.py)
          scope_name: model's scope name
          sr: sample rate, Hz
          frame_len: frame's duration, seconds
          frame_overlap: duration of overlaps before and after current frame, seconds
            frame_overlap should be multiple of frame_len
          timestep_duration: time per step at model's output, seconds
          ext_model_infer_func: callback for external inference engine,
            if it is not None, then we don't build TF inference graph
          merge: whether to do merge in greedy decoder
          beam_width: beam width for beam search decoder if larger than 1
          language_model: path to LM (to use with beam search decoder)
          alpha: LM weight (trade-off between acoustic and LM scores)
          beta: word weight (added per every transcribed word in prediction)
        '''
        if ext_model_infer_func is None:
            # Build TF inference graph
            self.model_S2T, checkpoint_S2T = self._get_model(model_params, scope_name)

            # Create the session and load the checkpoints
            sess_config = tf.ConfigProto(allow_soft_placement=True)
            sess_config.gpu_options.allow_growth = True
            self.sess = tf.InteractiveSession(config=sess_config)
            vars_S2T = {}
            for v in tf.get_collection(tf.GraphKeys.VARIABLES):
                if scope_name in v.name:
                    vars_S2T['/'.join(v.op.name.split('/')[1:])] = v
            saver_S2T = tf.train.Saver(vars_S2T)
            saver_S2T.restore(self.sess, checkpoint_S2T)
            self.params = self.model_S2T.params
        else:
            # No TF, load pre-, post-processing parameters from config,
            # use external inference engine
            _, base_config, _, _ = get_base_config(model_params)
            self.params = base_config

        self.ext_model_infer_func = ext_model_infer_func

        self.vocab = self._load_vocab(
            self.model_S2T.params['data_layer_params']['vocab_file']
        )
        self.sr = sr
        self.frame_len = frame_len
        self.n_frame_len = int(frame_len * sr)
        self.frame_overlap = frame_overlap
        self.n_frame_overlap = int(frame_overlap * sr)
        if self.n_frame_overlap % self.n_frame_len:
            raise ValueError(
                "'frame_overlap' should be multiple of 'frame_len'"
            )
        self.n_timesteps_overlap = int(frame_overlap / timestep_duration) - 2
        self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len, dtype=np.float32)
        self.merge = merge
        self._beam_decoder = None
        # greedy decoder's state (unmerged transcription)
        self.text = ''
        # forerunner greedy decoder's state (unmerged transcription)
        self.forerunner_text = ''

        self.offset = 5
        # self._calibrate_offset()
        if beam_width > 1:
          if language_model is None:
            self._beam_decoder = BeamDecoder(self.vocab, beam_width)
          else:
            self._scorer = Scorer(alpha, beta, language_model, self.vocab)
            self._beam_decoder = BeamDecoder(self.vocab, beam_width, ext_scorer=self._scorer)
        self.reset()
Esempio n. 15
0
def main():
    parser = argparse.ArgumentParser(prog="SelfAttentionDS2 Histogram")

    parser.add_argument("--config", type=str, default=None, help="Config file")

    parser.add_argument("--audio", type=str, default=None, help="Audio file")

    parser.add_argument("--saved_model",
                        type=str,
                        default=None,
                        help="Saved model")

    parser.add_argument("--from_weights",
                        type=bool,
                        default=False,
                        help="Load from weights")

    parser.add_argument("--output",
                        type=str,
                        default=None,
                        help="Output dir storing histograms")

    args = parser.parse_args()

    config = UserConfig(args.config, args.config, learning=False)
    speech_featurizer = SpeechFeaturizer(config["speech_config"])
    text_featurizer = CharFeaturizer(config["decoder_config"])
    text_featurizer.add_scorer(
        Scorer(**text_featurizer.decoder_config["lm_config"],
               vocabulary=text_featurizer.vocab_array))

    f, c = speech_featurizer.compute_feature_dim()
    satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                      arch_config=config["model_config"],
                                      num_classes=text_featurizer.num_classes)
    satt_ds2_model._build([1, 50, f, c])

    if args.from_weights:
        satt_ds2_model.load_weights(args.saved_model)
    else:
        saved_model = tf.keras.models.load_model(args.saved_model)
        satt_ds2_model.set_weights(saved_model.get_weights())

    satt_ds2_model.summary(line_length=100)

    satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer)

    signal = read_raw_audio(args.audio, speech_featurizer.sample_rate)
    features = speech_featurizer.extract(signal)
    decoded = satt_ds2_model.recognize_beam(tf.expand_dims(features, 0),
                                            lm=True)
    print(bytes_to_string(decoded.numpy()))

    # for i in range(1, len(satt_ds2_model.base_model.layers)):
    #     func = tf.keras.backend.function([satt_ds2_model.base_model.input],
    #                                      [satt_ds2_model.base_model.layers[i].output])
    #     data = func([np.expand_dims(features, 0), 1])[0][0]
    #     print(data.shape)
    #     plt.figure(figsize=(16, 5))
    #     ax = plt.gca()
    #     im = ax.imshow(data.T, origin="lower", aspect="auto")
    #     ax.set_title(f"{satt_ds2_model.base_model.layers[i].name}", fontweight="bold")
    #     divider = make_axes_locatable(ax)
    #     cax = divider.append_axes("right", size="5%", pad=0.05)
    #     plt.colorbar(im, cax=cax)
    #     plt.savefig(os.path.join(
    #         args.output, f"{i}_{satt_ds2_model.base_model.layers[i].name}.png"))
    #     plt.clf()
    #     plt.cla()
    #     plt.close()

    fc = satt_ds2_model(tf.expand_dims(features, 0), training=False)
    plt.figure(figsize=(16, 5))
    ax = plt.gca()
    ax.set_title(f"{satt_ds2_model.layers[-1].name}", fontweight="bold")
    im = ax.imshow(fc[0].numpy().T, origin="lower", aspect="auto")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    plt.savefig(
        os.path.join(args.output, f"{satt_ds2_model.layers[-1].name}.png"))
    plt.clf()
    plt.cla()
    plt.close()
    fc = tf.nn.softmax(fc)
    plt.figure(figsize=(16, 5))
    ax = plt.gca()
    ax.set_title("Softmax", fontweight="bold")
    im = ax.imshow(fc[0].numpy().T, origin="lower", aspect="auto")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    plt.savefig(os.path.join(args.output, "softmax.png"))
    plt.clf()
    plt.cla()
    plt.close()
    plt.figure(figsize=(16, 5))
    ax = plt.gca()
    ax.set_title("Log Mel Spectrogram", fontweight="bold")
    im = ax.imshow(features[:, :, 0].T, origin="lower", aspect="auto")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    plt.savefig(os.path.join(args.output, "features.png"))
    plt.clf()
    plt.cla()
    plt.close()
Esempio n. 16
0
    def run(args):
        assert args.mode in modes, f"Mode must in {modes}"

        config = UserConfig(DEFAULT_YAML, args.config, learning=True)
        speech_featurizer = SpeechFeaturizer(config["speech_config"])
        text_featurizer = TextFeaturizer(config["decoder_config"])

        if args.mode == "train":
            tf.random.set_seed(2020)

            if args.mixed_precision:
                policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
                tf.keras.mixed_precision.experimental.set_policy(policy)
                print("Enabled mixed precision training")

            ctc_trainer = CTCTrainer(speech_featurizer, text_featurizer,
                                     config["learning_config"]["running_config"],
                                     args.mixed_precision)

            if args.tfrecords:
                train_dataset = ASRTFRecordDataset(
                    config["learning_config"]["dataset_config"]["train_paths"],
                    config["learning_config"]["dataset_config"]["tfrecords_dir"],
                    speech_featurizer, text_featurizer, "train",
                    augmentations=config["learning_config"]["augmentations"], shuffle=True,
                )
                eval_dataset = ASRTFRecordDataset(
                    config["learning_config"]["dataset_config"]["eval_paths"],
                    config["learning_config"]["dataset_config"]["tfrecords_dir"],
                    speech_featurizer, text_featurizer, "eval", shuffle=False
                )
            else:
                train_dataset = ASRSliceDataset(
                    stage="train", speech_featurizer=speech_featurizer,
                    text_featurizer=text_featurizer,
                    data_paths=config["learning_config"]["dataset_config"]["train_paths"],
                    augmentations=config["learning_config"]["augmentations"], shuffle=True,
                )
                eval_dataset = ASRSliceDataset(
                    stage="eval", speech_featurizer=speech_featurizer,
                    text_featurizer=text_featurizer,
                    data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
                    shuffle=False
                )

            # Build DS2 model
            f, c = speech_featurizer.compute_feature_dim()
            with ctc_trainer.strategy.scope():
                satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                                  arch_config=config["model_config"],
                                                  num_classes=text_featurizer.num_classes)
                satt_ds2_model._build([1, 50, f, c])
                optimizer = create_optimizer(
                    name=config["learning_config"]["optimizer_config"]["name"],
                    d_model=config["model_config"]["att"]["head_size"],
                    **config["learning_config"]["optimizer_config"]["config"]
                )
            # Compile
            ctc_trainer.compile(satt_ds2_model, optimizer, max_to_keep=args.max_ckpts)

            ctc_trainer.fit(train_dataset, eval_dataset, args.eval_train_ratio)

            if args.export:
                if args.from_weights:
                    ctc_trainer.model.save_weights(args.export)
                else:
                    ctc_trainer.model.save(args.export)

        elif args.mode == "test":
            tf.random.set_seed(0)
            assert args.export

            text_featurizer.add_scorer(
                Scorer(**text_featurizer.decoder_config["lm_config"],
                       vocabulary=text_featurizer.vocab_array))

            # Build DS2 model
            f, c = speech_featurizer.compute_feature_dim()
            satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                              arch_config=config["model_config"],
                                              num_classes=text_featurizer.num_classes)
            satt_ds2_model._build([1, 50, f, c])
            satt_ds2_model.summary(line_length=100)
            optimizer = create_optimizer(
                name=config["learning_config"]["optimizer_config"]["name"],
                d_model=config["model_config"]["att"]["head_size"],
                **config["learning_config"]["optimizer_config"]["config"]
            )

            batch_size = config["learning_config"]["running_config"]["batch_size"]
            if args.tfrecords:
                test_dataset = ASRTFRecordDataset(
                    config["learning_config"]["dataset_config"]["test_paths"],
                    config["learning_config"]["dataset_config"]["tfrecords_dir"],
                    speech_featurizer, text_featurizer, "test",
                    augmentations=config["learning_config"]["augmentations"], shuffle=False
                ).create(batch_size * args.eval_train_ratio)
            else:
                test_dataset = ASRSliceDataset(
                    stage="test", speech_featurizer=speech_featurizer,
                    text_featurizer=text_featurizer,
                    data_paths=config["learning_config"]["dataset_config"]["test_paths"],
                    augmentations=config["learning_config"]["augmentations"], shuffle=False
                ).create(batch_size * args.eval_train_ratio)

            ctc_tester = BaseTester(
                config=config["learning_config"]["running_config"],
                saved_path=args.export, from_weights=args.from_weights
            )
            ctc_tester.compile(satt_ds2_model, speech_featurizer, text_featurizer)
            ctc_tester.run(test_dataset)

        else:
            assert args.export

            # Build DS2 model
            f, c = speech_featurizer.compute_feature_dim()
            satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                              arch_config=config["model_config"],
                                              num_classes=text_featurizer.num_classes)
            satt_ds2_model._build([1, 50, f, c])
            optimizer = create_optimizer(
                name=config["learning_config"]["optimizer_config"]["name"],
                d_model=config["model_config"]["att"]["head_size"],
                **config["learning_config"]["optimizer_config"]["config"]
            )

            def save_func(**kwargs):
                if args.from_weights:
                    kwargs["model"].save_weights(args.export)
                else:
                    kwargs["model"].save(args.export)

            save_from_checkpoint(func=save_func,
                                 outdir=config["learning_config"]["running_config"]["outdir"],
                                 model=satt_ds2_model, optimizer=optimizer)