コード例 #1
0
ファイル: trm_lm.py プロジェクト: hommmm/TensorflowASR-1
    def load_model(self,training=True):
        self.model = Transformer(**self.model_config)


        if not training:
            self.model._build()
            self.load_checkpoint()

        self.model.start_id=self.lm_featurizer.start
        self.model.end_id=self.lm_featurizer.stop
コード例 #2
0
    def load_model(self, training=True):
        self.model = Transformer(**self.model_config)

        try:
            if not training:
                self.model._build()
                self.load_checkpoint()
        except:
            logging.info('lm loading model failed.')
        self.model.start_id = self.word_featurizer.start
        self.model.end_id = self.word_featurizer.stop
コード例 #3
0
class LM():
    def __init__(self, config):
        self.config = config
        self.vocab_featurizer = TextFeaturizer(config['lm_vocab'])
        self.word_featurizer = TextFeaturizer(config['lm_word'])
        self.model_config = self.config['model_config']
        self.model_config.update({
            'input_vocab_size':
            self.vocab_featurizer.num_classes,
            'target_vocab_size':
            self.word_featurizer.num_classes
        })

    def load_model(self, training=True):
        self.model = Transformer(**self.model_config)

        try:
            if not training:
                self.model._build()
                self.load_checkpoint()
        except:
            logging.info('lm loading model failed.')
        self.model.start_id = self.word_featurizer.start
        self.model.end_id = self.word_featurizer.stop

    def convert_to_pb(self, export_path):
        import tensorflow as tf
        self.model.inference(np.ones([1, 10], 'int32'))

        concrete_func = self.model.inference.get_concrete_function()
        tf.saved_model.save(self.model, export_path, signatures=concrete_func)

    def load_checkpoint(self, ):
        """Load checkpoint."""

        self.checkpoint_dir = os.path.join(
            self.config['running_config']["outdir"], "checkpoints")
        files = os.listdir(self.checkpoint_dir)
        files.sort(key=lambda x: int(x.split('_')[-1].replace('.h5', '')))
        self.model.load_weights(os.path.join(self.checkpoint_dir, files[-1]))

    def encode(self, word, token):
        x = [token.start]
        for i in word:
            x.append(token.token_to_index[i])
        x.append(token.stop)
        return np.array(x)[np.newaxis, :]

    def decode(self, out, token):
        de = []
        for i in out[1:]:
            de.append(token.index_to_token[i])
        return de

    def predict(self, pins):
        x = self.encode(pins, self.vocab_featurizer)
        result = self.model.inference(x)
        return result
コード例 #4
0
class LM():
    def __init__(self, config, punc_config=None):
        self.config = config
        self.am_featurizer = TextFeaturizer(config['am_token'])
        self.lm_featurizer = TextFeaturizer(config['lm_token'])
        self.model_config = self.config['model_config']
        self.model_config.update({
            'input_vocab_size':
            self.am_featurizer.num_classes,
            'target_vocab_size':
            self.lm_featurizer.num_classes
        })
        self.punc_config = punc_config
        if punc_config:
            self.punc_vocab_featurizer = TextFeaturizer(
                punc_config['punc_vocab'])
            self.punc_bd_featurizer = TextFeaturizer(
                punc_config['punc_biaodian'])
            self.punc_model_config = self.punc_config['model_config']
            self.punc_model_config.update({
                'input_vocab_size':
                self.punc_vocab_featurizer.num_classes,
                'bd_vocab_size':
                self.punc_bd_featurizer.num_classes
            })

    def load_model(self, training=True):
        self.model = Transformer(**self.model_config)
        if self.punc_config is not None:
            self.punc_model = punc_transformer.Transformer(
                **self.punc_model_config)
        if not training:
            self.model._build()

            if self.punc_config is not None:
                self.punc_model._build()
            self.load_checkpoint()

        self.model.start_id = self.lm_featurizer.start
        self.model.end_id = self.lm_featurizer.stop

    def convert_to_pb(self, export_path):
        import tensorflow as tf
        self.model.inference(np.ones([1, 10], 'int32'))

        concrete_func = self.model.inference.get_concrete_function()
        tf.saved_model.save(self.model, export_path, signatures=concrete_func)

    def load_checkpoint(self, ):
        """Load checkpoint."""

        self.checkpoint_dir = os.path.join(
            self.config['running_config']["outdir"], "checkpoints")
        files = os.listdir(self.checkpoint_dir)
        files.sort(key=lambda x: int(x.split('_')[-1].replace('.h5', '')))
        self.model.load_weights(os.path.join(self.checkpoint_dir, files[-1]))
        if self.punc_config is not None:
            self.checkpoint_dir = os.path.join(
                self.punc_config['running_config']["outdir"], "checkpoints")
            files = os.listdir(self.checkpoint_dir)
            files.sort(key=lambda x: int(x.split('_')[-1].replace('.h5', '')))
            self.punc_model.load_weights(
                os.path.join(self.checkpoint_dir, files[-1]))

    def encode(self, word, token):
        x = [token.start]
        for i in word:
            x.append(token.token_to_index[i])
        x.append(token.stop)
        return np.array(x)[np.newaxis, :]

    def decode(self, out, token):
        de = []
        for i in out[1:]:
            de.append(token.index_to_token[i])
        return de

    def predict(self, pins):
        x = self.encode(pins, self.am_featurizer)
        result = self.model.inference(x)
        return result

    def creat_mask(self, seq):
        seq_pad = tf.cast(tf.equal(seq, 0), tf.float32)
        return seq_pad[:, tf.newaxis,
                       tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

    def only_chinese(self, word):
        n = ''
        for ch in word:
            if '\u4e00' <= ch <= '\u9fff':
                n += ch
        return n

    def punc_predict(self, txt):
        chinese = self.only_chinese(txt)
        x = np.array(self.encode(chinese, self.punc_vocab_featurizer), 'int32')
        mask = self.creat_mask(x)
        result = self.punc_model.inference(x, mask)
        decoded = self.punc_decoded(chinese, result[0].numpy())
        value = self.iextract(decoded, txt)
        return value

    def iextract(self, decoded, input_strs):

        idx = 0
        inp = list(input_strs)
        for n in decoded:
            idx_ = inp.index(n[0], idx)
            inp[idx_] = ''.join(n)
            idx = idx_ + 1

        return inp

    def punc_decoded(self, chinese, bd_out):
        de = []
        for i in range(1, len(chinese) + 1):
            now = [chinese[i - 1]]

            if bd_out[i].argmax(-1) > 1 and bd_out[i].max() >= 0.8:
                result = bd_out[i].argmax(-1)
                now.append(self.punc_bd_featurizer.vocab_array[result])
            de.append(now)
        return de