Пример #1
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        self.eval_target = 'phone' if self.config['data']['corpus'][
            'target'] == 'ipa' else 'char'

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        if self.paras.transfer:
            self.transfer_weight()

        # Automatically load pre-trained model if self.paras.load is given
        if self.paras.load:
            self.load_ckpt()
Пример #2
0
    def __init__(self):
        self.asr = ASR()
        self.tts = TTS()
        self.dm = DM()

        # TODO: isso deve ir pra config, preferencialmente num JSON
        self.SESSION_ID = 'tcc-chatbot'
Пример #3
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        #print(self.feat_dim) #160
        batch_size = self.config['data']['corpus']['batch_size'] // 2
        self.model = ASR(self.feat_dim, self.vocab_size, batch_size,
                         **self.config['model']).to(self.device)

        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        '''label smoothing'''
        if self.config['hparas']['label_smoothing']:
            self.seq_loss = LabelSmoothingLoss(31, 0.1)
            print('[INFO]  using label smoothing. ')
        else:
            self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        self.ctc_loss = torch.nn.CTCLoss(
            blank=0,
            zero_infinity=False)  # Note: zero_infinity=False is unstable?

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.lr_scheduler = self.optimizer.lr_scheduler
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Transfer Learning
        if self.transfer_learning:
            self.verbose('Apply transfer learning: ')
            self.verbose('      Train encoder layers: {}'.format(
                self.train_enc))
            self.verbose('      Train decoder:        {}'.format(
                self.train_dec))
            self.verbose('      Save name:            {}'.format(
                self.save_name))

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
Пример #4
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        self.feat_dim = 120
        self.vocab_size = 46 
        init_adadelta = True
        ''' Setup ASR model and optimizer '''
        # Model
        # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.src_config['model']).to(self.device)
        self.verbose(self.model.create_msg())

        if self.finetune_first>0:
            names = ["encoder.layers.%d"%i for i in range(self.finetune_first)]
            model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}]
        else:
            model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb' in self.config) and (
            self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.src_config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
        # Beam decoder
        self.decoder = BeamDecoder(
            self.model, self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        # del self.model
        # del self.emb_decoder
        self.decoder.to(self.device)
Пример #5
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                and (self.config['emb']['fuse'] > 0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(self.tokenizer,
                                                    self.model.dec_dim,
                                                    **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        self.ctc_only = False
        if self.greedy:
            # Greedy decoding: attention-based if the ASR has a decoder, else use CTC
            self.decoder = copy.deepcopy(self.model).to(self.device)
        else:
            if (not self.model.enable_att) or self.config['decode'].get(
                    'ctc_weight', 0.0) == 1.0:
                # Pure CTC Beam Decoder
                assert self.config['decode']['beam_size'] <= self.config[
                    'decode']['vocab_candidate']
                self.decoder = CTCBeamDecoder(
                    self.model.to(self.device),
                    [1] + [r for r in range(3, self.vocab_size)],
                    self.config['decode']['beam_size'],
                    self.config['decode']['vocab_candidate'],
                    lm_path=self.config['decode']['lm_path'],
                    lm_config=self.config['decode']['lm_config'],
                    lm_weight=self.config['decode']['lm_weight'],
                    device=self.device)
                self.ctc_only = True
            else:
                # Joint CTC-Attention Beam Decoder
                self.decoder = BeamDecoder(self.model.cpu(), self.emb_decoder,
                                           **self.config['decode'])

        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder
Пример #6
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = self.src_config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)

        # Load target model in eval mode
        self.load_ckpt()
Пример #7
0
 def set_model(self):
     ''' Setup ASR model and optimizer '''
     # Model
     self.model = ASR(self.vocab_size, **self.config['model']).to(self.device)
     self.verbose(self.model.create_msg())
     # Losses
     self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
     # Optimizer
     self.optimizer = Optimizer(
         self.model.parameters(), **self.config['hparas'])
     # Enable AMP if needed
     self.enable_apex()
     # load pre-trained model
     if self.paras.load:
         self.load_ckpt()
         ckpt = torch.load(self.paras.load, map_location=self.device)
         self.model.load_state_dict(ckpt['model'])
         self.optimizer.load_opt_state_dict(ckpt['optimizer'])
         self.step = ckpt['global_step']
         self.verbose('Load ckpt from {}, restarting at step {}'.format(
             self.paras.load, self.step))
Пример #8
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        self.paras.load = 'ckpt/asr_example_sd0/best_att.pth'

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        self.model = ASR(self.feat_dim, self.vocab_size, **
                         self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
Пример #10
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = self.src_config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.config['model']).to(self.device)

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                and (self.config['emb']['fuse'] > 0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        # Beam decoder
        self.decoder = BeamDecoder(
            self.model.cpu(), self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder
    def set_model(self):
        ''' Setup ASR model '''
        # Model

        self.model = ASR(self.feat_dim, self.vocab_size,
                         **self.config['model'])

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                                  and (self.config['emb']['fuse']>0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(self.tokenizer,
                                                    self.model.dec_dim,
                                                    **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        # self.ctc_only = False
        if self.greedy:
            self.decoder = copy.deepcopy(self.model).to(self.device)
        else:
            # Beam decoder
            # TODO: CTC decoding function Hidden by author
            # if not self.model.enable_att or self.config['decode'].get('ctc_weight', 0.0) == 1.0:
            # For CTC only decoding (character level)

            # self.decoder = CTCBeamDecoder(self.model.to(self.device),
            #     range(self.model.vocab_size),
            #     self.config['decode']['beam_size'],
            #     self.config['decode']['vocab_candidate'])
            # self.ctc_only = True
            # else:
            # self.decoder = BeamDecoder(self.model, self.emb_decoder, **self.config['decode'])
            self.decoder = BeamDecoder(self.model, self.emb_decoder,
                                       **self.config['decode'])

        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder
        self.emb_decoder = None
Пример #12
0
class Client:
    def __init__(self):
        self.asr = ASR()
        self.tts = TTS()
        self.dm = DM()

        # TODO: isso deve ir pra config, preferencialmente num JSON
        self.SESSION_ID = 'tcc-chatbot'

    def run(self):

        session_client = dialogflow.SessionsClient(
            credentials=GOOGLE_APPLICATION_CREDENTIALS)
        session = session_client.session_path(DIALOGFLOW_PROJECT_ID,
                                              self.SESSION_ID)

        wait_keyword = True

        while True:

            try:
                if SILENT_MODE:
                    text_to_be_analyzed = input("[Client] <-- ").lower()
                else:
                    if wait_keyword:
                        self.asr.waitKeyword()
                    text_to_be_analyzed = self.asr.listen()

                if not text_to_be_analyzed:
                    print("[Client] Nada reconhecido")
                    continue
                else:
                    text_input = dialogflow.types.TextInput(
                        text=text_to_be_analyzed,
                        language_code=DIALOGFLOW_LANGUAGE_CODE)
                    query_input = dialogflow.types.QueryInput(text=text_input)

                    try:
                        result = session_client.detect_intent(
                            session=session, query_input=query_input)
                        response = self.dm.treatResult(result)

                        # Acionamento
                        if SIMULATION:
                            import socket
                            HOST = '127.0.0.1'
                            PORT = 31415
                            a = str(response['actions']).encode()
                            with socket.socket(socket.AF_INET,
                                               socket.SOCK_STREAM) as s:
                                s.connect((HOST, PORT))
                                s.sendall(a)
                        else:
                            pass  # Implementação real de acionamento
                        print("[Client] Acionamento:", response['actions'])

                        # Resposta falada
                        if response['answer'] and not SILENT_MODE:
                            self.tts.speak(response['answer'])

                        wait_keyword = response['end_conversation']

                    except Exception as e:
                        print("[Client] Erro ao tentar detectar a inteção:")
                        print("        ", e)

            except KeyboardInterrupt:
                print("\n[Client] Parando cliente")
                break
Пример #13
0
class Solver(BaseSolver):
    ''' Solver for training'''

    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # ToDo : support tr/eval on different corpus
        assert self.config['data']['corpus']['name'] == self.src_config['data']['corpus']['name']
        self.config['data']['corpus']['path'] = self.src_config['data']['corpus']['path']
        self.config['data']['corpus']['bucketing'] = False

        # The follow attribute should be identical to training config
        self.config['data']['audio'] = self.src_config['data']['audio']
        self.config['data']['text'] = self.src_config['data']['text']
        self.config['model'] = self.src_config['model']

        # Output file
        self.output_file = str(self.ckpdir)+'_{}_{}.csv'

        # Override batch size for beam decoding
        self.greedy = self.config['decode']['beam_size'] == 1
        if not self.greedy:
            self.config['data']['corpus']['batch_size'] = 1
        else:
            # ToDo : implement greedy
            raise NotImplementedError

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        if self.paras.upstream is not None:
            print(f'[Solver] - using S3PRL {self.paras.upstream}')
            self.dv_set, self.tt_set, self.vocab_size, self.tokenizer, msg = \
                            load_wav_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, 
                                             False, **self.config['data'])
            self.upstream = torch.hub.load(
                's3prl/s3prl',
                args.upstream,
                feature_selection = args.upstream_feature_selection,
                refresh = args.upstream_refresh,
                ckpt = args.upstream_ckpt,
                force_reload = True,
            )
            self.feat_dim = self.upstream.get_output_dim()
        else:
            self.dv_set, self.tt_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
                load_dataset(self.paras.njobs, self.paras.gpu,
                             self.paras.pin_memory, False, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = self.src_config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.config['model']).to(self.device)

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                and (self.config['emb']['fuse'] > 0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        # Beam decoder
        self.decoder = BeamDecoder(
            self.model.cpu(), self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder

    def exec(self):
        ''' Testing End-to-end ASR system '''
        for s, ds in zip(['dev', 'test'], [self.dv_set, self.tt_set]):
            # Setup output
            self.cur_output_path = self.output_file.format(s, 'output')
            with open(self.cur_output_path, 'w') as f:
                f.write('idx\thyp\ttruth\n')

            if self.greedy:
                # Greedy decode
                self.verbose(
                    'Performing batch-wise greedy decoding on {} set, num of batch = {}.'.format(s, len(ds)))
                self.verbose('Results will be stored at {}'.format(
                    self.cur_output_path))
            else:
                # Additional output to store all beams
                self.cur_beam_path = self.output_file.format(s, 'beam')
                with open(self.cur_beam_path, 'w') as f:
                    f.write('idx\tbeam\thyp\ttruth\n')
                self.verbose(
                    'Performing instance-wise beam decoding on {} set. (NOTE: use --njobs to speedup)'.format(s))
                # Minimal function to pickle
                beam_decode_func = partial(beam_decode, model=copy.deepcopy(
                    self.decoder), device=self.device)

                def handler(data):
                    if self.paras.upstream is not None:
                        # feat is raw waveform
                        name, feat, feat_len, txt = data

                        device = 'cpu' if self.paras.deterministic else self.device
                        self.upstream.to(device)

                        def to_device(feat):
                            return [f.to(device) for f in feat]

                        def extract_feature(feat):
                            feat = self.upstream(to_device(feat))
                            return feat

                        self.upstream.eval()
                        with torch.no_grad():
                            feat = extract_feature(feat)

                        feat_len = torch.LongTensor([len(f) for f in feat])
                        feat = pad_sequence(feat, batch_first=True)
                        txt = pad_sequence(txt, batch_first=True)
                        data = [name, feat, feat_len, txt]

                    return data

                # Parallel beam decode
                results = Parallel(n_jobs=self.paras.njobs)(
                    delayed(beam_decode_func)(handler(data)) for data in tqdm(ds))
                self.verbose(
                    'Results/Beams will be stored at {} / {}.'.format(self.cur_output_path, self.cur_beam_path))
                self.write_hyp(results, self.cur_output_path,
                               self.cur_beam_path)
        self.verbose('All done !')

    def write_hyp(self, results, best_path, beam_path):
        '''Record decoding results'''
        for name, hyp_seqs, truth in tqdm(results):
            hyp_seqs = [self.tokenizer.decode(hyp) for hyp in hyp_seqs]
            truth = self.tokenizer.decode(truth)
            with open(best_path, 'a') as f:
                f.write('\t'.join([name, hyp_seqs[0], truth])+'\n')
            if not self.greedy:
                with open(beam_path, 'a') as f:
                    for b, hyp in enumerate(hyp_seqs):
                        f.write('\t'.join([name, str(b), hyp, truth])+'\n')
Пример #14
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Logger settings
        self.best_wer = {'att': 3.0, 'ctc': 3.0}
        # Curriculum learning affects data loader
        self.curriculum = self.config['hparas']['curriculum']

    def fetch_data(self, data):
        ''' Move data to device and compute text seq. length'''
        _, feat, feat_len, txt = data
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
            load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                         self.curriculum > 0, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()

        # ToDo: other training methods

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                               teacher=txt, get_dec_state=self.emb_reg)

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss

                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight

                if att_output is not None:
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(
                        att_output.contiguous().view(b * t, -1),
                        txt.contiguous().view(-1))
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)
                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    self.write_log('loss', {
                        'tr_ctc': ctc_loss,
                        'tr_att': att_loss
                    })
                    self.write_log('emb_loss', {'tr': emb_loss})
                    self.write_log(
                        'wer', {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt),
                            'tr_ctc':
                            cal_er(self.tokenizer, ctc_output, txt, ctc=True)
                        })
                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
        self.log.close()

    def validate(self):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None:
            self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                               emb_decoder=self.emb_decoder)

            dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt))
            dev_wer['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if self.step == 1:
                        self.write_log('true_text{}'.format(i),
                                       self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log(
                            'att_align{}'.format(i),
                            feat_to_fig(att_align[i, 0, :, :].cpu().detach()))
                        self.write_log(
                            'att_text{}'.format(i),
                            self.tokenizer.decode(
                                att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log(
                            'ctc_text{}'.format(i),
                            self.tokenizer.decode(
                                ctc_output[i].argmax(dim=-1).tolist(),
                                ignore_repeat=True))

        # Ckpt if performance improves
        for task in ['att', 'ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            if dev_wer[task] < self.best_wer[task]:
                self.best_wer[task] = dev_wer[task]
                self.save_checkpoint('best_{}.pth'.format(task), 'wer',
                                     dev_wer[task])
            self.write_log('wer', {'dv_' + task: dev_wer[task]})
        self.save_checkpoint('latest.pth',
                             'wer',
                             dev_wer['att'],
                             show_msg=False)

        # Resume training
        self.model.train()
        if self.emb_decoder is not None:
            self.emb_decoder.train()
Пример #15
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # ToDo : support tr/eval on different corpus
        assert self.config['data']['corpus']['name'] == self.src_config[
            'data']['corpus']['name']
        self.config['data']['corpus']['path'] = self.src_config['data'][
            'corpus']['path']
        self.config['data']['corpus']['bucketing'] = False

        # The follow attribute should be identical to training config
        self.config['data']['audio'] = self.src_config['data']['audio']
        self.config['data']['text'] = self.src_config['data']['text']
        self.config['model'] = self.src_config['model']

        # Output file
        self.output_file = str(self.ckpdir) + '_{}_{}.csv'

        # Override batch size for beam decoding
        self.greedy = self.config['decode']['beam_size'] == 1
        if not self.greedy:
            self.config['data']['corpus']['batch_size'] = 1
        else:
            # ToDo : implement greedy
            raise NotImplementedError

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.dv_set, self.tt_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
            load_dataset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, False, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = True
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model'])

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                and (self.config['emb']['fuse'] > 0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(self.tokenizer,
                                                    self.model.dec_dim,
                                                    **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        # Beam decoder
        self.decoder = BeamDecoder(self.model.cpu(), self.emb_decoder,
                                   **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder

    def exec(self):
        ''' Testing End-to-end ASR system '''
        for s, ds in zip(['dev', 'test'], [self.dv_set, self.tt_set]):
            # Setup output
            self.cur_output_path = self.output_file.format(s, 'output')
            with open(self.cur_output_path, 'w') as f:
                f.write('idx\thyp\ttruth\n')

            if self.greedy:
                # Greedy decode
                self.verbose(
                    'Performing batch-wise greedy decoding on {} set, num of batch = {}.'
                    .format(s, len(ds)))
                self.verbose('Results will be stored at {}'.format(
                    self.cur_output_path))
            else:
                # Additional output to store all beams
                self.cur_beam_path = self.output_file.format(s, 'beam')
                with open(self.cur_beam_path, 'w') as f:
                    f.write('idx\tbeam\thyp\ttruth\n')
                self.verbose(
                    'Performing instance-wise beam decoding on {} set. (NOTE: use --njobs to speedup)'
                    .format(s))
                # Minimal function to pickle
                beam_decode_func = partial(beam_decode,
                                           model=copy.deepcopy(self.decoder),
                                           device=self.device)
                # Parallel beam decode
                results = Parallel(n_jobs=self.paras.njobs)(
                    delayed(beam_decode_func)(data) for data in tqdm(ds))
                self.verbose('Results/Beams will be stored at {} / {}.'.format(
                    self.cur_output_path, self.cur_beam_path))
                self.write_hyp(results, self.cur_output_path,
                               self.cur_beam_path)
        self.verbose('All done !')

    def write_hyp(self, results, best_path, beam_path):
        '''Record decoding results'''
        for name, hyp_seqs, truth in tqdm(results):
            hyp_seqs = [self.tokenizer.decode(hyp) for hyp in hyp_seqs]
            truth = self.tokenizer.decode(truth)
            with open(best_path, 'a') as f:
                f.write('\t'.join([name, hyp_seqs[0], truth]) + '\n')
            if not self.greedy:
                with open(beam_path, 'a') as f:
                    for b, hyp in enumerate(hyp_seqs):
                        f.write('\t'.join([name, str(b), hyp, truth]) + '\n')
Пример #16
0
    def __init__(self,
                 n_mels,
                 linear_dim,
                 vocab_size,
                 n_spkr,
                 encoder,
                 codebook,
                 decoder,
                 spkr_latent_dim,
                 max_frames_per_phn,
                 stop_threshold,
                 asr_postnet_weight=0.0,
                 txt_update_codebook=False,
                 pretrained_asr=None,
                 pretrained_emb=None,
                 pretrained_tts=None):
        super().__init__()
        # Setup attributes
        self.in_dim = n_mels
        self.vocab_size = vocab_size
        self.n_spkr = n_spkr
        self.n_mels = n_mels
        self.linear_dim = linear_dim
        self.spkr_latent_dim = spkr_latent_dim
        self.stop_threshold = stop_threshold
        self.max_frames_per_phn = max_frames_per_phn
        self.txt_update_codebook = txt_update_codebook

        self.code_bone = codebook.pop('bone')
        self.latent_dim = codebook['latent_dim']
        self.commit_weight = codebook['commit_weight']
        self.vq_weight = codebook['vq_weight']
        self.n_frames_per_step = decoder['decoder']['n_frames_per_step']

        # ----------------- ASR model -----------------
        self.asr = ASR(n_mels, self.latent_dim, **encoder)
        self.time_reduce_factor = self.asr.time_reduce_factor
        self.use_asr_postnet = asr_postnet_weight > 0
        if self.use_asr_postnet:
            self.asr_postnet_weight = asr_postnet_weight
            self.asr_postnet = ASRPostnet(self.latent_dim, self.latent_dim)

        # ----------------- Latent code ---------------
        if self.code_bone == 'l2':
            self.codebook = L2Embedding(vocab_size, False, **codebook)
        elif self.code_bone == 'seperate':
            self.codebook = SeperateEmbedding(vocab_size, False, **codebook)
        else:
            raise NotImplementedError

        # ------------- speaker embedding -------------
        self.spkr_embed = nn.Embedding(self.n_spkr, spkr_latent_dim)
        # self.spkr_enc = SpeakerEncoder(n_mels, spkr_latent_dim, **spkr_encoder)

        # ----------------- TTS model -----------------
        self.tts = TTS(n_mels, self.linear_dim, self.codebook.out_dim,
                       self.spkr_latent_dim, decoder)

        # Load init. weights
        self.pretrain_asr = pretrained_asr is not None and pretrained_asr != ''
        if self.pretrain_asr:
            old_asr = torch.load(pretrained_asr)['model']
            old_asr = OrderedDict([
                (i[0].replace(PRETRAINED_ENCODER_PREFIX, ''), i[1])
                for i in old_asr.items()
            ])  # rename parameters to match vqvae
            missing, _ = self.asr.load_state_dict(old_asr, strict=False)
            assert missing == [], 'Missing pretrained para. {}'.format(missing)
        self.pretrained_emb = pretrained_emb is not None and pretrained_emb != ''
        if self.pretrained_emb:
            old_emb = torch.load(pretrained_asr)['model']
            self.codebook.load_pretrained_embedding(old_emb)
        self.pretrained_tts = pretrained_tts is not None and pretrained_tts != ''
        if self.pretrained_tts:
            old_tts = torch.load(pretrained_tts)['model']
            old_tts = OrderedDict([(i[0].replace(PRETRAINED_DECODER_PREFIX,
                                                 ''), i[1])
                                   for i in old_tts.items()])
            missing, _ = self.tts.decoder.load_state_dict(
                old_tts, strict=False
            )  # TTAO: are ALL weights included in pre-training?
            assert missing == [], 'Missing pretrained para. {}'.format(missing)
            old_postnet = OrderedDict([(i[0].replace(PRETRAINED_POSTNET_PREFIX,''),i[1])\
                                          for i in old_tts.items() if PRETRAINED_POSTNET_PREFIX in i[0]])
            missing, _ = self.tts.postnet.load_state_dict(
                old_postnet, strict=False
            )  # TTAO: are ALL weights included in pre-training?
            assert missing == [], 'Missing pretrained para. {}'.format(missing)
Пример #17
0
class Solver(BaseSolver):
    ''' Solver for training'''

    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # ToDo : support tr/eval on different corpus
        assert self.config['data']['corpus']['name'] == self.src_config['data']['corpus']['name']
        self.config['data']['corpus']['path'] = self.src_config['data']['corpus']['path']
        self.config['data']['corpus']['bucketing'] = False

        # The follow attribute should be identical to training config
        self.config['data']['audio'] = self.src_config['data']['audio']
        self.config['data']['corpus']['train_split'] = self.src_config['data']['corpus']['train_split']
        self.config['data']['text'] = self.src_config['data']['text']
        self.tokenizer = load_text_encoder(**self.config['data']['text'])
        self.config['model'] = self.src_config['model']
        self.finetune_first = 5
        self.best_wer = {'att': 3.0, 'ctc': 3.0}

        # Output file
        self.output_file = str(self.ckpdir)+'_{}_{}.csv'

        # Override batch size for beam decoding
        self.greedy = self.config['decode']['beam_size'] == 1
        self.dealer = Datadealer(self.config['data']['audio'])
        self.ctc = self.config['decode']['ctc_weight'] == 1.0
        if not self.greedy:
            self.config['data']['corpus']['batch_size'] = 1
        else:
            # ToDo : implement greedy
            raise NotImplementedError

        # Logger settings
        self.logdir = os.path.join(paras.logdir, self.exp_name)
        self.log = SummaryWriter(
            self.logdir, flush_secs=self.TB_FLUSH_FREQ)
        self.timer = Timer()

    def fetch_data(self, data):
        ''' Move data to device and compute text seq. length'''
        _, feat, feat_len, txt = data
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self, batch_size=7):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        prev_batch_size = self.config['data']['corpus']['batch_size']
        self.config['data']['corpus']['batch_size'] = batch_size
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
            load_dataset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, False, **self.config['data'])
        self.config['data']['corpus']['batch_size'] = prev_batch_size
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model '''
        # Model
        self.feat_dim = 120
        self.vocab_size = 46 
        init_adadelta = True
        ''' Setup ASR model and optimizer '''
        # Model
        # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.src_config['model']).to(self.device)
        self.verbose(self.model.create_msg())

        if self.finetune_first>0:
            names = ["encoder.layers.%d"%i for i in range(self.finetune_first)]
            model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}]
        else:
            model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb' in self.config) and (
            self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.src_config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
        # Beam decoder
        self.decoder = BeamDecoder(
            self.model, self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        # del self.model
        # del self.emb_decoder
        self.decoder.to(self.device)

    def exec(self):
        ''' Testing End-to-end ASR system '''
        while True:
            try:
                filename = input("Input wav file name: ")
                if filename == "exit":
                    return
                feat, feat_len = self.dealer(filename)
                feat = feat.to(self.device)
                feat_len = feat_len.to(self.device)
                # Decode
                with torch.no_grad():
                    hyps = self.decoder(feat, feat_len)

                hyp_seqs = [hyp.outIndex for hyp in hyps]
                hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs]
                for txt in hyp_txts:
                    print(txt)
            except:
                print("Invalid file")
                pass

    def recognize(self, filename):
        try:
            feat, feat_len = self.dealer(filename)
            feat = feat.to(self.device)
            feat_len = feat_len.to(self.device)
            # Decode
            with torch.no_grad():
                hyps = self.decoder(feat, feat_len)
            
            hyp_seqs = [hyp.outIndex for hyp in hyps]
            hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs]
            return hyp_txts[0]
        except Exception as e:
            print(e)
            app.logger.debug(e)
            return "Invalid file"

    def fetch_finetune_data(self, filename, fixed_text):
        feat, feat_len = self.dealer(filename)
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        text = self.tokenizer.encode(fixed_text)
        text = torch.tensor(text).to(self.device)
        text_len = len(text)
        return [feat, feat_len, text, text_len]

    def merge_batch(self, main_batch, attach_batch):
        max_feat_len = max(main_batch[1])
        max_text_len = max(main_batch[3])
        if attach_batch[0].shape[1] > max_feat_len:
            # reduce extra long example
            attach_batch[0] = attach_batch[0][:,:max_feat_len]
            attach_batch[1][0] = max_feat_len
        else:
            # pad to max_feat_len
            padding = torch.zeros(1, max_feat_len - attach_batch[0].shape[1], attach_batch[0].shape[2], dtype=attach_batch[0].dtype).to(self.device)
            attach_batch[0] = torch.cat([attach_batch[0], padding], dim=1)
        if attach_batch[2].shape[0] > max_text_len:
            attach_batch[2] = attach_batch[2][:max_text_len]
            main_batch[3][0] = max_text_len
        else:
            padding = torch.zeros(max_text_len - attach_batch[2].shape[0], dtype=attach_batch[2].dtype).to(self.device)
            try:
                attach_batch[2] = torch.cat([attach_batch[2], padding], dim=0).unsqueeze(0)
            except:
                pdb.set_trace()
        new_batch = (
            torch.cat([main_batch[0], attach_batch[0]], dim=0),
            torch.cat([main_batch[1], attach_batch[1]], dim=0),
            torch.cat([main_batch[2], attach_batch[2]], dim=0),
            torch.cat([main_batch[3], torch.tensor([attach_batch[3]]).to(self.device)], dim=0)
        )
        return new_batch
            


    def finetune(self, filename, fixed_text, max_step=5):
        # Load data for finetune
        self.verbose('Total training steps {}.'.format(
            human_format(max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        accum_count = 0
        self.timer.set()
        step = 0
        for data in self.tr_set:
            # Pre-step : update tf_rate/lr_rate and do zero_grad
            if max_step == 0:
                break
            tf_rate = self.optimizer.pre_step(400000)
            total_loss = 0

            # Fetch data
            finetune_data = self.fetch_finetune_data(filename, fixed_text)
            main_batch = self.fetch_data(data)
            new_batch = self.merge_batch(main_batch, finetune_data)
            feat, feat_len, txt, txt_len = new_batch
            self.timer.cnt('rd')

            # Forward model
            # Note: txt should NOT start w/ <sos>
            ctc_output, encode_len, att_output, att_align, dec_state = \
                self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                            teacher=txt, get_dec_state=self.emb_reg)

            # Plugins
            if self.emb_reg:
                emb_loss, fuse_output = self.emb_decoder(
                    dec_state, att_output, label=txt)
                total_loss += self.emb_decoder.weight*emb_loss

            # Compute all objectives
            if ctc_output is not None:
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                txt.to_sparse().values().to(device='cpu', dtype=torch.int32),
                                                [ctc_output.shape[1]] *
                                                len(ctc_output),
                                                txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(
                        0, 1), txt, encode_len, txt_len)
                total_loss += ctc_loss*self.model.ctc_weight

            if att_output is not None:
                b, t, _ = att_output.shape
                att_output = fuse_output if self.emb_fuse else att_output
                att_loss = self.seq_loss(
                    att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1))
                total_loss += att_loss*(1-self.model.ctc_weight)

            self.timer.cnt('fw')

            # Backprop
            grad_norm = self.backward(total_loss)
            step += 1

            # Logger
            self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                        .format(total_loss.cpu().item(), grad_norm, self.timer.show()))
            self.write_log(
                'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss})
            self.write_log('emb_loss', {'tr': emb_loss})
            self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt),
                                'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
            if self.emb_fuse:
                if self.emb_decoder.fuse_learnable:
                    self.write_log('fuse_lambda', {
                                'emb': self.emb_decoder.get_weight()})
                self.write_log(
                    'fuse_temp', {'temp': self.emb_decoder.get_temp()})

            # End of step
            # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
            torch.cuda.empty_cache()
            self.timer.set()
            if step > max_step:
                break
        ret = self.validate()
        self.log.close()
        return ret


    def validate(self):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None:
            self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                               emb_decoder=self.emb_decoder)

            dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt))
            dev_wer['ctc'].append(cal_er(self.tokenizer, ctc_output, txt, ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set)//2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if True:
                        self.write_log('true_text{}'.format(
                            i), self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log('att_align{}'.format(i), feat_to_fig(
                            att_align[i, 0, :, :].cpu().detach()))
                        self.write_log('att_text{}'.format(i), self.tokenizer.decode(
                            att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log('ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].argmax(dim=-1).tolist(),
                                                                                     ignore_repeat=True))

        # Skip save model here
        # Ckpt if performance improves
        to_prints = []
        for task in ['att', 'ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            if dev_wer[task] < self.best_wer[task]:
                to_print = f"WER of {task}: {dev_wer[task]} < prev best ({self.best_wer[task]})"
                self.best_wer[task] = dev_wer[task]
            else:
                to_print = f"WER of {task}: {dev_wer[task]} >= prev best ({self.best_wer[task]})"
            print(to_print, flush=True)
            to_prints.append(to_print)
        #         self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task])
            self.write_log('wer', {'dv_'+task: dev_wer[task]})
        # self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False)

        # Resume training
        self.model.train()
        if self.emb_decoder is not None:
            self.emb_decoder.train()
        return '\n'.join(to_prints)
Пример #18
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Logger settings
        self.best_wer = {'ctc': 3.0}
        self.best_per = {'ctc': 3.0}
        # Curriculum learning affects data loader
        self.curriculum = self.config['hparas']['curriculum']

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg= \
            load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                         self.curriculum > 0, **self.config['data'])
        self.verbose(msg)

    def transfer_weight(self):
        # Transfer optimizer
        ckpt_path = self.config['data']['transfer'].pop('src_ckpt')
        ckpt = torch.load(ckpt_path, map_location=self.device)

        #optim_ckpt = ckpt['optimizer']
        #for ctc_final_related_param in optim_ckpt['param_groups'][0]['params'][-2:]:
        #    optim_ckpt['state'].pop(ctc_final_related_param)

        #self.optimizer.load_opt_state_dict(optim_ckpt)

        # Load weights
        msg = self.model.transfer_with_mapping(ckpt,
                                               self.config['data']['transfer'],
                                               self.tokenizer)
        del ckpt

        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        self.eval_target = 'phone' if self.config['data']['corpus'][
            'target'] == 'ipa' else 'char'

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        if self.paras.transfer:
            self.transfer_weight()

        # Automatically load pre-trained model if self.paras.load is given
        if self.paras.load:
            self.load_ckpt()
        # ToDo: other training methods

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss = None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                # zero grad here
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len = self.model(feat, feat_len)

                # Compute all objectives
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(
                        ctc_output.transpose(0, 1),
                        txt.to_sparse().values().to(device='cpu',
                                                    dtype=torch.int32),
                        [ctc_output.shape[1]] * len(ctc_output),
                        txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt,
                                             encode_len, txt_len)

                total_loss = ctc_loss

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1
                # Logger

                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
                    ctc_output = [
                        x[:length].argmax(dim=-1)
                        for x, length in zip(ctc_output, encode_len)
                    ]
                    self.write_log(
                        'per', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='per',
                                   ctc=True)
                        })
                    self.write_log(
                        'wer', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='wer',
                                   ctc=True)
                        })
                    self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
        #self.log.close()
    def validate(self):
        # Eval mode
        self.model.eval()
        dev_per = {'ctc': []}
        dev_wer = {'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len = self.model(feat, feat_len)

            ctc_output = [
                x[:length].argmax(dim=-1)
                for x, length in zip(ctc_output, encode_len)
            ]
            dev_per['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True))
            dev_wer['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    #if self.step == 1:
                    self.write_log('true_text{}'.format(i),
                                   self.tokenizer.decode(txt[i].tolist()))
                    self.write_log(
                        'ctc_text{}'.format(i),
                        self.tokenizer.decode(ctc_output[i].tolist(),
                                              ignore_repeat=True))

        # Ckpt if performance improves
        for task in ['ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            dev_per[task] = sum(dev_per[task]) / len(dev_per[task])
            if dev_per[task] < self.best_per[task]:
                self.best_per[task] = dev_per[task]
                self.save_checkpoint('best_{}.pth'.format('per'), 'per',
                                     dev_per[task])
                self.log.log_other('dv_best_per', self.best_per['ctc'])
            if self.eval_target == 'char' and dev_wer[task] < self.best_wer[
                    task]:
                self.best_wer[task] = dev_wer[task]
                self.save_checkpoint('best_{}.pth'.format('wer'), 'wer',
                                     dev_wer[task])
                self.log.log_other('dv_best_wer', self.best_wer['ctc'])

            self.write_log('per', {'dv_' + task: dev_per[task]})
            if self.eval_target == 'char':
                self.write_log('wer', {'dv_' + task: dev_wer[task]})
        self.save_checkpoint('latest.pth',
                             'per',
                             dev_per['ctc'],
                             show_msg=False)
        if self.paras.save_every:
            self.save_checkpoint(f'{self.step}.path',
                                 'per',
                                 dev_per['ctc'],
                                 show_msg=False)

        # Resume training
        self.model.train()
Пример #19
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # Curriculum learning affects data loader
        self.curriculum = self.config['hparas']['curriculum']
        self.val_mode = self.config['hparas']['val_mode'].lower()
        self.WER = 'per' if self.val_mode == 'per' else 'wer'

    def fetch_data(self, data, train=False):
        ''' Move data to device and compute text seq. length'''
        # feat: B x T x D
        _, feat, feat_len, txt = data

        if self.paras.upstream is not None:
            # feat is raw waveform
            device = 'cpu' if self.paras.deterministic else self.device
            self.upstream.to(device)
            self.specaug.to(device)

            def to_device(feat):
                return [f.to(device) for f in feat]

            def extract_feature(feat):
                feat = self.upstream(to_device(feat))
                if train and self.config['data']['audio'][
                        'augment'] and 'aug' not in self.paras.upstream:
                    feat = [self.specaug(f) for f in feat]
                return feat

            if HALF_BATCHSIZE_AUDIO_LEN < 3500 and train:
                first_len = extract_feature(feat[:1])[0].shape[0]
                if first_len > HALF_BATCHSIZE_AUDIO_LEN:
                    feat = feat[::2]
                    txt = txt[::2]

            if self.paras.upstream_trainable:
                self.upstream.train()
                feat = extract_feature(feat)
            else:
                with torch.no_grad():
                    self.upstream.eval()
                    feat = extract_feature(feat)

            feat_len = torch.LongTensor([len(f) for f in feat])
            feat = pad_sequence(feat, batch_first=True)
            txt = pad_sequence(txt, batch_first=True)

        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        if self.paras.upstream is not None:
            print(f'[Solver] - using S3PRL {self.paras.upstream}')
            self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \
                            load_wav_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                        self.curriculum>0,
                                        **self.config['data'])
            self.upstream = torch.hub.load(
                's3prl/s3prl',
                self.paras.upstream,
                feature_selection=self.paras.upstream_feature_selection,
                refresh=self.paras.upstream_refresh,
                ckpt=self.paras.upstream_ckpt,
                force_reload=True,
            )
            self.feat_dim = self.upstream.get_output_dim()
            self.specaug = Augment()
        else:
            self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      self.curriculum>0,
                                      **self.config['data'])
        self.verbose(msg)

        # Dev set sames
        self.dv_names = []
        if type(self.dv_set) is list:
            for ds in self.config['data']['corpus']['dev_split']:
                self.dv_names.append(ds[0])
        else:
            self.dv_names = self.config['data']['corpus']['dev_split'][0]

        # Logger settings
        if type(self.dv_names) is str:
            self.best_wer = {
                'att': {
                    self.dv_names: 3.0
                },
                'ctc': {
                    self.dv_names: 3.0
                }
            }
        else:
            self.best_wer = {'att': {}, 'ctc': {}}
            for name in self.dv_names:
                self.best_wer['att'][name] = 3.0
                self.best_wer['ctc'][name] = 3.0

    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        #print(self.feat_dim) #160
        batch_size = self.config['data']['corpus']['batch_size'] // 2
        self.model = ASR(self.feat_dim, self.vocab_size, batch_size,
                         **self.config['model']).to(self.device)

        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        '''label smoothing'''
        if self.config['hparas']['label_smoothing']:
            self.seq_loss = LabelSmoothingLoss(31, 0.1)
            print('[INFO]  using label smoothing. ')
        else:
            self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        self.ctc_loss = torch.nn.CTCLoss(
            blank=0,
            zero_infinity=False)  # Note: zero_infinity=False is unstable?

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.lr_scheduler = self.optimizer.lr_scheduler
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Transfer Learning
        if self.transfer_learning:
            self.verbose('Apply transfer learning: ')
            self.verbose('      Train encoder layers: {}'.format(
                self.train_enc))
            self.verbose('      Train decoder:        {}'.format(
                self.train_dec))
            self.verbose('      Save name:            {}'.format(
                self.save_name))

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        if self.transfer_learning:
            self.model.encoder.fix_layers(self.fix_enc)
            if self.fix_dec and self.model.enable_att:
                self.model.decoder.fix_layers()
            if self.fix_dec and self.model.enable_ctc:
                self.model.fix_ctc_layer()

        self.n_epochs = 0
        self.timer.set()
        '''early stopping for ctc '''
        self.early_stoping = self.config['hparas']['early_stopping']
        stop_epoch = 10
        batch_size = self.config['data']['corpus']['batch_size']
        stop_step = len(self.tr_set) * stop_epoch // batch_size

        while self.step < self.max_step:
            ctc_loss, att_loss, emb_loss = None, None, None
            # Renew dataloader to enable random sampling

            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      False, **self.config['data'])

            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data,
                                                               train=True)

                self.timer.cnt('rd')
                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate,
                                    teacher=txt, get_dec_state=self.emb_reg)
                # Clear not used objects
                del att_align

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss
                else:
                    del dec_state
                ''' early stopping ctc'''
                if self.early_stoping:
                    if self.step > stop_step:
                        ctc_output = None
                        self.model.ctc_weight = 0
                #print(ctc_output.shape)
                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            #[int(encode_len.max()) for _ in encode_len],
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight
                    del encode_len

                if att_output is not None:
                    #print(att_output.shape)
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(att_output.view(b * t, -1),
                                             txt.view(-1))
                    # Sum each uttr and devide by length then mean over batch
                    # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float())
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\
                            .format(total_loss.cpu().item(),grad_norm,self.timer.show()))
                    self.write_log('emb_loss', {'tr': emb_loss})
                    if att_output is not None:
                        self.write_log('loss', {'tr_att': att_loss})
                        self.write_log(self.WER, {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt)
                        })
                        self.write_log(
                            'cer', {
                                'tr_att':
                                cal_er(self.tokenizer,
                                       att_output,
                                       txt,
                                       mode='cer')
                            })
                    if ctc_output is not None:
                        self.write_log('loss', {'tr_ctc': ctc_loss})
                        self.write_log(
                            self.WER, {
                                'tr_ctc':
                                cal_er(
                                    self.tokenizer, ctc_output, txt, ctc=True)
                            })
                        self.write_log(
                            'cer', {
                                'tr_ctc':
                                cal_er(self.tokenizer,
                                       ctc_output,
                                       txt,
                                       mode='cer',
                                       ctc=True)
                            })
                        self.write_log(
                            'ctc_text_train',
                            self.tokenizer.decode(
                                ctc_output[0].argmax(dim=-1).tolist(),
                                ignore_repeat=True))
                    # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0:
                    #     self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True))
                    #del total_loss

                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    if type(self.dv_set) is list:
                        for dv_id in range(len(self.dv_set)):
                            self.validate(self.dv_set[dv_id],
                                          self.dv_names[dv_id])
                    else:
                        self.validate(self.dv_set, self.dv_names)
                if self.step % (len(self.tr_set) //
                                batch_size) == 0:  # one epoch
                    print('Have finished epoch: ', self.n_epochs)
                    self.n_epochs += 1

                if self.lr_scheduler == None:
                    lr = self.optimizer.opt.param_groups[0]['lr']

                    if self.step == 1:
                        print(
                            '[INFO]    using lr schedular defined by Daniel, init lr = ',
                            lr)

                    if self.step > 99999 and self.step % 2000 == 0:
                        lr = lr * 0.85
                        for param_group in self.optimizer.opt.param_groups:
                            param_group['lr'] = lr
                        print('[INFO]     at step:', self.step)
                        print('[INFO]   lr reduce to', lr)

                    #self.lr_scheduler.step(total_loss)
                # End of step
                # if self.step % EMPTY_CACHE_STEP == 0:
                # Empty cuda cache after every fixed amount of steps
                torch.cuda.empty_cache(
                )  # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                self.timer.set()
                if self.step > self.max_step: break

            #update lr_scheduler

        self.log.close()
        print('[INFO] Finished training after', human_format(self.max_step),
              'steps.')

    def validate(self, _dv_set, _name):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None: self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}
        dev_cer = {'att': [], 'ctc': []}
        dev_er = {'att': [], 'ctc': []}

        for i, data in enumerate(_dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(_dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model( feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                                    emb_decoder=self.emb_decoder)

            if att_output is not None:
                dev_wer['att'].append(
                    cal_er(self.tokenizer, att_output, txt, mode='wer'))
                dev_cer['att'].append(
                    cal_er(self.tokenizer, att_output, txt, mode='cer'))
                dev_er['att'].append(
                    cal_er(self.tokenizer, att_output, txt,
                           mode=self.val_mode))
            if ctc_output is not None:
                dev_wer['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode='wer',
                           ctc=True))
                dev_cer['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode='cer',
                           ctc=True))
                dev_er['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode=self.val_mode,
                           ctc=True))

            # Show some example on tensorboard
            if i == len(_dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if self.step == 1:
                        self.write_log('true_text_{}_{}'.format(_name, i),
                                       self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log(
                            'att_align_{}_{}'.format(_name, i),
                            feat_to_fig(att_align[i, 0, :, :].cpu().detach()))
                        self.write_log(
                            'att_text_{}_{}'.format(_name, i),
                            self.tokenizer.decode(
                                att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log(
                            'ctc_text_{}_{}'.format(_name, i),
                            self.tokenizer.decode(
                                ctc_output[i].argmax(dim=-1).tolist(),
                                ignore_repeat=True))

        # Ckpt if performance improves
        tasks = []
        if len(dev_er['att']) > 0:
            tasks.append('att')
        if len(dev_er['ctc']) > 0:
            tasks.append('ctc')

        for task in tasks:
            dev_er[task] = sum(dev_er[task]) / len(dev_er[task])
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            dev_cer[task] = sum(dev_cer[task]) / len(dev_cer[task])
            if dev_er[task] < self.best_wer[task][_name]:
                self.best_wer[task][_name] = dev_er[task]
                self.save_checkpoint(
                    'best_{}_{}.pth'.format(
                        task, _name +
                        (self.save_name if self.transfer_learning else '')),
                    self.val_mode, dev_er[task], _name)
            if self.step >= self.max_step:
                self.save_checkpoint(
                    'last_{}_{}.pth'.format(
                        task, _name +
                        (self.save_name if self.transfer_learning else '')),
                    self.val_mode, dev_er[task], _name)
            self.write_log(self.WER,
                           {'dv_' + task + '_' + _name.lower(): dev_wer[task]})
            self.write_log('cer',
                           {'dv_' + task + '_' + _name.lower(): dev_cer[task]})
            # if self.transfer_learning:
            #     print('[{}] WER {:.4f} / CER {:.4f} on {}'.format(human_format(self.step), dev_wer[task], dev_cer[task], _name))

        # Resume training
        self.model.train()
        if self.transfer_learning:
            self.model.encoder.fix_layers(self.fix_enc)
            if self.fix_dec and self.model.enable_att:
                self.model.decoder.fix_layers()
            if self.fix_dec and self.model.enable_ctc:
                self.model.fix_ctc_layer()

        if self.emb_decoder is not None: self.emb_decoder.train()
Пример #20
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # ToDo : support tr/eval on different corpus
        assert self.config['data']['corpus']['name'] == self.src_config[
            'data']['corpus']['name']
        self.config['data']['corpus']['path'] = self.src_config['data'][
            'corpus']['path']
        self.config['data']['corpus']['bucketing'] = False

        # The follow attribute should be identical to training config
        self.config['data']['audio'] = self.src_config['data']['audio']
        self.config['data']['text'] = self.src_config['data']['text']
        self.config['hparas'] = self.src_config['hparas']
        self.config['model'] = self.src_config['model']

        # Output file
        self.output_file = str(self.ckpdir) + '_{}_{}.csv'

        # Override batch size for beam decoding
        self.greedy = self.config['decode']['beam_size'] == 1
        if not self.greedy:
            self.config['data']['corpus']['batch_size'] = 1

        self.step = 0

    def fetch_data(self, data):
        ''' Move data to device and compute text seq. length,
            For Greedy decoding only ( beam_decode & ctc_beam_decode otherwise)'''
        _, feat, feat_len, txt = data
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.dv_set, self.tt_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
            load_dataset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, False, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)

        # Plug-ins
        if ('emb' in self.config) and (self.config['emb']['enable']) \
                and (self.config['emb']['fuse'] > 0):
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(self.tokenizer,
                                                    self.model.dec_dim,
                                                    **self.config['emb'])

        # Load target model in eval mode
        self.load_ckpt()

        self.ctc_only = False
        if self.greedy:
            # Greedy decoding: attention-based if the ASR has a decoder, else use CTC
            self.decoder = copy.deepcopy(self.model).to(self.device)
        else:
            if (not self.model.enable_att) or self.config['decode'].get(
                    'ctc_weight', 0.0) == 1.0:
                # Pure CTC Beam Decoder
                assert self.config['decode']['beam_size'] <= self.config[
                    'decode']['vocab_candidate']
                self.decoder = CTCBeamDecoder(
                    self.model.to(self.device),
                    [1] + [r for r in range(3, self.vocab_size)],
                    self.config['decode']['beam_size'],
                    self.config['decode']['vocab_candidate'],
                    lm_path=self.config['decode']['lm_path'],
                    lm_config=self.config['decode']['lm_config'],
                    lm_weight=self.config['decode']['lm_weight'],
                    device=self.device)
                self.ctc_only = True
            else:
                # Joint CTC-Attention Beam Decoder
                self.decoder = BeamDecoder(self.model.cpu(), self.emb_decoder,
                                           **self.config['decode'])

        self.verbose(self.decoder.create_msg())
        del self.model
        del self.emb_decoder

    def greedy_decode(self, dv_set):
        ''' Greedy Decoding '''
        results = []
        for i, data in enumerate(dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)
            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.decoder( feat, feat_len, int(float(feat_len.max()) * self.config['decode']['max_len_ratio']),
                                    emb_decoder=self.emb_decoder)
            for j in range(len(txt)):
                idx = j + self.config['data']['corpus']['batch_size'] * i
                if att_output is not None:
                    hyp_seqs = att_output[j].argmax(dim=-1).tolist()
                else:
                    hyp_seqs = ctc_output[j].argmax(dim=-1).tolist()
                true_txt = txt[j]
                results.append((str(idx), [hyp_seqs], true_txt))
        return results

    def exec(self):
        ''' Testing End-to-end ASR system '''
        for s, ds in zip(['dev', 'test'], [self.dv_set, self.tt_set]):
            # Setup output
            self.cur_output_path = self.output_file.format(s, 'output')
            with open(self.cur_output_path, 'w', encoding='UTF-8') as f:
                f.write('idx\thyp\ttruth\n')

            if self.greedy:
                # Greedy decode
                self.verbose(
                    'Performing batch-wise greedy decoding on {} set, num of batch = {}.'
                    .format(s, len(ds)))
                results = self.greedy_decode(ds)
                self.verbose('Results will be stored at {}'.format(
                    self.cur_output_path))
                self.write_hyp(results, self.cur_output_path, '-')
            elif self.ctc_only:
                # CTC beam decode
                # Additional output to store all beams
                self.cur_beam_path = self.output_file.format(
                    s, 'beam-{}-{}'.format(self.config['decode']['beam_size'],
                                           self.config['decode']['lm_weight']))
                with open(self.cur_beam_path, 'w') as f:
                    f.write('idx\tbeam\thyp\ttruth\n')
                self.verbose(
                    'Performing instance-wise CTC beam decoding on {} set, num of batch = {}.'
                    .format(s, len(ds)))
                # Minimal function to pickle
                ctc_beam_decode_func = partial(ctc_beam_decode,
                                               model=copy.deepcopy(
                                                   self.decoder),
                                               device=self.device)
                # Parallel beam decode
                results = Parallel(n_jobs=self.paras.njobs)(
                    delayed(ctc_beam_decode_func)(data) for data in tqdm(ds))
                self.verbose('Results/Beams will be stored at {} / {}'.format(
                    self.cur_output_path, self.cur_beam_path))
                self.write_hyp(results, self.cur_output_path,
                               self.cur_beam_path)
            else:
                # Joint CTC-Attention beam decode
                # Additional output to store all beams
                self.cur_beam_path = self.output_file.format(
                    s, 'beam-{}-{}'.format(self.config['decode']['beam_size'],
                                           self.config['decode']['lm_weight']))
                with open(self.cur_beam_path, 'w') as f:
                    f.write('idx\tbeam\thyp\ttruth\n')
                self.verbose(
                    'Performing instance-wise beam decoding on {} set. (NOTE: use --njobs to speedup)'
                    .format(s))
                # Minimal function to pickle
                beam_decode_func = partial(beam_decode,
                                           model=copy.deepcopy(self.decoder),
                                           device=self.device)
                # Parallel beam decode
                results = Parallel(n_jobs=self.paras.njobs)(
                    delayed(beam_decode_func)(data) for data in tqdm(ds))
                self.verbose('Results/Beams will be stored at {} / {}.'.format(
                    self.cur_output_path, self.cur_beam_path))
                self.write_hyp(results, self.cur_output_path,
                               self.cur_beam_path)
        self.verbose('All done !')

    def write_hyp(self, results, best_path, beam_path):
        '''Record decoding results'''
        if self.greedy:
            # Ignores repeated symbols if is decoded with CTC
            ignore_repeat = not self.decoder.enable_att
        else:
            ignore_repeat = False
        for name, hyp_seqs, truth in tqdm(results):
            if self.ctc_only and not self.greedy:
                new_hyp_seqs = [
                    self.tokenizer.decode(hyp, ignore_repeat=False)
                    for hyp in hyp_seqs[:-1]
                ]
                hyp_seqs = new_hyp_seqs + [
                    self.tokenizer.decode(hyp_seqs[-1], ignore_repeat=True)
                ]
            else:
                hyp_seqs = [self.tokenizer.decode(hyp) for hyp in hyp_seqs]

            truth = self.tokenizer.decode(truth)
            with open(best_path, 'a') as f:
                if len(hyp_seqs[0]) == 0:
                    # Set the sequence to a whitespace if it was empty
                    hyp_seqs[0] = ' '
                f.write('\t'.join([name, hyp_seqs[0], truth]) + '\n')
            if not self.greedy:
                with open(beam_path, 'a', encoding='UTF-8') as f:
                    for b, hyp in enumerate(hyp_seqs):
                        f.write('\t'.join([name, str(b), hyp, truth]) + '\n')
Пример #21
0
class Solver(BaseSolver):
    ''' Solver for training language models'''

    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Logger settings
        self.best_loss = 10

    def fetch_data(self, data):
        ''' Move data to device, insert <sos> and compute text seq. length'''
        txt = torch.cat(
            (torch.zeros((data.shape[0], 1), dtype=torch.long), data), dim=1).to(self.device)
        txt_len = torch.sum(data != 0, dim=-1)
        return txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \
            load_textset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        self.model = ASR(self.vocab_size, **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            self.model.parameters(), **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step))

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        self.timer.set()

        while self.step < self.max_step:
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                self.optimizer.pre_step(self.step)

                # Fetch data
                txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                pred, _ = self.model(txt[:, :-1], txt_len)

                # Compute all objectives
                lm_loss = self.seq_loss(
                    pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1))
                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(lm_loss)
                self.step += 1

                # Logger
                if self.step % self.PROGRESS_STEP == 0:
                    self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                                  .format(lm_loss.cpu().item(), grad_norm, self.timer.show()))
                    self.write_log('entropy', {'tr': lm_loss})
                    self.write_log(
                        'perplexity', {'tr': torch.exp(lm_loss).cpu().item()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                self.timer.set()
                if self.step > self.max_step:
                    break
        self.log.close()

    def validate(self):
        # Eval mode
        self.model.eval()
        dev_loss = []

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set)))
            # Fetch data
            txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                pred, _ = self.model(txt[:, :-1], txt_len)
            lm_loss = self.seq_loss(
                pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1))
            dev_loss.append(lm_loss)

        # Ckpt if performance improves
        dev_loss = sum(dev_loss)/len(dev_loss)
        dev_ppx = torch.exp(dev_loss).cpu().item()
        if dev_loss < self.best_loss:
            self.best_loss = dev_loss
            self.save_checkpoint('best_ppx.pth', 'perplexity', dev_ppx)
        self.write_log('entropy', {'dv': dev_loss})
        self.write_log('perplexity', {'dv': dev_ppx})

        # Show some example of last batch on tensorboard
        for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
            if self.step == 1:
                self.write_log('true_text{}'.format(
                    i), self.tokenizer.decode(txt[i].tolist()))
            self.write_log('pred_text{}'.format(i), self.tokenizer.decode(
                pred[i].argmax(dim=-1).tolist()))

        # Resume training
        self.model.train()