예제 #1
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        self.eval_target = 'phone' if self.config['data']['corpus'][
            'target'] == 'ipa' else 'char'

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        if self.paras.transfer:
            self.transfer_weight()

        # Automatically load pre-trained model if self.paras.load is given
        if self.paras.load:
            self.load_ckpt()
예제 #2
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''

        # Model
        # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device)
        self.model = Prediction(self.vocab_size,
                                **self.config['model']).to(self.device)
        self.rnnlm = RNNLM(self.vocab_size,
                           **self.config['model']).to(self.device)

        self.verbose(self.rnnlm.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            list(self.model.parameters()) + list(self.rnnlm.parameters()),
            **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step))
예제 #3
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        #print(self.feat_dim) #160
        batch_size = self.config['data']['corpus']['batch_size'] // 2
        self.model = ASR(self.feat_dim, self.vocab_size, batch_size,
                         **self.config['model']).to(self.device)

        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        '''label smoothing'''
        if self.config['hparas']['label_smoothing']:
            self.seq_loss = LabelSmoothingLoss(31, 0.1)
            print('[INFO]  using label smoothing. ')
        else:
            self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        self.ctc_loss = torch.nn.CTCLoss(
            blank=0,
            zero_infinity=False)  # Note: zero_infinity=False is unstable?

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.lr_scheduler = self.optimizer.lr_scheduler
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Transfer Learning
        if self.transfer_learning:
            self.verbose('Apply transfer learning: ')
            self.verbose('      Train encoder layers: {}'.format(
                self.train_enc))
            self.verbose('      Train decoder:        {}'.format(
                self.train_dec))
            self.verbose('      Save name:            {}'.format(
                self.save_name))

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
예제 #4
0
    def set_model(self):
        ''' Setup ASR model '''
        # Model
        self.feat_dim = 120
        self.vocab_size = 46 
        init_adadelta = True
        ''' Setup ASR model and optimizer '''
        # Model
        # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.src_config['model']).to(self.device)
        self.verbose(self.model.create_msg())

        if self.finetune_first>0:
            names = ["encoder.layers.%d"%i for i in range(self.finetune_first)]
            model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}]
        else:
            model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb' in self.config) and (
            self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.src_config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
        # Beam decoder
        self.decoder = BeamDecoder(
            self.model, self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        # del self.model
        # del self.emb_decoder
        self.decoder.to(self.device)
예제 #5
0
    def set_model(self):
        ''' Setup model and optimizer '''
        # Load SSL models for feature extraction
        self.verbose([' Load feat. extractor ckpt from '\
                        +self.config['model']['feat']['ckpt']])
        if self.feature in ['apc', 'vqapc']:
            from model.apc import APC as Net
        elif self.feature == 'npc':
            from model.npc import NPC as Net
            if self.feat_spec is not None:
                self.verbose([' Using specific feature: ' + self.feat_spec])
        else:
            raise NotImplementedError
        self.feat_extractor = Net(input_size=self.audio_dim,
                                  **self.ssl_config['model']['paras'])
        ckpt = torch.load(
            self.config['model']['feat']['ckpt'],
            map_location=self.device if self.mode == 'train' else 'cpu')
        ckpt['model'] = {k.replace('module.','',1):v \
                            for k,v in ckpt['model'].items()}
        self.feat_extractor.load_state_dict(ckpt['model'])

        # Classifier model
        self.model = CLF(feat_dim=self.feat_extractor.code_dim,
                         **self.config['model']['clf'])
        if self.gpu:
            self.feat_extractor = self.feat_extractor.cuda()
            self.feat_extractor.eval()
            self.model = self.model.cuda()
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        ignore_idx = 0 if self.task == 'phn-clf' else -1
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx)
        if self.gpu:
            self.loss = self.loss.cuda()

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        self.load_ckpt()
        self.model.train()
예제 #6
0
    def set_model(self):
        ''' Setup Audio AE-model and optimizer '''
        # Model
        self.model = VQVAE(self.n_mels, self.linear_dim, self.vocab_size,
                           self.n_spkr, **self.config['model']).to(self.device)
        self.n_frames_per_step = self.model.n_frames_per_step
        self.verbose(self.model.create_msg())

        # Objective
        self.freq_loss = partial(
            freq_loss,
            sample_rate=self.audio_converter.sr,
            n_mels=self.audio_converter.n_mels,
            loss=self.config['hparas']['freq_loss_type'],
            differential_loss=self.config['hparas']['differential_loss'],
            emphasize_linear_low=self.config['hparas']['emphasize_linear_low'])
        self.ctc_loss = torch.nn.CTCLoss()
        self.stop_loss = torch.nn.BCEWithLogitsLoss()

        # Optimizer
        self.optimizer = Optimizer(self.model.parameters(),
                                   **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())
        ### ToDo : unsup first?
        self.verbose('           | ASR weight = {}\t| start step = {}'.format(
            self.asr_weight, 0))
        self.verbose('           | TTS weight = {}\t| start step = {}'.format(
            self.tts_weight, 0))
        self.verbose('           | Txt weight = {}\t| start step = {}'.format(
            self.unpair_text_weight, self.unpair_text_start_step))
        self.verbose('           | Sph weight = {}\t| start step = {}'.format(
            self.unpair_speech_weight, self.unpair_speech_start_step))
        # ToDo: load pre-trained model
        if self.paras.load:
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step))
예제 #7
0
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        self.paras.load = 'ckpt/asr_example_sd0/best_att.pth'

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        self.model = ASR(self.feat_dim, self.vocab_size, **
                         self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
예제 #9
0
    def set_model(self):
        ''' Setup model and optimizer '''
        # Model
        self.method = self.config['model']['method']
        if self.method in ['apc','vqapc']:
            self.n_future = self.config['model']['n_future']
            from model.apc import APC as Net
        elif self.method == 'npc':
            from model.npc import NPC as Net
        else:
            raise NotImplementedError
        self.model = Net(input_size=self.audio_dim, **self.config['model']['paras'])
        if self.gpu:
            self.model = self.model.cuda()
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Loss
        if 'npc' in self.method:
            # Avoid reduction for NPC for zero-padding
            self.loss = torch.nn.L1Loss(reduction='none')
        else:
            # APC family have zero-padding with torch API
            self.loss = torch.nn.L1Loss()
        if self.gpu:
            self.loss = self.loss.cuda()

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()

        # ToDo:  Data Parallel?
        # self.model = torch.nn.DataParallel(self.model)
        self.model.train()
예제 #10
0
    def __init__(
        self,
        attack_configs,
        discriminator_configs,
        src_vocab,
        trg_vocab,
        data_iterator,
        save_to,
        device="cpu",
    ):
        """
        initiate translation environments, needs a discriminator and translator
        :param attack_configs: attack configures dictionary
        :param save_to: discriminator models
        :param data_iterator: use to provide data for environment initiate
        the directory of the src sentences
        :param device: (string) devices to allocate variables("cpu", "cuda:*")
        default as cpu
        """
        self.data_iterator = data_iterator
        discriminator_model_configs = discriminator_configs[
            "discriminator_model_configs"]
        discriminator_optim_configs = discriminator_configs[
            "discriminator_optimizer_configs"]
        self.victim_config_path = attack_configs["victim_configs"]
        self.victim_model_path = attack_configs["victim_model"]
        # determine devices
        self.device = device
        with open(self.victim_config_path.strip()) as v_f:
            print("open victim configs...%s" % self.victim_config_path)
            victim_configs = yaml.load(v_f)

        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.translate_model = build_translate_model(victim_configs,
                                                     self.victim_model_path,
                                                     vocab_src=self.src_vocab,
                                                     vocab_trg=self.trg_vocab,
                                                     device=self.device)
        self.translate_model.eval()
        self.w2p, self.w2vocab = load_or_extract_near_vocab(
            config_path=self.victim_config_path,
            model_path=self.victim_model_path,
            init_perturb_rate=attack_configs["init_perturb_rate"],
            save_to=os.path.join(save_to, "near_vocab"),
            save_to_full=os.path.join(save_to, "full_near_vocab"),
            top_reserve=12,
            emit_as_id=True)
        #########################################################
        # to update discriminator
        # discriminator_data_configs = attack_configs["discriminator_data_configs"]
        self.discriminator = TransDiscriminator(
            n_src_words=self.src_vocab.max_n_words,
            n_trg_words=self.trg_vocab.max_n_words,
            **discriminator_model_configs)
        self.discriminator.to(self.device)

        load_embedding(self.discriminator,
                       model_path=self.victim_model_path,
                       device=self.device)

        self.optim_D = Optimizer(
            name=discriminator_optim_configs["optimizer"],
            model=self.discriminator,
            lr=discriminator_optim_configs["learning_rate"],
            grad_clip=discriminator_optim_configs["grad_clip"],
            optim_args=discriminator_optim_configs["optimizer_params"])
        self.criterion_D = nn.CrossEntropyLoss(
        )  # used in discriminator updates
        self.scheduler_D = None  # default as None
        if discriminator_optim_configs['schedule_method'] is not None:
            if discriminator_optim_configs['schedule_method'] == "loss":
                self.scheduler_D = ReduceOnPlateauScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs["scheduler_configs"])
            elif discriminator_optim_configs['schedule_method'] == "noam":
                self.scheduler_D = NoamScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs['scheduler_configs'])
            elif discriminator_optim_configs["schedule_method"] == "rsqrt":
                self.scheduler_D = RsqrtScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(discriminator_optim_configs['schedule_method']))
        ############################################################
        self._init_state()
        self.adversarial = attack_configs[
            "adversarial"]  # adversarial sample or reinforced samples
        self.r_s_weight = attack_configs["r_s_weight"]
        self.r_d_weight = attack_configs["r_d_weight"]
예제 #11
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Logger settings
        self.best_wer = {'ctc': 3.0}
        self.best_per = {'ctc': 3.0}
        # Curriculum learning affects data loader
        self.curriculum = self.config['hparas']['curriculum']

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg= \
            load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                         self.curriculum > 0, **self.config['data'])
        self.verbose(msg)

    def transfer_weight(self):
        # Transfer optimizer
        ckpt_path = self.config['data']['transfer'].pop('src_ckpt')
        ckpt = torch.load(ckpt_path, map_location=self.device)

        #optim_ckpt = ckpt['optimizer']
        #for ctc_final_related_param in optim_ckpt['param_groups'][0]['params'][-2:]:
        #    optim_ckpt['state'].pop(ctc_final_related_param)

        #self.optimizer.load_opt_state_dict(optim_ckpt)

        # Load weights
        msg = self.model.transfer_with_mapping(ckpt,
                                               self.config['data']['transfer'],
                                               self.tokenizer)
        del ckpt

        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        self.eval_target = 'phone' if self.config['data']['corpus'][
            'target'] == 'ipa' else 'char'

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        if self.paras.transfer:
            self.transfer_weight()

        # Automatically load pre-trained model if self.paras.load is given
        if self.paras.load:
            self.load_ckpt()
        # ToDo: other training methods

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss = None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                # zero grad here
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len = self.model(feat, feat_len)

                # Compute all objectives
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(
                        ctc_output.transpose(0, 1),
                        txt.to_sparse().values().to(device='cpu',
                                                    dtype=torch.int32),
                        [ctc_output.shape[1]] * len(ctc_output),
                        txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt,
                                             encode_len, txt_len)

                total_loss = ctc_loss

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1
                # Logger

                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
                    ctc_output = [
                        x[:length].argmax(dim=-1)
                        for x, length in zip(ctc_output, encode_len)
                    ]
                    self.write_log(
                        'per', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='per',
                                   ctc=True)
                        })
                    self.write_log(
                        'wer', {
                            'tr_ctc':
                            cal_er(self.tokenizer,
                                   ctc_output,
                                   txt,
                                   mode='wer',
                                   ctc=True)
                        })
                    self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
        #self.log.close()
    def validate(self):
        # Eval mode
        self.model.eval()
        dev_per = {'ctc': []}
        dev_wer = {'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len = self.model(feat, feat_len)

            ctc_output = [
                x[:length].argmax(dim=-1)
                for x, length in zip(ctc_output, encode_len)
            ]
            dev_per['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True))
            dev_wer['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    #if self.step == 1:
                    self.write_log('true_text{}'.format(i),
                                   self.tokenizer.decode(txt[i].tolist()))
                    self.write_log(
                        'ctc_text{}'.format(i),
                        self.tokenizer.decode(ctc_output[i].tolist(),
                                              ignore_repeat=True))

        # Ckpt if performance improves
        for task in ['ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            dev_per[task] = sum(dev_per[task]) / len(dev_per[task])
            if dev_per[task] < self.best_per[task]:
                self.best_per[task] = dev_per[task]
                self.save_checkpoint('best_{}.pth'.format('per'), 'per',
                                     dev_per[task])
                self.log.log_other('dv_best_per', self.best_per['ctc'])
            if self.eval_target == 'char' and dev_wer[task] < self.best_wer[
                    task]:
                self.best_wer[task] = dev_wer[task]
                self.save_checkpoint('best_{}.pth'.format('wer'), 'wer',
                                     dev_wer[task])
                self.log.log_other('dv_best_wer', self.best_wer['ctc'])

            self.write_log('per', {'dv_' + task: dev_per[task]})
            if self.eval_target == 'char':
                self.write_log('wer', {'dv_' + task: dev_wer[task]})
        self.save_checkpoint('latest.pth',
                             'per',
                             dev_per['ctc'],
                             show_msg=False)
        if self.paras.save_every:
            self.save_checkpoint(f'{self.step}.path',
                                 'per',
                                 dev_per['ctc'],
                                 show_msg=False)

        # Resume training
        self.model.train()
예제 #12
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras):
        super().__init__(config, paras)
        # Logger settings
        self.val_loss = 1000
        self.cur_epoch = 0

    def fetch_data(self, data):
        ''' Move data to device '''
        file_id, audio_feat, audio_len = data
        if self.gpu:
            audio_feat = audio_feat.cuda()
        return file_id, audio_feat, audio_len

    def load_data(self):
        ''' Load data for training/validation '''
        self.tr_set, self.dv_set, _, self.audio_dim, msg = \
            prepare_data(self.paras.njobs, self.paras.dev_njobs, self.paras.gpu,
                         self.paras.pin_memory, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup model and optimizer '''
        # Model
        self.method = self.config['model']['method']
        if self.method in ['apc','vqapc']:
            self.n_future = self.config['model']['n_future']
            from model.apc import APC as Net
        elif self.method == 'npc':
            from model.npc import NPC as Net
        else:
            raise NotImplementedError
        self.model = Net(input_size=self.audio_dim, **self.config['model']['paras'])
        if self.gpu:
            self.model = self.model.cuda()
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Loss
        if 'npc' in self.method:
            # Avoid reduction for NPC for zero-padding
            self.loss = torch.nn.L1Loss(reduction='none')
        else:
            # APC family have zero-padding with torch API
            self.loss = torch.nn.L1Loss()
        if self.gpu:
            self.loss = self.loss.cuda()

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()

        # ToDo:  Data Parallel?
        # self.model = torch.nn.DataParallel(self.model)
        self.model.train()

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training epoch {}.'.format(
            human_format(self.epoch)))
        self.timer.set()
        aug_loss = None
        ep_len = len(self.tr_set)

        for ep in range(self.epoch):
            # Pre-step, decay
            if ep>0:
                self.optimizer.decay()

            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                self.optimizer.pre_step(self.step)
                
                # Fetch data
                _, audio_feat, audio_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward real data
                if 'npc' in self.method:
                    # NPC: input = target
                    pred, _ = self.model(audio_feat)
                    loss = self.loss(pred, audio_feat)
                    # Compute loss on valid part only
                    effective_loss = 0
                    for i,a_len in enumerate(audio_len):
                        effective_loss += loss[i,:a_len,:].mean(dim=-1).sum()
                    loss = effective_loss/sum(audio_len)
                else:
                    # APC: input = shifted target
                    audio_len = [l-self.n_future for l in audio_len]
                    pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=False)
                    loss = self.loss(pred, audio_feat[:,self.n_future:,:])
                self.timer.cnt('fw')
                # Backprop
                grad_norm = self.backward(loss)
                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                                  .format(100*float(self.step%ep_len)/ep_len,
                                          loss.cpu().item(),
                                          grad_norm,
                                          self.timer.show()))
                    self.write_log('loss', {'tr': loss})
                    
                if (self.step == 1) or (self.step % self.PLOT_STEP == 0):
                    # Perplexity of P(token)
                    g1_ppx, g2_ppx = self.model.report_ppx()     
                    self.write_log('ppx', {'group 1':g1_ppx,
                                           'group 2':g2_ppx})
                    g1_usg, g2_usg = self.model.report_usg() # Empty cache
                    # Plots
                    if self.paras.draw:
                        g1_hist = draw(g1_usg, hist=True)
                        g2_hist = draw(g2_usg, hist=True)
                        self.write_log('VQ Group 1 Hist.',g1_hist)
                        self.write_log('VQ Group 2 Hist.',g2_hist)
                        # Some spectrograms
                        plt_idx = 0
                        self.write_log('Spectrogram (raw)', draw(audio_feat[plt_idx]))
                        self.write_log('Spectrogram (pred)', draw(pred[plt_idx]))

                # End of step
                self.timer.set()
            # End of epoch
            self.cur_epoch += 1
            self.validate()
        self.log.close()

    def validate(self):
        # Eval mode
        self.model.eval()
        dev_loss = []
        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set)))
            # Fetch data
            _, audio_feat, audio_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                if 'npc' in self.method:
                    pred, _ = self.model(audio_feat, testing=True)
                    loss = self.loss(pred, audio_feat)
                    # Compute loss on valid part only
                    effective_loss = 0
                    for i,a_len in enumerate(audio_len):
                        effective_loss += loss[i,:a_len,:].mean(dim=-1).sum()
                    loss = effective_loss/sum(audio_len)
                else:
                    audio_len = [l-self.n_future for l in audio_len]
                    pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=True)
                    loss = self.loss(pred, audio_feat[:,self.n_future:,:])
                dev_loss.append(loss.cpu().item())

        # Record metric
        dev_loss = sum(dev_loss)/len(dev_loss)
        self.write_log('loss', {'dev':dev_loss})
        if dev_loss < self.val_loss:
            self.val_loss = dev_loss
            self.save_checkpoint('best_loss.pth', 'loss', dev_loss)
        # Resume training
        self.model.train()
예제 #13
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    train_batch_size = training_configs["batch_size"] * max(1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        ),
        shuffle=training_configs['shuffle']
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=train_batch_size,
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=train_buffer_size,
                                     batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build EMA
    if training_configs['ema_decay'] > 0.0:
        ema = ExponentialMovingAverage(named_params=nmt_model.named_parameters(), decay=training_configs['ema_decay'])
    else:
        ema = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                     total=len(training_iterator),
                                     unit="sents"
                                     )
        for batch in training_iter:

            uidx += 1

            if scheduler is None:
                pass
            elif optimizer_configs["schedule_method"] == "loss":
                scheduler.step(metric=best_valid_loss)
            else:
                scheduler.step(global_step=uidx)

            seqs_x, seqs_y = batch

            n_samples_t = len(seqs_x)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            training_progress_bar.update(n_samples_t)

            optim.zero_grad()

            # Prepare data
            for seqs_x_t, seqs_y_t in split_shard(seqs_x, seqs_y, split_size=training_configs['update_cycle']):
                x, y = prepare_data(seqs_x_t, seqs_y_t, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=n_samples_t,
                                       norm_by_words=training_configs["norm_by_words"])
            optim.step()

            if ema is not None:
                ema.step()

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:

                    checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                          collections=model_collections, ema=ema)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

                best_valid_loss = min_history_loss

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                if ema is not None:
                    origin_state_dict = deepcopy(nmt_model.state_dict())
                    nmt_model.load_state_dict(ema.state_dict(), strict=False)

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"]
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                summary_writer.add_scalar("bleu", valid_bleu, uidx)
                summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=nmt_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                if ema is not None:
                    nmt_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
예제 #14
0
class VqvaeTrainer(BaseSolver):
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Init settings
        self.step = 0
        self.best_tts_loss = 100.0
        self.best_per = 2.0
        self.asr_weight = self.config['hparas']['asr_weight']
        self.tts_weight = self.config['hparas']['tts_weight']
        self.unpair_text_start_step = config['hparas'][
            'unpair_text_start_step']
        self.unpair_text_weight = self.config['hparas']['unpair_text_weight']
        self.unpair_speech_start_step = config['hparas'][
            'unpair_speech_start_step']
        self.unpair_speech_weight = self.config['hparas'][
            'unpair_speech_weight']

    def fetch_data(self, iter_name):
        # Load from iterator
        mel = None
        while mel is None:
            try:
                mel, aug_mel, linear, sid, text = next(getattr(
                    self, iter_name))
            except StopIteration:
                setattr(self, iter_name,
                        iter(getattr(self, iter_name.replace('iter', 'set'))))

        # Pad to match n_frames_per_step (at least 1 frame padded)
        pad_len = self.n_frames_per_step - (mel.shape[1] %
                                            self.n_frames_per_step)
        mel = torch.cat(
            [mel, SPEC_PAD_VALUE * torch.ones_like(mel)[:, :pad_len, :]],
            dim=1)
        linear = torch.cat(
            [linear, SPEC_PAD_VALUE * torch.ones_like(linear)[:, :pad_len, :]],
            dim=1)

        return mel.to(self.device),\
               aug_mel.to(self.device),\
               linear.to(self.device),\
               text.to(self.device),\
               sid.to(self.device)

        #return mel.to(self.device, non_blocking=True),\
        #       aug_mel.to(self.device, non_blocking=True),\
        #       linear.to(self.device, non_blocking=True),\
        #       text.to(self.device, non_blocking=True),\
        #       sid.to(self.device, non_blocking=True)

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.verbose(['Loading data... large corpus may took a while.'])
        self.unpair_set, self.pair_set, self.dev_set, self.test_set, self.audio_converter, self.tokenizer, data_msg = \
                load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data'])
        self.pair_iter = iter(self.pair_set)
        self.unpair_iter = iter(self.unpair_set)
        self.dev_iter = iter(self.dev_set)
        # Feature statics
        self.n_mels, self.linear_dim = self.audio_converter.feat_dim
        self.vocab_size = self.tokenizer.vocab_size
        self.n_spkr = len(
            json.load(open(self.config['data']['corpus']['spkr_map'])))
        self.verbose(data_msg)

    def set_model(self):
        ''' Setup Audio AE-model and optimizer '''
        # Model
        self.model = VQVAE(self.n_mels, self.linear_dim, self.vocab_size,
                           self.n_spkr, **self.config['model']).to(self.device)
        self.n_frames_per_step = self.model.n_frames_per_step
        self.verbose(self.model.create_msg())

        # Objective
        self.freq_loss = partial(
            freq_loss,
            sample_rate=self.audio_converter.sr,
            n_mels=self.audio_converter.n_mels,
            loss=self.config['hparas']['freq_loss_type'],
            differential_loss=self.config['hparas']['differential_loss'],
            emphasize_linear_low=self.config['hparas']['emphasize_linear_low'])
        self.ctc_loss = torch.nn.CTCLoss()
        self.stop_loss = torch.nn.BCEWithLogitsLoss()

        # Optimizer
        self.optimizer = Optimizer(self.model.parameters(),
                                   **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())
        ### ToDo : unsup first?
        self.verbose('           | ASR weight = {}\t| start step = {}'.format(
            self.asr_weight, 0))
        self.verbose('           | TTS weight = {}\t| start step = {}'.format(
            self.tts_weight, 0))
        self.verbose('           | Txt weight = {}\t| start step = {}'.format(
            self.unpair_text_weight, self.unpair_text_start_step))
        self.verbose('           | Sph weight = {}\t| start step = {}'.format(
            self.unpair_speech_weight, self.unpair_speech_start_step))
        # ToDo: load pre-trained model
        if self.paras.load:
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step))

    def exec(self):
        self.verbose(
            ['Total training steps {}.'.format(human_format(self.max_step))])
        self.timer.set()
        unpair_speech_loss, unpair_text_loss, unsup_pred, unsup_trans, unsup_align = None, None, None, None, None
        ctc_nan_flag, ignore_speech_flag = 0, 0
        tok_usage, gt_usage = [], []
        cnter = {'ctc_nan': 0, 'unp_sph': 0, 'unp_txt': 0}

        while self.step < self.max_step:
            # --------------------- Load data ----------------------- #
            # Unpair setting
            unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = None, None, None, None, None
            post_pred, asr_post_loss = None, None  # For ASR postnet only
            use_unpair_text = self.unpair_text_weight > 0 and self.step > self.unpair_text_start_step
            use_unpair_speech = self.unpair_speech_weight > 0 and self.step > self.unpair_speech_start_step

            tf_rate = self.optimizer.pre_step(
                self.step)  # Catch the returned tf_rate if needed
            # ToDo : change # of sup. step = 2 x # of unsup. step ?
            mel, aug_mel, linear, text, sid = self.fetch_data(
                iter_name='pair_iter')

            # Load unpaired data only when use_unpair_xxx == True
            if self.step % 2 == 0:  #2
                # if True:
                # ASR first
                speech_first = True
                if use_unpair_speech:
                    unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \
                                                    self.fetch_data(iter_name='unpair_iter')
            else:
                # TTS first
                speech_first = False
                if use_unpair_text:
                    cnter['unp_txt'] += 1
                    unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \
                                                    self.fetch_data(iter_name='unpair_iter')

            total_loss = 0
            bs = len(mel)
            self.timer.cnt('rd')
            try:
                # ----------------------- Forward ------------------------ #
                if speech_first:
                    # Cycle : speech -> text -> speech
                    pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \
                                self.model.speech_to_text(paired_mel=aug_mel, unpaired_mel= unpair_aug_mel)

                    # Check to involve unsupervised Speech2Speech
                    if unpair_latent is not None:
                        # ASR output is the representataion for speech2speech
                        cnter['unp_sph'] += 1
                        ignore_speech_cycle = False
                        unpaired_teacher = unpair_mel
                    else:
                        # ASR output is all blank (cannot be passed to TTS) only paired text is used
                        ignore_speech_cycle = True
                        unpaired_teacher = None

                    # text -> speech
                    pair_mel_pred, pair_linear_pred, pair_align, _, \
                    unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\
                                self.model.text_to_speech(paired_text = text,
                                                          paired_sid=sid,
                                                          unpaired_sid=unpair_sid,
                                                          unpaired_latent = unpair_latent,
                                                          unpaired_text= None,
                                                          unpaired_latent_len = unpair_latent_len,
                                                          paired_teacher = mel,
                                                          unpaired_teacher = unpaired_teacher,
                                                          tf_rate = tf_rate
                                                         )
                else:
                    # Cycle : text -> speech -> text
                    pair_mel_pred, pair_linear_pred, pair_align, _, \
                    unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\
                                self.model.text_to_speech(paired_text=text,
                                                          paired_sid=sid,
                                                          unpaired_sid=unpair_sid,
                                                          unpaired_latent=None,
                                                          unpaired_text=unpair_text,
                                                          unpaired_latent_len=None,
                                                          paired_teacher=mel,
                                                          unpaired_teacher=None,
                                                          tf_rate=tf_rate
                                                         )
                    if use_unpair_text:
                        unpair_mel_pred = unpair_mel_pred.detach(
                        )  # Stop-grad for tts in text2text
                    pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \
                                self.model.speech_to_text(paired_mel=aug_mel,
                                                          unpaired_mel=unpair_mel_pred,  #None, #unpair_mel_pred, #None, #unpaired_mel= unpair_mel_pred,
                                                          using_fake_mel=use_unpair_text)

                # Paired ASR loss
                asr_loss = self.compute_ctcloss(aug_mel, pair_prob, text)
                if self.model.use_asr_postnet:
                    total_loss = total_loss + self.asr_weight * (
                        1 - self.model.asr_postnet_weight) * asr_loss
                    asr_post_loss = self.compute_ctcloss(aug_mel,
                                                         pair_post_prob,
                                                         text,
                                                         apply_log=False)
                    total_loss = total_loss + self.asr_weight * self.model.asr_postnet_weight * asr_post_loss
                else:
                    total_loss = total_loss + self.asr_weight * asr_loss
                if math.isnan(asr_loss) or math.isinf(asr_loss):
                    cnter['ctc_nan'] += 1
                    asr_loss = 0

                # Paired TTS loss
                mel_loss = self.freq_loss(pair_mel_pred, mel)
                linear_loss = self.freq_loss(pair_linear_pred, linear)
                tts_loss = mel_loss + linear_loss
                total_loss = total_loss + self.tts_weight * tts_loss

                # Unpaired loss
                if speech_first:
                    # Unpaired speech reconstruction loss
                    if not ignore_speech_cycle:
                        unpair_speech_loss = self.freq_loss(unpair_mel_pred, unpair_mel) +\
                                            self.freq_loss(unpair_linear_pred, unpair_linear)
                        #total_loss += self.unpair_speech_weight*unpair_speech_loss
                        if self.step > self.unpair_speech_start_step:
                            total_loss += self.unpair_speech_weight * unpair_speech_loss
                elif use_unpair_text:
                    # Unpaired text reconstruction loss
                    ctc_input = (unpair_prob + EPS).transpose(0, 1).log()
                    if self.paras.actual_len:
                        asr_input_len = (unpair_text != 0).sum(
                            dim=-1) * FRAME_PHN_RATIO
                        asr_input_len = asr_input_len + asr_input_len % self.model.n_frames_per_step
                        ctc_len = 1 + (asr_input_len //
                                       self.model.time_reduce_factor)
                    else:
                        ctc_len = torch.LongTensor(
                            [unpair_prob.shape[1]] *
                            unpair_prob.shape[0]).to(device=self.device)
                    unpair_text_loss = self.ctc_loss(
                        ctc_input,
                        unpair_text.to_sparse().values(), ctc_len,
                        torch.sum(unpair_text != 0, dim=-1))
                    if math.isnan(unpair_text_loss) or math.isinf(
                            unpair_text_loss):
                        cnter['ctc_nan'] += 1
                        unpair_text_loss = 0
                    total_loss += self.unpair_text_weight * unpair_text_loss

                # VQ-loss
                # if vq_loss>0:
                #     total_loss += self.model.vq_weight*vq_loss
                # if commit_loss>0:
                #     total_loss += self.model.commit_weight*commit_loss

                # Statics (over unsup. speech only)
                if speech_first and use_unpair_speech:
                    unsup_pred = unpair_prob.argmax(dim=-1).cpu()
                    unsup_trans = unpair_text.cpu()
                    tok_usage += unsup_pred.flatten().tolist()
                    gt_usage += unsup_trans.flatten().tolist()
                    if unpair_align is not None:
                        unsup_align = unpair_align.detach().cpu()
                    else:
                        unsup_align = [None] * bs

                self.timer.cnt('fw')

                # ----------------------- Backward ------------------------ #
                grad_norm = self.backward(total_loss)
                # For debugging
                # if math.isnan(grad_norm):
                # import IPython
                # IPython.embed()
                self.step += 1

                # Log
                if (self.step == 1) or (self.step % self._PROGRESS_STEP == 0):
                    self.progress('Tr stat | Loss - {:.2f} (CTC-nan/unp-sph/unp-txt={}/{}/{}) | Grad. Norm - {:.2f} | {} '\
                                  .format(total_loss.cpu().item(), cnter['ctc_nan'], cnter['unp_sph'], cnter['unp_txt'],
                                          grad_norm, self.timer.show()))
                    self.write_log(
                        'txt_loss', {
                            'pair':
                            asr_loss.item() if asr_loss is not None else None,
                            'unpair':
                            unpair_text_loss.item()
                            if unpair_text_loss is not None else None,
                            'post':
                            asr_post_loss.item()
                            if asr_post_loss is not None else None
                        })
                    self.write_log(
                        'speech_loss', {
                            'pair':
                            tts_loss.item() if tts_loss is not None else None,
                            'unpair':
                            unpair_speech_loss.item()
                            if unpair_speech_loss is not None else None
                        })
                    #self.write_log('stop_err',{'tr':stop_err})
                    # if commit_loss>0:
                    #     self.write_log('commit',{'tr':commit_loss})
                    # if vq_loss>0:
                    #     self.write_log('commit',{'vq':vq_loss})
                    # self.write_log('temperature',{'temp':self.model.codebook.temp.data})
                    # self.write_log('ppx',{'tr':cal_ppx(p_code)})
                    for k in cnter.keys():
                        cnter[k] = 0
                    if (self.step == 1) or (self.step % ATTENTION_PLOT_STEP
                                            == 0):
                        align = pair_align.cpu()  # align shape BxDsxEs
                        sup_pred = pair_prob.argmax(dim=-1).cpu()
                        sup_trans = text.cpu()
                        if self.model.use_asr_postnet:
                            post_pred = pair_post_prob.argmax(dim=-1).cpu()
                        self.write_log(
                            'per', {
                                'pair': cal_per(sup_pred, sup_trans),
                                'unpair': cal_per(unsup_pred, unsup_trans),
                                'post': cal_per(post_pred, sup_trans)
                            })
                        self.write_log(
                            'unpair_hist',
                            data_to_bar(tok_usage, gt_usage, self.vocab_size,
                                        self.tokenizer._vocab_list))
                        for i in range(LISTEN_N_EXAMPLES):
                            self.write_log(
                                'pair_align{}'.format(i),
                                feat_to_fig(align[i].cpu().detach()))
                            if unsup_align is not None and unsup_align[
                                    i] is not None:
                                self.write_log(
                                    'unpair_align{}'.format(i),
                                    feat_to_fig(unsup_align[i].cpu().detach()))
                        tok_usage, gt_usage = [], []

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                self.timer.set()
                if self.step > self.max_step: break

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    self.verbose('WARNING: ran out of memory, retrying batch')
                    for p in self.model.parameters():
                        if p.grad is not None:
                            del p.grad  # free some memory
                    torch.cuda.empty_cache()
                else:
                    print(repr(e))
                    errorout()

    def validate(self):
        # Eval mode
        self.model.eval()
        dev_tts_loss, dev_per, dev_post_per, dev_stop_err = [], [], [], []

        for i in range(len(self.dev_set)):
            self.progress('Valid step - {}/{}'.format(i + 1,
                                                      len(self.dev_set)))
            # Fetch data
            mel, aug_mel, linear, text, sid = self.fetch_data(
                iter_name='dev_iter')

            # Forward model
            with torch.no_grad():
                # test ASR
                pair_prob, _, _, _, _, pair_post_prob, _ = self.model.speech_to_text(
                    paired_mel=mel, unpaired_mel=None)
                dev_per.append(cal_per(pair_prob, text))
                if pair_post_prob is not None:
                    dev_post_per.append((cal_per(pair_post_prob, text)))

                # test TTS (Note: absolute dec step now)
                pair_mel_pred, pair_linear_pred, pair_align, _, _, _, _, _ = \
                        self.model.text_to_speech(paired_text = text,
                                                  paired_sid=sid,
                                                  unpaired_sid=None,
                                                  unpaired_latent=None,
                                                  unpaired_text=None,
                                                  unpaired_latent_len=None,
                                                  paired_teacher=mel.shape[1],
                                                  unpaired_teacher=None,
                                                  tf_rate=0.0)
                dev_tts_loss.append(
                    self.freq_loss(pair_mel_pred, mel) +
                    self.freq_loss(pair_linear_pred, linear))

            if i == len(self.dev_set) // 2:
                # pick n longest samples in the median batch
                sample_txt = text.cpu()[:LISTEN_N_EXAMPLES]
                hyp = pair_prob.argmax(dim=-1).cpu()[:LISTEN_N_EXAMPLES]
                mel_p = pair_mel_pred.cpu()[:LISTEN_N_EXAMPLES]
                linear_p = pair_linear_pred.cpu()[:LISTEN_N_EXAMPLES]
                #post_mel_p = tts_pred.cpu()[:LISTEN_N_EXAMPLES,1] # PostNet product
                align_p = pair_align.cpu()[:LISTEN_N_EXAMPLES]
                sample_mel = mel.cpu()[:LISTEN_N_EXAMPLES]
                sample_linear = linear.cpu()[:LISTEN_N_EXAMPLES]

        # Ckpt if performance improves
        dev_tts_loss = sum(dev_tts_loss) / len(dev_tts_loss)
        dev_per = sum(dev_per) / len(dev_per)
        dev_post_per = sum(dev_post_per) / len(dev_post_per) if len(
            dev_post_per) > 0 else None
        #dev_stop_err = sum(dev_stop_err)/len(dev_stop_err)

        if self.paras.store_best_per:
            if dev_per < self.best_per:
                self.best_per = dev_per
                self.save_checkpoint('best_per.pth', dev_per)
            if (dev_post_per is not None) and (dev_post_per < self.best_per):
                self.best_per = dev_post_per
                self.save_checkpoint('best_post_per.pth', dev_post_per)
        else:
            if dev_tts_loss < self.best_tts_loss:
                self.best_tts_loss = dev_tts_loss
                if self.step > 1:
                    self.save_checkpoint('tts_{}.pth'.format(self.step),
                                         dev_tts_loss)
            if dev_per < self.best_per:
                self.best_per = dev_per
                if self.step > 1:
                    self.save_checkpoint('asr_{}.pth'.format(self.step),
                                         dev_per)
            if (dev_post_per is not None) and (dev_post_per < self.best_per):
                self.best_per = dev_post_per
                self.save_checkpoint(
                    'best_post_per.pth', dev_post_per
                )  # Note: didnot recode best per from postnet or not

        if ((self.step > 1) and
            (self.step % CKPT_STEP == 0)) and not self.paras.store_best_per:
            # Regular ckpt
            self.save_checkpoint('step_{}.pth'.format(self.step), dev_tts_loss)

        # Logger
        # Write model output (no G-F-lim if picking per)
        for i, (m_p, l_p, a_p,
                h_p) in enumerate(zip(mel_p, linear_p, align_p, hyp)):
            self.write_log('hyp_text{}'.format(i),
                           self.tokenizer.decode(h_p.tolist()))
            self.write_log('mel_spec{}'.format(i), feat_to_fig(m_p))
            self.write_log('linear_spec{}'.format(i), feat_to_fig(l_p))
            self.write_log('dv_align{}'.format(i), feat_to_fig(a_p))
            if not self.paras.store_best_per:
                self.write_log('mel_wave{}'.format(i),
                               self.audio_converter.feat_to_wave(m_p))
                self.write_log('linear_wave{}'.format(i),
                               self.audio_converter.feat_to_wave(l_p))
        # Write ground truth
        if self.step == 1:
            for i, (mel, linear, gt_txt) in enumerate(
                    zip(sample_mel, sample_linear, sample_txt)):
                self.write_log('truth_text{}'.format(i),
                               self.tokenizer.decode(gt_txt.tolist()))
                self.write_log('mel_spec{}_gt'.format(i), feat_to_fig(mel))
                self.write_log('mel_wave{}_gt'.format(i),
                               self.audio_converter.feat_to_wave(mel))
                self.write_log('linear_spec{}_gt'.format(i),
                               feat_to_fig(linear))
                self.write_log('linear_wave{}_gt'.format(i),
                               self.audio_converter.feat_to_wave(linear))

        self.write_log('speech_loss', {'dev': dev_tts_loss})
        self.write_log('per', {'dev': dev_per, 'dev_post': dev_post_per})
        self.write_log('codebook', (self.model.codebook.embedding.weight.data,
                                    self.tokenizer._vocab_list))
        #self.write_log('stop_err',{'dev':dev_stop_err})
        # Resume training
        self.model.train()

    def compute_ctcloss(self,
                        model_input,
                        model_output,
                        target,
                        apply_log=True):
        if apply_log:
            ctc_input = (model_output + EPS).transpose(0, 1).log()
        else:
            ctc_input = model_output.transpose(0, 1)

        if self.paras.actual_len:
            asr_input_len = torch.sum(
                (model_input == SPEC_PAD_VALUE).long().sum(dim=-1) !=
                model_input.shape[-1],
                dim=-1)
            ctc_len = asr_input_len // self.model.time_reduce_factor
            ctc_target = target
        else:
            ctc_target = target.to_sparse().values()
            ctc_len = torch.LongTensor(
                [model_output.shape[1]] *
                model_output.shape[0]).to(device=self.device)
        return self.ctc_loss(ctc_input, ctc_target, ctc_len,
                             torch.sum(target != 0, dim=-1))
예제 #15
0
    def __init__(self, reinforce_configs,
                 annunciator_configs,
                 src_vocab, trg_vocab,
                 data_iterator,
                 save_to,
                 device="cpu",
                 ):
        """
        initiate translation environments, needs a Scorer and translator
        :param reinforce_configs: attack configures dictionary
        :param annunciator_configs: discriminator or scorer configs(provide survive signals)
        :param save_to: path to save the model
        :param data_iterator: use to provide data for environment initiate
        the directory of the src sentences
        :param device: (string) devices to allocate variables("cpu", "cuda:*")
        default as cpu
        """
        # environment devices
        self.device = device
        self.data_iterator = data_iterator
        scorer_model_configs = annunciator_configs["scorer_model_configs"]
        # discriminator_model_configs = annunciator_configs["discriminator_model_configs"]
        annunciator_optim_configs = annunciator_configs["annunciator_optimizer_configs"]

        victim_config_path = reinforce_configs["victim_configs"]
        victim_model_path = reinforce_configs["victim_model"]
        with open(victim_config_path.strip()) as v_f:
            INFO("env open victim configs at %s" % victim_config_path)
            victim_configs = yaml.load(v_f, Loader=yaml.FullLoader)

        # to extract the embedding as representation
        # *vocab and *emb will provide psudo-reinforced embedding to train annunciator
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        # translation model for BLEU(take src_embs as inputs) and corresponding embedding layers
        self.src_emb, self.trg_emb, self.translate_model = build_translate_model(
            victim_configs, victim_model_path,
            vocab_src=self.src_vocab, vocab_trg=self.trg_vocab,
            device=self.device)

        self.max_roll_out_step = victim_configs["data_configs"]["max_len"][0]
        self.src_emb.eval()  # source language embeddings
        self.trg_emb.eval()  # target language embeddings
        self.translate_model.eval()

        # the epsilon range used for action space when perturbation
        _, _, self.limit_dist = load_or_extract_near_vocab(
            config_path=victim_config_path, model_path=victim_model_path,
            init_perturb_rate=reinforce_configs["init_perturb_rate"],
            save_to=os.path.join(save_to, "near_vocab"),
            save_to_full=os.path.join(save_to, "full_near_vocab"),
            top_reserve=12, emit_as_id=True)

        #########################################################
        # scorer(an Annunciator object) provides intrinsic step rewards
        self.annunciator = TransScorer(
            victim_configs, victim_model_path, self.trg_emb,
            **scorer_model_configs)
        self.annunciator.to(self.device)
        # # discriminator(an Annunciator object) provides intrisic step rewards and terminal signal
        # self.discriminator = TransDiscriminator(
        #     victim_configs, victim_model_path,
        #     **discriminator_model_configs)
        # self.discriminator.to(self.device)
        # Annunciator update configs
        self.acc_bound = annunciator_configs["acc_bound"]
        self.mse_bound = annunciator_configs["mse_bound"]
        self.min_update_steps = annunciator_configs["valid_freq"]
        self.max_update_steps = annunciator_configs["annunciator_update_steps"]
        # the optimizer and schedule used for Annunciator update.
        self.optim_A = Optimizer(
            name=annunciator_optim_configs["optimizer"],
            model=self.annunciator,
            lr=annunciator_optim_configs["learning_rate"],
            grad_clip=annunciator_optim_configs["grad_clip"],
            optim_args=annunciator_optim_configs["optimizer_params"])

        self.scheduler_A = None  # default as None
        if annunciator_optim_configs['schedule_method'] is not None:
            if annunciator_optim_configs['schedule_method'] == "loss":
                self.scheduler_A = ReduceOnPlateauScheduler(optimizer=self.optim_A,
                                                            **annunciator_optim_configs["scheduler_configs"])
            elif annunciator_optim_configs['schedule_method'] == "noam":
                self.scheduler_A = NoamScheduler(optimizer=self.optim_A,
                                                 **annunciator_optim_configs['scheduler_configs'])
            elif annunciator_optim_configs["schedule_method"] == "rsqrt":
                self.scheduler_A = RsqrtScheduler(optimizer=self.optim_A,
                                                  **annunciator_optim_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    annunciator_optim_configs['schedule_method']))
        self.criterion_A = nn.CrossEntropyLoss()
        ############################################################
        self.adversarial = reinforce_configs["adversarial"]  # adversarial or reinforce as learning objects
        self.r_s_weight = reinforce_configs["r_s_weight"]
        self.r_i_weight = reinforce_configs["r_i_weight"]
예제 #16
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Logger settings
        self.best_wer = {'att': 3.0, 'ctc': 3.0}
        # Curriculum learning affects data loader
        self.curriculum = self.config['hparas']['curriculum']

    def fetch_data(self, data):
        ''' Move data to device and compute text seq. length'''
        _, feat, feat_len, txt = data
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
            load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                         self.curriculum > 0, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta,
                         **self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()

        # ToDo: other training methods

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        self.timer.set()

        while self.step < self.max_step:
            # Renew dataloader to enable random sampling
            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                    load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                 False, **self.config['data'])
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                               teacher=txt, get_dec_state=self.emb_reg)

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss

                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight

                if att_output is not None:
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(
                        att_output.contiguous().view(b * t, -1),
                        txt.contiguous().view(-1))
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)
                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(total_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    self.write_log('loss', {
                        'tr_ctc': ctc_loss,
                        'tr_att': att_loss
                    })
                    self.write_log('emb_loss', {'tr': emb_loss})
                    self.write_log(
                        'wer', {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt),
                            'tr_ctc':
                            cal_er(self.tokenizer, ctc_output, txt, ctc=True)
                        })
                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                torch.cuda.empty_cache()
                self.timer.set()
                if self.step > self.max_step:
                    break
            n_epochs += 1
        self.log.close()

    def validate(self):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None:
            self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                               emb_decoder=self.emb_decoder)

            dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt))
            dev_wer['ctc'].append(
                cal_er(self.tokenizer, ctc_output, txt, ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if self.step == 1:
                        self.write_log('true_text{}'.format(i),
                                       self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log(
                            'att_align{}'.format(i),
                            feat_to_fig(att_align[i, 0, :, :].cpu().detach()))
                        self.write_log(
                            'att_text{}'.format(i),
                            self.tokenizer.decode(
                                att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log(
                            'ctc_text{}'.format(i),
                            self.tokenizer.decode(
                                ctc_output[i].argmax(dim=-1).tolist(),
                                ignore_repeat=True))

        # Ckpt if performance improves
        for task in ['att', 'ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            if dev_wer[task] < self.best_wer[task]:
                self.best_wer[task] = dev_wer[task]
                self.save_checkpoint('best_{}.pth'.format(task), 'wer',
                                     dev_wer[task])
            self.write_log('wer', {'dv_' + task: dev_wer[task]})
        self.save_checkpoint('latest.pth',
                             'wer',
                             dev_wer['att'],
                             show_msg=False)

        # Resume training
        self.model.train()
        if self.emb_decoder is not None:
            self.emb_decoder.train()
예제 #17
0
def run():
    # default actor threads as 1
    os.environ["OMP_NUM_THREADS"] = "1"
    mp = _mp.get_context('spawn')
    args = parser.parse_args()
    if not os.path.exists(args.save_to):
        os.mkdir(args.save_to)
    with open(args.config_path, "r") as f, \
            open(os.path.join(args.save_to, "current_attack_configs.yaml"), "w") as current_configs:
        configs = yaml.load(f)
        yaml.dump(configs, current_configs)
    attack_configs = configs["attack_configs"]
    attacker_configs = configs["attacker_configs"]
    attacker_model_configs = attacker_configs["attacker_model_configs"]
    attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"]
    discriminator_configs = configs["discriminator_configs"]
    # training_configs = configs["training_configs"]

    # initial best saver for global model
    global_saver = Saver(
        save_prefix="{0}.final".format(os.path.join(args.save_to, "ACmodel")),
        num_max_keeping=attack_configs["num_kept_checkpoints"])
    # the Global variable of  USE_GPU is mainly used for environments
    GlobalNames.SEED = attack_configs["seed"]
    GlobalNames.USE_GPU = args.use_gpu
    torch.manual_seed(GlobalNames.SEED)

    # build vocabulary and data iterator for env
    with open(attack_configs["victim_configs"], "r") as victim_f:
        victim_configs = yaml.load(victim_f)
    data_configs = victim_configs["data_configs"]
    src_vocab = Vocabulary(**data_configs["vocabularies"][0])
    trg_vocab = Vocabulary(**data_configs["vocabularies"][1])
    data_set = ZipDataset(
        TextLineDataset(
            data_path=data_configs["train_data"][0],
            vocabulary=src_vocab,
        ),
        TextLineDataset(
            data_path=data_configs["train_data"][1],
            vocabulary=trg_vocab,
        ),
        shuffle=attack_configs["shuffle"]
    )  # we build the parallel data sets and iterate inside a thread

    # global model variables (trg network to save the results)
    global_attacker = attacker.Attacker(src_vocab.max_n_words,
                                        **attacker_model_configs)
    global_attacker = global_attacker.cpu()
    global_attacker.share_memory()
    if args.share_optim:
        # initiate optimizer and set to share mode
        optimizer = Optimizer(
            name=attacker_optimizer_configs["optimizer"],
            model=global_attacker,
            lr=attacker_optimizer_configs["learning_rate"],
            grad_clip=attacker_optimizer_configs["grad_clip"],
            optim_args=attacker_optimizer_configs["optimizer_params"])
        optimizer.optim.share_memory()
        # Build scheduler for optimizer if needed
        if attacker_optimizer_configs['schedule_method'] is not None:
            if attacker_optimizer_configs['schedule_method'] == "loss":
                scheduler = ReduceOnPlateauScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            elif attacker_optimizer_configs['schedule_method'] == "noam":
                scheduler = NoamScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs['scheduler_configs'])
            elif attacker_optimizer_configs["schedule_method"] == "rsqrt":
                scheduler = RsqrtScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(attacker_optimizer_configs['schedule_method']))
                scheduler = None
        else:
            scheduler = None
    else:
        optimizer = None
        scheduler = None

    # load from checkpoint for global model
    global_saver.load_latest(model=global_attacker,
                             optim=optimizer,
                             lr_scheduler=scheduler)

    if args.use_gpu:
        # collect available devices and distribute env on the available gpu
        device = "cuda"
        devices = []
        for i in range(torch.cuda.device_count()):
            devices += ["cuda:%d" % i]
        print("available gpus:", devices)
    else:
        device = "cpu"
        devices = [device]

    process = []
    counter = mp.Value("i", 0)
    lock = mp.Lock()  # for multiple attackers update

    INFO("extract near candidates")
    _, _ = load_or_extract_near_vocab(
        config_path=attack_configs["victim_configs"],
        model_path=attack_configs["victim_model"],
        init_perturb_rate=attack_configs["init_perturb_rate"],
        save_to=os.path.join(args.save_to, "near_vocab"),
        save_to_full=os.path.join(args.save_to, "full_near_vocab"),
        top_reserve=12,
        emit_as_id=True)

    # train(0, device, args, counter, lock,
    #       attack_configs, discriminator_configs,
    #       src_vocab, trg_vocab, data_set,
    #       global_attacker, attacker_configs,
    #       optimizer, scheduler,
    #       global_saver)

    # valid(args.n, device, args,
    #      attack_configs, discriminator_configs,
    #      src_vocab, trg_vocab, data_set,
    #      global_attacker, attacker_configs, counter)
    # run multiple training process of local attacker to update global one

    for rank in range(args.n):
        print("initialize training thread on cuda:%d" % (rank + 1))
        p = mp.Process(target=train,
                       args=(rank, "cuda:%d" % (rank + 1), args, counter, lock,
                             attack_configs, discriminator_configs, src_vocab,
                             trg_vocab, data_set, global_attacker,
                             attacker_configs, optimizer, scheduler,
                             global_saver))
        p.start()
        process.append(p)
    # run the dev thread for initiation
    print("initialize dev thread on cuda:0")
    p = mp.Process(target=valid,
                   args=(0, "cuda:0", args, attack_configs,
                         discriminator_configs, src_vocab, trg_vocab, data_set,
                         global_attacker, attacker_configs, counter))
    p.start()
    process.append(p)

    for p in process:
        p.join()
예제 #18
0
def train(rank,
          device,
          args,
          counter,
          lock,
          attack_configs,
          discriminator_configs,
          src_vocab,
          trg_vocab,
          data_set,
          global_attacker,
          attacker_configs,
          optimizer=None,
          scheduler=None,
          saver=None):
    """
    running train process
    #1# train the env_discriminator
    #2# run attacker AC based on rewards from trained env_discriminator
    #3# run training updates attacker AC
    #4#
    :param rank: (int) the rank of the process (from multiprocess)
    :param device: the device of the process
    :param counter: python multiprocess variable
    :param lock: python multiprocess variable
    :param args: global args
    :param attack_configs: attack settings
    :param discriminator_configs: discriminator settings
    :param src_vocab:
    :param trg_vocab:
    :param data_set: (data_iterator object) provide batched data labels
    :param global_attacker: the model to sync from
    :param attacker_configs: local attacker settings
    :param optimizer: uses shared optimizer for the attacker
            use local one if none
    :param scheduler: uses shared scheduler for the attacker,
            use local one if none
    :param saver: model saver
    :return:
    """
    trust_acc = acc_bound = discriminator_configs["acc_bound"]
    converged_bound = discriminator_configs["converged_bound"]
    patience = discriminator_configs["patience"]
    attacker_model_configs = attacker_configs["attacker_model_configs"]
    attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"]

    # this is for multi-processing, GlobalNames can not be direct inherited
    GlobalNames.USE_GPU = args.use_gpu
    GlobalNames.SEED = attack_configs["seed"]
    torch.manual_seed(GlobalNames.SEED + rank)

    # initiate local saver and load checkpoint if possible
    local_saver = Saver(save_prefix="{0}.local".format(
        os.path.join(args.save_to, "train_env%d" % rank, "ACmodel")),
                        num_max_keeping=attack_configs["num_kept_checkpoints"])

    attack_iterator = DataIterator(dataset=data_set,
                                   batch_size=attack_configs["batch_size"],
                                   use_bucket=True,
                                   buffer_size=attack_configs["buffer_size"],
                                   numbering=True)

    summary_writer = SummaryWriter(
        log_dir=os.path.join(args.save_to, "train_env%d" % rank))
    local_attacker = attacker.Attacker(src_vocab.max_n_words,
                                       **attacker_model_configs)
    # build optimizer for attacker
    if optimizer is None:
        optimizer = Optimizer(
            name=attacker_optimizer_configs["optimizer"],
            model=global_attacker,
            lr=attacker_optimizer_configs["learning_rate"],
            grad_clip=attacker_optimizer_configs["grad_clip"],
            optim_args=attacker_optimizer_configs["optimizer_params"])
        # Build scheduler for optimizer if needed
        if attacker_optimizer_configs['schedule_method'] is not None:
            if attacker_optimizer_configs['schedule_method'] == "loss":
                scheduler = ReduceOnPlateauScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            elif attacker_optimizer_configs['schedule_method'] == "noam":
                scheduler = NoamScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs['scheduler_configs'])
            elif attacker_optimizer_configs["schedule_method"] == "rsqrt":
                scheduler = RsqrtScheduler(
                    optimizer=optimizer,
                    **attacker_optimizer_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(attacker_optimizer_configs['schedule_method']))
                scheduler = None
        else:
            scheduler = None

    local_saver.load_latest(model=local_attacker,
                            optim=optimizer,
                            lr_scheduler=scheduler)

    attacker_iterator = attack_iterator.build_generator()
    env = Translate_Env(attack_configs=attack_configs,
                        discriminator_configs=discriminator_configs,
                        src_vocab=src_vocab,
                        trg_vocab=trg_vocab,
                        data_iterator=attacker_iterator,
                        save_to=args.save_to,
                        device=device)
    episode_count = 0
    episode_length = 0
    local_steps = 0  # optimization steps: for learning rate schedules
    patience_t = patience
    while True:  # infinite loop of data set
        # we will continue with a new iterator with refreshed environments
        # whenever the last iterator breaks with "StopIteration"
        attacker_iterator = attack_iterator.build_generator()
        env.reset_data_iter(attacker_iterator)
        padded_src = env.reset()
        padded_src = torch.from_numpy(padded_src)
        if device != "cpu":
            padded_src = padded_src.to(device)
        done = True
        discriminator_base_steps = local_steps

        while True:
            # check for update of discriminator
            # if env.acc_validation(local_attacker, use_gpu=True if env.device != "cpu" else False) < 0.55:
            if episode_count % attacker_configs["attacker_update_steps"] == 0:
                """ stop criterion:
                when updates a discriminator, we check for acc. If acc fails acc_bound,
                we reset the discriminator and try, until acc reaches the bound with patience.
                otherwise the training thread stops
                """
                try:
                    discriminator_base_steps, trust_acc = env.update_discriminator(
                        local_attacker,
                        discriminator_base_steps,
                        min_update_steps=discriminator_configs[
                            "acc_valid_freq"],
                        max_update_steps=discriminator_configs[
                            "discriminator_update_steps"],
                        accuracy_bound=acc_bound,
                        summary_writer=summary_writer)
                except StopIteration:
                    INFO("finish one training epoch, reset data_iterator")
                    break

                discriminator_base_steps += 1  # a flag to label the discriminator updates
                if trust_acc < converged_bound:  # GAN target reached
                    patience_t -= 1
                    INFO(
                        "discriminator reached GAN convergence bound: %d times"
                        % patience_t)
                else:  # reset patience if discriminator is refreshed
                    patience_t = patience

            if saver and local_steps % attack_configs["save_freq"] == 0:
                local_saver.save(global_step=local_steps,
                                 model=local_attacker,
                                 optim=optimizer,
                                 lr_scheduler=scheduler)

                if trust_acc < converged_bound:  # and patience_t == patience-1:
                    # we only save the global params reaching acc_bound
                    torch.save(global_attacker.state_dict(),
                               os.path.join(args.save_to, "ACmodel.final"))
                    # saver.raw_save(model=global_attacker)

            if patience_t == 0:
                WARN("maximum patience reached. Training Thread should stop")
                break

            local_attacker.train()  # switch back to training mode

            # for a initial (reset) attacker from global parameters
            if done:
                INFO("sync from global model")
                local_attacker.load_state_dict(global_attacker.state_dict())
            # move the local attacker params back to device after updates
            local_attacker = local_attacker.to(device)
            values = []  # training critic: network outputs
            log_probs = []
            rewards = []  # actual rewards
            entropies = []

            local_steps += 1
            # run sequences step of attack
            try:
                for i in range(args.action_roll_steps):
                    episode_length += 1
                    attack_out, critic_out = local_attacker(
                        padded_src, padded_src[:, env.index - 1:env.index + 2])
                    logit_attack_out = torch.log(attack_out)
                    entropy = -(attack_out *
                                logit_attack_out).sum(dim=-1).mean()

                    summary_writer.add_scalar("action_entropy",
                                              scalar_value=entropy,
                                              global_step=local_steps)
                    entropies.append(entropy)  # for entropy loss
                    actions = attack_out.multinomial(num_samples=1).detach()
                    # only extract the log prob for chosen action (avg over batch)
                    log_attack_out = logit_attack_out.gather(-1,
                                                             actions).mean()
                    padded_src, reward, terminal_signal = env.step(
                        actions.squeeze())
                    done = terminal_signal or episode_length > args.max_episode_lengths

                    with lock:
                        counter.value += 1

                    if done:
                        episode_length = 0
                        padded_src = env.reset()

                    padded_src = torch.from_numpy(padded_src)
                    if device != "cpu":
                        padded_src = padded_src.to(device)

                    values.append(
                        critic_out.mean())  # list of torch variables (scalar)
                    log_probs.append(
                        log_attack_out)  # list of torch variables (scalar)
                    rewards.append(reward)  # list of reward variables

                    if done:
                        episode_count += 1
                        break
            except StopIteration:
                INFO("finish one training epoch, reset data_iterator")
                break

            R = torch.zeros(1, 1)
            gae = torch.zeros(1, 1)
            if device != "cpu":
                R = R.to(device)
                gae = gae.to(device)

            if not done:  # calculate value loss
                value = local_attacker.get_critic(
                    padded_src, padded_src[:, env.index - 1:env.index + 2])
                R = value.mean().detach()

            values.append(R)
            policy_loss = 0
            value_loss = 0

            # collect values for training
            for i in reversed((range(len(rewards)))):
                # value loss and policy loss must be clipped to stabilize training
                R = attack_configs["gamma"] * R + rewards[i]
                advantage = R - values[i]
                value_loss = value_loss + 0.5 * advantage.pow(2)

                delta_t = rewards[i] + attack_configs["gamma"] * \
                          values[i + 1] - values[i]
                gae = gae * attack_configs["gamma"] * attack_configs["tau"] + \
                      delta_t
                policy_loss = policy_loss - log_probs[i] * gae.detach() - \
                              attack_configs["entropy_coef"] * entropies[i]
                print("policy_loss", policy_loss)
                print("gae", gae)

            # update with optimizer
            optimizer.zero_grad()
            # we decay the loss according to discriminator's accuracy as a trust region constrain
            summary_writer.add_scalar("policy_loss",
                                      scalar_value=policy_loss * trust_acc,
                                      global_step=local_steps)
            summary_writer.add_scalar("value_loss",
                                      scalar_value=value_loss * trust_acc,
                                      global_step=local_steps)

            total_loss = trust_acc * policy_loss + \
                         trust_acc * attack_configs["value_coef"] * value_loss
            total_loss.backward()

            if attacker_optimizer_configs[
                    "schedule_method"] is not None and attacker_optimizer_configs[
                        "schedule_method"] != "loss":
                scheduler.step(global_step=local_steps)

            # move the model params to CPU and
            # assign local gradients to the global model to update
            local_attacker.to("cpu").ensure_shared_grads(global_attacker)
            optimizer.step()
            print("bingo!")

        if patience_t == 0:
            INFO("Reach maximum Discriminator patience, Finish")
            break
예제 #19
0
파일: tune.py 프로젝트: wangqi1996/njunmt
def tune(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    if flags.multi_gpu:
        dist.distributed_init(flags.shared_dir)
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        local_rank = dist.get_local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data
    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos
    # bt tag dataset
    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        is_train_dataset=True),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        is_train_dataset=True))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    lrate = optimizer_configs['learning_rate']
    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            vocab_tgt=vocab_tgt,
                            **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'],
                          padding_idx=vocab_tgt.pad)

    INFO(critic)

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=flags.pretrain_exclude_prefix,
                          device=Constants.CURRENT_DEVICE)
    # froze_parameters
    froze_params(nmt_model, flags.froze_config)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')

    if not flags.multi_gpu:
        optim = Optimizer(name=optimizer_configs['optimizer'],
                          model=nmt_model,
                          lr=lrate,
                          grad_clip=optimizer_configs['grad_clip'],
                          optim_args=optimizer_configs['optimizer_params'],
                          update_cycle=training_configs['update_cycle'])
    else:
        optim = dist.DistributedOptimizer(
            name=optimizer_configs['optimizer'],
            model=nmt_model,
            lr=lrate,
            grad_clip=optimizer_configs['grad_clip'],
            optim_args=optimizer_configs['optimizer_params'],
            device_id=local_rank)

    # 5. Build scheduler for optimizer if needed
    scheduler = build_scheduler(
        schedule_method=optimizer_configs['schedule_method'],
        optimizer=optim,
        scheduler_configs=optimizer_configs['scheduler_configs'])

    # 6. build moving average
    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=nmt_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if flags.reload:
        checkpoint_saver.load_latest(model=nmt_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma,
                                     device=Constants.CURRENT_DEVICE)

    # broadcast parameters and optimizer states
    if world_size > 1:
        INFO("Broadcasting model parameters...")
        dist.broadcast_parameters(params=nmt_model.state_dict())
        INFO("Broadcasting optimizer states...")
        dist.broadcast_optimizer_state(optimizer=optim.optim)
        INFO('Done.')

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]

    train_loss_meter = AverageMeter()
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    update_cycle = training_configs['update_cycle']
    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None
        # INFO(Constants.USE_BT)
        for batch in training_iter:
            # bt attrib data
            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

                loss = compute_forward(
                    model=nmt_model,
                    critic=critic,
                    seqs_x=x,
                    seqs_y=y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size
                train_loss += loss

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if update_cycle == 0:

                # 0. reduce variables
                if world_size > 1:
                    grad_denom = dist.all_reduce_py(grad_denom)
                    train_loss = dist.all_reduce_py(train_loss)
                    cum_n_words = dist.all_reduce_py(cum_n_words)

                # 1. update parameters
                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)
                    training_progress_bar.set_description(
                        ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                    postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                        train_loss, valid_loss, best_valid_loss)
                    training_progress_bar.set_postfix_str(postfix_str)

                # 2. learning rate scheduling
                if scheduler is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                    scheduler.step(global_step=uidx)

                # 3. update moving average
                if ma is not None and eidx >= training_configs[
                        'moving_average_start_epoch']:
                    ma.step()

                # 4. update meters
                train_loss_meter.update(train_loss, grad_denom)
                sent_per_sec_meter.update(grad_denom)
                tok_per_sec_meter.update(cum_n_words)

                # 5. reset accumulated variables, update uidx
                update_cycle = training_configs['update_cycle']
                grad_denom = 0
                uidx += 1
                cum_n_words = 0.0
                train_loss = 0.0

            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar("lrate",
                                              scalar_value=lrate,
                                              global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()

            # ================================================================================== #
            # Saving checkpoints
            # if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=flags.debug):
            #     model_collections.add_to_collection("uidx", uidx)
            #     model_collections.add_to_collection("eidx", eidx)
            #     model_collections.add_to_collection("bad_count", bad_count)
            #
            #     if not is_early_stop:
            #         if rank == 0:
            #             checkpoint_saver.save(global_step=uidx,
            #                                   model=nmt_model,
            #                                   optim=optim,
            #                                   lr_scheduler=scheduler,
            #                                   collections=model_collections,
            #                                   ma=ma)

        torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
예제 #20
0
파일: nmt.py 프로젝트: whr94621/ODC-NMT
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    GlobalNames.USE_GPU = FLAGS.use_gpu

    if FLAGS.multi_gpu:

        if hvd is None or distributed is None:
            ERROR("Distributed training is disable. Please check the installation of Horovod.")

        hvd.init()
        world_size = hvd.size()
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if GlobalNames.USE_GPU:
        torch.cuda.set_device(local_rank)
        CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    if rank != 0:
        close_logging()

    # write log of training to file.
    if rank == 0:
        write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    # ================================================================================== #
    # Parsing configuration files

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_baseline_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary(**data_configs["vocabularies"][0])
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][1])

    actual_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        ),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        )
    )

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['valid_data'][0],
                        vocabulary=vocab_src,
                        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        )
    )

    training_iterator = DataIterator(dataset=train_bitext_dataset,
                                     batch_size=training_configs["batch_size"],
                                     use_bucket=training_configs['use_bucket'],
                                     buffer_size=actual_buffer_size,
                                     batching_func=training_configs['batching_key'],
                                     world_size=world_size,
                                     rank=rank)

    valid_iterator = DataIterator(dataset=valid_bitext_dataset,
                                  batch_size=training_configs['valid_batch_size'],
                                  use_bucket=True, buffer_size=100000, numbering=True,
                                  world_size=world_size, rank=rank)

    bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"],
                                  num_refs=data_configs["num_refs"],
                                  lang_pair=data_configs["lang_pair"],
                                  sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'],
                                  postprocess=training_configs["bleu_valid_configs"]['postprocess']
                                  )

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)),
                             num_max_keeping=training_configs['num_kept_checkpoints']
                             )
    best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model'])

    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(nmt_model)

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=nmt_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'],
                      distributed=True if world_size > 1 else False,
                      update_cycle=training_configs['update_cycle']
                      )
    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(optimizer=optim,
                                                 **optimizer_configs["scheduler_configs"]
                                                 )

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs'])
        else:
            WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(moving_average_method=training_configs['moving_average_method'],
                           named_params=nmt_model.named_parameters(),
                           alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler,
                                     collections=model_collections, ma=ma)

    # broadcast parameters and optimizer states
    if world_size > 1:
        hvd.broadcast_parameters(params=nmt_model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer=optim.optim, root_rank=0)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    cum_n_samples = 0
    cum_n_words = 0
    best_valid_loss = 1.0 * 1e10  # Max Float
    update_cycle = training_configs['update_cycle']
    grad_denom = 0

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=FLAGS.log_path)
    else:
        summary_writer = None

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc='  - (Epoch %d)   ' % eidx,
                                         total=len(training_iterator),
                                         unit="sents"
                                         )
        else:
            training_progress_bar = None

        for batch in training_iter:

            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)

            cum_n_samples += batch_size
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU)

                loss = compute_forward(model=nmt_model,
                                       critic=critic,
                                       seqs_x=x,
                                       seqs_y=y,
                                       eval=False,
                                       normalization=1.0,
                                       norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom
            # - update uidx
            # - update moving average

            if update_cycle == 0:
                if world_size > 1:
                    grad_denom = distributed.all_reduce(grad_denom)

                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)

                update_cycle = training_configs['update_cycle']
                grad_denom = 0

                uidx += 1

                if scheduler is None:
                    pass
                elif optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)
                else:
                    scheduler.step(global_step=uidx)

                if ma is not None and eidx >= training_configs['moving_average_start_epoch']:
                    ma.step()
            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']):

                if world_size > 1:
                    cum_n_words = sum(distributed.all_gather(cum_n_words))
                    cum_n_samples = sum(distributed.all_gather(cum_n_samples))

                # words per second and sents per second
                words_per_sec = cum_n_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_n_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx)
                    summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx)
                    summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx)
                    summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx)

                # Reset timer
                timer.tic()
                cum_n_words = 0
                cum_n_samples = 0

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'],
                                       debug=FLAGS.debug):

                valid_loss = loss_validation(model=nmt_model,
                                             critic=critic,
                                             valid_iterator=valid_iterator,
                                             rank=rank,
                                             world_size=world_size
                                             )

                model_collections.add_to_collection("history_losses", valid_loss)

                min_history_loss = np.array(model_collections.get_collection("history_losses")).min()

                best_valid_loss = min_history_loss

                if summary_writer is not None:
                    summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                    summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop

            if should_trigger_by_steps(global_step=uidx, n_epoch=eidx,
                                       every_n_step=training_configs['bleu_valid_freq'],
                                       min_step=training_configs['bleu_valid_warmup'],
                                       debug=FLAGS.debug):

                valid_bleu = bleu_validation(uidx=uidx,
                                             valid_iterator=valid_iterator,
                                             batch_size=training_configs["bleu_valid_batch_size"],
                                             model=nmt_model,
                                             bleu_scorer=bleu_scorer,
                                             vocab_tgt=vocab_tgt,
                                             valid_dir=FLAGS.valid_path,
                                             max_steps=training_configs["bleu_valid_configs"]["max_steps"],
                                             beam_size=training_configs["bleu_valid_configs"]["beam_size"],
                                             alpha=training_configs["bleu_valid_configs"]["alpha"],
                                             world_size=world_size,
                                             rank=rank,
                                             )

                model_collections.add_to_collection(key="history_bleus", value=valid_bleu)

                best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max())

                if summary_writer is not None:
                    summary_writer.add_scalar("bleu", valid_bleu, uidx)
                    summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(), best_model_prefix + ".final")

                            # 2. record all several best models
                            best_model_saver.save(global_step=uidx, model=nmt_model, ma=ma)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs['early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count
                ))

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    if rank == 0:
                        checkpoint_saver.save(global_step=uidx,
                                              model=nmt_model,
                                              optim=optim,
                                              lr_scheduler=scheduler,
                                              collections=model_collections,
                                              ma=ma)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
예제 #21
0
class Solver(BaseSolver):
    ''' Solver for training'''

    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # ToDo : support tr/eval on different corpus
        assert self.config['data']['corpus']['name'] == self.src_config['data']['corpus']['name']
        self.config['data']['corpus']['path'] = self.src_config['data']['corpus']['path']
        self.config['data']['corpus']['bucketing'] = False

        # The follow attribute should be identical to training config
        self.config['data']['audio'] = self.src_config['data']['audio']
        self.config['data']['corpus']['train_split'] = self.src_config['data']['corpus']['train_split']
        self.config['data']['text'] = self.src_config['data']['text']
        self.tokenizer = load_text_encoder(**self.config['data']['text'])
        self.config['model'] = self.src_config['model']
        self.finetune_first = 5
        self.best_wer = {'att': 3.0, 'ctc': 3.0}

        # Output file
        self.output_file = str(self.ckpdir)+'_{}_{}.csv'

        # Override batch size for beam decoding
        self.greedy = self.config['decode']['beam_size'] == 1
        self.dealer = Datadealer(self.config['data']['audio'])
        self.ctc = self.config['decode']['ctc_weight'] == 1.0
        if not self.greedy:
            self.config['data']['corpus']['batch_size'] = 1
        else:
            # ToDo : implement greedy
            raise NotImplementedError

        # Logger settings
        self.logdir = os.path.join(paras.logdir, self.exp_name)
        self.log = SummaryWriter(
            self.logdir, flush_secs=self.TB_FLUSH_FREQ)
        self.timer = Timer()

    def fetch_data(self, data):
        ''' Move data to device and compute text seq. length'''
        _, feat, feat_len, txt = data
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self, batch_size=7):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        prev_batch_size = self.config['data']['corpus']['batch_size']
        self.config['data']['corpus']['batch_size'] = batch_size
        self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
            load_dataset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, False, **self.config['data'])
        self.config['data']['corpus']['batch_size'] = prev_batch_size
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model '''
        # Model
        self.feat_dim = 120
        self.vocab_size = 46 
        init_adadelta = True
        ''' Setup ASR model and optimizer '''
        # Model
        # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta'
        self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **
                         self.src_config['model']).to(self.device)
        self.verbose(self.model.create_msg())

        if self.finetune_first>0:
            names = ["encoder.layers.%d"%i for i in range(self.finetune_first)]
            model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}]
        else:
            model_paras = [{'params': self.model.parameters()}]

        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Note: zero_infinity=False is unstable?
        self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False)

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb' in self.config) and (
            self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.src_config['hparas'])
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()
        # Beam decoder
        self.decoder = BeamDecoder(
            self.model, self.emb_decoder, **self.config['decode'])
        self.verbose(self.decoder.create_msg())
        # del self.model
        # del self.emb_decoder
        self.decoder.to(self.device)

    def exec(self):
        ''' Testing End-to-end ASR system '''
        while True:
            try:
                filename = input("Input wav file name: ")
                if filename == "exit":
                    return
                feat, feat_len = self.dealer(filename)
                feat = feat.to(self.device)
                feat_len = feat_len.to(self.device)
                # Decode
                with torch.no_grad():
                    hyps = self.decoder(feat, feat_len)

                hyp_seqs = [hyp.outIndex for hyp in hyps]
                hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs]
                for txt in hyp_txts:
                    print(txt)
            except:
                print("Invalid file")
                pass

    def recognize(self, filename):
        try:
            feat, feat_len = self.dealer(filename)
            feat = feat.to(self.device)
            feat_len = feat_len.to(self.device)
            # Decode
            with torch.no_grad():
                hyps = self.decoder(feat, feat_len)
            
            hyp_seqs = [hyp.outIndex for hyp in hyps]
            hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs]
            return hyp_txts[0]
        except Exception as e:
            print(e)
            app.logger.debug(e)
            return "Invalid file"

    def fetch_finetune_data(self, filename, fixed_text):
        feat, feat_len = self.dealer(filename)
        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        text = self.tokenizer.encode(fixed_text)
        text = torch.tensor(text).to(self.device)
        text_len = len(text)
        return [feat, feat_len, text, text_len]

    def merge_batch(self, main_batch, attach_batch):
        max_feat_len = max(main_batch[1])
        max_text_len = max(main_batch[3])
        if attach_batch[0].shape[1] > max_feat_len:
            # reduce extra long example
            attach_batch[0] = attach_batch[0][:,:max_feat_len]
            attach_batch[1][0] = max_feat_len
        else:
            # pad to max_feat_len
            padding = torch.zeros(1, max_feat_len - attach_batch[0].shape[1], attach_batch[0].shape[2], dtype=attach_batch[0].dtype).to(self.device)
            attach_batch[0] = torch.cat([attach_batch[0], padding], dim=1)
        if attach_batch[2].shape[0] > max_text_len:
            attach_batch[2] = attach_batch[2][:max_text_len]
            main_batch[3][0] = max_text_len
        else:
            padding = torch.zeros(max_text_len - attach_batch[2].shape[0], dtype=attach_batch[2].dtype).to(self.device)
            try:
                attach_batch[2] = torch.cat([attach_batch[2], padding], dim=0).unsqueeze(0)
            except:
                pdb.set_trace()
        new_batch = (
            torch.cat([main_batch[0], attach_batch[0]], dim=0),
            torch.cat([main_batch[1], attach_batch[1]], dim=0),
            torch.cat([main_batch[2], attach_batch[2]], dim=0),
            torch.cat([main_batch[3], torch.tensor([attach_batch[3]]).to(self.device)], dim=0)
        )
        return new_batch
            


    def finetune(self, filename, fixed_text, max_step=5):
        # Load data for finetune
        self.verbose('Total training steps {}.'.format(
            human_format(max_step)))
        ctc_loss, att_loss, emb_loss = None, None, None
        n_epochs = 0
        accum_count = 0
        self.timer.set()
        step = 0
        for data in self.tr_set:
            # Pre-step : update tf_rate/lr_rate and do zero_grad
            if max_step == 0:
                break
            tf_rate = self.optimizer.pre_step(400000)
            total_loss = 0

            # Fetch data
            finetune_data = self.fetch_finetune_data(filename, fixed_text)
            main_batch = self.fetch_data(data)
            new_batch = self.merge_batch(main_batch, finetune_data)
            feat, feat_len, txt, txt_len = new_batch
            self.timer.cnt('rd')

            # Forward model
            # Note: txt should NOT start w/ <sos>
            ctc_output, encode_len, att_output, att_align, dec_state = \
                self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate,
                            teacher=txt, get_dec_state=self.emb_reg)

            # Plugins
            if self.emb_reg:
                emb_loss, fuse_output = self.emb_decoder(
                    dec_state, att_output, label=txt)
                total_loss += self.emb_decoder.weight*emb_loss

            # Compute all objectives
            if ctc_output is not None:
                if self.paras.cudnn_ctc:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                txt.to_sparse().values().to(device='cpu', dtype=torch.int32),
                                                [ctc_output.shape[1]] *
                                                len(ctc_output),
                                                txt_len.cpu().tolist())
                else:
                    ctc_loss = self.ctc_loss(ctc_output.transpose(
                        0, 1), txt, encode_len, txt_len)
                total_loss += ctc_loss*self.model.ctc_weight

            if att_output is not None:
                b, t, _ = att_output.shape
                att_output = fuse_output if self.emb_fuse else att_output
                att_loss = self.seq_loss(
                    att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1))
                total_loss += att_loss*(1-self.model.ctc_weight)

            self.timer.cnt('fw')

            # Backprop
            grad_norm = self.backward(total_loss)
            step += 1

            # Logger
            self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                        .format(total_loss.cpu().item(), grad_norm, self.timer.show()))
            self.write_log(
                'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss})
            self.write_log('emb_loss', {'tr': emb_loss})
            self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt),
                                'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)})
            if self.emb_fuse:
                if self.emb_decoder.fuse_learnable:
                    self.write_log('fuse_lambda', {
                                'emb': self.emb_decoder.get_weight()})
                self.write_log(
                    'fuse_temp', {'temp': self.emb_decoder.get_temp()})

            # End of step
            # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
            torch.cuda.empty_cache()
            self.timer.set()
            if step > max_step:
                break
        ret = self.validate()
        self.log.close()
        return ret


    def validate(self):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None:
            self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                               emb_decoder=self.emb_decoder)

            dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt))
            dev_wer['ctc'].append(cal_er(self.tokenizer, ctc_output, txt, ctc=True))

            # Show some example on tensorboard
            if i == len(self.dv_set)//2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if True:
                        self.write_log('true_text{}'.format(
                            i), self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log('att_align{}'.format(i), feat_to_fig(
                            att_align[i, 0, :, :].cpu().detach()))
                        self.write_log('att_text{}'.format(i), self.tokenizer.decode(
                            att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log('ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].argmax(dim=-1).tolist(),
                                                                                     ignore_repeat=True))

        # Skip save model here
        # Ckpt if performance improves
        to_prints = []
        for task in ['att', 'ctc']:
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            if dev_wer[task] < self.best_wer[task]:
                to_print = f"WER of {task}: {dev_wer[task]} < prev best ({self.best_wer[task]})"
                self.best_wer[task] = dev_wer[task]
            else:
                to_print = f"WER of {task}: {dev_wer[task]} >= prev best ({self.best_wer[task]})"
            print(to_print, flush=True)
            to_prints.append(to_print)
        #         self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task])
            self.write_log('wer', {'dv_'+task: dev_wer[task]})
        # self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False)

        # Resume training
        self.model.train()
        if self.emb_decoder is not None:
            self.emb_decoder.train()
        return '\n'.join(to_prints)
예제 #22
0
def train(flags):
    """
    flags:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # ================================================================================== #
    # Initialization for training on different devices
    # - CPU/GPU
    # - Single/Distributed
    Constants.USE_GPU = flags.use_gpu

    if flags.multi_gpu:
        dist.distributed_init(flags.shared_dir)
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        local_rank = dist.get_local_rank()
    else:
        world_size = 1
        rank = 0
        local_rank = 0

    if Constants.USE_GPU:
        torch.cuda.set_device(local_rank)
        Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank)
    else:
        Constants.CURRENT_DEVICE = "cpu"

    # If not root_rank, close logging
    # else write log of training to file.
    if rank == 0:
        write_log_to_file(
            os.path.join(flags.log_path,
                         "%s.log" % time.strftime("%Y%m%d-%H%M%S")))
    else:
        close_logging()

    # ================================================================================== #
    # Parsing configuration files
    # - Load default settings
    # - Load pre-defined settings
    # - Load user-defined settings

    configs = prepare_configs(flags.config_path, flags.predefined_config)

    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    INFO(pretty_configs(configs))

    # use odc
    if training_configs['use_odc'] is True:
        ave_best_k = check_odc_config(training_configs)
    else:
        ave_best_k = 0

    Constants.SEED = training_configs['seed']

    set_seed(Constants.SEED)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0])
    vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1])

    Constants.EOS = vocab_src.eos
    Constants.PAD = vocab_src.pad
    Constants.BOS = vocab_src.bos

    train_bitext_dataset = ZipDataset(
        TextLineDataset(data_path=data_configs['train_data'][0],
                        vocabulary=vocab_src,
                        max_len=data_configs['max_len'][0],
                        is_train_dataset=True),
        TextLineDataset(data_path=data_configs['train_data'][1],
                        vocabulary=vocab_tgt,
                        max_len=data_configs['max_len'][1],
                        is_train_dataset=True))

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=data_configs['valid_data'][0],
            vocabulary=vocab_src,
            is_train_dataset=False,
        ),
        TextLineDataset(data_path=data_configs['valid_data'][1],
                        vocabulary=vocab_tgt,
                        is_train_dataset=False))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=training_configs["batch_size"],
        use_bucket=training_configs['use_bucket'],
        buffer_size=training_configs['buffer_size'],
        batching_func=training_configs['batching_key'],
        world_size=world_size,
        rank=rank)

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True,
        world_size=world_size,
        rank=rank)

    bleu_scorer = SacreBLEUScorer(
        reference_path=data_configs["bleu_valid_reference"],
        num_refs=data_configs["num_refs"],
        lang_pair=data_configs["lang_pair"],
        sacrebleu_args=training_configs["bleu_valid_configs"]
        ['sacrebleu_args'],
        postprocess=training_configs["bleu_valid_configs"]['postprocess'])

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial

    lrate = optimizer_configs['learning_rate']
    model_collections = Collections()

    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])

    best_model_prefix = os.path.join(
        flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX)

    best_k_saver = BestKSaver(
        save_prefix="{0}.best_k_ckpt".format(
            os.path.join(flags.saveto, flags.model_name)),
        num_max_keeping=training_configs['num_kept_best_k_checkpoints'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    nmt_model = build_model(n_src_vocab=vocab_src.max_n_words,
                            n_tgt_vocab=vocab_tgt.max_n_words,
                            padding_idx=vocab_src.pad,
                            vocab_src=vocab_src,
                            **model_configs)
    INFO(nmt_model)

    # build teacher model
    teacher_model, teacher_model_path = get_teacher_model(
        training_configs, model_configs, vocab_src, vocab_tgt, flags)

    # build critic
    critic = CombinationCriterion(model_configs['loss_configs'],
                                  padding_idx=vocab_tgt.pad,
                                  teacher=teacher_model)
    # INFO(critic)
    critic.INFO()

    # 2. Move to GPU
    if Constants.USE_GPU:
        nmt_model = nmt_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    load_pretrained_model(nmt_model,
                          flags.pretrain_path,
                          exclude_prefix=None,
                          device=Constants.CURRENT_DEVICE)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 4. Build optimizer
    INFO('Building Optimizer...')

    if not flags.multi_gpu:
        optim = Optimizer(name=optimizer_configs['optimizer'],
                          model=nmt_model,
                          lr=lrate,
                          grad_clip=optimizer_configs['grad_clip'],
                          optim_args=optimizer_configs['optimizer_params'],
                          update_cycle=training_configs['update_cycle'])
    else:
        optim = dist.DistributedOptimizer(
            name=optimizer_configs['optimizer'],
            model=nmt_model,
            lr=lrate,
            grad_clip=optimizer_configs['grad_clip'],
            optim_args=optimizer_configs['optimizer_params'],
            device_id=local_rank)

    # 5. Build scheduler for optimizer if needed
    scheduler = build_scheduler(
        schedule_method=optimizer_configs['schedule_method'],
        optimizer=optim,
        scheduler_configs=optimizer_configs['scheduler_configs'])

    # 6. build moving average
    ma = build_ma(training_configs, nmt_model.named_parameters())

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if flags.reload:
        checkpoint_saver.load_latest(model=nmt_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma,
                                     device=Constants.CURRENT_DEVICE)

    # broadcast parameters and optimizer states
    if world_size > 1:
        INFO("Broadcasting model parameters...")
        dist.broadcast_parameters(params=nmt_model.state_dict())
        INFO("Broadcasting optimizer states...")
        dist.broadcast_optimizer_state(optimizer=optim.optim)
        INFO('Done.')

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [1])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]
    is_early_stop = model_collections.get_collection("is_early_stop", [
        False,
    ])[-1]
    teacher_patience = model_collections.get_collection(
        "teacher_patience", [training_configs['teacher_patience']])[-1]

    train_loss_meter = AverageMeter()
    train_loss_dict_meter = AverageMeterDict(critic.get_critic_name())
    sent_per_sec_meter = TimeMeter()
    tok_per_sec_meter = TimeMeter()

    update_cycle = training_configs['update_cycle']
    grad_denom = 0
    train_loss = 0.0
    cum_n_words = 0
    train_loss_dict = dict()
    valid_loss = best_valid_loss = float('inf')

    if rank == 0:
        summary_writer = SummaryWriter(log_dir=flags.log_path)
    else:
        summary_writer = None

    sent_per_sec_meter.start()
    tok_per_sec_meter.start()

    INFO('Begin training...')

    while True:

        if summary_writer is not None:
            summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()

        if rank == 0:
            training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
                eidx, uidx),
                                         total=len(training_iterator),
                                         unit="sents")
        else:
            training_progress_bar = None

        for batch in training_iter:

            seqs_x, seqs_y = batch

            batch_size = len(seqs_x)
            cum_n_words += sum(len(s) for s in seqs_y)

            try:
                # Prepare data
                x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU)

                loss, loss_dict = compute_forward(
                    model=nmt_model,
                    critic=critic,
                    seqs_x=x,
                    seqs_y=y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=training_configs["norm_by_words"])

                update_cycle -= 1
                grad_denom += batch_size
                train_loss += loss
                train_loss_dict = add_dict_value(train_loss_dict, loss_dict)

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                else:
                    raise e

            # When update_cycle becomes 0, it means end of one batch. Several things will be done:
            # - update parameters
            # - reset update_cycle and grad_denom, update uidx
            # - learning rate scheduling
            # - update moving average

            if update_cycle == 0:

                # 0. reduce variables
                if world_size > 1:
                    grad_denom = dist.all_reduce_py(grad_denom)
                    train_loss = dist.all_reduce_py(train_loss)
                    train_loss_dict = dist.all_reduce_py(train_loss_dict)
                    cum_n_words = dist.all_reduce_py(cum_n_words)

                # 1. update parameters
                optim.step(denom=grad_denom)
                optim.zero_grad()

                if training_progress_bar is not None:
                    training_progress_bar.update(grad_denom)
                    training_progress_bar.set_description(
                        ' - (Epc {}, Upd {}) '.format(eidx, uidx))

                    postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format(
                        train_loss, valid_loss, best_valid_loss)
                    for critic_name, loss_value in train_loss_dict.items():
                        postfix_str += (critic_name +
                                        ': {:.2f}, ').format(loss_value)
                    training_progress_bar.set_postfix_str(postfix_str)

                # 2. learning rate scheduling
                if scheduler is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                    scheduler.step(global_step=uidx)

                # 3. update moving average
                if ma is not None and eidx >= training_configs[
                        'moving_average_start_epoch']:
                    ma.step()

                # 4. update meters
                train_loss_meter.update(train_loss, grad_denom)
                train_loss_dict_meter.update(train_loss_dict, grad_denom)
                sent_per_sec_meter.update(grad_denom)
                tok_per_sec_meter.update(cum_n_words)

                # 5. reset accumulated variables, update uidx
                update_cycle = training_configs['update_cycle']
                grad_denom = 0
                uidx += 1
                cum_n_words = 0.0
                train_loss = 0.0
                train_loss_dict = dict()

            else:
                continue

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):

                lrate = list(optim.get_lrate())[0]

                if summary_writer is not None:
                    summary_writer.add_scalar(
                        "Speed(sents/sec)",
                        scalar_value=sent_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "Speed(words/sec)",
                        scalar_value=tok_per_sec_meter.ave,
                        global_step=uidx)
                    summary_writer.add_scalar(
                        "train_loss",
                        scalar_value=train_loss_meter.ave,
                        global_step=uidx)
                    # add loss for every critic
                    if flags.display_loss_detail:
                        combination_loss = train_loss_dict_meter.value
                        for key, value in combination_loss.items():
                            summary_writer.add_scalar(key,
                                                      scalar_value=value,
                                                      global_step=uidx)
                    summary_writer.add_scalar("lrate",
                                              scalar_value=lrate,
                                              global_step=uidx)
                    summary_writer.add_scalar("oom_count",
                                              scalar_value=oom_count,
                                              global_step=uidx)

                # Reset Meters
                sent_per_sec_meter.reset()
                tok_per_sec_meter.reset()
                train_loss_meter.reset()
                train_loss_dict_meter.reset()

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    debug=flags.debug):
                with cache_parameters(nmt_model):

                    valid_loss, valid_loss_dict = loss_evaluation(
                        model=nmt_model,
                        critic=critic,
                        valid_iterator=valid_iterator,
                        rank=rank,
                        world_size=world_size)

                if scheduler is not None and optimizer_configs[
                        "schedule_method"] == "loss":
                    scheduler.step(metric=valid_loss)

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()
                best_valid_loss = min_history_loss

                if summary_writer is not None:
                    summary_writer.add_scalar("loss",
                                              valid_loss,
                                              global_step=uidx)
                    summary_writer.add_scalar("best_loss",
                                              min_history_loss,
                                              global_step=uidx)

            # ================================================================================== #
            # BLEU Validation & Early Stop
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['bleu_valid_freq'],
                    min_step=training_configs['bleu_valid_warmup'],
                    debug=flags.debug):

                with cache_parameters(nmt_model):

                    valid_bleu = bleu_evaluation(
                        uidx=uidx,
                        valid_iterator=valid_iterator,
                        batch_size=training_configs["bleu_valid_batch_size"],
                        model=nmt_model,
                        bleu_scorer=bleu_scorer,
                        vocab_src=vocab_src,
                        vocab_tgt=vocab_tgt,
                        valid_dir=flags.valid_path,
                        max_steps=training_configs["bleu_valid_configs"]
                        ["max_steps"],
                        beam_size=training_configs["bleu_valid_configs"]
                        ["beam_size"],
                        alpha=training_configs["bleu_valid_configs"]["alpha"],
                        world_size=world_size,
                        rank=rank,
                    )

                model_collections.add_to_collection(key="history_bleus",
                                                    value=valid_bleu)

                best_valid_bleu = float(
                    np.array(model_collections.get_collection(
                        "history_bleus")).max())

                if summary_writer is not None:
                    summary_writer.add_scalar("bleu", valid_bleu, uidx)
                    summary_writer.add_scalar("best_bleu", best_valid_bleu,
                                              uidx)

                # If model get new best valid bleu score
                if valid_bleu >= best_valid_bleu:
                    bad_count = 0

                    if is_early_stop is False:
                        if rank == 0:
                            # 1. save the best model
                            torch.save(nmt_model.state_dict(),
                                       best_model_prefix + ".final")

                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")
                        exit(0)

                if rank == 0:
                    best_k_saver.save(global_step=uidx,
                                      metric=valid_bleu,
                                      model=nmt_model,
                                      optim=optim,
                                      lr_scheduler=scheduler,
                                      collections=model_collections,
                                      ma=ma)

                # ODC
                if training_configs['use_odc'] is True:
                    if valid_bleu >= best_valid_bleu:
                        pass

                        # choose method to generate teachers from checkpoints
                        # - best
                        # - ave_k_best
                        # - ma

                        if training_configs['teacher_choice'] == 'ma':
                            teacher_params = ma.export_ma_params()
                        elif training_configs['teacher_choice'] == 'best':
                            teacher_params = nmt_model.state_dict()
                        elif "ave_best" in training_configs['teacher_choice']:
                            if best_k_saver.num_saved >= ave_best_k:
                                teacher_params = average_checkpoints(
                                    best_k_saver.get_all_ckpt_path()
                                    [-ave_best_k:])
                            else:
                                teacher_params = nmt_model.state_dict()
                        else:
                            raise ValueError(
                                "can not support teacher choice %s" %
                                training_configs['teacher_choice'])
                        torch.save(teacher_params, teacher_model_path)
                        del teacher_params
                        teacher_patience = 0
                        critic.set_use_KD(False)
                    else:
                        teacher_patience += 1
                        if teacher_patience >= training_configs[
                                'teacher_refresh_warmup']:
                            teacher_params = torch.load(
                                teacher_model_path,
                                map_location=Constants.CURRENT_DEVICE)
                            teacher_model.load_state_dict(teacher_params,
                                                          strict=False)
                            del teacher_params
                            critic.reset_teacher(teacher_model)
                            critic.set_use_KD(True)

                if summary_writer is not None:
                    summary_writer.add_scalar("bad_count", bad_count, uidx)

                info_str = "{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4} ".format(
                    uidx, valid_loss, valid_bleu, lrate, bad_count)
                for key, value in valid_loss_dict.items():
                    info_str += (key + ': {0:.2f} '.format(value))
                INFO(info_str)

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(
                    uidx,
                    eidx,
                    every_n_step=training_configs['save_freq'],
                    debug=flags.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)
                model_collections.add_to_collection("teacher_patience",
                                                    teacher_patience)
                if not is_early_stop:
                    if rank == 0:
                        checkpoint_saver.save(global_step=uidx,
                                              model=nmt_model,
                                              optim=optim,
                                              lr_scheduler=scheduler,
                                              collections=model_collections,
                                              ma=ma)

        if training_progress_bar is not None:
            training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
예제 #23
0
class Translate_Env(object):
    """
    wrap translate environment for multiple agents
    env needs parallel data to evaluate bleu_degredation
    state of the env is defined as the batched src labels and current target index
    environment yields rewards based on discriminator and finally by sentence-level BLEU
    :return: translation multiple sentences and return changed bleu
    """
    def __init__(
        self,
        attack_configs,
        discriminator_configs,
        src_vocab,
        trg_vocab,
        data_iterator,
        save_to,
        device="cpu",
    ):
        """
        initiate translation environments, needs a discriminator and translator
        :param attack_configs: attack configures dictionary
        :param save_to: discriminator models
        :param data_iterator: use to provide data for environment initiate
        the directory of the src sentences
        :param device: (string) devices to allocate variables("cpu", "cuda:*")
        default as cpu
        """
        self.data_iterator = data_iterator
        discriminator_model_configs = discriminator_configs[
            "discriminator_model_configs"]
        discriminator_optim_configs = discriminator_configs[
            "discriminator_optimizer_configs"]
        self.victim_config_path = attack_configs["victim_configs"]
        self.victim_model_path = attack_configs["victim_model"]
        # determine devices
        self.device = device
        with open(self.victim_config_path.strip()) as v_f:
            print("open victim configs...%s" % self.victim_config_path)
            victim_configs = yaml.load(v_f)

        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.translate_model = build_translate_model(victim_configs,
                                                     self.victim_model_path,
                                                     vocab_src=self.src_vocab,
                                                     vocab_trg=self.trg_vocab,
                                                     device=self.device)
        self.translate_model.eval()
        self.w2p, self.w2vocab = load_or_extract_near_vocab(
            config_path=self.victim_config_path,
            model_path=self.victim_model_path,
            init_perturb_rate=attack_configs["init_perturb_rate"],
            save_to=os.path.join(save_to, "near_vocab"),
            save_to_full=os.path.join(save_to, "full_near_vocab"),
            top_reserve=12,
            emit_as_id=True)
        #########################################################
        # to update discriminator
        # discriminator_data_configs = attack_configs["discriminator_data_configs"]
        self.discriminator = TransDiscriminator(
            n_src_words=self.src_vocab.max_n_words,
            n_trg_words=self.trg_vocab.max_n_words,
            **discriminator_model_configs)
        self.discriminator.to(self.device)

        load_embedding(self.discriminator,
                       model_path=self.victim_model_path,
                       device=self.device)

        self.optim_D = Optimizer(
            name=discriminator_optim_configs["optimizer"],
            model=self.discriminator,
            lr=discriminator_optim_configs["learning_rate"],
            grad_clip=discriminator_optim_configs["grad_clip"],
            optim_args=discriminator_optim_configs["optimizer_params"])
        self.criterion_D = nn.CrossEntropyLoss(
        )  # used in discriminator updates
        self.scheduler_D = None  # default as None
        if discriminator_optim_configs['schedule_method'] is not None:
            if discriminator_optim_configs['schedule_method'] == "loss":
                self.scheduler_D = ReduceOnPlateauScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs["scheduler_configs"])
            elif discriminator_optim_configs['schedule_method'] == "noam":
                self.scheduler_D = NoamScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs['scheduler_configs'])
            elif discriminator_optim_configs["schedule_method"] == "rsqrt":
                self.scheduler_D = RsqrtScheduler(
                    optimizer=self.optim_D,
                    **discriminator_optim_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".
                     format(discriminator_optim_configs['schedule_method']))
        ############################################################
        self._init_state()
        self.adversarial = attack_configs[
            "adversarial"]  # adversarial sample or reinforced samples
        self.r_s_weight = attack_configs["r_s_weight"]
        self.r_d_weight = attack_configs["r_d_weight"]

    def _init_state(self):
        """
        initiate batched sentences / origin_bleu / index (start from first label, no BOS/EOS)
        the initial state of the environment
        :return: env states (the src, index)
        """
        self.index = 1
        self.origin_bleu = []
        batch = next(self.data_iterator)
        assert len(
            batch
        ) == 3, "must be provided with line index (check for data_iterator)"
        # training, parallel trg is provided
        _, seqs_x, self.seqs_y = batch
        self.sent_len = [len(x) for x in seqs_x]  # for terminal signals
        self.terminal_signal = [0] * len(seqs_x)  # for terminal signals

        self.padded_src, self.padded_trg = self.prepare_data(
            seqs_x=seqs_x, seqs_y=self.seqs_y)
        self.origin_result = self.translate()
        # calculate BLEU scores for the top candidate
        for index, sent_t in enumerate(self.seqs_y):
            bleu_t = bleu.sentence_bleu(references=[sent_t],
                                        hypothesis=self.origin_result[index],
                                        emulate_multibleu=True)
            self.origin_bleu.append(bleu_t)
        return self.padded_src.cpu().numpy()

    def get_src_vocab(self):
        return self.src_vocab

    def reset(self):
        return self._init_state()

    def reset_data_iter(
            self, data_iter):  # reset data iterator with provided iterator
        self.data_iterator = data_iter
        return

    def reset_discriminator(self):
        self.discriminator.reset()
        load_embedding(self.discriminator,
                       model_path=self.victim_model_path,
                       device=self.device)

    def prepare_D_data(self, attacker, seqs_x, seqs_y, batch_first=True):
        """
        using global_attacker to generate training data for discriminator
        :param attacker: prepare the data
        :param seqs_x: list of sources
        :param seqs_y: corresponding targets
        :param batch_first: first dimension of seqs be batch
        :param device: cpu or cuda*
        :return: perturbed seqsx, seqsy, flags
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            # pack seqs into tensor with pads
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size),
                           fill_value=pad,
                           dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x

        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))

        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first)
        # training mode attack: randomly choose half of the seqs to attack
        attacker.eval()
        x, flags = attacker.seq_attack(x, self.w2vocab, training_mode=True)

        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))

        y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first)
        flags.to(self.device)

        # # print trace
        # flag_list = flags.cpu().numpy().tolist()
        # x_list = x.cpu().numpy().tolist()
        # for i in range(len(flag_list)):
        #     if flag_list[i]==1:
        #         print(self.src_vocab.ids2sent(seqs_x[i]))
        #         print(self.src_vocab.ids2sent(x_list[i]))
        #         print(self.trg_vocab.ids2sent(seqs_y[i]))
        return x, y, flags

    def prepare_data(self, seqs_x, seqs_y=None, batch_first=True):
        """
        Args:
            eval ('bool'): indicator for eval/infer.
        Returns: padded data matrices
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size),
                           fill_value=pad,
                           dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x

        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))
        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first)
        if seqs_y is None:
            return x
        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))
        y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first)

        return x, y

    def acc_validation(self, attacker):
        self.discriminator.eval()
        acc = 0
        sample_count = 0
        for i in range(5):
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            seq_nums, seqs_x, seqs_y = batch
            x, y, flags = self.prepare_D_data(attacker, seqs_x, seqs_y)
            # set components to evaluation mode
            self.discriminator.eval()
            with torch.no_grad():
                preds = self.discriminator(x, y).argmax(dim=-1)
                acc += torch.eq(preds, flags).sum()
                sample_count += preds.size(0)
        acc = acc.float() / sample_count
        return acc.item()

    def compute_D_forward(self, seqs_x, seqs_y, gold_flags, evaluate=False):
        """
        get loss according to criterion
        :param: gold_flags=1 if perturbed, otherwise 0
        :return: loss value
        """
        if not evaluate:
            # set components to training mode(dropout layers)
            self.discriminator.train()
            self.criterion_D.train()
            with torch.enable_grad():
                class_probs = self.discriminator(seqs_x, seqs_y)
                loss = self.criterion_D(class_probs, gold_flags)
            torch.autograd.backward(loss)
            return loss.item()
        else:
            # set components to evaluation mode(dropout layers)
            self.discriminator.eval()
            self.criterion_D.eval()
            with torch.no_grad():
                class_probs = self.discriminator(seqs_x, seqs_y)
                loss = self.criterion_D(class_probs, gold_flags)
        return loss.item()

    def update_discriminator(self,
                             attacker_model,
                             base_steps=0,
                             min_update_steps=20,
                             max_update_steps=300,
                             accuracy_bound=0.8,
                             summary_writer=None):
        """
        update discriminator
        :param attacker_model: attacker to generate training data for discriminator
        :param base_steps: used for saving
        :param min_update_steps: (integer) minimum update steps,
                    also the discriminator evaluate steps
        :param max_update_steps: (integer) maximum update steps
        :param accuracy_bound: (float) update until accuracy reaches the bound
                    (or max_update_steps)
        :param summary_writer: used to log discriminator learning information
        :return: steps and test accuracy as trust region
        """
        INFO("update discriminator")
        self.optim_D.zero_grad()
        attacker_model = attacker_model.to(self.device)
        step = 0
        while True:
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            # update the discriminator
            step += 1
            if self.scheduler_D is not None:
                # override learning rate in self.optim_D
                self.scheduler_D.step(global_step=step)
            _, seqs_x, seqs_y = batch  # returned tensor type of the data
            try:
                x, y, flags = self.prepare_D_data(attacker_model, seqs_x,
                                                  seqs_y)
                loss = self.compute_D_forward(seqs_x=x,
                                              seqs_y=y,
                                              gold_flags=flags)
                self.optim_D.step()
                print("discriminator loss:", loss)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print("WARNING: out of memory, skipping batch")
                    self.optim_D.zero_grad()
                else:
                    raise e

            # valid for accuracy / check for break (if any)
            if step % min_update_steps == 0:
                acc = self.acc_validation(attacker_model)
                print("discriminator acc: %2f" % acc)
                summary_writer.add_scalar("discriminator",
                                          scalar_value=acc,
                                          global_step=base_steps + step)
                if accuracy_bound and acc > accuracy_bound:
                    INFO("discriminator reached training acc bound, updated.")
                    return base_steps + step, acc

            if step > max_update_steps:
                acc = self.acc_validation(attacker_model)
                print("discriminator acc: %2f" % acc)
                INFO("Reach maximum discriminator update. Finished.")
                return base_steps + step, acc  # stop updates

    def translate(self, inputs=None):
        """
        translate the self.perturbed_src
        :param inputs: if None, translate perturbed sequences stored in the environments
        :return: list of translation results
        """
        if inputs is None:
            inputs = self.padded_src
        with torch.no_grad():
            print(inputs.device)
            perturbed_results = beam_search(
                self.translate_model,
                beam_size=5,
                max_steps=150,
                src_seqs=inputs,
                alpha=-1.0,
            )
        perturbed_results = perturbed_results.cpu().numpy().tolist()
        # only use the top result from the result
        result = []
        for sent in perturbed_results:
            sent = [wid for wid in sent[0] if wid != PAD]
            result.append(sent)

        return result

    def step(self, actions):
        """
        step update for the environment: finally update self.index
        this is defined as inference of the environments
        :param actions: whether to perturb (action distribution vector
                    in shape [batch, 1])on current index
                 *  result of torch.argmax(actor_output_distribution, dim=-1)
                    test: actions = actor_output_distribution.argmax(dim=-1)
                    or train: actions = actor.output_distribution.multinomial(dim=-1)
                    can be on cpu or cuda.
        :return: updated states/ rewards/ terminal signal from the environments
                 reward (float), terminal_signal (boolean)
        """
        with torch.no_grad():
            terminal = False  # default is not terminated
            batch_size = actions.shape[0]
            reward = 0
            inputs = self.padded_src[:, self.index]
            inputs_mask = ~inputs.eq(PAD)
            target_of_step = []
            # modification on sequences (state)
            for batch_index in range(batch_size):
                word_id = inputs[batch_index]
                target_word_id = self.w2vocab[word_id.item()][np.random.choice(
                    len(self.w2vocab[word_id.item()]), 1)[0]]
                target_of_step += [target_word_id]
            if self.device != "cpu" and not actions.is_cuda:
                actions = actions.to(self.device)
                actions *= inputs_mask  # PAD is neglect
            # override the state src with random choice from candidates
            self.padded_src[:, self.index] *= (1 - actions)
            adjustification_ = torch.tensor(target_of_step)
            adjustification_ = adjustification_.to(self.device)
            self.padded_src[:, self.index] += adjustification_ * actions

            # update sequences' pointer
            self.index += 1
            """ run discriminator check for terminal signals, update local terminal list
            True: all sentences in the batch is defined as false by self.discriminator
            False: otherwise
            """
            # get discriminator distribution on the current src state
            discriminate_out = self.discriminator(self.padded_src,
                                                  self.padded_trg)
            self.terminal_signal = self.terminal_signal or discriminate_out.detach(
            ).argmax(dim=-1).cpu().numpy().tolist()
            signal = (1 - discriminate_out.argmax(dim=-1)).sum().item()
            if signal == 0 or self.index == self.padded_src.shape[1] - 1:
                terminal = True  # no need to further explore or reached EOS for all src
            """ collect rewards on the current state
            """
            # calculate intermediate survival rewards
            if not terminal:
                # survival rewards for survived objects
                distribution, discriminate_index = discriminate_out.max(dim=-1)
                distribution = distribution.detach().cpu().numpy()
                discriminate_index = (1 - discriminate_index).cpu().numpy()
                survival_value = distribution * discriminate_index * (
                    1 - np.array(self.terminal_signal))
                reward += survival_value.sum() * self.r_s_weight
            else:  # only penalty for overall intermediate termination
                reward = -1 * batch_size

            # only check for finished relative BLEU degradation when survival on the last label
            if self.index == self.padded_src.shape[1] - 1:
                # re-tokenize ignore the original UNK for victim model
                inputs = self.padded_src.cpu().numpy().tolist()
                new_inputs = []
                for indices in inputs:
                    # remove EOS, BOS, PAD
                    new_line = [
                        word_id for word_id in indices
                        if word_id not in [EOS, BOS, PAD]
                    ]
                    new_line = self.src_vocab.ids2sent(new_line)
                    if not hasattr(self.src_vocab.tokenizer, "bpe"):
                        new_line = new_line.strip().split()
                    else:
                        new_token = []
                        for w in new_line.strip().split():
                            if w != self.src_vocab.id2token(UNK):
                                new_token.append(
                                    self.src_vocab.tokenizer.bpe.segment_word(
                                        w))
                            else:
                                new_token.append([w])
                        new_line = sum(new_token, [])
                    new_line = [self.src_vocab.token2id(t) for t in new_line]
                    new_inputs.append(new_line)
                # translate calculate padded_src
                perturbed_result = self.translate(
                    self.prepare_data(seqs_x=new_inputs, ))
                # calculate final BLEU degredation:
                episodic_rewards = []
                for i, sent in enumerate(self.seqs_y):
                    # sentence is still surviving
                    if self.index >= self.sent_len[
                            i] - 1 and self.terminal_signal[i] == 0:
                        if self.origin_bleu[i] == 0:
                            # here we want to minimize noise from original bad cases
                            relative_degraded_value = 0
                        else:
                            relative_degraded_value = (
                                self.origin_bleu[i] - bleu.sentence_bleu(
                                    references=[sent],
                                    hypothesis=perturbed_result[i],
                                    emulate_multibleu=True))

                            # print(relative_degraded_value, self.origin_bleu[i])
                            relative_degraded_value /= self.origin_bleu[i]
                        if self.adversarial:
                            episodic_rewards.append(relative_degraded_value)
                        else:
                            episodic_rewards.append(-relative_degraded_value)
                    else:
                        episodic_rewards.append(0.0)
                reward += sum(episodic_rewards) * self.r_d_weight

            reward = reward / batch_size

        return self.padded_src.cpu().numpy(), reward, terminal,
예제 #24
0
    def _init_local_optims(self, rephraser_optimizer_configs):
        """ actor, critic, alpha optimizers and lr scheduler if necessary
        rephraser_optimizer_configs:
            optimizer: "adafactor"
            learning_rate: 0.01
            grad_clip: -1.0
            optimizer_params: ~
            schedule_method: rsqrt
            scheduler_configs:
                d_model: *dim
                warmup_steps: 100
        """
        # initiate local optimizer
        if rephraser_optimizer_configs is None:
            self.actor_optimizer = None
            self.critic_optimizer = None
            self.log_alpha_optimizer = None
            # self.actor_icm_optimizer = None
            self.actor_scheduler = None
            self.critic_scheduler = None
        else:
            self.actor_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.actor,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            self.critic_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.critic,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            # hardcoded entropy weight updates and icm updates
            self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                        lr=1e-4,
                                                        betas=(0.9, 0.999))
            # self.actor_icm_optimizer = torch.optim.Adam(self.actor.icm.parameters(), lr=1e-3, )

            # Build scheduler for optimizer if needed
            if rephraser_optimizer_configs['schedule_method'] is not None:
                if rephraser_optimizer_configs['schedule_method'] == "loss":
                    self.actor_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs['schedule_method'] == "noam":
                    self.actor_scheduler = NoamScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = NoamScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs["schedule_method"] == "rsqrt":
                    self.actor_scheduler = RsqrtScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = RsqrtScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                else:
                    WARN(
                        "Unknown scheduler name {0}. Do not use lr_scheduling."
                        .format(
                            rephraser_optimizer_configs['schedule_method']))
                    self.actor_scheduler = None
                    self.critic_scheduler = None
            else:
                self.actor_scheduler = None
                self.critic_scheduler = None
예제 #25
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)

        # Curriculum learning affects data loader
        self.curriculum = self.config['hparas']['curriculum']
        self.val_mode = self.config['hparas']['val_mode'].lower()
        self.WER = 'per' if self.val_mode == 'per' else 'wer'

    def fetch_data(self, data, train=False):
        ''' Move data to device and compute text seq. length'''
        # feat: B x T x D
        _, feat, feat_len, txt = data

        if self.paras.upstream is not None:
            # feat is raw waveform
            device = 'cpu' if self.paras.deterministic else self.device
            self.upstream.to(device)
            self.specaug.to(device)

            def to_device(feat):
                return [f.to(device) for f in feat]

            def extract_feature(feat):
                feat = self.upstream(to_device(feat))
                if train and self.config['data']['audio'][
                        'augment'] and 'aug' not in self.paras.upstream:
                    feat = [self.specaug(f) for f in feat]
                return feat

            if HALF_BATCHSIZE_AUDIO_LEN < 3500 and train:
                first_len = extract_feature(feat[:1])[0].shape[0]
                if first_len > HALF_BATCHSIZE_AUDIO_LEN:
                    feat = feat[::2]
                    txt = txt[::2]

            if self.paras.upstream_trainable:
                self.upstream.train()
                feat = extract_feature(feat)
            else:
                with torch.no_grad():
                    self.upstream.eval()
                    feat = extract_feature(feat)

            feat_len = torch.LongTensor([len(f) for f in feat])
            feat = pad_sequence(feat, batch_first=True)
            txt = pad_sequence(txt, batch_first=True)

        feat = feat.to(self.device)
        feat_len = feat_len.to(self.device)
        txt = txt.to(self.device)
        txt_len = torch.sum(txt != 0, dim=-1)

        return feat, feat_len, txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        if self.paras.upstream is not None:
            print(f'[Solver] - using S3PRL {self.paras.upstream}')
            self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \
                            load_wav_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                        self.curriculum>0,
                                        **self.config['data'])
            self.upstream = torch.hub.load(
                's3prl/s3prl',
                self.paras.upstream,
                feature_selection=self.paras.upstream_feature_selection,
                refresh=self.paras.upstream_refresh,
                ckpt=self.paras.upstream_ckpt,
                force_reload=True,
            )
            self.feat_dim = self.upstream.get_output_dim()
            self.specaug = Augment()
        else:
            self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      self.curriculum>0,
                                      **self.config['data'])
        self.verbose(msg)

        # Dev set sames
        self.dv_names = []
        if type(self.dv_set) is list:
            for ds in self.config['data']['corpus']['dev_split']:
                self.dv_names.append(ds[0])
        else:
            self.dv_names = self.config['data']['corpus']['dev_split'][0]

        # Logger settings
        if type(self.dv_names) is str:
            self.best_wer = {
                'att': {
                    self.dv_names: 3.0
                },
                'ctc': {
                    self.dv_names: 3.0
                }
            }
        else:
            self.best_wer = {'att': {}, 'ctc': {}}
            for name in self.dv_names:
                self.best_wer['att'][name] = 3.0
                self.best_wer['ctc'][name] = 3.0

    def set_model(self):
        ''' Setup ASR model and optimizer '''
        # Model
        #print(self.feat_dim) #160
        batch_size = self.config['data']['corpus']['batch_size'] // 2
        self.model = ASR(self.feat_dim, self.vocab_size, batch_size,
                         **self.config['model']).to(self.device)

        self.verbose(self.model.create_msg())
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        '''label smoothing'''
        if self.config['hparas']['label_smoothing']:
            self.seq_loss = LabelSmoothingLoss(31, 0.1)
            print('[INFO]  using label smoothing. ')
        else:
            self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        self.ctc_loss = torch.nn.CTCLoss(
            blank=0,
            zero_infinity=False)  # Note: zero_infinity=False is unstable?

        # Plug-ins
        self.emb_fuse = False
        self.emb_reg = ('emb'
                        in self.config) and (self.config['emb']['enable'])
        if self.emb_reg:
            from src.plugin import EmbeddingRegularizer
            self.emb_decoder = EmbeddingRegularizer(
                self.tokenizer, self.model.dec_dim,
                **self.config['emb']).to(self.device)
            model_paras.append({'params': self.emb_decoder.parameters()})
            self.emb_fuse = self.emb_decoder.apply_fuse
            if self.emb_fuse:
                self.seq_loss = torch.nn.NLLLoss(ignore_index=0)
            self.verbose(self.emb_decoder.create_msg())

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.lr_scheduler = self.optimizer.lr_scheduler
        self.verbose(self.optimizer.create_msg())

        # Enable AMP if needed
        self.enable_apex()

        # Transfer Learning
        if self.transfer_learning:
            self.verbose('Apply transfer learning: ')
            self.verbose('      Train encoder layers: {}'.format(
                self.train_enc))
            self.verbose('      Train decoder:        {}'.format(
                self.train_dec))
            self.verbose('      Save name:            {}'.format(
                self.save_name))

        # Automatically load pre-trained model if self.paras.load is given
        self.load_ckpt()

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        if self.transfer_learning:
            self.model.encoder.fix_layers(self.fix_enc)
            if self.fix_dec and self.model.enable_att:
                self.model.decoder.fix_layers()
            if self.fix_dec and self.model.enable_ctc:
                self.model.fix_ctc_layer()

        self.n_epochs = 0
        self.timer.set()
        '''early stopping for ctc '''
        self.early_stoping = self.config['hparas']['early_stopping']
        stop_epoch = 10
        batch_size = self.config['data']['corpus']['batch_size']
        stop_step = len(self.tr_set) * stop_epoch // batch_size

        while self.step < self.max_step:
            ctc_loss, att_loss, emb_loss = None, None, None
            # Renew dataloader to enable random sampling

            if self.curriculum > 0 and n_epochs == self.curriculum:
                self.verbose(
                    'Curriculum learning ends after {} epochs, starting random sampling.'
                    .format(n_epochs))
                self.tr_set, _, _, _, _, _ = \
                         load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory,
                                      False, **self.config['data'])

            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                tf_rate = self.optimizer.pre_step(self.step)
                total_loss = 0

                # Fetch data
                feat, feat_len, txt, txt_len = self.fetch_data(data,
                                                               train=True)

                self.timer.cnt('rd')
                # Forward model
                # Note: txt should NOT start w/ <sos>
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate,
                                    teacher=txt, get_dec_state=self.emb_reg)
                # Clear not used objects
                del att_align

                # Plugins
                if self.emb_reg:
                    emb_loss, fuse_output = self.emb_decoder(dec_state,
                                                             att_output,
                                                             label=txt)
                    total_loss += self.emb_decoder.weight * emb_loss
                else:
                    del dec_state
                ''' early stopping ctc'''
                if self.early_stoping:
                    if self.step > stop_step:
                        ctc_output = None
                        self.model.ctc_weight = 0
                #print(ctc_output.shape)
                # Compute all objectives
                if ctc_output is not None:
                    if self.paras.cudnn_ctc:
                        ctc_loss = self.ctc_loss(
                            ctc_output.transpose(0, 1),
                            txt.to_sparse().values().to(device='cpu',
                                                        dtype=torch.int32),
                            [ctc_output.shape[1]] * len(ctc_output),
                            #[int(encode_len.max()) for _ in encode_len],
                            txt_len.cpu().tolist())
                    else:
                        ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1),
                                                 txt, encode_len, txt_len)
                    total_loss += ctc_loss * self.model.ctc_weight
                    del encode_len

                if att_output is not None:
                    #print(att_output.shape)
                    b, t, _ = att_output.shape
                    att_output = fuse_output if self.emb_fuse else att_output
                    att_loss = self.seq_loss(att_output.view(b * t, -1),
                                             txt.view(-1))
                    # Sum each uttr and devide by length then mean over batch
                    # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float())
                    total_loss += att_loss * (1 - self.model.ctc_weight)

                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(total_loss)

                self.step += 1

                # Logger
                if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
                    self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\
                            .format(total_loss.cpu().item(),grad_norm,self.timer.show()))
                    self.write_log('emb_loss', {'tr': emb_loss})
                    if att_output is not None:
                        self.write_log('loss', {'tr_att': att_loss})
                        self.write_log(self.WER, {
                            'tr_att':
                            cal_er(self.tokenizer, att_output, txt)
                        })
                        self.write_log(
                            'cer', {
                                'tr_att':
                                cal_er(self.tokenizer,
                                       att_output,
                                       txt,
                                       mode='cer')
                            })
                    if ctc_output is not None:
                        self.write_log('loss', {'tr_ctc': ctc_loss})
                        self.write_log(
                            self.WER, {
                                'tr_ctc':
                                cal_er(
                                    self.tokenizer, ctc_output, txt, ctc=True)
                            })
                        self.write_log(
                            'cer', {
                                'tr_ctc':
                                cal_er(self.tokenizer,
                                       ctc_output,
                                       txt,
                                       mode='cer',
                                       ctc=True)
                            })
                        self.write_log(
                            'ctc_text_train',
                            self.tokenizer.decode(
                                ctc_output[0].argmax(dim=-1).tolist(),
                                ignore_repeat=True))
                    # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0:
                    #     self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True))
                    #del total_loss

                    if self.emb_fuse:
                        if self.emb_decoder.fuse_learnable:
                            self.write_log(
                                'fuse_lambda',
                                {'emb': self.emb_decoder.get_weight()})
                        self.write_log('fuse_temp',
                                       {'temp': self.emb_decoder.get_temp()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    if type(self.dv_set) is list:
                        for dv_id in range(len(self.dv_set)):
                            self.validate(self.dv_set[dv_id],
                                          self.dv_names[dv_id])
                    else:
                        self.validate(self.dv_set, self.dv_names)
                if self.step % (len(self.tr_set) //
                                batch_size) == 0:  # one epoch
                    print('Have finished epoch: ', self.n_epochs)
                    self.n_epochs += 1

                if self.lr_scheduler == None:
                    lr = self.optimizer.opt.param_groups[0]['lr']

                    if self.step == 1:
                        print(
                            '[INFO]    using lr schedular defined by Daniel, init lr = ',
                            lr)

                    if self.step > 99999 and self.step % 2000 == 0:
                        lr = lr * 0.85
                        for param_group in self.optimizer.opt.param_groups:
                            param_group['lr'] = lr
                        print('[INFO]     at step:', self.step)
                        print('[INFO]   lr reduce to', lr)

                    #self.lr_scheduler.step(total_loss)
                # End of step
                # if self.step % EMPTY_CACHE_STEP == 0:
                # Empty cuda cache after every fixed amount of steps
                torch.cuda.empty_cache(
                )  # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
                self.timer.set()
                if self.step > self.max_step: break

            #update lr_scheduler

        self.log.close()
        print('[INFO] Finished training after', human_format(self.max_step),
              'steps.')

    def validate(self, _dv_set, _name):
        # Eval mode
        self.model.eval()
        if self.emb_decoder is not None: self.emb_decoder.eval()
        dev_wer = {'att': [], 'ctc': []}
        dev_cer = {'att': [], 'ctc': []}
        dev_er = {'att': [], 'ctc': []}

        for i, data in enumerate(_dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(_dv_set)))
            # Fetch data
            feat, feat_len, txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                ctc_output, encode_len, att_output, att_align, dec_state = \
                    self.model( feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO),
                                    emb_decoder=self.emb_decoder)

            if att_output is not None:
                dev_wer['att'].append(
                    cal_er(self.tokenizer, att_output, txt, mode='wer'))
                dev_cer['att'].append(
                    cal_er(self.tokenizer, att_output, txt, mode='cer'))
                dev_er['att'].append(
                    cal_er(self.tokenizer, att_output, txt,
                           mode=self.val_mode))
            if ctc_output is not None:
                dev_wer['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode='wer',
                           ctc=True))
                dev_cer['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode='cer',
                           ctc=True))
                dev_er['ctc'].append(
                    cal_er(self.tokenizer,
                           ctc_output,
                           txt,
                           mode=self.val_mode,
                           ctc=True))

            # Show some example on tensorboard
            if i == len(_dv_set) // 2:
                for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
                    if self.step == 1:
                        self.write_log('true_text_{}_{}'.format(_name, i),
                                       self.tokenizer.decode(txt[i].tolist()))
                    if att_output is not None:
                        self.write_log(
                            'att_align_{}_{}'.format(_name, i),
                            feat_to_fig(att_align[i, 0, :, :].cpu().detach()))
                        self.write_log(
                            'att_text_{}_{}'.format(_name, i),
                            self.tokenizer.decode(
                                att_output[i].argmax(dim=-1).tolist()))
                    if ctc_output is not None:
                        self.write_log(
                            'ctc_text_{}_{}'.format(_name, i),
                            self.tokenizer.decode(
                                ctc_output[i].argmax(dim=-1).tolist(),
                                ignore_repeat=True))

        # Ckpt if performance improves
        tasks = []
        if len(dev_er['att']) > 0:
            tasks.append('att')
        if len(dev_er['ctc']) > 0:
            tasks.append('ctc')

        for task in tasks:
            dev_er[task] = sum(dev_er[task]) / len(dev_er[task])
            dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task])
            dev_cer[task] = sum(dev_cer[task]) / len(dev_cer[task])
            if dev_er[task] < self.best_wer[task][_name]:
                self.best_wer[task][_name] = dev_er[task]
                self.save_checkpoint(
                    'best_{}_{}.pth'.format(
                        task, _name +
                        (self.save_name if self.transfer_learning else '')),
                    self.val_mode, dev_er[task], _name)
            if self.step >= self.max_step:
                self.save_checkpoint(
                    'last_{}_{}.pth'.format(
                        task, _name +
                        (self.save_name if self.transfer_learning else '')),
                    self.val_mode, dev_er[task], _name)
            self.write_log(self.WER,
                           {'dv_' + task + '_' + _name.lower(): dev_wer[task]})
            self.write_log('cer',
                           {'dv_' + task + '_' + _name.lower(): dev_cer[task]})
            # if self.transfer_learning:
            #     print('[{}] WER {:.4f} / CER {:.4f} on {}'.format(human_format(self.step), dev_wer[task], dev_cer[task], _name))

        # Resume training
        self.model.train()
        if self.transfer_learning:
            self.model.encoder.fix_layers(self.fix_enc)
            if self.fix_dec and self.model.enable_att:
                self.model.decoder.fix_layers()
            if self.fix_dec and self.model.enable_ctc:
                self.model.fix_ctc_layer()

        if self.emb_decoder is not None: self.emb_decoder.train()
예제 #26
0
class Translate_Env(object):
    """
    wrap translate environment for multiple agents
    env needs parallel data to evaluate final bleu improvement

    stores the states as [current src embeddings, index], yields rewards at each step
    environment yields rewards based on scorer and finally by sentence-level BLEU
    :return: translation multiple sentences and return changed bleu
    """
    def __init__(self, reinforce_configs,
                 annunciator_configs,
                 src_vocab, trg_vocab,
                 data_iterator,
                 save_to,
                 device="cpu",
                 ):
        """
        initiate translation environments, needs a Scorer and translator
        :param reinforce_configs: attack configures dictionary
        :param annunciator_configs: discriminator or scorer configs(provide survive signals)
        :param save_to: path to save the model
        :param data_iterator: use to provide data for environment initiate
        the directory of the src sentences
        :param device: (string) devices to allocate variables("cpu", "cuda:*")
        default as cpu
        """
        # environment devices
        self.device = device
        self.data_iterator = data_iterator
        scorer_model_configs = annunciator_configs["scorer_model_configs"]
        # discriminator_model_configs = annunciator_configs["discriminator_model_configs"]
        annunciator_optim_configs = annunciator_configs["annunciator_optimizer_configs"]

        victim_config_path = reinforce_configs["victim_configs"]
        victim_model_path = reinforce_configs["victim_model"]
        with open(victim_config_path.strip()) as v_f:
            INFO("env open victim configs at %s" % victim_config_path)
            victim_configs = yaml.load(v_f, Loader=yaml.FullLoader)

        # to extract the embedding as representation
        # *vocab and *emb will provide psudo-reinforced embedding to train annunciator
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        # translation model for BLEU(take src_embs as inputs) and corresponding embedding layers
        self.src_emb, self.trg_emb, self.translate_model = build_translate_model(
            victim_configs, victim_model_path,
            vocab_src=self.src_vocab, vocab_trg=self.trg_vocab,
            device=self.device)

        self.max_roll_out_step = victim_configs["data_configs"]["max_len"][0]
        self.src_emb.eval()  # source language embeddings
        self.trg_emb.eval()  # target language embeddings
        self.translate_model.eval()

        # the epsilon range used for action space when perturbation
        _, _, self.limit_dist = load_or_extract_near_vocab(
            config_path=victim_config_path, model_path=victim_model_path,
            init_perturb_rate=reinforce_configs["init_perturb_rate"],
            save_to=os.path.join(save_to, "near_vocab"),
            save_to_full=os.path.join(save_to, "full_near_vocab"),
            top_reserve=12, emit_as_id=True)

        #########################################################
        # scorer(an Annunciator object) provides intrinsic step rewards
        self.annunciator = TransScorer(
            victim_configs, victim_model_path, self.trg_emb,
            **scorer_model_configs)
        self.annunciator.to(self.device)
        # # discriminator(an Annunciator object) provides intrisic step rewards and terminal signal
        # self.discriminator = TransDiscriminator(
        #     victim_configs, victim_model_path,
        #     **discriminator_model_configs)
        # self.discriminator.to(self.device)
        # Annunciator update configs
        self.acc_bound = annunciator_configs["acc_bound"]
        self.mse_bound = annunciator_configs["mse_bound"]
        self.min_update_steps = annunciator_configs["valid_freq"]
        self.max_update_steps = annunciator_configs["annunciator_update_steps"]
        # the optimizer and schedule used for Annunciator update.
        self.optim_A = Optimizer(
            name=annunciator_optim_configs["optimizer"],
            model=self.annunciator,
            lr=annunciator_optim_configs["learning_rate"],
            grad_clip=annunciator_optim_configs["grad_clip"],
            optim_args=annunciator_optim_configs["optimizer_params"])

        self.scheduler_A = None  # default as None
        if annunciator_optim_configs['schedule_method'] is not None:
            if annunciator_optim_configs['schedule_method'] == "loss":
                self.scheduler_A = ReduceOnPlateauScheduler(optimizer=self.optim_A,
                                                            **annunciator_optim_configs["scheduler_configs"])
            elif annunciator_optim_configs['schedule_method'] == "noam":
                self.scheduler_A = NoamScheduler(optimizer=self.optim_A,
                                                 **annunciator_optim_configs['scheduler_configs'])
            elif annunciator_optim_configs["schedule_method"] == "rsqrt":
                self.scheduler_A = RsqrtScheduler(optimizer=self.optim_A,
                                                  **annunciator_optim_configs["scheduler_configs"])
            else:
                WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    annunciator_optim_configs['schedule_method']))
        self.criterion_A = nn.CrossEntropyLoss()
        ############################################################
        self.adversarial = reinforce_configs["adversarial"]  # adversarial or reinforce as learning objects
        self.r_s_weight = reinforce_configs["r_s_weight"]
        self.r_i_weight = reinforce_configs["r_i_weight"]

    def _init_state(self, rephraser=None):
        """
        initiate batched sentences / origin_bleu / index (start from first label, no BOS/EOS)
        the initial state of the environment. (applied on the env's device)
        :return: env states (the src, index)
        """
        self.index = 1  # step index for perturbation
        self.origin_bleu = []  # saving origin BLEU

        batch = next(self.data_iterator)
        assert len(batch) == 3, "must be provided with line index (check for data_iterator)"

        # training, parallel trg is provided for evaluation (src grouped by similar length)
        _, seqs_x, self.seqs_y = batch
        self.sent_len = [len(x) for x in seqs_x]
        self.survival_signals = np.array([1] * len(seqs_x))  # the survival signals, 1 when true.

        # for reinforce inputs(embedding level).
        padded_src, padded_trg = self.prepare_data(
            seqs_x=seqs_x, seqs_y=self.seqs_y)
        self.x_emb = self.src_emb(padded_src).detach()  # float
        self.y_emb = self.trg_emb(padded_trg).detach()
        self.x_pad_indicator = padded_src.detach().eq(PAD)  # byte indicating PAD tokens
        self.y_pad_indicator = padded_trg.detach().eq(PAD)

        # randomly choose half of the sequence and perturbed by given agent
        # for self learning (rephraser can be on the cpu())
        if rephraser is not None:
            # self.x_emb, mask_to_UNK = rephraser.random_seq_perturb(
            #     self.x_emb, self.x_pad_indicator,
            #     half_mode=True, rand_act=False, enable_UNK=False)
            # self.x_emb = self.x_emb.to(self.device)
            # mask_to_UNK = mask_to_UNK.to(self.device)
            # # print("x_emb shape:", self.x_emb.shape, "mask_to_UNK shape:", mask_to_UNK.shape)
            # self.x_emb = self.x_emb*(1.-mask_to_UNK.float().unsqueeze(dim=2)) + \
            #              self.src_emb((UNK * mask_to_UNK).long())
            # self.x_emb = self.x_emb.detach()
            _, self.x_emb, _ = rephraser.random_seq_perturb(
                self.x_emb, self.x_pad_indicator, half_mode=True,
                rand_act=False)
            self.x_emb = self.x_emb.detach()

        # print(self.x_mask.shape, self.x_emb.shape)
        self.origin_result = self.translate()
        # calculate BLEU scores for the top candidate
        for index, sent_t in enumerate(self.seqs_y):
            bleu_t = bleu.sentence_bleu(references=[sent_t],
                                        hypothesis=self.origin_result[index],
                                        emulate_multibleu=True)
            self.origin_bleu.append(bleu_t)

        INFO("initialize env on: %s"%self.x_emb.device)
        return self.x_emb.cpu().numpy()

    def get_src_vocab(self):
        return self.src_vocab

    def reset(self, rephraser=None):
        """
        when the steps are exhausted.
        :param rephraser: rephraser is default None for no self-improving learning
        :return: reset environments' embedding
        """
        return self._init_state(rephraser)

    def reset_data_iter(self, data_iter):  # reset data iterator with provided iterator
        self.data_iterator = data_iter
        return

    def reset_annunciator(self):
        # a backup, would be deprecated
        self.annunciator.reset()

    def prepare_A_data(self, agent,
                       seqs_x, seqs_y,
                       batch_first=True,
                       half_mode=True,
                       rand_act=True):
        """
        use the current rephraser to generate data for Annunciator training
        perturbation will be applied to a random sequence step.
        (perturb all the former steps as the origin_emb, and perturb one more step as
        the perturbed_emb)
        such process will rephrase the entire batch.
        :param agent: prepare the data for scorer training (actor and critic)
        :param seqs_x: list of sources
        :param seqs_y: list of targets
        :param batch_first: first dimension of seqs be batch
        :param rand_act: sample the actions based on rephraser outputs
        :return: origin_x_emb, perturbed_x_emb, y_emb, x_mask, y_mask, flags
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            # pack seqs into tensor with pads
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x
        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))
        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD,
                             batch_first=batch_first)
        x_emb = self.src_emb(x).detach().to(self.device)
        x_pad_indicator = x.detach().eq(PAD).to(self.device)

        # # mere actor rollout
        # origin_x_emb, perturbed_x_emb, flags = rephraser.random_seq_perturb(
        #     x_emb, x_pad_indicator,
        #     half_mode=True, rand_act=rand_act)

        # actor rollout w/ critic's restriction
        with torch.no_grad():
            agent.actor.eval()
            agent.critic.eval()
            batch_size, max_seq_len = x_pad_indicator.shape
            perturbed_x_emb = x_emb.detach().clone()
            x_mask = 1 - x_pad_indicator.int()
            for t in range(1, max_seq_len-1):
                former_emb = perturbed_x_emb
                input_emb = former_emb[:, t-1:t+2, :]
                if rand_act:
                    actions, _ = agent.actor.sample_normal(
                        x_emb=former_emb, x_pad_indicator=x_pad_indicator,
                        label_emb=input_emb, reparamization=False)
                else:
                    mu, _ = agent.actor.forward(
                        x_emb=former_emb, x_pad_indicator=x_pad_indicator,
                        label_emb=input_emb)
                    actions = mu * agent.actor.action_range
                # actions shape [batch, emb_dim]
                critique = agent.critic(
                    x_emb=former_emb, x_pad_indicator=x_pad_indicator,
                    label_emb=input_emb, action=actions)
                # actions_masks shape [batch, 1]
                actions_mask = critique.gt(0).int() * x_mask[:, t].unsqueeze(dim=1)
                # mask unnecessary actions
                perturbed_x_emb[:,t,:] += actions * actions_mask

            flags = x_emb.new_ones(batch_size)
            if half_mode:
                flags = torch.bernoulli(0.5 * flags).to(x_emb.device)
            perturbed_x_emb = perturbed_x_emb * flags.unsqueeze(dim=1). unsqueeze(dim=2) \
                              + x_emb * (1-flags).unsqueeze(dim=1).unsqueeze(dim=2)

        origin_x_emb = x_emb
        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))
        y = _np_pad_batch_2D(seqs_y, pad=PAD,
                             batch_first=batch_first)
        y_emb = self.trg_emb(y).detach().to(self.device)
        y_pad_indicator = y.detach().eq(PAD).to(self.device)

        perturbed_x_emb.detach().to(self.device)
        return origin_x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags.long()

    def prepare_data(self, seqs_x, seqs_y=None, batch_first=True):
        """
        prepare the batched, padded data with BOS and EOS for translation.
        used in initialization.
        Returns: padded data matrices (batch_size, max_seq_len)
        """
        def _np_pad_batch_2D(samples, pad, batch_first=True):
            batch_size = len(samples)
            sizes = [len(s) for s in samples]
            max_size = max(sizes)
            x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64')
            for ii in range(batch_size):
                x_np[ii, :sizes[ii]] = samples[ii]
            if batch_first is False:
                x_np = np.transpose(x_np, [1, 0])
            x = torch.tensor(x_np).to(self.device)
            return x

        seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x))
        x = _np_pad_batch_2D(samples=seqs_x, pad=PAD,
                             batch_first=batch_first)
        if seqs_y is None:
            return x
        seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y))
        y = _np_pad_batch_2D(seqs_y, pad=PAD,
                             batch_first=batch_first)
        return x, y

    def ratio_validation(self, agent, overall_contrast=True):
        """
        validate the mse of the environments scorer for the given rephraser
        used for checkpoints and other checks
        :param rephraser generates the data for validation.
        :return: the mse of the current scorer in environment.
        """
        # set victim encoder and scorer to evaluation mode
        self.annunciator.eval()
        # for i in range(5):
        try:
            batch = next(self.data_iterator)
        except StopIteration:
            batch = next(self.data_iterator)
        seq_nums, seqs_x, seqs_y = batch

        origin_x_emb, perturbed_x_emb, y_emb, x_mask, y_mask, flags = self.prepare_A_data(
            agent, seqs_x, seqs_y, half_mode=False, rand_act=False)
        origin_density_score = self.annunciator.get_density_score(
            origin_x_emb, x_mask, seqs_y)
        perturbed_density_score = self.annunciator.get_density_score(
            perturbed_x_emb, x_mask, seqs_y)
        density_score = origin_density_score/(origin_density_score+perturbed_density_score)
        if overall_contrast:
            return density_score.mean().item()
        else:
            return perturbed_density_score.mean().item()

    def acc_validation(self, agent):
        """
        validate the acc of the environments discriminator by a given rephraser
        used for checkpoints
        :param agent generates data for validation
        :return the accuracy of the discriminator to evaluation mode
        """
        self.annunciator.eval()
        acc = 0
        sample_count = 0
        for i in range(5):
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            seq_nums, seqs_x, seqs_y = batch
            origin_x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags = \
                self.prepare_A_data(agent, seqs_x, seqs_y, half_mode=True)
            with torch.no_grad():
                preds = self.annunciator(perturbed_x_emb, x_pad_indicator,
                                         y_emb, y_pad_indicator).argmax(dim=-1)
                acc += torch.eq(preds, flags).sum()
                sample_count += preds.shape[0]
        acc = acc.float() / sample_count
        return acc.item()

    # def compute_P_forward(self,
    #                       origin_x_emb, perturbed_x_emb, x_mask,
    #                       evaluate=False):
    #     """
    #     process the victim encoder embedding and get CE loss
    #     :param origin_x_emb: float tensor, input embeddings of input tokens
    #     :param perturbed_x_emb: float tensor, perturbed inputs embeddings
    #     :param x_mask: byte tensor, mask of the input tokens
    #     :return: loss value
    #     """
    #     if not evaluate:
    #         # set components to training mode(dropout layers)
    #         self.scorer.train()
    #         with torch.enable_grad():
    #             loss = self.scorer(origin_x_emb, perturbed_x_emb, x_mask).mean()
    #         torch.autograd.backward(loss)
    #         return loss.item()
    #     else:
    #         # set components to evaluation mode(dropout layers)
    #         self.scorer.eval()
    #         with torch.enable_grad():
    #             loss = self.scorer(origin_x_emb, perturbed_x_emb, x_mask).mean()
    #     return loss.item()

    def compute_A_forward(self, x_emb, y_emb, x_pad_indicator, y_pad_indicator, gold_flags,
                          evaluate=False):
        """get loss according to criterion
        :param gold_flags=1 if perturbed, otherwise 0
        :param evaluate: False during training mode
        :return loss value
        """
        if not evaluate:
            # set components to training mode(dropout layers)
            self.annunciator.train()
            self.criterion_A.train()
            with torch.enable_grad():
                class_probs = self.annunciator(
                    x_emb, x_pad_indicator,
                    y_emb, y_pad_indicator)
                loss = self.criterion_A(class_probs, gold_flags)
            torch.autograd.backward(loss)
            return loss.item()
        else:
            # set components to evaluation mode(dropout layers)
            self.annunciator.eval()
            self.criterion_A.eval()
            with torch.no_grad():
                class_probs = self.annunciator(
                    x_emb, x_pad_indicator,
                    y_emb, y_pad_indicator)
                loss = self.criterion_A(class_probs, gold_flags)
        return loss.item()

    def update_annunciator(self,
                           agent,
                           base_steps=0,
                           min_update_steps=1,
                           max_update_steps=300,
                           accuracy_bound=0.8,
                           overall_update_weight=0.5,
                           summary_writer=None):
        """
        update discriminator using given rephraser
        :param agent: AC agent to generate training data for discriminator
        :param base_steps: used for saving
        :param min_update_steps: (integer) minimum update steps,
                    also the discriminator evaluate steps
        :param max_update_steps: (integer) maximum update steps
        :param accuracy_bound: (float) update until accuracy reaches the bound
                    (or max_update_steps)
        :param summary_writer: used to log discriminator learning information
        :return: steps and test accuracy as trust region
        """
        INFO("update annunciator")
        self.optim_A.zero_grad()
        agent.to(self.device)
        step = 0
        while True:
            try:
                batch = next(self.data_iterator)
            except StopIteration:
                batch = next(self.data_iterator)
            # update the discriminator
            step += 1
            if self.scheduler_A is not None:
                # override learning rate in self.optim_D
                self.scheduler_A.step(global_step=step)
            _, seqs_x, seqs_y = batch  # returned tensor type of the data
            try:
                x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags = \
                    self.prepare_A_data(agent, seqs_x, seqs_y, half_mode=False, rand_act=True)
                loss = self.annunciator(x_emb, perturbed_x_emb, x_pad_indicator, seqs_y, overall_update_weight)
                # for name, p in self.annunciator.named_parameters():
                #     if "weight" in name:
                #         loss += torch.norm(p, 2)  # with l2-norm against overfitting
                torch.autograd.backward(loss)
                self.optim_A.step()
                print("annunciator loss:", loss)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print("WARNING: out of memory, skipping batch")
                    self.optim_A.zero_grad()
                else:
                    raise e

            # valid for accuracy / check for break (if any)
            if step % min_update_steps == 0:
                perturbed_density = self.ratio_validation(agent, overall_contrast=False)
                overall_density = self.ratio_validation(agent)
                if summary_writer is not None:
                    summary_writer.add_scalar("a_contrast_ratio", scalar_value=overall_density, global_step=base_steps+step)
                    summary_writer.add_scalar("a_ratio_src", scalar_value=perturbed_density, global_step=base_steps+step)
                print("overall density: %2f" % overall_density)
                if accuracy_bound and overall_density > accuracy_bound:
                    INFO("annunciator reached training bound, updated")
                    return base_steps+step, overall_density

            if step > max_update_steps:
                overall_density = self.ratio_validation(agent)
                perturbed_density = self.ratio_validation(agent, overall_contrast=False)
                print("overall density: %2f" % overall_density)
                INFO("Reach maximum annunciator update. Finished.")
                return base_steps+step, overall_density   # stop updates


    def translate(self, x_emb=None, x_mask=None):
        """
        translate by given embedding
        :param src_emb: if None, translate embeddings stored in the environments
        :param src_mask: input mask paired with embedding
        :return: list of translation results
        """
        if x_emb is None:  # original translation with original embedding
            x_emb = self.x_emb
            x_mask = self.x_pad_indicator

        with torch.no_grad():
            perturbed_results = beam_search(
                self.translate_model,
                beam_size=5, max_steps=150,
                src_embs=x_emb, src_mask=x_mask,
                alpha=-1.0)
        perturbed_results = perturbed_results.cpu().numpy().tolist()
        # only use the top result from the result
        result = []
        for sent in perturbed_results:
            sent = [wid for wid in sent[0] if wid != PAD]
            result.append(sent)
        return result

    def get_state(self):
        """
        retrieve states for the learning
        :return: the states of the environment
        """
        states = self.x_emb  # current sen embeddings, [batch_size, len, emb_dim]
        masks = 1. - self.x_pad_indicator.float()  # indicates valid tokens [batch, max_len]
        rephrase_positions = torch.tensor(np.array([self.index] * masks.shape[0])).unsqueeze(dim=-1).long()  # current state positions [batch, 1]
        survival_signals = torch.tensor(self.survival_signals).unsqueeze(dim=-1).float()  # [batch_size, 1]
        return states, masks, rephrase_positions, survival_signals

    def step(self, action):
        """
        step update for the environment: finally update self.index
        this is defined as inference of the environments
        states are returned in np.array
        :param action: tensor.variable as action input(in shape [batch, dim])
            on current index for step updates
        :return: updated states/ rewards/ terminal signal from the environments
                 reward (list of float), terminal_signal (list of boolean)
        """
        with torch.no_grad():
            self.annunciator.eval()
            batch_size, _ = action.shape
            batched_rewards = [0.] * batch_size
            if self.device != "cpu" and not action.is_cuda:
                WARN("mismatching action for gpu_env, move actions to %s"%self.device)
                action = action.to(self.device)

            # extract the step mask for actions and rewards
            inputs_mask = 1. - self.x_pad_indicator.float()
            inputs_mask = inputs_mask[:, self.index]  # slice at current step(index), mask of [batch]
            inputs_mask *= inputs_mask.new_tensor(self.survival_signals)  # mask those terminated

            # update current src embedding with action
            origin_emb = self.x_emb.clone().detach()
            # update embedding; cancel modification on PAD
            self.x_emb[:, self.index, :] += (action * inputs_mask.unsqueeze(dim=1))  # actions on PAD is masked

            # update survival_signals, which later determines whether rewards are valid for return
            # 1. mask survival by step and sent-len
            step_reward_mask = [int(self.index <= i) for i in self.sent_len]
            # 2. get batched sentence matching for survival signals on the current src state
            # # as the reward process (probs on ``survival'' as rewards)
            d_probs = self.annunciator.get_density_score(self.x_emb, self.x_pad_indicator, self.seqs_y)
            # print("dprobs:",d_probs)
            signals = d_probs.detach().lt(0.5).long().cpu().numpy().tolist()    # 1 as terminate
            # print("signals:", signals)

            if 1 in step_reward_mask:  # rollout reaches the sents length
                # 0 as survive, 1 as terminate
                probs = d_probs.detach().cpu().numpy()
                discriminate_index = d_probs.detach().lt(0.5).long()
                survival_mask = (1 - discriminate_index).cpu().numpy()
                survival_value = probs * survival_mask
                terminate_punishment = probs * discriminate_index.cpu().numpy()

                # looping for survival signals and step rewards
                origin_survival_signals = self.survival_signals.copy()
                for i in range(batch_size):
                    # update survivals signals
                    self.survival_signals[i] = self.survival_signals[i] * (1-signals[i]) * step_reward_mask[i]
                    if self.survival_signals[i]:
                        batched_rewards[i] += survival_value[i] * self.r_s_weight
                    elif origin_survival_signals[i]:
                        # punish once the survival signal flips
                        batched_rewards[i] -= terminate_punishment[i] * self.r_i_weight
            else:  # all dead, no need to calculate other rewards, it's ok to waste some samples
                return self.x_emb.cpu().numpy(), np.array(batched_rewards), self.survival_signals

            # additional episodic reward for surviving sequences (w/ finished sentence at current step)
            bleu_mask = [int(self.index == i) for i in self.sent_len]
            bleu_mask = [bleu_mask[i]*self.survival_signals[i] for i in range(batch_size)]
            if 1 in bleu_mask:
                # check for the finished line and mask out the others
                perturbed_results = self.translate(self.x_emb, self.x_pad_indicator)
                episodic_rewards = []
                for i, sent in enumerate(self.seqs_y):
                    if bleu_mask[i] == 1:
                        degraded_value = (self.origin_bleu[i]-bleu.sentence_bleu(
                            references=[sent],
                            hypothesis=perturbed_results[i],
                            emulate_multibleu=True
                        ))
                        if self.adversarial:  # relative degradation
                            if self.origin_bleu[i] == 0:
                                relative_degraded_bleu = 0
                            else:
                                relative_degraded_bleu = degraded_value/self.origin_bleu[i]
                            episodic_rewards.append(relative_degraded_bleu)
                        else:  # absolute improvement
                            print("bleu variation:", self.origin_bleu[i],-degraded_value)
                            episodic_rewards.append(-degraded_value)
                    else:
                        episodic_rewards.append(0.0)
                # append additional episodic rewards
                batched_rewards = [batched_rewards[i]+episodic_rewards[i]*self.r_i_weight
                                   for i in range(batch_size)]

        # update sequences' pointer for rephrasing
        self.index += 1

        return self.x_emb.cpu().numpy(), np.array(batched_rewards), self.survival_signals
예제 #27
0
def train(FLAGS):
    """
    FLAGS:
        saveto: str
        reload: store_true
        config_path: str
        pretrain_path: str, default=""
        model_name: str
        log_path: str
    """

    # write log of training to file.
    write_log_to_file(
        os.path.join(FLAGS.log_path,
                     "%s.log" % time.strftime("%Y%m%d-%H%M%S")))

    GlobalNames.USE_GPU = FLAGS.use_gpu

    if GlobalNames.USE_GPU:
        CURRENT_DEVICE = "cpu"
    else:
        CURRENT_DEVICE = "cuda:0"

    config_path = os.path.abspath(FLAGS.config_path)
    with open(config_path.strip()) as f:
        configs = yaml.load(f)

    INFO(pretty_configs(configs))

    # Add default configs
    configs = default_configs(configs)
    data_configs = configs['data_configs']
    model_configs = configs['model_configs']
    optimizer_configs = configs['optimizer_configs']
    training_configs = configs['training_configs']

    GlobalNames.SEED = training_configs['seed']

    set_seed(GlobalNames.SEED)

    best_model_prefix = os.path.join(
        FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)

    timer = Timer()

    # ================================================================================== #
    # Load Data

    INFO('Loading data...')
    timer.tic()

    # Generate target dictionary
    vocab_tgt = Vocabulary(**data_configs["vocabularies"][0])

    train_batch_size = training_configs["batch_size"] * max(
        1, training_configs["update_cycle"])
    train_buffer_size = training_configs["buffer_size"] * max(
        1, training_configs["update_cycle"])

    train_bitext_dataset = ZipDataset(TextLineDataset(
        data_path=data_configs['train_data'][0],
        vocabulary=vocab_tgt,
        max_len=data_configs['max_len'][0],
    ),
                                      shuffle=training_configs['shuffle'])

    valid_bitext_dataset = ZipDataset(
        TextLineDataset(
            data_path=data_configs['valid_data'][0],
            vocabulary=vocab_tgt,
        ))

    training_iterator = DataIterator(
        dataset=train_bitext_dataset,
        batch_size=train_batch_size,
        use_bucket=training_configs['use_bucket'],
        buffer_size=train_buffer_size,
        batching_func=training_configs['batching_key'])

    valid_iterator = DataIterator(
        dataset=valid_bitext_dataset,
        batch_size=training_configs['valid_batch_size'],
        use_bucket=True,
        buffer_size=100000,
        numbering=True)

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    lrate = optimizer_configs['learning_rate']
    is_early_stop = False

    # ================================ Begin ======================================== #
    # Build Model & Optimizer
    # We would do steps below on after another
    #     1. build models & criterion
    #     2. move models & criterion to gpu if needed
    #     3. load pre-trained model if needed
    #     4. build optimizer
    #     5. build learning rate scheduler if needed
    #     6. load checkpoints if needed

    # 0. Initial
    model_collections = Collections()
    checkpoint_saver = Saver(
        save_prefix="{0}.ckpt".format(
            os.path.join(FLAGS.saveto, FLAGS.model_name)),
        num_max_keeping=training_configs['num_kept_checkpoints'])
    best_model_saver = Saver(
        save_prefix=best_model_prefix,
        num_max_keeping=training_configs['num_kept_best_model'])

    # 1. Build Model & Criterion
    INFO('Building model...')
    timer.tic()
    lm_model = build_model(n_tgt_vocab=vocab_tgt.max_n_words, **model_configs)
    INFO(lm_model)

    params_total = sum([p.numel() for n, p in lm_model.named_parameters()])
    params_with_embedding = sum([
        p.numel() for n, p in lm_model.named_parameters()
        if n.find('embedding') == -1
    ])
    INFO('Total parameters: {}'.format(params_total))
    INFO('Total parameters (excluding word embeddings): {}'.format(
        params_with_embedding))

    critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'])

    INFO(critic)
    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # 2. Move to GPU
    if GlobalNames.USE_GPU:
        lm_model = lm_model.cuda()
        critic = critic.cuda()

    # 3. Load pretrained model if needed
    lm_model.init_parameters(FLAGS.pretrain_path, device=CURRENT_DEVICE)

    # 4. Build optimizer
    INFO('Building Optimizer...')
    optim = Optimizer(name=optimizer_configs['optimizer'],
                      model=lm_model,
                      lr=lrate,
                      grad_clip=optimizer_configs['grad_clip'],
                      optim_args=optimizer_configs['optimizer_params'])

    # 5. Build scheduler for optimizer if needed
    if optimizer_configs['schedule_method'] is not None:

        if optimizer_configs['schedule_method'] == "loss":

            scheduler = ReduceOnPlateauScheduler(
                optimizer=optim, **optimizer_configs["scheduler_configs"])

        elif optimizer_configs['schedule_method'] == "noam":
            scheduler = NoamScheduler(optimizer=optim,
                                      **optimizer_configs['scheduler_configs'])
        else:
            WARN(
                "Unknown scheduler name {0}. Do not use lr_scheduling.".format(
                    optimizer_configs['schedule_method']))
            scheduler = None
    else:
        scheduler = None

    # 6. build moving average

    if training_configs['moving_average_method'] is not None:
        ma = MovingAverage(
            moving_average_method=training_configs['moving_average_method'],
            named_params=lm_model.named_parameters(),
            alpha=training_configs['moving_average_alpha'])
    else:
        ma = None

    INFO('Done. Elapsed time {0}'.format(timer.toc()))

    # Reload from latest checkpoint
    if FLAGS.reload:
        checkpoint_saver.load_latest(model=lm_model,
                                     optim=optim,
                                     lr_scheduler=scheduler,
                                     collections=model_collections,
                                     ma=ma)

    # ================================================================================== #
    # Prepare training

    eidx = model_collections.get_collection("eidx", [0])[-1]
    uidx = model_collections.get_collection("uidx", [0])[-1]
    bad_count = model_collections.get_collection("bad_count", [0])[-1]
    oom_count = model_collections.get_collection("oom_count", [0])[-1]

    summary_writer = SummaryWriter(log_dir=FLAGS.log_path)

    cum_samples = 0
    cum_words = 0
    valid_loss = best_valid_loss = float('inf')  # Max Float
    saving_files = []

    # Timer for computing speed
    timer_for_speed = Timer()
    timer_for_speed.tic()

    INFO('Begin training...')

    while True:
        summary_writer.add_scalar("Epoch", (eidx + 1), uidx)

        # Build iterator and progress bar
        training_iter = training_iterator.build_generator()
        training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format(
            eidx, uidx),
                                     total=len(training_iterator),
                                     unit="sents")
        for batch in training_iter:

            uidx += 1

            if optimizer_configs[
                    "schedule_method"] is not None and optimizer_configs[
                        "schedule_method"] != "loss":
                scheduler.step(global_step=uidx)

            seqs_y = batch

            n_samples_t = len(seqs_y)
            n_words_t = sum(len(s) for s in seqs_y)

            cum_samples += n_samples_t
            cum_words += n_words_t

            train_loss = 0.
            optim.zero_grad()
            try:
                # Prepare data
                for (seqs_y_t, ) in split_shard(
                        seqs_y, split_size=training_configs['update_cycle']):
                    y = prepare_data(seqs_y_t, cuda=GlobalNames.USE_GPU)

                    loss = compute_forward(
                        model=lm_model,
                        critic=critic,
                        # seqs_x=x,
                        seqs_y=y,
                        eval=False,
                        normalization=n_samples_t,
                        norm_by_words=training_configs["norm_by_words"])
                    train_loss += loss / y.size(
                        1) if not training_configs["norm_by_words"] else loss
                optim.step()

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom_count += 1
                    optim.zero_grad()
                else:
                    raise e

            if ma is not None and eidx >= training_configs[
                    'moving_average_start_epoch']:
                ma.step()

            training_progress_bar.update(n_samples_t)
            training_progress_bar.set_description(
                ' - (Epc {}, Upd {}) '.format(eidx, uidx))
            training_progress_bar.set_postfix_str(
                'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f})'.format(
                    train_loss, valid_loss, best_valid_loss))
            summary_writer.add_scalar("train_loss",
                                      scalar_value=train_loss,
                                      global_step=uidx)

            # ================================================================================== #
            # Display some information
            if should_trigger_by_steps(
                    uidx, eidx, every_n_step=training_configs['disp_freq']):
                # words per second and sents per second
                words_per_sec = cum_words / (timer.toc(return_seconds=True))
                sents_per_sec = cum_samples / (timer.toc(return_seconds=True))
                lrate = list(optim.get_lrate())[0]

                summary_writer.add_scalar("Speed(words/sec)",
                                          scalar_value=words_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("Speed(sents/sen)",
                                          scalar_value=sents_per_sec,
                                          global_step=uidx)
                summary_writer.add_scalar("lrate",
                                          scalar_value=lrate,
                                          global_step=uidx)
                summary_writer.add_scalar("oom_count",
                                          scalar_value=oom_count,
                                          global_step=uidx)

                # Reset timer
                timer.tic()
                cum_words = 0
                cum_samples = 0

            # ================================================================================== #
            # Saving checkpoints
            if should_trigger_by_steps(
                    uidx,
                    eidx,
                    every_n_step=training_configs['save_freq'],
                    debug=FLAGS.debug):
                model_collections.add_to_collection("uidx", uidx)
                model_collections.add_to_collection("eidx", eidx)
                model_collections.add_to_collection("bad_count", bad_count)

                if not is_early_stop:
                    checkpoint_saver.save(global_step=uidx,
                                          model=lm_model,
                                          optim=optim,
                                          lr_scheduler=scheduler,
                                          collections=model_collections,
                                          ma=ma)

            # ================================================================================== #
            # Loss Validation & Learning rate annealing
            if should_trigger_by_steps(
                    global_step=uidx,
                    n_epoch=eidx,
                    every_n_step=training_configs['loss_valid_freq'],
                    debug=FLAGS.debug):

                if ma is not None:
                    origin_state_dict = deepcopy(lm_model.state_dict())
                    lm_model.load_state_dict(ma.export_ma_params(),
                                             strict=False)

                valid_loss = loss_validation(
                    model=lm_model,
                    critic=critic,
                    valid_iterator=valid_iterator,
                    norm_by_words=training_configs["norm_by_words"])

                model_collections.add_to_collection("history_losses",
                                                    valid_loss)

                min_history_loss = np.array(
                    model_collections.get_collection("history_losses")).min()

                summary_writer.add_scalar("loss", valid_loss, global_step=uidx)
                summary_writer.add_scalar("best_loss",
                                          min_history_loss,
                                          global_step=uidx)

                if ma is not None:
                    lm_model.load_state_dict(origin_state_dict)
                    del origin_state_dict

                if optimizer_configs["schedule_method"] == "loss":
                    scheduler.step(metric=best_valid_loss)

                # If model get new best valid loss
                if valid_loss < best_valid_loss:
                    bad_count = 0

                    if is_early_stop is False:
                        # 1. save the best model
                        torch.save(lm_model.state_dict(),
                                   best_model_prefix + ".final")

                        # 2. record all several best models
                        best_model_saver.save(global_step=uidx, model=lm_model)
                else:
                    bad_count += 1

                    # At least one epoch should be traversed
                    if bad_count >= training_configs[
                            'early_stop_patience'] and eidx > 0:
                        is_early_stop = True
                        WARN("Early Stop!")

                best_valid_loss = min_history_loss

                summary_writer.add_scalar("bad_count", bad_count, uidx)

                INFO("{0} Loss: {1:.2f} lrate: {2:6f} patience: {3}".format(
                    uidx, valid_loss, lrate, bad_count))

        training_progress_bar.close()

        eidx += 1
        if eidx > training_configs["max_epochs"]:
            break
예제 #28
0
class SACAgent(object):
    def __init__(self,
                 device="cpu",
                 d_word_vec=512,
                 d_model=256,
                 limit_dist=0.1,
                 dropout=0.0,
                 reparam_noise=1e-6,
                 **kwargs):
        self.device = device
        self.actor = Rephraser(d_word_vec=d_word_vec,
                               d_model=d_model,
                               limit_dist=limit_dist,
                               dropout=dropout,
                               reparam_noise=reparam_noise).to(device)
        self.critic = CriticNet(d_word_vec=d_word_vec,
                                d_model=d_model,
                                limit_dist=limit_dist,
                                dropout=dropout,
                                reparam_noise=reparam_noise).to(device)
        self.saver = Saver(save_prefix="{0}.ckpt".format(
            os.path.join(kwargs["save_to"], "ACmodel")),
                           num_max_keeping=kwargs["num_kept_checkpoints"])
        self.soft_update_lock = mp.Lock()
        # the entropy regularization weight for SAC learning
        self.learnable_temperature = kwargs["learnable_temperature"]
        self.target_entropy = -d_word_vec  # act_dim (d_word_vec) as the expected entropy base
        self.log_alpha = torch.tensor(np.log(kwargs["init_temperature"])).to(
            self.device)
        self.log_alpha.requires_grad = True
        # initialize the training mode for the Agent
        self.train()
        self._init_local_optims(kwargs["rephraser_optimizer_configs"])
        # self.load_model()  # always reload model if there is any in the path

    def to(self, device):
        self.actor.to(device)
        self.critic.to(device)
        self.log_alpha.to(device)
        return self

    def share_memory(self):
        # global model needs to share memory with other threads
        self.actor.share_memory()
        self.critic.share_memory()

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def load_model(self, load_final_path: str = None):
        """
        load from path by self.saver
        :param load_final_path: final model path dir, final model doesn't have optim_params
        :return: training step count int
        """
        step = 0
        model_collections = Collections()
        if load_final_path:
            # self.saver.load_latest(
            #     actor_model=self.actor, critic_model=self.critic
            # )  # load from the latest ckpt model
            state_dict = torch.load(os.path.join(load_final_path))
            self.actor.load_state_dict(state_dict["actor_model"])
            self.critic.load_state_dict(state_dict["critic_model"])
        else:
            self.saver.load_latest(collections=model_collections,
                                   actor_model=self.actor,
                                   critic_model=self.critic,
                                   actor_optim=self.actor_optimizer,
                                   critic_optim=self.critic_optimizer,
                                   actor_scheduler=self.actor_scheduler,
                                   critic_scheduler=self.critic_scheduler)
            step = model_collections.get_collection("step", [0])[-1]
        return step

    def save_model(
            self,
            step=None,
            save_to_final=None):  # save model parameters, optims, lr_steps
        model_collections = Collections()
        if step is not None:
            model_collections.add_to_collection("step", step)
            self.saver.save(global_step=step,
                            collections=model_collections,
                            actor_model=self.actor,
                            critic_model=self.critic,
                            actor_optim=self.actor_optimizer,
                            critic_optim=self.critic_optimizer,
                            actor_scheduler=self.actor_scheduler,
                            critic_scheduler=self.critic_scheduler)
        else:  # only save the model parameters
            assert save_to_final is not None, "final model saving dir must be provided"
            collection = dict()
            collection["actor_model"] = self.actor.state_dict()
            collection["critic_model"] = self.critic.state_dict()
            torch.save(collection, os.path.join(save_to_final,
                                                "ACmodel.final"))
        return

    def _init_local_optims(self, rephraser_optimizer_configs):
        """ actor, critic, alpha optimizers and lr scheduler if necessary
        rephraser_optimizer_configs:
            optimizer: "adafactor"
            learning_rate: 0.01
            grad_clip: -1.0
            optimizer_params: ~
            schedule_method: rsqrt
            scheduler_configs:
                d_model: *dim
                warmup_steps: 100
        """
        # initiate local optimizer
        if rephraser_optimizer_configs is None:
            self.actor_optimizer = None
            self.critic_optimizer = None
            self.log_alpha_optimizer = None
            # self.actor_icm_optimizer = None
            self.actor_scheduler = None
            self.critic_scheduler = None
        else:
            self.actor_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.actor,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            self.critic_optimizer = Optimizer(
                name=rephraser_optimizer_configs["optimizer"],
                model=self.critic,
                lr=rephraser_optimizer_configs["learning_rate"],
                grad_clip=rephraser_optimizer_configs["grad_clip"],
                optim_args=rephraser_optimizer_configs["optimizer_params"])
            # hardcoded entropy weight updates and icm updates
            self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                        lr=1e-4,
                                                        betas=(0.9, 0.999))
            # self.actor_icm_optimizer = torch.optim.Adam(self.actor.icm.parameters(), lr=1e-3, )

            # Build scheduler for optimizer if needed
            if rephraser_optimizer_configs['schedule_method'] is not None:
                if rephraser_optimizer_configs['schedule_method'] == "loss":
                    self.actor_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = ReduceOnPlateauScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs['schedule_method'] == "noam":
                    self.actor_scheduler = NoamScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = NoamScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                elif rephraser_optimizer_configs["schedule_method"] == "rsqrt":
                    self.actor_scheduler = RsqrtScheduler(
                        optimizer=self.actor_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                    self.critic_scheduler = RsqrtScheduler(
                        optimizer=self.critic_optimizer,
                        **rephraser_optimizer_configs["scheduler_configs"])
                else:
                    WARN(
                        "Unknown scheduler name {0}. Do not use lr_scheduling."
                        .format(
                            rephraser_optimizer_configs['schedule_method']))
                    self.actor_scheduler = None
                    self.critic_scheduler = None
            else:
                self.actor_scheduler = None
                self.critic_scheduler = None

    def sync_from(self, sac_agent):
        with sac_agent.soft_update_lock, self.soft_update_lock:
            self.actor.sync_from(sac_agent.actor)
            self.critic.sync_from(sac_agent.critic)

    def train(self, training=True):
        # default training is true
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        return self

    def update_critic(self,
                      states,
                      masks,
                      rephrase_positions,
                      actions,
                      rewards,
                      survive_and_no_maxs,
                      target_critic,
                      update_step,
                      discount_factor,
                      summary_writer,
                      update_trust_region=0.8):
        """
        update critic by using a target_critic net (usually a global critic model) and a buffer
        SARSA for TD learning
        :param states:
        :param masks:
        :param rephrase_positions:
        :param actions: actions
        :param rewards: the rewards
        :param survive_and_no_maxs: able to rollout next step for TD learning

        :param target_critic: provides target value estimation(global model is usually on cpu)
        :param discount_factor: for discounted rewards update
        :param update_step: learning steps
        :param update_trust_region: discount for loss updates
        """

        label_emb = slice_by_indices(
            states, rephrase_positions,
            device=self.device)  # next_rephrase_positions to label emb
        next_action, log_probs = self.actor.sample_normal(states,
                                                          1. - masks,
                                                          label_emb,
                                                          reparamization=True)
        log_probs = log_probs.sum(dim=-1, keepdims=True)
        next_states = transition(states, masks, actions, rephrase_positions)
        next_rephrase_positions = rephrase_positions + 1
        next_label_emb = slice_by_indices(next_states,
                                          next_rephrase_positions,
                                          device=self.device)

        # # note that with intrinsic curiosity module, the rewards will add curiosity bonus
        # self.actor.icm.eval()
        # rephrase_feature = self.actor.preprocess(states, 1.-masks, label_emb)
        # next_rephrase_feature = self.actor.preprocess(next_states, 1.-masks, next_label_emb)
        # bonus = self.actor.icm.get_surprise_bonus(rephrase_feature, next_rephrase_feature, actions).detach()
        # bonus = 0.01 * bonus * survive_and_no_maxs
        # print("bonus:", bonus.sum())
        # rewards += bonus
        # print("rewards:", rewards.squeeze())

        # note that log_probs has the same dimension with the action. thus the log_prob of a whole action is the sum along dimensions.
        target_critic.eval()
        target_V = target_critic(
            next_states, 1. - masks, next_label_emb,
            next_action) - log_probs * self.alpha.detach()
        target_Q = rewards + (
            survive_and_no_maxs
        ) * discount_factor * target_V  #  we have next states for TD learning rollout
        target_Q = target_Q.detach()

        # get current Q estimates
        current_Q = self.critic(states, 1. - masks, label_emb, actions)
        critic_loss = F.mse_loss(current_Q, target_Q)
        critic_loss *= update_trust_region
        print("critic_loss", critic_loss.sum())

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        if self.critic_scheduler is not None:
            self.critic_scheduler.step(global_step=update_step)
        critic_loss.backward()
        self.critic_optimizer.step()

        # logging: entropy/target_entropy ratio, critic_loss,
        summary_writer.add_scalar("critic_loss",
                                  scalar_value=critic_loss,
                                  global_step=update_step)

    def update_actor_and_alpha(self,
                               states,
                               masks,
                               rephrase_positions,
                               target_critic,
                               update_step,
                               summary_writer,
                               update_trust_region=0.5):
        """
        :param states: tensor states from the buffer samples
        :param masks: indicats the valid token positions
        :param rephrase_positions: induce the next states by the given states
        :param update_trust_region: current annnunciator trust_acc (valid).
        served as a trust region for RL updates;
        also the weight of rewind or reinforce
        trust_acc * rewind_loss + (1-trust_acc) * policy_loss
        """
        self.actor.train()
        label_emb = slice_by_indices(states,
                                     rephrase_positions,
                                     device=self.device)
        actions, log_probs = self.actor.sample_normal(states,
                                                      1. - masks,
                                                      label_emb,
                                                      reparamization=True)
        log_probs = log_probs.sum(dim=-1, keepdims=True)
        actor_Q = self.critic(states, 1. - masks, label_emb, actions)
        policy_loss = (self.alpha.detach() * log_probs - actor_Q).mean()

        summary_writer.add_scalar('policy_loss', policy_loss, update_step)
        summary_writer.add_scalar('entropy_ratio',
                                  -log_probs.mean() / self.target_entropy,
                                  update_step)

        # the policy rewind loss, the rewind is determined by target value estimates (estimated survival + improvements)
        # negative means rewind needed.
        target_Q = target_critic(states, 1. - masks, label_emb,
                                 actions).detach()
        rewind_mask = target_Q.lt(0.).detach().float()  # [batch, 1]
        next_states = transition(states, masks, actions,
                                 rephrase_positions).detach()
        next_label_emb = slice_by_indices(next_states,
                                          rephrase_positions,
                                          device=self.device).detach()
        rewind_action, _ = self.actor.forward(next_states, 1. - masks,
                                              next_label_emb)
        target_actions = -actions.detach()
        rewind_loss = F.mse_loss(
            rewind_action * rewind_mask * self.actor.action_range,
            target_actions * rewind_mask)
        summary_writer.add_scalar('rewind_loss', rewind_loss, update_step)

        # the higher trust region means less indicative the perturbations are, policy should focus more on the rewind.
        actor_loss = (update_trust_region) * rewind_loss + (
            1. - update_trust_region) * policy_loss
        ## update the intrinsic reward module: action reconstruction and feature prediction mse
        # self.actor.icm.train()
        # next_states = transition(states, masks, actions, rephrase_positions)
        # next_label_emb = slice_by_indices(next_states, rephrase_positions, device=self.device)
        # rephrase_feature = self.actor.preprocess(states, 1.0-masks, label_emb)
        # next_rephrase_feature = self.actor.preprocess(next_states, 1.0-masks, next_label_emb)
        # if update_step<3000:
        #     # icm updates does not propagate to the policy on the early stage
        #     icm_loss = self.actor.icm(rephrase_feature.detach(), next_rephrase_feature.detach(), actions)
        # else:
        #     icm_loss = self.actor.icm(rephrase_feature, next_rephrase_feature, actions)
        # summary_writer.add_scalar("intrinsic_curiosity_loss", icm_loss, update_step)
        # # the 0.1 is the setting by Intrinsic curiosity learning
        # actor_loss = actor_loss + icm_loss

        # optimize the actor
        self.actor_optimizer.zero_grad()
        # self.actor_icm_optimizer.zero_grad()
        if self.actor_scheduler is not None:
            self.actor_scheduler.step(global_step=update_step)
        actor_loss.backward()
        self.actor_optimizer.step()
        # self.actor_icm_optimizer.step()

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha *
                          (-log_probs - self.target_entropy).detach()).mean()
            summary_writer.add_scalar('alpha_loss', alpha_loss, update_step)
            summary_writer.add_scalar('alpha', self.alpha, update_step)
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    def update_local_net(self,
                         local_agent_configs,
                         replay_buffer,
                         target_critic,
                         update_step,
                         discount_factor,
                         summary_writer,
                         update_trust_region=1.0):
        """
        :param local_agent_configs: provides agent update freq
        :param replay_buffer: provides the SARSA listed below
            [states, masks, actions, rephrase_positions, rewards, terminal_signals]
            states: the embedding as states. [batch, len, emb_dim] float
            masks: the indicator of valid token for embedding. [batch, len] float
                actions: the action embedding on the position. [batch, emb_dim] float
            rephrase_positions: the position to rephrase. [batch, 1]  long
            rewards: the rewards for the transition. [batch, 1] float
            terminal_signals: the terminal signals for the transition. [batch, 1] float
        :param target_critic: provides the global critic
        :param update_step: for lr scheduler and logging
        :param discount_factor: rollout-rewards discount
        :param summary_writer: logging
        :param update_trust_region: discount for loss updates
        """
        learn_batch_size = local_agent_configs["rephraser_learning_batch"]

        states, masks, \
        actions, rephrase_positions, \
        rewards, _, survive_and_no_maxs = replay_buffer.sample(learn_batch_size, device=self.device)
        INFO("update local agent critics on device: %s" % self.device)
        self.update_critic(states, masks, rephrase_positions, actions, rewards,
                           survive_and_no_maxs, target_critic, update_step,
                           discount_factor, summary_writer,
                           update_trust_region)
        if update_step % local_agent_configs["actor_update_freq"] == 0:
            INFO("update local agent policy on device: %s" % self.device)
            self.update_actor_and_alpha(states, masks, rephrase_positions,
                                        target_critic, update_step,
                                        summary_writer, update_trust_region)

    def soft_update_target_net(self, target_SACAgent, tau):
        # soft update the target network. first move to CPU, than move back to local
        # mind not to update global model while reading and synch local models.
        self.to(target_SACAgent.device)
        with target_SACAgent.soft_update_lock:
            for param, target_param in zip(
                    self.critic.parameters(),
                    target_SACAgent.critic.parameters()):
                target_param.data.copy_(tau * param.data +
                                        (1 - tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(),
                                           target_SACAgent.actor.parameters()):
                target_param.data.copy_(tau * param.data +
                                        (1 - tau) * target_param.data)
        self.to(self.device)
예제 #29
0
class Solver(BaseSolver):
    ''' Solver for training language models'''
    def __init__(self, config, paras, mode):
        super().__init__(config, paras, mode)
        # Logger settings
        self.best_loss = 10

    def fetch_data(self, data):
        ''' Move data to device, insert <sos> and compute text seq. length'''
        txt = torch.cat((torch.zeros(
            (data.shape[0], 1), dtype=torch.long), data),
                        dim=1).to(self.device)
        txt_len = torch.sum(data != 0, dim=-1)
        return txt, txt_len

    def load_data(self):
        ''' Load data for training/validation, store tokenizer and input/output shape'''
        self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \
            load_textset(self.paras.njobs, self.paras.gpu,
                         self.paras.pin_memory, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup ASR model and optimizer '''

        # Model
        # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device)
        self.model = Prediction(self.vocab_size,
                                **self.config['model']).to(self.device)
        self.rnnlm = RNNLM(self.vocab_size,
                           **self.config['model']).to(self.device)

        self.verbose(self.rnnlm.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            list(self.model.parameters()) + list(self.rnnlm.parameters()),
            **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step))

    def exec(self):
        ''' Training End-to-end ASR system '''
        self.verbose('Total training steps {}.'.format(
            human_format(self.max_step)))
        self.timer.set()

        while self.step < self.max_step:
            for data in self.tr_set:
                # Pre-step : update tf_rate/lr_rate and do zero_grad
                self.optimizer.pre_step(self.step)

                # Fetch data
                txt, txt_len = self.fetch_data(data)
                self.timer.cnt('rd')

                # Forward model
                outputs, hidden = self.model(txt[:, :-1], txt_len)
                pred = self.rnnlm(outputs)

                # Compute all objectives
                lm_loss = self.seq_loss(pred.view(-1, self.vocab_size),
                                        txt[:, 1:].reshape(-1))
                self.timer.cnt('fw')

                # Backprop
                grad_norm = self.backward(lm_loss)
                self.step += 1

                # Logger
                if self.step % self.PROGRESS_STEP == 0:
                    self.progress(
                        'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'.
                        format(lm_loss.cpu().item(), grad_norm,
                               self.timer.show()))
                    self.write_log('entropy', {'tr': lm_loss})
                    self.write_log('perplexity',
                                   {'tr': torch.exp(lm_loss).cpu().item()})

                # Validation
                if (self.step == 1) or (self.step % self.valid_step == 0):
                    self.validate()

                # End of step
                self.timer.set()
                if self.step > self.max_step:
                    break
        self.log.close()

    def validate(self):
        # Eval mode
        self.model.eval()
        self.rnnlm.eval()
        dev_loss = []

        for i, data in enumerate(self.dv_set):
            self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set)))
            # Fetch data
            txt, txt_len = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                outputs, hidden = self.model(txt[:, :-1], txt_len)
                pred = self.rnnlm(outputs)
            lm_loss = self.seq_loss(pred.view(-1, self.vocab_size),
                                    txt[:, 1:].reshape(-1))
            dev_loss.append(lm_loss)

        # Ckpt if performance improves
        dev_loss = sum(dev_loss) / len(dev_loss)
        dev_ppx = torch.exp(dev_loss).cpu().item()
        if dev_loss < self.best_loss:
            self.best_loss = dev_loss
            self.save_checkpoint('best_ppx.pth', 'perplexity', dev_ppx)
        self.write_log('entropy', {'dv': dev_loss})
        self.write_log('perplexity', {'dv': dev_ppx})

        # Show some example of last batch on tensorboard
        for i in range(min(len(txt), self.DEV_N_EXAMPLE)):
            if self.step == 1:
                self.write_log('true_text{}'.format(i),
                               self.tokenizer.decode(txt[i].tolist()))
            self.write_log(
                'pred_text{}'.format(i),
                self.tokenizer.decode(pred[i].argmax(dim=-1).tolist()))

        # Resume training
        self.model.train()
        self.rnnlm.train()
예제 #30
0
class Solver(BaseSolver):
    ''' Solver for training'''
    def __init__(self, config, paras):
        super().__init__(config, paras)
        # Logger settings
        self.best_dev_er = 1.0
        self.cur_epoch = 0
        # Configs following self-supervised learning
        self.task = self.paras.task
        assert self.task in ['phn-clf', 'spk-clf'], 'unsupported task'
        self.ssl_config = yaml.load(open(
            self.config['model']['feat']['config'], 'r'),
                                    Loader=yaml.FullLoader)
        self.feature = self.ssl_config['model']['method']
        if self.feature == 'npc' and 'spec' in self.config['model']['feat']:
            # NPC has additional option to use unmasked feature
            self.feat_spec = self.config['model']['feat']['spec']
        else:
            self.feat_spec = None
        self.config['data']['audio'] = self.ssl_config['data']['audio']

    def fetch_data(self, data, train=True):
        ''' Move data to device '''
        file_id, audio_feat, audio_len, label = data
        if self.gpu:
            audio_feat = audio_feat.cuda()
            label = label.cuda()
        # Extract feature
        with torch.no_grad():
            if self.feat_spec is not None:
                # Get unmasked feature from particular NPC layer
                n_layer_feat = int(self.feat_spec.split('-')[-1])
                audio_feat = self.feat_extractor.get_unmasked_feat(
                    audio_feat, n_layer_feat)
            elif self.feature == 'npc':
                # Get masked feature from NPC
                _, audio_feat = self.feat_extractor(audio_feat, testing=True)
            else:
                # Get feature from APC based model
                _, audio_feat = self.feat_extractor(audio_feat,
                                                    audio_len,
                                                    testing=True)
            # Mean pool feature for spkr classification
            if self.task == 'spk-clf':
                single_feat = []
                for a_feat, a_len in zip(audio_feat, audio_len):
                    single_feat.append(a_feat[:a_len].mean(dim=0))
                audio_feat = torch.stack(single_feat, dim=0)
        return file_id, audio_feat, audio_len, label

    def load_data(self):
        ''' Load data for training/validation '''
        self.tr_set, self.dv_set, self.tt_set, self.audio_dim, msg = \
            prepare_data(self.paras.njobs,self.paras.dev_njobs,self.paras.gpu,
                         self.paras.pin_memory, **self.config['data'])
        self.verbose(msg)

    def set_model(self):
        ''' Setup model and optimizer '''
        # Load SSL models for feature extraction
        self.verbose([' Load feat. extractor ckpt from '\
                        +self.config['model']['feat']['ckpt']])
        if self.feature in ['apc', 'vqapc']:
            from model.apc import APC as Net
        elif self.feature == 'npc':
            from model.npc import NPC as Net
            if self.feat_spec is not None:
                self.verbose([' Using specific feature: ' + self.feat_spec])
        else:
            raise NotImplementedError
        self.feat_extractor = Net(input_size=self.audio_dim,
                                  **self.ssl_config['model']['paras'])
        ckpt = torch.load(
            self.config['model']['feat']['ckpt'],
            map_location=self.device if self.mode == 'train' else 'cpu')
        ckpt['model'] = {k.replace('module.','',1):v \
                            for k,v in ckpt['model'].items()}
        self.feat_extractor.load_state_dict(ckpt['model'])

        # Classifier model
        self.model = CLF(feat_dim=self.feat_extractor.code_dim,
                         **self.config['model']['clf'])
        if self.gpu:
            self.feat_extractor = self.feat_extractor.cuda()
            self.feat_extractor.eval()
            self.model = self.model.cuda()
        model_paras = [{'params': self.model.parameters()}]

        # Losses
        ignore_idx = 0 if self.task == 'phn-clf' else -1
        self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx)
        if self.gpu:
            self.loss = self.loss.cuda()

        # Optimizer
        self.optimizer = Optimizer(model_paras, **self.config['hparas'])
        self.verbose(self.optimizer.create_msg())

        self.load_ckpt()
        self.model.train()

    def exec(self):
        ''' Training End-to-end ASR system '''
        if self.paras.mode == 'train':
            self.verbose('Total training epoch {}.'.format(
                human_format(self.epoch)))
            self.timer.set()
            ep_len = len(self.tr_set)
            for ep in range(self.epoch):
                if ep > 0:
                    # Lr decay if needed
                    self.optimizer.decay()
                for data in self.tr_set:
                    # Pre-step :  do zero_grad
                    self.optimizer.pre_step(self.step)

                    # Fetch data
                    self.timer.cnt('rd')
                    _, audio_feat, audio_len, label = self.fetch_data(data)

                    # Forward
                    pred = self.model(audio_feat)
                    if self.task == 'phn-clf':
                        pred = pred.permute(0, 2, 1)  # BxCxT for phn clf
                    loss = self.loss(pred, label)
                    self.timer.cnt('fw')

                    # Backprop
                    grad_norm = self.backward(loss)
                    self.step += 1

                    # Logger
                    if (self.step == 1) or (self.step % self.PROGRESS_STEP
                                            == 0):
                        self.progress(
                            ' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'
                            .format(100 * float(self.step % ep_len) / ep_len,
                                    loss.cpu().item(), grad_norm,
                                    self.timer.show()))
                        self.write_log(self.task + '_loss', {'tr': loss})
                        if self.task == 'phn-clf':
                            tr_er = cal_per(pred, label, audio_len)[0]
                        else:
                            tr_er = (pred.argmax(dim=-1) != label)
                            tr_er = tr_er.sum().detach().cpu().float() / len(
                                label)
                        self.write_log(self.task + '_er', {'tr': tr_er})
                    # End of step
                    self.timer.set()
                # End of epoch
                self.cur_epoch += 1
                self.validate()

        # Test at the end
        self.validate(test=True)
        self.log.close()

    def validate(self, test=False):
        # Eval mode
        self.model.eval()
        val_loss = []
        split = 'dev'
        val_hit, val_total = 0.0, 0.0
        ds = self.tt_set if test else self.dv_set

        # In training mode, best model is stored in RAM for test
        # ToDo: load ckpt
        if test:
            split = 'test'
            if self.paras.mode == 'train':
                self.model = self.best_model
                if self.gpu:
                    self.model = self.model.cuda()

        for i, data in enumerate(ds):
            self.progress('Valid step - {}/{}'.format(i + 1, len(ds)))
            # Fetch data
            _, audio_feat, audio_len, label = self.fetch_data(data)

            # Forward model
            with torch.no_grad():
                # Prediction
                pred = self.model(audio_feat)
                if self.task == 'phn-clf':
                    pred = pred.permute(0, 2, 1)  # BxCxT
                # Accumulate batch result
                val_loss.append(self.loss(pred, label))
                if self.task == 'phn-clf':
                    _, hit, total = cal_per(pred, label, audio_len)
                    val_hit += hit
                    val_total += total
                else:
                    hit = (pred.argmax(dim=-1) == label).sum()
                    val_hit += hit.detach().cpu().float()
                    val_total += len(label)
                # Write testing prediction if needed
                if test and self.paras.write_test:
                    if self.task == 'phn-clf':
                        pred = pred.argmax(dim=1).detach().cpu()
                    label = label.cpu()
                    with open(os.path.join(self.ckpdir, self.task + '.csv'),
                              'a') as f:
                        for p, l, a_len in zip(pred, label, audio_len):
                            for x, y in zip(p[:a_len].tolist(),
                                            l[:a_len].tolist()):
                                f.write('{}\t{}\n'.format(x, y))

        # Record metric, store ckpt by dev error rate
        val_loss = sum(val_loss) / len(val_loss)
        val_er = 1.0 - val_hit / val_total
        self.write_log(self.task + '_loss', {split: val_loss})
        self.write_log(self.task + '_er', {split: val_er})
        if split == 'dev' and self.best_dev_er > val_er:
            self.best_dev_er = val_er
            self.save_checkpoint('best.pth', self.task + '_er', val_er)
            self.best_model = copy.deepcopy(self.model.cpu())  # Clone for test

        # Resume training
        if self.gpu:
            self.model = self.model.cuda()
        self.model.train()