Exemplo n.º 1
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        self.feat_dim = 120
        self.vocab_size = 46 
        init_adadelta = True
        ''' Setup ASR model and optimizer '''
        # Model
        # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.src_config['model']).to(self.device)
        self.verbose(self.model.create_msg())

        if self.finetune_first>0:
            names = ["encoder.layers.%d"%i for i in range(self.finetune_first)]
            model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}]
        else:
            model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb' in self.config) and (
            self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.src_config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
        # Beam decoder
        self.decoder = BeamDecoder(
            self.model, self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        # del self.model
        # del self.emb_decoder
        self.decoder.to(self.device)
Exemplo n.º 2
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                and (self.config['emb']['fuse'] > 0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(self.tokenizer,
                                                    self.model.dec_dim,
                                                    **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        self.ctc_only = False
        if self.greedy:
            # Greedy decoding: attention-based if the ASR has a decoder, else use CTC
            self.decoder = copy.deepcopy(self.model).to(self.device)
        else:
            if (not self.model.enable_att) or self.config['decode'].get(
                    'ctc_weight', 0.0) == 1.0:
                # Pure CTC Beam Decoder
                assert self.config['decode']['beam_size'] <= self.config[
                    'decode']['vocab_candidate']
                self.decoder = CTCBeamDecoder(
                    self.model.to(self.device),
                    [1] + [r for r in range(3, self.vocab_size)],
                    self.config['decode']['beam_size'],
                    self.config['decode']['vocab_candidate'],
                    lm_path=self.config['decode']['lm_path'],
                    lm_config=self.config['decode']['lm_config'],
                    lm_weight=self.config['decode']['lm_weight'],
                    device=self.device)
                self.ctc_only = True
            else:
                # Joint CTC-Attention Beam Decoder
                self.decoder = BeamDecoder(self.model.cpu(), self.emb_decoder,
                                           **self.config['decode'])

        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder
    def set_model(self):
        ''' Setup ASR model '''
        # Model

        self.model = ASR(self.feat_dim, self.vocab_size,
                         **self.config['model'])

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                                  and (self.config['emb']['fuse']>0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(self.tokenizer,
                                                    self.model.dec_dim,
                                                    **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        # self.ctc_only = False
        if self.greedy:
            self.decoder = copy.deepcopy(self.model).to(self.device)
        else:
            # Beam decoder
            # TODO: CTC decoding function Hidden by author
            # if not self.model.enable_att or self.config['decode'].get('ctc_weight', 0.0) == 1.0:
            # For CTC only decoding (character level)

            # self.decoder = CTCBeamDecoder(self.model.to(self.device),
            #     range(self.model.vocab_size),
            #     self.config['decode']['beam_size'],
            #     self.config['decode']['vocab_candidate'])
            # self.ctc_only = True
            # else:
            # self.decoder = BeamDecoder(self.model, self.emb_decoder, **self.config['decode'])
            self.decoder = BeamDecoder(self.model, self.emb_decoder,
                                       **self.config['decode'])

        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder
        self.emb_decoder = None
Exemplo n.º 4
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = self.src_config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.config['model']).to(self.device)

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                and (self.config['emb']['fuse'] > 0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        # Beam decoder
        self.decoder = BeamDecoder(
            self.model.cpu(), self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder