Exemplo n.º 1
0
    def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio,
                 lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0):
        super().__init__()
        # Setup
        self.beam_size = beam_size
        self.min_len_ratio = min_len_ratio
        self.max_len_ratio = max_len_ratio
        self.asr = asr

        # ToDo : implement pure ctc decode
        # assert self.asr.enable_att

        # Additional decoding modules
        self.apply_ctc = ctc_weight > 0
        if self.apply_ctc:
            assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder'
            self.ctc_w = ctc_weight
            self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size)

        self.apply_lm = lm_weight > 0
        if self.apply_lm:
            self.lm_w = lm_weight
            self.lm_path = lm_path
            lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
            self.lm = RNNLM(self.asr.vocab_size, **lm_config['model'])
            self.lm.load_state_dict(torch.load(
                self.lm_path, map_location='cpu')['model'])
            self.lm.eval()

        self.apply_emb = emb_decoder is not None
        if self.apply_emb:
            self.emb_decoder = emb_decoder
Exemplo n.º 2
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''

        # Model
        # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device)
        self.model = Prediction(self.vocab_size,
                                **self.config['model']).to(self.device)
        self.rnnlm = RNNLM(self.vocab_size,
                           **self.config['model']).to(self.device)

        self.verbose(self.rnnlm.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            list(self.model.parameters()) + list(self.rnnlm.parameters()),
            **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step))
Exemplo n.º 3
0
    def __init__(self,
                 asr,
                 vocab_range,
                 beam_size,
                 vocab_candidate,
                 lm_path='',
                 lm_config='',
                 lm_weight=0.0,
                 device=None):
        super().__init__()
        # Setup
        self.asr = asr
        self.vocab_range = vocab_range
        self.beam_size = beam_size
        self.vocab_cand = vocab_candidate
        assert self.vocab_cand <= len(self.vocab_range)

        assert self.asr.enable_ctc

        # Setup RNNLM
        self.apply_lm = lm_weight > 0
        self.lm_w = 0
        if self.apply_lm:
            self.device = device
            self.lm_w = lm_weight
            self.lm_path = lm_path
            lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
            self.lm = RNNLM(self.asr.vocab_size,
                            **lm_config['model']).to(self.device)
            self.lm.load_state_dict(
                torch.load(self.lm_path, map_location='cpu')['model'])
            self.lm.eval()