def set_model(self): ''' Setup ASR model and optimizer ''' # Model init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) self.eval_target = 'phone' if self.config['data']['corpus'][ 'target'] == 'ipa' else 'char' # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() if self.paras.transfer: self.transfer_weight() # Automatically load pre-trained model if self.paras.load is given if self.paras.load: self.load_ckpt()
def set_model(self): ''' Setup ASR model and optimizer ''' # Model # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.model = Prediction(self.vocab_size, **self.config['model']).to(self.device) self.rnnlm = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.verbose(self.rnnlm.create_msg()) # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer self.optimizer = Optimizer( list(self.model.parameters()) + list(self.rnnlm.parameters()), **self.config['hparas']) # Enable AMP if needed self.enable_apex() # load pre-trained model if self.paras.load: self.load_ckpt() ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step))
def set_model(self): ''' Setup ASR model and optimizer ''' # Model #print(self.feat_dim) #160 batch_size = self.config['data']['corpus']['batch_size'] // 2 self.model = ASR(self.feat_dim, self.vocab_size, batch_size, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses '''label smoothing''' if self.config['hparas']['label_smoothing']: self.seq_loss = LabelSmoothingLoss(31, 0.1) print('[INFO] using label smoothing. ') else: self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) self.ctc_loss = torch.nn.CTCLoss( blank=0, zero_infinity=False) # Note: zero_infinity=False is unstable? # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and (self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.lr_scheduler = self.optimizer.lr_scheduler self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Transfer Learning if self.transfer_learning: self.verbose('Apply transfer learning: ') self.verbose(' Train encoder layers: {}'.format( self.train_enc)) self.verbose(' Train decoder: {}'.format( self.train_dec)) self.verbose(' Save name: {}'.format( self.save_name)) # Automatically load pre-trained model if self.paras.load is given self.load_ckpt()
def set_model(self): ''' Setup ASR model ''' # Model self.feat_dim = 120 self.vocab_size = 46 init_adadelta = True ''' Setup ASR model and optimizer ''' # Model # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, ** self.src_config['model']).to(self.device) self.verbose(self.model.create_msg()) if self.finetune_first>0: names = ["encoder.layers.%d"%i for i in range(self.finetune_first)] model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}] else: model_paras = [{'params': self.model.parameters()}] # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and ( self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.src_config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # Beam decoder self.decoder = BeamDecoder( self.model, self.emb_decoder, **self.config['decode']) self.verbose(self.decoder.create_msg()) # del self.model # del self.emb_decoder self.decoder.to(self.device)
def set_model(self): ''' Setup model and optimizer ''' # Load SSL models for feature extraction self.verbose([' Load feat. extractor ckpt from '\ +self.config['model']['feat']['ckpt']]) if self.feature in ['apc', 'vqapc']: from model.apc import APC as Net elif self.feature == 'npc': from model.npc import NPC as Net if self.feat_spec is not None: self.verbose([' Using specific feature: ' + self.feat_spec]) else: raise NotImplementedError self.feat_extractor = Net(input_size=self.audio_dim, **self.ssl_config['model']['paras']) ckpt = torch.load( self.config['model']['feat']['ckpt'], map_location=self.device if self.mode == 'train' else 'cpu') ckpt['model'] = {k.replace('module.','',1):v \ for k,v in ckpt['model'].items()} self.feat_extractor.load_state_dict(ckpt['model']) # Classifier model self.model = CLF(feat_dim=self.feat_extractor.code_dim, **self.config['model']['clf']) if self.gpu: self.feat_extractor = self.feat_extractor.cuda() self.feat_extractor.eval() self.model = self.model.cuda() model_paras = [{'params': self.model.parameters()}] # Losses ignore_idx = 0 if self.task == 'phn-clf' else -1 self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx) if self.gpu: self.loss = self.loss.cuda() # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) self.load_ckpt() self.model.train()
def set_model(self): ''' Setup Audio AE-model and optimizer ''' # Model self.model = VQVAE(self.n_mels, self.linear_dim, self.vocab_size, self.n_spkr, **self.config['model']).to(self.device) self.n_frames_per_step = self.model.n_frames_per_step self.verbose(self.model.create_msg()) # Objective self.freq_loss = partial( freq_loss, sample_rate=self.audio_converter.sr, n_mels=self.audio_converter.n_mels, loss=self.config['hparas']['freq_loss_type'], differential_loss=self.config['hparas']['differential_loss'], emphasize_linear_low=self.config['hparas']['emphasize_linear_low']) self.ctc_loss = torch.nn.CTCLoss() self.stop_loss = torch.nn.BCEWithLogitsLoss() # Optimizer self.optimizer = Optimizer(self.model.parameters(), **self.config['hparas']) self.verbose(self.optimizer.create_msg()) ### ToDo : unsup first? self.verbose(' | ASR weight = {}\t| start step = {}'.format( self.asr_weight, 0)) self.verbose(' | TTS weight = {}\t| start step = {}'.format( self.tts_weight, 0)) self.verbose(' | Txt weight = {}\t| start step = {}'.format( self.unpair_text_weight, self.unpair_text_start_step)) self.verbose(' | Sph weight = {}\t| start step = {}'.format( self.unpair_speech_weight, self.unpair_speech_start_step)) # ToDo: load pre-trained model if self.paras.load: ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step))
def set_model(self): ''' Setup ASR model and optimizer ''' # Model init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and (self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() self.paras.load = 'ckpt/asr_example_sd0/best_att.pth' # Automatically load pre-trained model if self.paras.load is given self.load_ckpt()
def set_model(self): ''' Setup ASR model and optimizer ''' # Model self.model = ASR(self.feat_dim, self.vocab_size, ** self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Automatically load pre-trained model if self.paras.load is given self.load_ckpt()
def set_model(self): ''' Setup model and optimizer ''' # Model self.method = self.config['model']['method'] if self.method in ['apc','vqapc']: self.n_future = self.config['model']['n_future'] from model.apc import APC as Net elif self.method == 'npc': from model.npc import NPC as Net else: raise NotImplementedError self.model = Net(input_size=self.audio_dim, **self.config['model']['paras']) if self.gpu: self.model = self.model.cuda() self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Loss if 'npc' in self.method: # Avoid reduction for NPC for zero-padding self.loss = torch.nn.L1Loss(reduction='none') else: # APC family have zero-padding with torch API self.loss = torch.nn.L1Loss() if self.gpu: self.loss = self.loss.cuda() # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # ToDo: Data Parallel? # self.model = torch.nn.DataParallel(self.model) self.model.train()
def __init__( self, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_iterator, save_to, device="cpu", ): """ initiate translation environments, needs a discriminator and translator :param attack_configs: attack configures dictionary :param save_to: discriminator models :param data_iterator: use to provide data for environment initiate the directory of the src sentences :param device: (string) devices to allocate variables("cpu", "cuda:*") default as cpu """ self.data_iterator = data_iterator discriminator_model_configs = discriminator_configs[ "discriminator_model_configs"] discriminator_optim_configs = discriminator_configs[ "discriminator_optimizer_configs"] self.victim_config_path = attack_configs["victim_configs"] self.victim_model_path = attack_configs["victim_model"] # determine devices self.device = device with open(self.victim_config_path.strip()) as v_f: print("open victim configs...%s" % self.victim_config_path) victim_configs = yaml.load(v_f) self.src_vocab = src_vocab self.trg_vocab = trg_vocab self.translate_model = build_translate_model(victim_configs, self.victim_model_path, vocab_src=self.src_vocab, vocab_trg=self.trg_vocab, device=self.device) self.translate_model.eval() self.w2p, self.w2vocab = load_or_extract_near_vocab( config_path=self.victim_config_path, model_path=self.victim_model_path, init_perturb_rate=attack_configs["init_perturb_rate"], save_to=os.path.join(save_to, "near_vocab"), save_to_full=os.path.join(save_to, "full_near_vocab"), top_reserve=12, emit_as_id=True) ######################################################### # to update discriminator # discriminator_data_configs = attack_configs["discriminator_data_configs"] self.discriminator = TransDiscriminator( n_src_words=self.src_vocab.max_n_words, n_trg_words=self.trg_vocab.max_n_words, **discriminator_model_configs) self.discriminator.to(self.device) load_embedding(self.discriminator, model_path=self.victim_model_path, device=self.device) self.optim_D = Optimizer( name=discriminator_optim_configs["optimizer"], model=self.discriminator, lr=discriminator_optim_configs["learning_rate"], grad_clip=discriminator_optim_configs["grad_clip"], optim_args=discriminator_optim_configs["optimizer_params"]) self.criterion_D = nn.CrossEntropyLoss( ) # used in discriminator updates self.scheduler_D = None # default as None if discriminator_optim_configs['schedule_method'] is not None: if discriminator_optim_configs['schedule_method'] == "loss": self.scheduler_D = ReduceOnPlateauScheduler( optimizer=self.optim_D, **discriminator_optim_configs["scheduler_configs"]) elif discriminator_optim_configs['schedule_method'] == "noam": self.scheduler_D = NoamScheduler( optimizer=self.optim_D, **discriminator_optim_configs['scheduler_configs']) elif discriminator_optim_configs["schedule_method"] == "rsqrt": self.scheduler_D = RsqrtScheduler( optimizer=self.optim_D, **discriminator_optim_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.". format(discriminator_optim_configs['schedule_method'])) ############################################################ self._init_state() self.adversarial = attack_configs[ "adversarial"] # adversarial sample or reinforced samples self.r_s_weight = attack_configs["r_s_weight"] self.r_d_weight = attack_configs["r_d_weight"]
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Logger settings self.best_wer = {'ctc': 3.0} self.best_per = {'ctc': 3.0} # Curriculum learning affects data loader self.curriculum = self.config['hparas']['curriculum'] def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg= \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum > 0, **self.config['data']) self.verbose(msg) def transfer_weight(self): # Transfer optimizer ckpt_path = self.config['data']['transfer'].pop('src_ckpt') ckpt = torch.load(ckpt_path, map_location=self.device) #optim_ckpt = ckpt['optimizer'] #for ctc_final_related_param in optim_ckpt['param_groups'][0]['params'][-2:]: # optim_ckpt['state'].pop(ctc_final_related_param) #self.optimizer.load_opt_state_dict(optim_ckpt) # Load weights msg = self.model.transfer_with_mapping(ckpt, self.config['data']['transfer'], self.tokenizer) del ckpt self.verbose(msg) def set_model(self): ''' Setup ASR model and optimizer ''' # Model init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) self.eval_target = 'phone' if self.config['data']['corpus'][ 'target'] == 'ipa' else 'char' # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() if self.paras.transfer: self.transfer_weight() # Automatically load pre-trained model if self.paras.load is given if self.paras.load: self.load_ckpt() # ToDo: other training methods def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss = None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad # zero grad here tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len = self.model(feat, feat_len) # Compute all objectives if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss = ctc_loss self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) #self.write_log('wer', {'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) ctc_output = [ x[:length].argmax(dim=-1) for x, length in zip(ctc_output, encode_len) ] self.write_log( 'per', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True) }) self.write_log( 'wer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True) }) self.write_log('loss', {'tr_ctc': ctc_loss.cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1 #self.log.close() def validate(self): # Eval mode self.model.eval() dev_per = {'ctc': []} dev_wer = {'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len = self.model(feat, feat_len) ctc_output = [ x[:length].argmax(dim=-1) for x, length in zip(ctc_output, encode_len) ] dev_per['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='per', ctc=True)) dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True)) # Show some example on tensorboard if i == len(self.dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): #if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) self.write_log( 'ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].tolist(), ignore_repeat=True)) # Ckpt if performance improves for task in ['ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) dev_per[task] = sum(dev_per[task]) / len(dev_per[task]) if dev_per[task] < self.best_per[task]: self.best_per[task] = dev_per[task] self.save_checkpoint('best_{}.pth'.format('per'), 'per', dev_per[task]) self.log.log_other('dv_best_per', self.best_per['ctc']) if self.eval_target == 'char' and dev_wer[task] < self.best_wer[ task]: self.best_wer[task] = dev_wer[task] self.save_checkpoint('best_{}.pth'.format('wer'), 'wer', dev_wer[task]) self.log.log_other('dv_best_wer', self.best_wer['ctc']) self.write_log('per', {'dv_' + task: dev_per[task]}) if self.eval_target == 'char': self.write_log('wer', {'dv_' + task: dev_wer[task]}) self.save_checkpoint('latest.pth', 'per', dev_per['ctc'], show_msg=False) if self.paras.save_every: self.save_checkpoint(f'{self.step}.path', 'per', dev_per['ctc'], show_msg=False) # Resume training self.model.train()
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras): super().__init__(config, paras) # Logger settings self.val_loss = 1000 self.cur_epoch = 0 def fetch_data(self, data): ''' Move data to device ''' file_id, audio_feat, audio_len = data if self.gpu: audio_feat = audio_feat.cuda() return file_id, audio_feat, audio_len def load_data(self): ''' Load data for training/validation ''' self.tr_set, self.dv_set, _, self.audio_dim, msg = \ prepare_data(self.paras.njobs, self.paras.dev_njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup model and optimizer ''' # Model self.method = self.config['model']['method'] if self.method in ['apc','vqapc']: self.n_future = self.config['model']['n_future'] from model.apc import APC as Net elif self.method == 'npc': from model.npc import NPC as Net else: raise NotImplementedError self.model = Net(input_size=self.audio_dim, **self.config['model']['paras']) if self.gpu: self.model = self.model.cuda() self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Loss if 'npc' in self.method: # Avoid reduction for NPC for zero-padding self.loss = torch.nn.L1Loss(reduction='none') else: # APC family have zero-padding with torch API self.loss = torch.nn.L1Loss() if self.gpu: self.loss = self.loss.cuda() # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # ToDo: Data Parallel? # self.model = torch.nn.DataParallel(self.model) self.model.train() def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training epoch {}.'.format( human_format(self.epoch))) self.timer.set() aug_loss = None ep_len = len(self.tr_set) for ep in range(self.epoch): # Pre-step, decay if ep>0: self.optimizer.decay() for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad self.optimizer.pre_step(self.step) # Fetch data _, audio_feat, audio_len = self.fetch_data(data) self.timer.cnt('rd') # Forward real data if 'npc' in self.method: # NPC: input = target pred, _ = self.model(audio_feat) loss = self.loss(pred, audio_feat) # Compute loss on valid part only effective_loss = 0 for i,a_len in enumerate(audio_len): effective_loss += loss[i,:a_len,:].mean(dim=-1).sum() loss = effective_loss/sum(audio_len) else: # APC: input = shifted target audio_len = [l-self.n_future for l in audio_len] pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=False) loss = self.loss(pred, audio_feat[:,self.n_future:,:]) self.timer.cnt('fw') # Backprop grad_norm = self.backward(loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress(' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(100*float(self.step%ep_len)/ep_len, loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('loss', {'tr': loss}) if (self.step == 1) or (self.step % self.PLOT_STEP == 0): # Perplexity of P(token) g1_ppx, g2_ppx = self.model.report_ppx() self.write_log('ppx', {'group 1':g1_ppx, 'group 2':g2_ppx}) g1_usg, g2_usg = self.model.report_usg() # Empty cache # Plots if self.paras.draw: g1_hist = draw(g1_usg, hist=True) g2_hist = draw(g2_usg, hist=True) self.write_log('VQ Group 1 Hist.',g1_hist) self.write_log('VQ Group 2 Hist.',g2_hist) # Some spectrograms plt_idx = 0 self.write_log('Spectrogram (raw)', draw(audio_feat[plt_idx])) self.write_log('Spectrogram (pred)', draw(pred[plt_idx])) # End of step self.timer.set() # End of epoch self.cur_epoch += 1 self.validate() self.log.close() def validate(self): # Eval mode self.model.eval() dev_loss = [] for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set))) # Fetch data _, audio_feat, audio_len = self.fetch_data(data) # Forward model with torch.no_grad(): if 'npc' in self.method: pred, _ = self.model(audio_feat, testing=True) loss = self.loss(pred, audio_feat) # Compute loss on valid part only effective_loss = 0 for i,a_len in enumerate(audio_len): effective_loss += loss[i,:a_len,:].mean(dim=-1).sum() loss = effective_loss/sum(audio_len) else: audio_len = [l-self.n_future for l in audio_len] pred, _ = self.model(audio_feat[:,:-self.n_future,:], audio_len, testing=True) loss = self.loss(pred, audio_feat[:,self.n_future:,:]) dev_loss.append(loss.cpu().item()) # Record metric dev_loss = sum(dev_loss)/len(dev_loss) self.write_log('loss', {'dev':dev_loss}) if dev_loss < self.val_loss: self.val_loss = dev_loss self.save_checkpoint('best_loss.pth', 'loss', dev_loss) # Resume training self.model.train()
def train(FLAGS): """ FLAGS: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # write log of training to file. write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) GlobalNames.USE_GPU = FLAGS.use_gpu if GlobalNames.USE_GPU: CURRENT_DEVICE = "cpu" else: CURRENT_DEVICE = "cuda:0" config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) INFO(pretty_configs(configs)) # Add default configs configs = default_configs(configs) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] GlobalNames.SEED = training_configs['seed'] set_seed(GlobalNames.SEED) best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) train_batch_size = training_configs["batch_size"] * max(1, training_configs["update_cycle"]) train_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"]) train_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['train_data'][0], vocabulary=vocab_src, max_len=data_configs['max_len'][0], ), TextLineDataset(data_path=data_configs['train_data'][1], vocabulary=vocab_tgt, max_len=data_configs['max_len'][1], ), shuffle=training_configs['shuffle'] ) valid_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['valid_data'][0], vocabulary=vocab_src, ), TextLineDataset(data_path=data_configs['valid_data'][1], vocabulary=vocab_tgt, ) ) training_iterator = DataIterator(dataset=train_bitext_dataset, batch_size=train_batch_size, use_bucket=training_configs['use_bucket'], buffer_size=train_buffer_size, batching_func=training_configs['batching_key']) valid_iterator = DataIterator(dataset=valid_bitext_dataset, batch_size=training_configs['valid_batch_size'], use_bucket=True, buffer_size=100000, numbering=True) bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"], num_refs=data_configs["num_refs"], lang_pair=data_configs["lang_pair"], sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'], postprocess=training_configs["bleu_valid_configs"]['postprocess'] ) INFO('Done. Elapsed time {0}'.format(timer.toc())) lrate = optimizer_configs['learning_rate'] is_early_stop = False # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial model_collections = Collections() checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)), num_max_keeping=training_configs['num_kept_checkpoints'] ) best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model']) # 1. Build Model & Criterion INFO('Building model...') timer.tic() nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) INFO(nmt_model) critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 2. Move to GPU if GlobalNames.USE_GPU: nmt_model = nmt_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE) # 4. Build optimizer INFO('Building Optimizer...') optim = Optimizer(name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'] ) # 5. Build scheduler for optimizer if needed if optimizer_configs['schedule_method'] is not None: if optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler(optimizer=optim, **optimizer_configs["scheduler_configs"] ) elif optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs']) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None # 6. build EMA if training_configs['ema_decay'] > 0.0: ema = ExponentialMovingAverage(named_params=nmt_model.named_parameters(), decay=training_configs['ema_decay']) else: ema = None INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if FLAGS.reload: checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections) # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [0])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] summary_writer = SummaryWriter(log_dir=FLAGS.log_path) cum_samples = 0 cum_words = 0 best_valid_loss = 1.0 * 1e10 # Max Float saving_files = [] # Timer for computing speed timer_for_speed = Timer() timer_for_speed.tic() INFO('Begin training...') while True: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() training_progress_bar = tqdm(desc=' - (Epoch %d) ' % eidx, total=len(training_iterator), unit="sents" ) for batch in training_iter: uidx += 1 if scheduler is None: pass elif optimizer_configs["schedule_method"] == "loss": scheduler.step(metric=best_valid_loss) else: scheduler.step(global_step=uidx) seqs_x, seqs_y = batch n_samples_t = len(seqs_x) n_words_t = sum(len(s) for s in seqs_y) cum_samples += n_samples_t cum_words += n_words_t training_progress_bar.update(n_samples_t) optim.zero_grad() # Prepare data for seqs_x_t, seqs_y_t in split_shard(seqs_x, seqs_y, split_size=training_configs['update_cycle']): x, y = prepare_data(seqs_x_t, seqs_y_t, cuda=GlobalNames.USE_GPU) loss = compute_forward(model=nmt_model, critic=critic, seqs_x=x, seqs_y=y, eval=False, normalization=n_samples_t, norm_by_words=training_configs["norm_by_words"]) optim.step() if ema is not None: ema.step() # ================================================================================== # # Display some information if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']): # words per second and sents per second words_per_sec = cum_words / (timer.toc(return_seconds=True)) sents_per_sec = cum_samples / (timer.toc(return_seconds=True)) lrate = list(optim.get_lrate())[0] summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx) summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) # Reset timer timer.tic() cum_words = 0 cum_samples = 0 # ================================================================================== # # Saving checkpoints if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug): model_collections.add_to_collection("uidx", uidx) model_collections.add_to_collection("eidx", eidx) model_collections.add_to_collection("bad_count", bad_count) if not is_early_stop: checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ema=ema) # ================================================================================== # # Loss Validation & Learning rate annealing if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'], debug=FLAGS.debug): if ema is not None: origin_state_dict = deepcopy(nmt_model.state_dict()) nmt_model.load_state_dict(ema.state_dict(), strict=False) valid_loss = loss_validation(model=nmt_model, critic=critic, valid_iterator=valid_iterator, ) model_collections.add_to_collection("history_losses", valid_loss) min_history_loss = np.array(model_collections.get_collection("history_losses")).min() summary_writer.add_scalar("loss", valid_loss, global_step=uidx) summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx) best_valid_loss = min_history_loss if ema is not None: nmt_model.load_state_dict(origin_state_dict) del origin_state_dict # ================================================================================== # # BLEU Validation & Early Stop if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['bleu_valid_freq'], min_step=training_configs['bleu_valid_warmup'], debug=FLAGS.debug): if ema is not None: origin_state_dict = deepcopy(nmt_model.state_dict()) nmt_model.load_state_dict(ema.state_dict(), strict=False) valid_bleu = bleu_validation(uidx=uidx, valid_iterator=valid_iterator, batch_size=training_configs["bleu_valid_batch_size"], model=nmt_model, bleu_scorer=bleu_scorer, vocab_tgt=vocab_tgt, valid_dir=FLAGS.valid_path, max_steps=training_configs["bleu_valid_configs"]["max_steps"], beam_size=training_configs["bleu_valid_configs"]["beam_size"], alpha=training_configs["bleu_valid_configs"]["alpha"] ) model_collections.add_to_collection(key="history_bleus", value=valid_bleu) best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max()) summary_writer.add_scalar("bleu", valid_bleu, uidx) summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx) # If model get new best valid bleu score if valid_bleu >= best_valid_bleu: bad_count = 0 if is_early_stop is False: # 1. save the best model torch.save(nmt_model.state_dict(), best_model_prefix + ".final") # 2. record all several best models best_model_saver.save(global_step=uidx, model=nmt_model) else: bad_count += 1 # At least one epoch should be traversed if bad_count >= training_configs['early_stop_patience'] and eidx > 0: is_early_stop = True WARN("Early Stop!") summary_writer.add_scalar("bad_count", bad_count, uidx) if ema is not None: nmt_model.load_state_dict(origin_state_dict) del origin_state_dict INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format( uidx, valid_loss, valid_bleu, lrate, bad_count )) training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break
class VqvaeTrainer(BaseSolver): def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Init settings self.step = 0 self.best_tts_loss = 100.0 self.best_per = 2.0 self.asr_weight = self.config['hparas']['asr_weight'] self.tts_weight = self.config['hparas']['tts_weight'] self.unpair_text_start_step = config['hparas'][ 'unpair_text_start_step'] self.unpair_text_weight = self.config['hparas']['unpair_text_weight'] self.unpair_speech_start_step = config['hparas'][ 'unpair_speech_start_step'] self.unpair_speech_weight = self.config['hparas'][ 'unpair_speech_weight'] def fetch_data(self, iter_name): # Load from iterator mel = None while mel is None: try: mel, aug_mel, linear, sid, text = next(getattr( self, iter_name)) except StopIteration: setattr(self, iter_name, iter(getattr(self, iter_name.replace('iter', 'set')))) # Pad to match n_frames_per_step (at least 1 frame padded) pad_len = self.n_frames_per_step - (mel.shape[1] % self.n_frames_per_step) mel = torch.cat( [mel, SPEC_PAD_VALUE * torch.ones_like(mel)[:, :pad_len, :]], dim=1) linear = torch.cat( [linear, SPEC_PAD_VALUE * torch.ones_like(linear)[:, :pad_len, :]], dim=1) return mel.to(self.device),\ aug_mel.to(self.device),\ linear.to(self.device),\ text.to(self.device),\ sid.to(self.device) #return mel.to(self.device, non_blocking=True),\ # aug_mel.to(self.device, non_blocking=True),\ # linear.to(self.device, non_blocking=True),\ # text.to(self.device, non_blocking=True),\ # sid.to(self.device, non_blocking=True) def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.verbose(['Loading data... large corpus may took a while.']) self.unpair_set, self.pair_set, self.dev_set, self.test_set, self.audio_converter, self.tokenizer, data_msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.pair_iter = iter(self.pair_set) self.unpair_iter = iter(self.unpair_set) self.dev_iter = iter(self.dev_set) # Feature statics self.n_mels, self.linear_dim = self.audio_converter.feat_dim self.vocab_size = self.tokenizer.vocab_size self.n_spkr = len( json.load(open(self.config['data']['corpus']['spkr_map']))) self.verbose(data_msg) def set_model(self): ''' Setup Audio AE-model and optimizer ''' # Model self.model = VQVAE(self.n_mels, self.linear_dim, self.vocab_size, self.n_spkr, **self.config['model']).to(self.device) self.n_frames_per_step = self.model.n_frames_per_step self.verbose(self.model.create_msg()) # Objective self.freq_loss = partial( freq_loss, sample_rate=self.audio_converter.sr, n_mels=self.audio_converter.n_mels, loss=self.config['hparas']['freq_loss_type'], differential_loss=self.config['hparas']['differential_loss'], emphasize_linear_low=self.config['hparas']['emphasize_linear_low']) self.ctc_loss = torch.nn.CTCLoss() self.stop_loss = torch.nn.BCEWithLogitsLoss() # Optimizer self.optimizer = Optimizer(self.model.parameters(), **self.config['hparas']) self.verbose(self.optimizer.create_msg()) ### ToDo : unsup first? self.verbose(' | ASR weight = {}\t| start step = {}'.format( self.asr_weight, 0)) self.verbose(' | TTS weight = {}\t| start step = {}'.format( self.tts_weight, 0)) self.verbose(' | Txt weight = {}\t| start step = {}'.format( self.unpair_text_weight, self.unpair_text_start_step)) self.verbose(' | Sph weight = {}\t| start step = {}'.format( self.unpair_speech_weight, self.unpair_speech_start_step)) # ToDo: load pre-trained model if self.paras.load: ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step)) def exec(self): self.verbose( ['Total training steps {}.'.format(human_format(self.max_step))]) self.timer.set() unpair_speech_loss, unpair_text_loss, unsup_pred, unsup_trans, unsup_align = None, None, None, None, None ctc_nan_flag, ignore_speech_flag = 0, 0 tok_usage, gt_usage = [], [] cnter = {'ctc_nan': 0, 'unp_sph': 0, 'unp_txt': 0} while self.step < self.max_step: # --------------------- Load data ----------------------- # # Unpair setting unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = None, None, None, None, None post_pred, asr_post_loss = None, None # For ASR postnet only use_unpair_text = self.unpair_text_weight > 0 and self.step > self.unpair_text_start_step use_unpair_speech = self.unpair_speech_weight > 0 and self.step > self.unpair_speech_start_step tf_rate = self.optimizer.pre_step( self.step) # Catch the returned tf_rate if needed # ToDo : change # of sup. step = 2 x # of unsup. step ? mel, aug_mel, linear, text, sid = self.fetch_data( iter_name='pair_iter') # Load unpaired data only when use_unpair_xxx == True if self.step % 2 == 0: #2 # if True: # ASR first speech_first = True if use_unpair_speech: unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \ self.fetch_data(iter_name='unpair_iter') else: # TTS first speech_first = False if use_unpair_text: cnter['unp_txt'] += 1 unpair_mel, unpair_aug_mel, unpair_linear, unpair_text, unpair_sid = \ self.fetch_data(iter_name='unpair_iter') total_loss = 0 bs = len(mel) self.timer.cnt('rd') try: # ----------------------- Forward ------------------------ # if speech_first: # Cycle : speech -> text -> speech pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \ self.model.speech_to_text(paired_mel=aug_mel, unpaired_mel= unpair_aug_mel) # Check to involve unsupervised Speech2Speech if unpair_latent is not None: # ASR output is the representataion for speech2speech cnter['unp_sph'] += 1 ignore_speech_cycle = False unpaired_teacher = unpair_mel else: # ASR output is all blank (cannot be passed to TTS) only paired text is used ignore_speech_cycle = True unpaired_teacher = None # text -> speech pair_mel_pred, pair_linear_pred, pair_align, _, \ unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\ self.model.text_to_speech(paired_text = text, paired_sid=sid, unpaired_sid=unpair_sid, unpaired_latent = unpair_latent, unpaired_text= None, unpaired_latent_len = unpair_latent_len, paired_teacher = mel, unpaired_teacher = unpaired_teacher, tf_rate = tf_rate ) else: # Cycle : text -> speech -> text pair_mel_pred, pair_linear_pred, pair_align, _, \ unpair_mel_pred, unpair_linear_pred, unpair_align, _ =\ self.model.text_to_speech(paired_text=text, paired_sid=sid, unpaired_sid=unpair_sid, unpaired_latent=None, unpaired_text=unpair_text, unpaired_latent_len=None, paired_teacher=mel, unpaired_teacher=None, tf_rate=tf_rate ) if use_unpair_text: unpair_mel_pred = unpair_mel_pred.detach( ) # Stop-grad for tts in text2text pair_prob, _, unpair_prob, unpair_latent, unpair_latent_len, pair_post_prob, _ = \ self.model.speech_to_text(paired_mel=aug_mel, unpaired_mel=unpair_mel_pred, #None, #unpair_mel_pred, #None, #unpaired_mel= unpair_mel_pred, using_fake_mel=use_unpair_text) # Paired ASR loss asr_loss = self.compute_ctcloss(aug_mel, pair_prob, text) if self.model.use_asr_postnet: total_loss = total_loss + self.asr_weight * ( 1 - self.model.asr_postnet_weight) * asr_loss asr_post_loss = self.compute_ctcloss(aug_mel, pair_post_prob, text, apply_log=False) total_loss = total_loss + self.asr_weight * self.model.asr_postnet_weight * asr_post_loss else: total_loss = total_loss + self.asr_weight * asr_loss if math.isnan(asr_loss) or math.isinf(asr_loss): cnter['ctc_nan'] += 1 asr_loss = 0 # Paired TTS loss mel_loss = self.freq_loss(pair_mel_pred, mel) linear_loss = self.freq_loss(pair_linear_pred, linear) tts_loss = mel_loss + linear_loss total_loss = total_loss + self.tts_weight * tts_loss # Unpaired loss if speech_first: # Unpaired speech reconstruction loss if not ignore_speech_cycle: unpair_speech_loss = self.freq_loss(unpair_mel_pred, unpair_mel) +\ self.freq_loss(unpair_linear_pred, unpair_linear) #total_loss += self.unpair_speech_weight*unpair_speech_loss if self.step > self.unpair_speech_start_step: total_loss += self.unpair_speech_weight * unpair_speech_loss elif use_unpair_text: # Unpaired text reconstruction loss ctc_input = (unpair_prob + EPS).transpose(0, 1).log() if self.paras.actual_len: asr_input_len = (unpair_text != 0).sum( dim=-1) * FRAME_PHN_RATIO asr_input_len = asr_input_len + asr_input_len % self.model.n_frames_per_step ctc_len = 1 + (asr_input_len // self.model.time_reduce_factor) else: ctc_len = torch.LongTensor( [unpair_prob.shape[1]] * unpair_prob.shape[0]).to(device=self.device) unpair_text_loss = self.ctc_loss( ctc_input, unpair_text.to_sparse().values(), ctc_len, torch.sum(unpair_text != 0, dim=-1)) if math.isnan(unpair_text_loss) or math.isinf( unpair_text_loss): cnter['ctc_nan'] += 1 unpair_text_loss = 0 total_loss += self.unpair_text_weight * unpair_text_loss # VQ-loss # if vq_loss>0: # total_loss += self.model.vq_weight*vq_loss # if commit_loss>0: # total_loss += self.model.commit_weight*commit_loss # Statics (over unsup. speech only) if speech_first and use_unpair_speech: unsup_pred = unpair_prob.argmax(dim=-1).cpu() unsup_trans = unpair_text.cpu() tok_usage += unsup_pred.flatten().tolist() gt_usage += unsup_trans.flatten().tolist() if unpair_align is not None: unsup_align = unpair_align.detach().cpu() else: unsup_align = [None] * bs self.timer.cnt('fw') # ----------------------- Backward ------------------------ # grad_norm = self.backward(total_loss) # For debugging # if math.isnan(grad_norm): # import IPython # IPython.embed() self.step += 1 # Log if (self.step == 1) or (self.step % self._PROGRESS_STEP == 0): self.progress('Tr stat | Loss - {:.2f} (CTC-nan/unp-sph/unp-txt={}/{}/{}) | Grad. Norm - {:.2f} | {} '\ .format(total_loss.cpu().item(), cnter['ctc_nan'], cnter['unp_sph'], cnter['unp_txt'], grad_norm, self.timer.show())) self.write_log( 'txt_loss', { 'pair': asr_loss.item() if asr_loss is not None else None, 'unpair': unpair_text_loss.item() if unpair_text_loss is not None else None, 'post': asr_post_loss.item() if asr_post_loss is not None else None }) self.write_log( 'speech_loss', { 'pair': tts_loss.item() if tts_loss is not None else None, 'unpair': unpair_speech_loss.item() if unpair_speech_loss is not None else None }) #self.write_log('stop_err',{'tr':stop_err}) # if commit_loss>0: # self.write_log('commit',{'tr':commit_loss}) # if vq_loss>0: # self.write_log('commit',{'vq':vq_loss}) # self.write_log('temperature',{'temp':self.model.codebook.temp.data}) # self.write_log('ppx',{'tr':cal_ppx(p_code)}) for k in cnter.keys(): cnter[k] = 0 if (self.step == 1) or (self.step % ATTENTION_PLOT_STEP == 0): align = pair_align.cpu() # align shape BxDsxEs sup_pred = pair_prob.argmax(dim=-1).cpu() sup_trans = text.cpu() if self.model.use_asr_postnet: post_pred = pair_post_prob.argmax(dim=-1).cpu() self.write_log( 'per', { 'pair': cal_per(sup_pred, sup_trans), 'unpair': cal_per(unsup_pred, unsup_trans), 'post': cal_per(post_pred, sup_trans) }) self.write_log( 'unpair_hist', data_to_bar(tok_usage, gt_usage, self.vocab_size, self.tokenizer._vocab_list)) for i in range(LISTEN_N_EXAMPLES): self.write_log( 'pair_align{}'.format(i), feat_to_fig(align[i].cpu().detach())) if unsup_align is not None and unsup_align[ i] is not None: self.write_log( 'unpair_align{}'.format(i), feat_to_fig(unsup_align[i].cpu().detach())) tok_usage, gt_usage = [], [] # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step self.timer.set() if self.step > self.max_step: break except RuntimeError as e: if 'out of memory' in str(e): self.verbose('WARNING: ran out of memory, retrying batch') for p in self.model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() else: print(repr(e)) errorout() def validate(self): # Eval mode self.model.eval() dev_tts_loss, dev_per, dev_post_per, dev_stop_err = [], [], [], [] for i in range(len(self.dev_set)): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dev_set))) # Fetch data mel, aug_mel, linear, text, sid = self.fetch_data( iter_name='dev_iter') # Forward model with torch.no_grad(): # test ASR pair_prob, _, _, _, _, pair_post_prob, _ = self.model.speech_to_text( paired_mel=mel, unpaired_mel=None) dev_per.append(cal_per(pair_prob, text)) if pair_post_prob is not None: dev_post_per.append((cal_per(pair_post_prob, text))) # test TTS (Note: absolute dec step now) pair_mel_pred, pair_linear_pred, pair_align, _, _, _, _, _ = \ self.model.text_to_speech(paired_text = text, paired_sid=sid, unpaired_sid=None, unpaired_latent=None, unpaired_text=None, unpaired_latent_len=None, paired_teacher=mel.shape[1], unpaired_teacher=None, tf_rate=0.0) dev_tts_loss.append( self.freq_loss(pair_mel_pred, mel) + self.freq_loss(pair_linear_pred, linear)) if i == len(self.dev_set) // 2: # pick n longest samples in the median batch sample_txt = text.cpu()[:LISTEN_N_EXAMPLES] hyp = pair_prob.argmax(dim=-1).cpu()[:LISTEN_N_EXAMPLES] mel_p = pair_mel_pred.cpu()[:LISTEN_N_EXAMPLES] linear_p = pair_linear_pred.cpu()[:LISTEN_N_EXAMPLES] #post_mel_p = tts_pred.cpu()[:LISTEN_N_EXAMPLES,1] # PostNet product align_p = pair_align.cpu()[:LISTEN_N_EXAMPLES] sample_mel = mel.cpu()[:LISTEN_N_EXAMPLES] sample_linear = linear.cpu()[:LISTEN_N_EXAMPLES] # Ckpt if performance improves dev_tts_loss = sum(dev_tts_loss) / len(dev_tts_loss) dev_per = sum(dev_per) / len(dev_per) dev_post_per = sum(dev_post_per) / len(dev_post_per) if len( dev_post_per) > 0 else None #dev_stop_err = sum(dev_stop_err)/len(dev_stop_err) if self.paras.store_best_per: if dev_per < self.best_per: self.best_per = dev_per self.save_checkpoint('best_per.pth', dev_per) if (dev_post_per is not None) and (dev_post_per < self.best_per): self.best_per = dev_post_per self.save_checkpoint('best_post_per.pth', dev_post_per) else: if dev_tts_loss < self.best_tts_loss: self.best_tts_loss = dev_tts_loss if self.step > 1: self.save_checkpoint('tts_{}.pth'.format(self.step), dev_tts_loss) if dev_per < self.best_per: self.best_per = dev_per if self.step > 1: self.save_checkpoint('asr_{}.pth'.format(self.step), dev_per) if (dev_post_per is not None) and (dev_post_per < self.best_per): self.best_per = dev_post_per self.save_checkpoint( 'best_post_per.pth', dev_post_per ) # Note: didnot recode best per from postnet or not if ((self.step > 1) and (self.step % CKPT_STEP == 0)) and not self.paras.store_best_per: # Regular ckpt self.save_checkpoint('step_{}.pth'.format(self.step), dev_tts_loss) # Logger # Write model output (no G-F-lim if picking per) for i, (m_p, l_p, a_p, h_p) in enumerate(zip(mel_p, linear_p, align_p, hyp)): self.write_log('hyp_text{}'.format(i), self.tokenizer.decode(h_p.tolist())) self.write_log('mel_spec{}'.format(i), feat_to_fig(m_p)) self.write_log('linear_spec{}'.format(i), feat_to_fig(l_p)) self.write_log('dv_align{}'.format(i), feat_to_fig(a_p)) if not self.paras.store_best_per: self.write_log('mel_wave{}'.format(i), self.audio_converter.feat_to_wave(m_p)) self.write_log('linear_wave{}'.format(i), self.audio_converter.feat_to_wave(l_p)) # Write ground truth if self.step == 1: for i, (mel, linear, gt_txt) in enumerate( zip(sample_mel, sample_linear, sample_txt)): self.write_log('truth_text{}'.format(i), self.tokenizer.decode(gt_txt.tolist())) self.write_log('mel_spec{}_gt'.format(i), feat_to_fig(mel)) self.write_log('mel_wave{}_gt'.format(i), self.audio_converter.feat_to_wave(mel)) self.write_log('linear_spec{}_gt'.format(i), feat_to_fig(linear)) self.write_log('linear_wave{}_gt'.format(i), self.audio_converter.feat_to_wave(linear)) self.write_log('speech_loss', {'dev': dev_tts_loss}) self.write_log('per', {'dev': dev_per, 'dev_post': dev_post_per}) self.write_log('codebook', (self.model.codebook.embedding.weight.data, self.tokenizer._vocab_list)) #self.write_log('stop_err',{'dev':dev_stop_err}) # Resume training self.model.train() def compute_ctcloss(self, model_input, model_output, target, apply_log=True): if apply_log: ctc_input = (model_output + EPS).transpose(0, 1).log() else: ctc_input = model_output.transpose(0, 1) if self.paras.actual_len: asr_input_len = torch.sum( (model_input == SPEC_PAD_VALUE).long().sum(dim=-1) != model_input.shape[-1], dim=-1) ctc_len = asr_input_len // self.model.time_reduce_factor ctc_target = target else: ctc_target = target.to_sparse().values() ctc_len = torch.LongTensor( [model_output.shape[1]] * model_output.shape[0]).to(device=self.device) return self.ctc_loss(ctc_input, ctc_target, ctc_len, torch.sum(target != 0, dim=-1))
def __init__(self, reinforce_configs, annunciator_configs, src_vocab, trg_vocab, data_iterator, save_to, device="cpu", ): """ initiate translation environments, needs a Scorer and translator :param reinforce_configs: attack configures dictionary :param annunciator_configs: discriminator or scorer configs(provide survive signals) :param save_to: path to save the model :param data_iterator: use to provide data for environment initiate the directory of the src sentences :param device: (string) devices to allocate variables("cpu", "cuda:*") default as cpu """ # environment devices self.device = device self.data_iterator = data_iterator scorer_model_configs = annunciator_configs["scorer_model_configs"] # discriminator_model_configs = annunciator_configs["discriminator_model_configs"] annunciator_optim_configs = annunciator_configs["annunciator_optimizer_configs"] victim_config_path = reinforce_configs["victim_configs"] victim_model_path = reinforce_configs["victim_model"] with open(victim_config_path.strip()) as v_f: INFO("env open victim configs at %s" % victim_config_path) victim_configs = yaml.load(v_f, Loader=yaml.FullLoader) # to extract the embedding as representation # *vocab and *emb will provide psudo-reinforced embedding to train annunciator self.src_vocab = src_vocab self.trg_vocab = trg_vocab # translation model for BLEU(take src_embs as inputs) and corresponding embedding layers self.src_emb, self.trg_emb, self.translate_model = build_translate_model( victim_configs, victim_model_path, vocab_src=self.src_vocab, vocab_trg=self.trg_vocab, device=self.device) self.max_roll_out_step = victim_configs["data_configs"]["max_len"][0] self.src_emb.eval() # source language embeddings self.trg_emb.eval() # target language embeddings self.translate_model.eval() # the epsilon range used for action space when perturbation _, _, self.limit_dist = load_or_extract_near_vocab( config_path=victim_config_path, model_path=victim_model_path, init_perturb_rate=reinforce_configs["init_perturb_rate"], save_to=os.path.join(save_to, "near_vocab"), save_to_full=os.path.join(save_to, "full_near_vocab"), top_reserve=12, emit_as_id=True) ######################################################### # scorer(an Annunciator object) provides intrinsic step rewards self.annunciator = TransScorer( victim_configs, victim_model_path, self.trg_emb, **scorer_model_configs) self.annunciator.to(self.device) # # discriminator(an Annunciator object) provides intrisic step rewards and terminal signal # self.discriminator = TransDiscriminator( # victim_configs, victim_model_path, # **discriminator_model_configs) # self.discriminator.to(self.device) # Annunciator update configs self.acc_bound = annunciator_configs["acc_bound"] self.mse_bound = annunciator_configs["mse_bound"] self.min_update_steps = annunciator_configs["valid_freq"] self.max_update_steps = annunciator_configs["annunciator_update_steps"] # the optimizer and schedule used for Annunciator update. self.optim_A = Optimizer( name=annunciator_optim_configs["optimizer"], model=self.annunciator, lr=annunciator_optim_configs["learning_rate"], grad_clip=annunciator_optim_configs["grad_clip"], optim_args=annunciator_optim_configs["optimizer_params"]) self.scheduler_A = None # default as None if annunciator_optim_configs['schedule_method'] is not None: if annunciator_optim_configs['schedule_method'] == "loss": self.scheduler_A = ReduceOnPlateauScheduler(optimizer=self.optim_A, **annunciator_optim_configs["scheduler_configs"]) elif annunciator_optim_configs['schedule_method'] == "noam": self.scheduler_A = NoamScheduler(optimizer=self.optim_A, **annunciator_optim_configs['scheduler_configs']) elif annunciator_optim_configs["schedule_method"] == "rsqrt": self.scheduler_A = RsqrtScheduler(optimizer=self.optim_A, **annunciator_optim_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format( annunciator_optim_configs['schedule_method'])) self.criterion_A = nn.CrossEntropyLoss() ############################################################ self.adversarial = reinforce_configs["adversarial"] # adversarial or reinforce as learning objects self.r_s_weight = reinforce_configs["r_s_weight"] self.r_i_weight = reinforce_configs["r_i_weight"]
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Logger settings self.best_wer = {'att': 3.0, 'ctc': 3.0} # Curriculum learning affects data loader self.curriculum = self.config['hparas']['curriculum'] def fetch_data(self, data): ''' Move data to device and compute text seq. length''' _, feat, feat_len, txt = data feat = feat.to(self.device) feat_len = feat_len.to(self.device) txt = txt.to(self.device) txt_len = torch.sum(txt != 0, dim=-1) return feat, feat_len, txt, txt_len def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum > 0, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup ASR model and optimizer ''' # Model init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and (self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # ToDo: other training methods def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 self.timer.set() while self.step < self.max_step: # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b * t, -1), txt.contiguous().view(-1)) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('loss', { 'tr_ctc': ctc_loss, 'tr_att': att_loss }) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log( 'wer', { 'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True) }) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if self.step > self.max_step: break n_epochs += 1 self.log.close() def validate(self): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt)) dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, ctc=True)) # Show some example on tensorboard if i == len(self.dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log( 'att_align{}'.format(i), feat_to_fig(att_align[i, 0, :, :].cpu().detach())) self.write_log( 'att_text{}'.format(i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log( 'ctc_text{}'.format(i), self.tokenizer.decode( ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Ckpt if performance improves for task in ['att', 'ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) if dev_wer[task] < self.best_wer[task]: self.best_wer[task] = dev_wer[task] self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task]) self.write_log('wer', {'dv_' + task: dev_wer[task]}) self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False) # Resume training self.model.train() if self.emb_decoder is not None: self.emb_decoder.train()
def run(): # default actor threads as 1 os.environ["OMP_NUM_THREADS"] = "1" mp = _mp.get_context('spawn') args = parser.parse_args() if not os.path.exists(args.save_to): os.mkdir(args.save_to) with open(args.config_path, "r") as f, \ open(os.path.join(args.save_to, "current_attack_configs.yaml"), "w") as current_configs: configs = yaml.load(f) yaml.dump(configs, current_configs) attack_configs = configs["attack_configs"] attacker_configs = configs["attacker_configs"] attacker_model_configs = attacker_configs["attacker_model_configs"] attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"] discriminator_configs = configs["discriminator_configs"] # training_configs = configs["training_configs"] # initial best saver for global model global_saver = Saver( save_prefix="{0}.final".format(os.path.join(args.save_to, "ACmodel")), num_max_keeping=attack_configs["num_kept_checkpoints"]) # the Global variable of USE_GPU is mainly used for environments GlobalNames.SEED = attack_configs["seed"] GlobalNames.USE_GPU = args.use_gpu torch.manual_seed(GlobalNames.SEED) # build vocabulary and data iterator for env with open(attack_configs["victim_configs"], "r") as victim_f: victim_configs = yaml.load(victim_f) data_configs = victim_configs["data_configs"] src_vocab = Vocabulary(**data_configs["vocabularies"][0]) trg_vocab = Vocabulary(**data_configs["vocabularies"][1]) data_set = ZipDataset( TextLineDataset( data_path=data_configs["train_data"][0], vocabulary=src_vocab, ), TextLineDataset( data_path=data_configs["train_data"][1], vocabulary=trg_vocab, ), shuffle=attack_configs["shuffle"] ) # we build the parallel data sets and iterate inside a thread # global model variables (trg network to save the results) global_attacker = attacker.Attacker(src_vocab.max_n_words, **attacker_model_configs) global_attacker = global_attacker.cpu() global_attacker.share_memory() if args.share_optim: # initiate optimizer and set to share mode optimizer = Optimizer( name=attacker_optimizer_configs["optimizer"], model=global_attacker, lr=attacker_optimizer_configs["learning_rate"], grad_clip=attacker_optimizer_configs["grad_clip"], optim_args=attacker_optimizer_configs["optimizer_params"]) optimizer.optim.share_memory() # Build scheduler for optimizer if needed if attacker_optimizer_configs['schedule_method'] is not None: if attacker_optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) elif attacker_optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler( optimizer=optimizer, **attacker_optimizer_configs['scheduler_configs']) elif attacker_optimizer_configs["schedule_method"] == "rsqrt": scheduler = RsqrtScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.". format(attacker_optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None else: optimizer = None scheduler = None # load from checkpoint for global model global_saver.load_latest(model=global_attacker, optim=optimizer, lr_scheduler=scheduler) if args.use_gpu: # collect available devices and distribute env on the available gpu device = "cuda" devices = [] for i in range(torch.cuda.device_count()): devices += ["cuda:%d" % i] print("available gpus:", devices) else: device = "cpu" devices = [device] process = [] counter = mp.Value("i", 0) lock = mp.Lock() # for multiple attackers update INFO("extract near candidates") _, _ = load_or_extract_near_vocab( config_path=attack_configs["victim_configs"], model_path=attack_configs["victim_model"], init_perturb_rate=attack_configs["init_perturb_rate"], save_to=os.path.join(args.save_to, "near_vocab"), save_to_full=os.path.join(args.save_to, "full_near_vocab"), top_reserve=12, emit_as_id=True) # train(0, device, args, counter, lock, # attack_configs, discriminator_configs, # src_vocab, trg_vocab, data_set, # global_attacker, attacker_configs, # optimizer, scheduler, # global_saver) # valid(args.n, device, args, # attack_configs, discriminator_configs, # src_vocab, trg_vocab, data_set, # global_attacker, attacker_configs, counter) # run multiple training process of local attacker to update global one for rank in range(args.n): print("initialize training thread on cuda:%d" % (rank + 1)) p = mp.Process(target=train, args=(rank, "cuda:%d" % (rank + 1), args, counter, lock, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_set, global_attacker, attacker_configs, optimizer, scheduler, global_saver)) p.start() process.append(p) # run the dev thread for initiation print("initialize dev thread on cuda:0") p = mp.Process(target=valid, args=(0, "cuda:0", args, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_set, global_attacker, attacker_configs, counter)) p.start() process.append(p) for p in process: p.join()
def train(rank, device, args, counter, lock, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_set, global_attacker, attacker_configs, optimizer=None, scheduler=None, saver=None): """ running train process #1# train the env_discriminator #2# run attacker AC based on rewards from trained env_discriminator #3# run training updates attacker AC #4# :param rank: (int) the rank of the process (from multiprocess) :param device: the device of the process :param counter: python multiprocess variable :param lock: python multiprocess variable :param args: global args :param attack_configs: attack settings :param discriminator_configs: discriminator settings :param src_vocab: :param trg_vocab: :param data_set: (data_iterator object) provide batched data labels :param global_attacker: the model to sync from :param attacker_configs: local attacker settings :param optimizer: uses shared optimizer for the attacker use local one if none :param scheduler: uses shared scheduler for the attacker, use local one if none :param saver: model saver :return: """ trust_acc = acc_bound = discriminator_configs["acc_bound"] converged_bound = discriminator_configs["converged_bound"] patience = discriminator_configs["patience"] attacker_model_configs = attacker_configs["attacker_model_configs"] attacker_optimizer_configs = attacker_configs["attacker_optimizer_configs"] # this is for multi-processing, GlobalNames can not be direct inherited GlobalNames.USE_GPU = args.use_gpu GlobalNames.SEED = attack_configs["seed"] torch.manual_seed(GlobalNames.SEED + rank) # initiate local saver and load checkpoint if possible local_saver = Saver(save_prefix="{0}.local".format( os.path.join(args.save_to, "train_env%d" % rank, "ACmodel")), num_max_keeping=attack_configs["num_kept_checkpoints"]) attack_iterator = DataIterator(dataset=data_set, batch_size=attack_configs["batch_size"], use_bucket=True, buffer_size=attack_configs["buffer_size"], numbering=True) summary_writer = SummaryWriter( log_dir=os.path.join(args.save_to, "train_env%d" % rank)) local_attacker = attacker.Attacker(src_vocab.max_n_words, **attacker_model_configs) # build optimizer for attacker if optimizer is None: optimizer = Optimizer( name=attacker_optimizer_configs["optimizer"], model=global_attacker, lr=attacker_optimizer_configs["learning_rate"], grad_clip=attacker_optimizer_configs["grad_clip"], optim_args=attacker_optimizer_configs["optimizer_params"]) # Build scheduler for optimizer if needed if attacker_optimizer_configs['schedule_method'] is not None: if attacker_optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) elif attacker_optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler( optimizer=optimizer, **attacker_optimizer_configs['scheduler_configs']) elif attacker_optimizer_configs["schedule_method"] == "rsqrt": scheduler = RsqrtScheduler( optimizer=optimizer, **attacker_optimizer_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.". format(attacker_optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None local_saver.load_latest(model=local_attacker, optim=optimizer, lr_scheduler=scheduler) attacker_iterator = attack_iterator.build_generator() env = Translate_Env(attack_configs=attack_configs, discriminator_configs=discriminator_configs, src_vocab=src_vocab, trg_vocab=trg_vocab, data_iterator=attacker_iterator, save_to=args.save_to, device=device) episode_count = 0 episode_length = 0 local_steps = 0 # optimization steps: for learning rate schedules patience_t = patience while True: # infinite loop of data set # we will continue with a new iterator with refreshed environments # whenever the last iterator breaks with "StopIteration" attacker_iterator = attack_iterator.build_generator() env.reset_data_iter(attacker_iterator) padded_src = env.reset() padded_src = torch.from_numpy(padded_src) if device != "cpu": padded_src = padded_src.to(device) done = True discriminator_base_steps = local_steps while True: # check for update of discriminator # if env.acc_validation(local_attacker, use_gpu=True if env.device != "cpu" else False) < 0.55: if episode_count % attacker_configs["attacker_update_steps"] == 0: """ stop criterion: when updates a discriminator, we check for acc. If acc fails acc_bound, we reset the discriminator and try, until acc reaches the bound with patience. otherwise the training thread stops """ try: discriminator_base_steps, trust_acc = env.update_discriminator( local_attacker, discriminator_base_steps, min_update_steps=discriminator_configs[ "acc_valid_freq"], max_update_steps=discriminator_configs[ "discriminator_update_steps"], accuracy_bound=acc_bound, summary_writer=summary_writer) except StopIteration: INFO("finish one training epoch, reset data_iterator") break discriminator_base_steps += 1 # a flag to label the discriminator updates if trust_acc < converged_bound: # GAN target reached patience_t -= 1 INFO( "discriminator reached GAN convergence bound: %d times" % patience_t) else: # reset patience if discriminator is refreshed patience_t = patience if saver and local_steps % attack_configs["save_freq"] == 0: local_saver.save(global_step=local_steps, model=local_attacker, optim=optimizer, lr_scheduler=scheduler) if trust_acc < converged_bound: # and patience_t == patience-1: # we only save the global params reaching acc_bound torch.save(global_attacker.state_dict(), os.path.join(args.save_to, "ACmodel.final")) # saver.raw_save(model=global_attacker) if patience_t == 0: WARN("maximum patience reached. Training Thread should stop") break local_attacker.train() # switch back to training mode # for a initial (reset) attacker from global parameters if done: INFO("sync from global model") local_attacker.load_state_dict(global_attacker.state_dict()) # move the local attacker params back to device after updates local_attacker = local_attacker.to(device) values = [] # training critic: network outputs log_probs = [] rewards = [] # actual rewards entropies = [] local_steps += 1 # run sequences step of attack try: for i in range(args.action_roll_steps): episode_length += 1 attack_out, critic_out = local_attacker( padded_src, padded_src[:, env.index - 1:env.index + 2]) logit_attack_out = torch.log(attack_out) entropy = -(attack_out * logit_attack_out).sum(dim=-1).mean() summary_writer.add_scalar("action_entropy", scalar_value=entropy, global_step=local_steps) entropies.append(entropy) # for entropy loss actions = attack_out.multinomial(num_samples=1).detach() # only extract the log prob for chosen action (avg over batch) log_attack_out = logit_attack_out.gather(-1, actions).mean() padded_src, reward, terminal_signal = env.step( actions.squeeze()) done = terminal_signal or episode_length > args.max_episode_lengths with lock: counter.value += 1 if done: episode_length = 0 padded_src = env.reset() padded_src = torch.from_numpy(padded_src) if device != "cpu": padded_src = padded_src.to(device) values.append( critic_out.mean()) # list of torch variables (scalar) log_probs.append( log_attack_out) # list of torch variables (scalar) rewards.append(reward) # list of reward variables if done: episode_count += 1 break except StopIteration: INFO("finish one training epoch, reset data_iterator") break R = torch.zeros(1, 1) gae = torch.zeros(1, 1) if device != "cpu": R = R.to(device) gae = gae.to(device) if not done: # calculate value loss value = local_attacker.get_critic( padded_src, padded_src[:, env.index - 1:env.index + 2]) R = value.mean().detach() values.append(R) policy_loss = 0 value_loss = 0 # collect values for training for i in reversed((range(len(rewards)))): # value loss and policy loss must be clipped to stabilize training R = attack_configs["gamma"] * R + rewards[i] advantage = R - values[i] value_loss = value_loss + 0.5 * advantage.pow(2) delta_t = rewards[i] + attack_configs["gamma"] * \ values[i + 1] - values[i] gae = gae * attack_configs["gamma"] * attack_configs["tau"] + \ delta_t policy_loss = policy_loss - log_probs[i] * gae.detach() - \ attack_configs["entropy_coef"] * entropies[i] print("policy_loss", policy_loss) print("gae", gae) # update with optimizer optimizer.zero_grad() # we decay the loss according to discriminator's accuracy as a trust region constrain summary_writer.add_scalar("policy_loss", scalar_value=policy_loss * trust_acc, global_step=local_steps) summary_writer.add_scalar("value_loss", scalar_value=value_loss * trust_acc, global_step=local_steps) total_loss = trust_acc * policy_loss + \ trust_acc * attack_configs["value_coef"] * value_loss total_loss.backward() if attacker_optimizer_configs[ "schedule_method"] is not None and attacker_optimizer_configs[ "schedule_method"] != "loss": scheduler.step(global_step=local_steps) # move the model params to CPU and # assign local gradients to the global model to update local_attacker.to("cpu").ensure_shared_grads(global_attacker) optimizer.step() print("bingo!") if patience_t == 0: INFO("Reach maximum Discriminator patience, Finish") break
def tune(flags): """ flags: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # ================================================================================== # # Initialization for training on different devices # - CPU/GPU # - Single/Distributed Constants.USE_GPU = flags.use_gpu if flags.multi_gpu: dist.distributed_init(flags.shared_dir) world_size = dist.get_world_size() rank = dist.get_rank() local_rank = dist.get_local_rank() else: world_size = 1 rank = 0 local_rank = 0 if Constants.USE_GPU: torch.cuda.set_device(local_rank) Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank) else: Constants.CURRENT_DEVICE = "cpu" # If not root_rank, close logging # else write log of training to file. if rank == 0: write_log_to_file( os.path.join(flags.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) else: close_logging() # ================================================================================== # # Parsing configuration files # - Load default settings # - Load pre-defined settings # - Load user-defined settings configs = prepare_configs(flags.config_path, flags.predefined_config) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] INFO(pretty_configs(configs)) Constants.SEED = training_configs['seed'] set_seed(Constants.SEED) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0]) vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1]) Constants.EOS = vocab_src.eos Constants.PAD = vocab_src.pad Constants.BOS = vocab_src.bos # bt tag dataset train_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['train_data'][0], vocabulary=vocab_src, max_len=data_configs['max_len'][0], is_train_dataset=True), TextLineDataset(data_path=data_configs['train_data'][1], vocabulary=vocab_tgt, max_len=data_configs['max_len'][1], is_train_dataset=True)) training_iterator = DataIterator( dataset=train_bitext_dataset, batch_size=training_configs["batch_size"], use_bucket=training_configs['use_bucket'], buffer_size=training_configs['buffer_size'], batching_func=training_configs['batching_key'], world_size=world_size, rank=rank) INFO('Done. Elapsed time {0}'.format(timer.toc())) # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial lrate = optimizer_configs['learning_rate'] model_collections = Collections() checkpoint_saver = Saver( save_prefix="{0}.ckpt".format( os.path.join(flags.saveto, flags.model_name)), num_max_keeping=training_configs['num_kept_checkpoints']) best_model_prefix = os.path.join( flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX) best_model_saver = Saver( save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model']) # 1. Build Model & Criterion INFO('Building model...') timer.tic() nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, padding_idx=vocab_src.pad, vocab_src=vocab_src, vocab_tgt=vocab_tgt, **model_configs) INFO(nmt_model) critic = NMTCriterion(label_smoothing=model_configs['label_smoothing'], padding_idx=vocab_tgt.pad) INFO(critic) # 2. Move to GPU if Constants.USE_GPU: nmt_model = nmt_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed load_pretrained_model(nmt_model, flags.pretrain_path, exclude_prefix=flags.pretrain_exclude_prefix, device=Constants.CURRENT_DEVICE) # froze_parameters froze_params(nmt_model, flags.froze_config) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 4. Build optimizer INFO('Building Optimizer...') if not flags.multi_gpu: optim = Optimizer(name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'], update_cycle=training_configs['update_cycle']) else: optim = dist.DistributedOptimizer( name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'], device_id=local_rank) # 5. Build scheduler for optimizer if needed scheduler = build_scheduler( schedule_method=optimizer_configs['schedule_method'], optimizer=optim, scheduler_configs=optimizer_configs['scheduler_configs']) # 6. build moving average if training_configs['moving_average_method'] is not None: ma = MovingAverage( moving_average_method=training_configs['moving_average_method'], named_params=nmt_model.named_parameters(), alpha=training_configs['moving_average_alpha']) else: ma = None INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if flags.reload: checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma, device=Constants.CURRENT_DEVICE) # broadcast parameters and optimizer states if world_size > 1: INFO("Broadcasting model parameters...") dist.broadcast_parameters(params=nmt_model.state_dict()) INFO("Broadcasting optimizer states...") dist.broadcast_optimizer_state(optimizer=optim.optim) INFO('Done.') # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [1])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] oom_count = model_collections.get_collection("oom_count", [0])[-1] is_early_stop = model_collections.get_collection("is_early_stop", [ False, ])[-1] train_loss_meter = AverageMeter() sent_per_sec_meter = TimeMeter() tok_per_sec_meter = TimeMeter() update_cycle = training_configs['update_cycle'] grad_denom = 0 train_loss = 0.0 cum_n_words = 0 valid_loss = best_valid_loss = float('inf') if rank == 0: summary_writer = SummaryWriter(log_dir=flags.log_path) else: summary_writer = None sent_per_sec_meter.start() tok_per_sec_meter.start() INFO('Begin training...') while True: if summary_writer is not None: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() if rank == 0: training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format( eidx, uidx), total=len(training_iterator), unit="sents") else: training_progress_bar = None # INFO(Constants.USE_BT) for batch in training_iter: # bt attrib data seqs_x, seqs_y = batch batch_size = len(seqs_x) cum_n_words += sum(len(s) for s in seqs_y) try: # Prepare data x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU) loss = compute_forward( model=nmt_model, critic=critic, seqs_x=x, seqs_y=y, eval=False, normalization=1.0, norm_by_words=training_configs["norm_by_words"]) update_cycle -= 1 grad_denom += batch_size train_loss += loss except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') oom_count += 1 else: raise e # When update_cycle becomes 0, it means end of one batch. Several things will be done: # - update parameters # - reset update_cycle and grad_denom, update uidx # - learning rate scheduling # - update moving average if update_cycle == 0: # 0. reduce variables if world_size > 1: grad_denom = dist.all_reduce_py(grad_denom) train_loss = dist.all_reduce_py(train_loss) cum_n_words = dist.all_reduce_py(cum_n_words) # 1. update parameters optim.step(denom=grad_denom) optim.zero_grad() if training_progress_bar is not None: training_progress_bar.update(grad_denom) training_progress_bar.set_description( ' - (Epc {}, Upd {}) '.format(eidx, uidx)) postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format( train_loss, valid_loss, best_valid_loss) training_progress_bar.set_postfix_str(postfix_str) # 2. learning rate scheduling if scheduler is not None and optimizer_configs[ "schedule_method"] != "loss": scheduler.step(global_step=uidx) # 3. update moving average if ma is not None and eidx >= training_configs[ 'moving_average_start_epoch']: ma.step() # 4. update meters train_loss_meter.update(train_loss, grad_denom) sent_per_sec_meter.update(grad_denom) tok_per_sec_meter.update(cum_n_words) # 5. reset accumulated variables, update uidx update_cycle = training_configs['update_cycle'] grad_denom = 0 uidx += 1 cum_n_words = 0.0 train_loss = 0.0 else: continue # ================================================================================== # # Display some information if should_trigger_by_steps( uidx, eidx, every_n_step=training_configs['disp_freq']): lrate = list(optim.get_lrate())[0] if summary_writer is not None: summary_writer.add_scalar( "Speed(sents/sec)", scalar_value=sent_per_sec_meter.ave, global_step=uidx) summary_writer.add_scalar( "Speed(words/sec)", scalar_value=tok_per_sec_meter.ave, global_step=uidx) summary_writer.add_scalar( "train_loss", scalar_value=train_loss_meter.ave, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx) # Reset Meters sent_per_sec_meter.reset() tok_per_sec_meter.reset() train_loss_meter.reset() # ================================================================================== # # Saving checkpoints # if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=flags.debug): # model_collections.add_to_collection("uidx", uidx) # model_collections.add_to_collection("eidx", eidx) # model_collections.add_to_collection("bad_count", bad_count) # # if not is_early_stop: # if rank == 0: # checkpoint_saver.save(global_step=uidx, # model=nmt_model, # optim=optim, # lr_scheduler=scheduler, # collections=model_collections, # ma=ma) torch.save(nmt_model.state_dict(), best_model_prefix + ".final") if training_progress_bar is not None: training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break
def train(FLAGS): """ FLAGS: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # ================================================================================== # # Initialization for training on different devices # - CPU/GPU # - Single/Distributed GlobalNames.USE_GPU = FLAGS.use_gpu if FLAGS.multi_gpu: if hvd is None or distributed is None: ERROR("Distributed training is disable. Please check the installation of Horovod.") hvd.init() world_size = hvd.size() rank = hvd.rank() local_rank = hvd.local_rank() else: world_size = 1 rank = 0 local_rank = 0 if GlobalNames.USE_GPU: torch.cuda.set_device(local_rank) CURRENT_DEVICE = "cuda:{0}".format(local_rank) else: CURRENT_DEVICE = "cpu" # If not root_rank, close logging if rank != 0: close_logging() # write log of training to file. if rank == 0: write_log_to_file(os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) # ================================================================================== # # Parsing configuration files config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) INFO(pretty_configs(configs)) # Add default configs configs = default_baseline_configs(configs) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] GlobalNames.SEED = training_configs['seed'] set_seed(GlobalNames.SEED) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary(**data_configs["vocabularies"][0]) vocab_tgt = Vocabulary(**data_configs["vocabularies"][1]) actual_buffer_size = training_configs["buffer_size"] * max(1, training_configs["update_cycle"]) train_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['train_data'][0], vocabulary=vocab_src, max_len=data_configs['max_len'][0], ), TextLineDataset(data_path=data_configs['train_data'][1], vocabulary=vocab_tgt, max_len=data_configs['max_len'][1], ) ) valid_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['valid_data'][0], vocabulary=vocab_src, ), TextLineDataset(data_path=data_configs['valid_data'][1], vocabulary=vocab_tgt, ) ) training_iterator = DataIterator(dataset=train_bitext_dataset, batch_size=training_configs["batch_size"], use_bucket=training_configs['use_bucket'], buffer_size=actual_buffer_size, batching_func=training_configs['batching_key'], world_size=world_size, rank=rank) valid_iterator = DataIterator(dataset=valid_bitext_dataset, batch_size=training_configs['valid_batch_size'], use_bucket=True, buffer_size=100000, numbering=True, world_size=world_size, rank=rank) bleu_scorer = SacreBLEUScorer(reference_path=data_configs["bleu_valid_reference"], num_refs=data_configs["num_refs"], lang_pair=data_configs["lang_pair"], sacrebleu_args=training_configs["bleu_valid_configs"]['sacrebleu_args'], postprocess=training_configs["bleu_valid_configs"]['postprocess'] ) INFO('Done. Elapsed time {0}'.format(timer.toc())) lrate = optimizer_configs['learning_rate'] is_early_stop = False # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial model_collections = Collections() best_model_prefix = os.path.join(FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX) checkpoint_saver = Saver(save_prefix="{0}.ckpt".format(os.path.join(FLAGS.saveto, FLAGS.model_name)), num_max_keeping=training_configs['num_kept_checkpoints'] ) best_model_saver = Saver(save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model']) INFO('Building model...') timer.tic() nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) INFO(nmt_model) critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 2. Move to GPU if GlobalNames.USE_GPU: nmt_model = nmt_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed load_pretrained_model(nmt_model, FLAGS.pretrain_path, exclude_prefix=None, device=CURRENT_DEVICE) # 4. Build optimizer INFO('Building Optimizer...') optim = Optimizer(name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'], distributed=True if world_size > 1 else False, update_cycle=training_configs['update_cycle'] ) # 5. Build scheduler for optimizer if needed if optimizer_configs['schedule_method'] is not None: if optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler(optimizer=optim, **optimizer_configs["scheduler_configs"] ) elif optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs']) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format(optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None # 6. build moving average if training_configs['moving_average_method'] is not None: ma = MovingAverage(moving_average_method=training_configs['moving_average_method'], named_params=nmt_model.named_parameters(), alpha=training_configs['moving_average_alpha']) else: ma = None INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if FLAGS.reload: checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) # broadcast parameters and optimizer states if world_size > 1: hvd.broadcast_parameters(params=nmt_model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer=optim.optim, root_rank=0) # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [1])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] oom_count = model_collections.get_collection("oom_count", [0])[-1] cum_n_samples = 0 cum_n_words = 0 best_valid_loss = 1.0 * 1e10 # Max Float update_cycle = training_configs['update_cycle'] grad_denom = 0 if rank == 0: summary_writer = SummaryWriter(log_dir=FLAGS.log_path) else: summary_writer = None # Timer for computing speed timer_for_speed = Timer() timer_for_speed.tic() INFO('Begin training...') while True: if summary_writer is not None: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() if rank == 0: training_progress_bar = tqdm(desc=' - (Epoch %d) ' % eidx, total=len(training_iterator), unit="sents" ) else: training_progress_bar = None for batch in training_iter: seqs_x, seqs_y = batch batch_size = len(seqs_x) cum_n_samples += batch_size cum_n_words += sum(len(s) for s in seqs_y) try: # Prepare data x, y = prepare_data(seqs_x, seqs_y, cuda=GlobalNames.USE_GPU) loss = compute_forward(model=nmt_model, critic=critic, seqs_x=x, seqs_y=y, eval=False, normalization=1.0, norm_by_words=training_configs["norm_by_words"]) update_cycle -= 1 grad_denom += batch_size except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') oom_count += 1 else: raise e # When update_cycle becomes 0, it means end of one batch. Several things will be done: # - update parameters # - reset update_cycle and grad_denom # - update uidx # - update moving average if update_cycle == 0: if world_size > 1: grad_denom = distributed.all_reduce(grad_denom) optim.step(denom=grad_denom) optim.zero_grad() if training_progress_bar is not None: training_progress_bar.update(grad_denom) update_cycle = training_configs['update_cycle'] grad_denom = 0 uidx += 1 if scheduler is None: pass elif optimizer_configs["schedule_method"] == "loss": scheduler.step(metric=best_valid_loss) else: scheduler.step(global_step=uidx) if ma is not None and eidx >= training_configs['moving_average_start_epoch']: ma.step() else: continue # ================================================================================== # # Display some information if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['disp_freq']): if world_size > 1: cum_n_words = sum(distributed.all_gather(cum_n_words)) cum_n_samples = sum(distributed.all_gather(cum_n_samples)) # words per second and sents per second words_per_sec = cum_n_words / (timer.toc(return_seconds=True)) sents_per_sec = cum_n_samples / (timer.toc(return_seconds=True)) lrate = list(optim.get_lrate())[0] if summary_writer is not None: summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx) summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx) # Reset timer timer.tic() cum_n_words = 0 cum_n_samples = 0 # ================================================================================== # # Loss Validation & Learning rate annealing if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'], debug=FLAGS.debug): valid_loss = loss_validation(model=nmt_model, critic=critic, valid_iterator=valid_iterator, rank=rank, world_size=world_size ) model_collections.add_to_collection("history_losses", valid_loss) min_history_loss = np.array(model_collections.get_collection("history_losses")).min() best_valid_loss = min_history_loss if summary_writer is not None: summary_writer.add_scalar("loss", valid_loss, global_step=uidx) summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx) # ================================================================================== # # BLEU Validation & Early Stop if should_trigger_by_steps(global_step=uidx, n_epoch=eidx, every_n_step=training_configs['bleu_valid_freq'], min_step=training_configs['bleu_valid_warmup'], debug=FLAGS.debug): valid_bleu = bleu_validation(uidx=uidx, valid_iterator=valid_iterator, batch_size=training_configs["bleu_valid_batch_size"], model=nmt_model, bleu_scorer=bleu_scorer, vocab_tgt=vocab_tgt, valid_dir=FLAGS.valid_path, max_steps=training_configs["bleu_valid_configs"]["max_steps"], beam_size=training_configs["bleu_valid_configs"]["beam_size"], alpha=training_configs["bleu_valid_configs"]["alpha"], world_size=world_size, rank=rank, ) model_collections.add_to_collection(key="history_bleus", value=valid_bleu) best_valid_bleu = float(np.array(model_collections.get_collection("history_bleus")).max()) if summary_writer is not None: summary_writer.add_scalar("bleu", valid_bleu, uidx) summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx) # If model get new best valid bleu score if valid_bleu >= best_valid_bleu: bad_count = 0 if is_early_stop is False: if rank == 0: # 1. save the best model torch.save(nmt_model.state_dict(), best_model_prefix + ".final") # 2. record all several best models best_model_saver.save(global_step=uidx, model=nmt_model, ma=ma) else: bad_count += 1 # At least one epoch should be traversed if bad_count >= training_configs['early_stop_patience'] and eidx > 0: is_early_stop = True WARN("Early Stop!") if summary_writer is not None: summary_writer.add_scalar("bad_count", bad_count, uidx) INFO("{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4}".format( uidx, valid_loss, valid_bleu, lrate, bad_count )) # ================================================================================== # # Saving checkpoints if should_trigger_by_steps(uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug): model_collections.add_to_collection("uidx", uidx) model_collections.add_to_collection("eidx", eidx) model_collections.add_to_collection("bad_count", bad_count) if not is_early_stop: if rank == 0: checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) if training_progress_bar is not None: training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # ToDo : support tr/eval on different corpus assert self.config['data']['corpus']['name'] == self.src_config['data']['corpus']['name'] self.config['data']['corpus']['path'] = self.src_config['data']['corpus']['path'] self.config['data']['corpus']['bucketing'] = False # The follow attribute should be identical to training config self.config['data']['audio'] = self.src_config['data']['audio'] self.config['data']['corpus']['train_split'] = self.src_config['data']['corpus']['train_split'] self.config['data']['text'] = self.src_config['data']['text'] self.tokenizer = load_text_encoder(**self.config['data']['text']) self.config['model'] = self.src_config['model'] self.finetune_first = 5 self.best_wer = {'att': 3.0, 'ctc': 3.0} # Output file self.output_file = str(self.ckpdir)+'_{}_{}.csv' # Override batch size for beam decoding self.greedy = self.config['decode']['beam_size'] == 1 self.dealer = Datadealer(self.config['data']['audio']) self.ctc = self.config['decode']['ctc_weight'] == 1.0 if not self.greedy: self.config['data']['corpus']['batch_size'] = 1 else: # ToDo : implement greedy raise NotImplementedError # Logger settings self.logdir = os.path.join(paras.logdir, self.exp_name) self.log = SummaryWriter( self.logdir, flush_secs=self.TB_FLUSH_FREQ) self.timer = Timer() def fetch_data(self, data): ''' Move data to device and compute text seq. length''' _, feat, feat_len, txt = data feat = feat.to(self.device) feat_len = feat_len.to(self.device) txt = txt.to(self.device) txt_len = torch.sum(txt != 0, dim=-1) return feat, feat_len, txt, txt_len def load_data(self, batch_size=7): ''' Load data for training/validation, store tokenizer and input/output shape''' prev_batch_size = self.config['data']['corpus']['batch_size'] self.config['data']['corpus']['batch_size'] = batch_size self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) self.config['data']['corpus']['batch_size'] = prev_batch_size self.verbose(msg) def set_model(self): ''' Setup ASR model ''' # Model self.feat_dim = 120 self.vocab_size = 46 init_adadelta = True ''' Setup ASR model and optimizer ''' # Model # init_adadelta = self.config['hparas']['optimizer'] == 'Adadelta' self.model = ASR(self.feat_dim, self.vocab_size, init_adadelta, ** self.src_config['model']).to(self.device) self.verbose(self.model.create_msg()) if self.finetune_first>0: names = ["encoder.layers.%d"%i for i in range(self.finetune_first)] model_paras = [{"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in names)]}] else: model_paras = [{'params': self.model.parameters()}] # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Note: zero_infinity=False is unstable? self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and ( self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.src_config['hparas']) self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() # Beam decoder self.decoder = BeamDecoder( self.model, self.emb_decoder, **self.config['decode']) self.verbose(self.decoder.create_msg()) # del self.model # del self.emb_decoder self.decoder.to(self.device) def exec(self): ''' Testing End-to-end ASR system ''' while True: try: filename = input("Input wav file name: ") if filename == "exit": return feat, feat_len = self.dealer(filename) feat = feat.to(self.device) feat_len = feat_len.to(self.device) # Decode with torch.no_grad(): hyps = self.decoder(feat, feat_len) hyp_seqs = [hyp.outIndex for hyp in hyps] hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs] for txt in hyp_txts: print(txt) except: print("Invalid file") pass def recognize(self, filename): try: feat, feat_len = self.dealer(filename) feat = feat.to(self.device) feat_len = feat_len.to(self.device) # Decode with torch.no_grad(): hyps = self.decoder(feat, feat_len) hyp_seqs = [hyp.outIndex for hyp in hyps] hyp_txts = [self.tokenizer.decode(hyp, ignore_repeat=self.ctc) for hyp in hyp_seqs] return hyp_txts[0] except Exception as e: print(e) app.logger.debug(e) return "Invalid file" def fetch_finetune_data(self, filename, fixed_text): feat, feat_len = self.dealer(filename) feat = feat.to(self.device) feat_len = feat_len.to(self.device) text = self.tokenizer.encode(fixed_text) text = torch.tensor(text).to(self.device) text_len = len(text) return [feat, feat_len, text, text_len] def merge_batch(self, main_batch, attach_batch): max_feat_len = max(main_batch[1]) max_text_len = max(main_batch[3]) if attach_batch[0].shape[1] > max_feat_len: # reduce extra long example attach_batch[0] = attach_batch[0][:,:max_feat_len] attach_batch[1][0] = max_feat_len else: # pad to max_feat_len padding = torch.zeros(1, max_feat_len - attach_batch[0].shape[1], attach_batch[0].shape[2], dtype=attach_batch[0].dtype).to(self.device) attach_batch[0] = torch.cat([attach_batch[0], padding], dim=1) if attach_batch[2].shape[0] > max_text_len: attach_batch[2] = attach_batch[2][:max_text_len] main_batch[3][0] = max_text_len else: padding = torch.zeros(max_text_len - attach_batch[2].shape[0], dtype=attach_batch[2].dtype).to(self.device) try: attach_batch[2] = torch.cat([attach_batch[2], padding], dim=0).unsqueeze(0) except: pdb.set_trace() new_batch = ( torch.cat([main_batch[0], attach_batch[0]], dim=0), torch.cat([main_batch[1], attach_batch[1]], dim=0), torch.cat([main_batch[2], attach_batch[2]], dim=0), torch.cat([main_batch[3], torch.tensor([attach_batch[3]]).to(self.device)], dim=0) ) return new_batch def finetune(self, filename, fixed_text, max_step=5): # Load data for finetune self.verbose('Total training steps {}.'.format( human_format(max_step))) ctc_loss, att_loss, emb_loss = None, None, None n_epochs = 0 accum_count = 0 self.timer.set() step = 0 for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad if max_step == 0: break tf_rate = self.optimizer.pre_step(400000) total_loss = 0 # Fetch data finetune_data = self.fetch_finetune_data(filename, fixed_text) main_batch = self.fetch_data(data) new_batch = self.merge_batch(main_batch, finetune_data) feat, feat_len, txt, txt_len = new_batch self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder( dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight*emb_loss # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose( 0, 1), txt, encode_len, txt_len) total_loss += ctc_loss*self.model.ctc_weight if att_output is not None: b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss( att_output.contiguous().view(b*t, -1), txt.contiguous().view(-1)) total_loss += att_loss*(1-self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) step += 1 # Logger self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(total_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log( 'loss', {'tr_ctc': ctc_loss, 'tr_att': att_loss}) self.write_log('emb_loss', {'tr': emb_loss}) self.write_log('wer', {'tr_att': cal_er(self.tokenizer, att_output, txt), 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, ctc=True)}) if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log('fuse_lambda', { 'emb': self.emb_decoder.get_weight()}) self.write_log( 'fuse_temp', {'temp': self.emb_decoder.get_temp()}) # End of step # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 torch.cuda.empty_cache() self.timer.set() if step > max_step: break ret = self.validate() self.log.close() return ret def validate(self): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i+1, len(self.dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model(feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) dev_wer['att'].append(cal_er(self.tokenizer, att_output, txt)) dev_wer['ctc'].append(cal_er(self.tokenizer, ctc_output, txt, ctc=True)) # Show some example on tensorboard if i == len(self.dv_set)//2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if True: self.write_log('true_text{}'.format( i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log('att_align{}'.format(i), feat_to_fig( att_align[i, 0, :, :].cpu().detach())) self.write_log('att_text{}'.format(i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log('ctc_text{}'.format(i), self.tokenizer.decode(ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Skip save model here # Ckpt if performance improves to_prints = [] for task in ['att', 'ctc']: dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) if dev_wer[task] < self.best_wer[task]: to_print = f"WER of {task}: {dev_wer[task]} < prev best ({self.best_wer[task]})" self.best_wer[task] = dev_wer[task] else: to_print = f"WER of {task}: {dev_wer[task]} >= prev best ({self.best_wer[task]})" print(to_print, flush=True) to_prints.append(to_print) # self.save_checkpoint('best_{}.pth'.format(task), 'wer', dev_wer[task]) self.write_log('wer', {'dv_'+task: dev_wer[task]}) # self.save_checkpoint('latest.pth', 'wer', dev_wer['att'], show_msg=False) # Resume training self.model.train() if self.emb_decoder is not None: self.emb_decoder.train() return '\n'.join(to_prints)
def train(flags): """ flags: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # ================================================================================== # # Initialization for training on different devices # - CPU/GPU # - Single/Distributed Constants.USE_GPU = flags.use_gpu if flags.multi_gpu: dist.distributed_init(flags.shared_dir) world_size = dist.get_world_size() rank = dist.get_rank() local_rank = dist.get_local_rank() else: world_size = 1 rank = 0 local_rank = 0 if Constants.USE_GPU: torch.cuda.set_device(local_rank) Constants.CURRENT_DEVICE = "cuda:{0}".format(local_rank) else: Constants.CURRENT_DEVICE = "cpu" # If not root_rank, close logging # else write log of training to file. if rank == 0: write_log_to_file( os.path.join(flags.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) else: close_logging() # ================================================================================== # # Parsing configuration files # - Load default settings # - Load pre-defined settings # - Load user-defined settings configs = prepare_configs(flags.config_path, flags.predefined_config) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] INFO(pretty_configs(configs)) # use odc if training_configs['use_odc'] is True: ave_best_k = check_odc_config(training_configs) else: ave_best_k = 0 Constants.SEED = training_configs['seed'] set_seed(Constants.SEED) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_src = Vocabulary.build_from_file(**data_configs['vocabularies'][0]) vocab_tgt = Vocabulary.build_from_file(**data_configs['vocabularies'][1]) Constants.EOS = vocab_src.eos Constants.PAD = vocab_src.pad Constants.BOS = vocab_src.bos train_bitext_dataset = ZipDataset( TextLineDataset(data_path=data_configs['train_data'][0], vocabulary=vocab_src, max_len=data_configs['max_len'][0], is_train_dataset=True), TextLineDataset(data_path=data_configs['train_data'][1], vocabulary=vocab_tgt, max_len=data_configs['max_len'][1], is_train_dataset=True)) valid_bitext_dataset = ZipDataset( TextLineDataset( data_path=data_configs['valid_data'][0], vocabulary=vocab_src, is_train_dataset=False, ), TextLineDataset(data_path=data_configs['valid_data'][1], vocabulary=vocab_tgt, is_train_dataset=False)) training_iterator = DataIterator( dataset=train_bitext_dataset, batch_size=training_configs["batch_size"], use_bucket=training_configs['use_bucket'], buffer_size=training_configs['buffer_size'], batching_func=training_configs['batching_key'], world_size=world_size, rank=rank) valid_iterator = DataIterator( dataset=valid_bitext_dataset, batch_size=training_configs['valid_batch_size'], use_bucket=True, buffer_size=100000, numbering=True, world_size=world_size, rank=rank) bleu_scorer = SacreBLEUScorer( reference_path=data_configs["bleu_valid_reference"], num_refs=data_configs["num_refs"], lang_pair=data_configs["lang_pair"], sacrebleu_args=training_configs["bleu_valid_configs"] ['sacrebleu_args'], postprocess=training_configs["bleu_valid_configs"]['postprocess']) INFO('Done. Elapsed time {0}'.format(timer.toc())) # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial lrate = optimizer_configs['learning_rate'] model_collections = Collections() checkpoint_saver = Saver( save_prefix="{0}.ckpt".format( os.path.join(flags.saveto, flags.model_name)), num_max_keeping=training_configs['num_kept_checkpoints']) best_model_prefix = os.path.join( flags.saveto, flags.model_name + Constants.MY_BEST_MODEL_SUFFIX) best_k_saver = BestKSaver( save_prefix="{0}.best_k_ckpt".format( os.path.join(flags.saveto, flags.model_name)), num_max_keeping=training_configs['num_kept_best_k_checkpoints']) # 1. Build Model & Criterion INFO('Building model...') timer.tic() nmt_model = build_model(n_src_vocab=vocab_src.max_n_words, n_tgt_vocab=vocab_tgt.max_n_words, padding_idx=vocab_src.pad, vocab_src=vocab_src, **model_configs) INFO(nmt_model) # build teacher model teacher_model, teacher_model_path = get_teacher_model( training_configs, model_configs, vocab_src, vocab_tgt, flags) # build critic critic = CombinationCriterion(model_configs['loss_configs'], padding_idx=vocab_tgt.pad, teacher=teacher_model) # INFO(critic) critic.INFO() # 2. Move to GPU if Constants.USE_GPU: nmt_model = nmt_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed load_pretrained_model(nmt_model, flags.pretrain_path, exclude_prefix=None, device=Constants.CURRENT_DEVICE) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 4. Build optimizer INFO('Building Optimizer...') if not flags.multi_gpu: optim = Optimizer(name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'], update_cycle=training_configs['update_cycle']) else: optim = dist.DistributedOptimizer( name=optimizer_configs['optimizer'], model=nmt_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params'], device_id=local_rank) # 5. Build scheduler for optimizer if needed scheduler = build_scheduler( schedule_method=optimizer_configs['schedule_method'], optimizer=optim, scheduler_configs=optimizer_configs['scheduler_configs']) # 6. build moving average ma = build_ma(training_configs, nmt_model.named_parameters()) INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if flags.reload: checkpoint_saver.load_latest(model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma, device=Constants.CURRENT_DEVICE) # broadcast parameters and optimizer states if world_size > 1: INFO("Broadcasting model parameters...") dist.broadcast_parameters(params=nmt_model.state_dict()) INFO("Broadcasting optimizer states...") dist.broadcast_optimizer_state(optimizer=optim.optim) INFO('Done.') # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [1])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] oom_count = model_collections.get_collection("oom_count", [0])[-1] is_early_stop = model_collections.get_collection("is_early_stop", [ False, ])[-1] teacher_patience = model_collections.get_collection( "teacher_patience", [training_configs['teacher_patience']])[-1] train_loss_meter = AverageMeter() train_loss_dict_meter = AverageMeterDict(critic.get_critic_name()) sent_per_sec_meter = TimeMeter() tok_per_sec_meter = TimeMeter() update_cycle = training_configs['update_cycle'] grad_denom = 0 train_loss = 0.0 cum_n_words = 0 train_loss_dict = dict() valid_loss = best_valid_loss = float('inf') if rank == 0: summary_writer = SummaryWriter(log_dir=flags.log_path) else: summary_writer = None sent_per_sec_meter.start() tok_per_sec_meter.start() INFO('Begin training...') while True: if summary_writer is not None: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() if rank == 0: training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format( eidx, uidx), total=len(training_iterator), unit="sents") else: training_progress_bar = None for batch in training_iter: seqs_x, seqs_y = batch batch_size = len(seqs_x) cum_n_words += sum(len(s) for s in seqs_y) try: # Prepare data x, y = prepare_data(seqs_x, seqs_y, cuda=Constants.USE_GPU) loss, loss_dict = compute_forward( model=nmt_model, critic=critic, seqs_x=x, seqs_y=y, eval=False, normalization=1.0, norm_by_words=training_configs["norm_by_words"]) update_cycle -= 1 grad_denom += batch_size train_loss += loss train_loss_dict = add_dict_value(train_loss_dict, loss_dict) except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') oom_count += 1 else: raise e # When update_cycle becomes 0, it means end of one batch. Several things will be done: # - update parameters # - reset update_cycle and grad_denom, update uidx # - learning rate scheduling # - update moving average if update_cycle == 0: # 0. reduce variables if world_size > 1: grad_denom = dist.all_reduce_py(grad_denom) train_loss = dist.all_reduce_py(train_loss) train_loss_dict = dist.all_reduce_py(train_loss_dict) cum_n_words = dist.all_reduce_py(cum_n_words) # 1. update parameters optim.step(denom=grad_denom) optim.zero_grad() if training_progress_bar is not None: training_progress_bar.update(grad_denom) training_progress_bar.set_description( ' - (Epc {}, Upd {}) '.format(eidx, uidx)) postfix_str = 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f}), '.format( train_loss, valid_loss, best_valid_loss) for critic_name, loss_value in train_loss_dict.items(): postfix_str += (critic_name + ': {:.2f}, ').format(loss_value) training_progress_bar.set_postfix_str(postfix_str) # 2. learning rate scheduling if scheduler is not None and optimizer_configs[ "schedule_method"] != "loss": scheduler.step(global_step=uidx) # 3. update moving average if ma is not None and eidx >= training_configs[ 'moving_average_start_epoch']: ma.step() # 4. update meters train_loss_meter.update(train_loss, grad_denom) train_loss_dict_meter.update(train_loss_dict, grad_denom) sent_per_sec_meter.update(grad_denom) tok_per_sec_meter.update(cum_n_words) # 5. reset accumulated variables, update uidx update_cycle = training_configs['update_cycle'] grad_denom = 0 uidx += 1 cum_n_words = 0.0 train_loss = 0.0 train_loss_dict = dict() else: continue # ================================================================================== # # Display some information if should_trigger_by_steps( uidx, eidx, every_n_step=training_configs['disp_freq']): lrate = list(optim.get_lrate())[0] if summary_writer is not None: summary_writer.add_scalar( "Speed(sents/sec)", scalar_value=sent_per_sec_meter.ave, global_step=uidx) summary_writer.add_scalar( "Speed(words/sec)", scalar_value=tok_per_sec_meter.ave, global_step=uidx) summary_writer.add_scalar( "train_loss", scalar_value=train_loss_meter.ave, global_step=uidx) # add loss for every critic if flags.display_loss_detail: combination_loss = train_loss_dict_meter.value for key, value in combination_loss.items(): summary_writer.add_scalar(key, scalar_value=value, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx) # Reset Meters sent_per_sec_meter.reset() tok_per_sec_meter.reset() train_loss_meter.reset() train_loss_dict_meter.reset() # ================================================================================== # # Loss Validation & Learning rate annealing if should_trigger_by_steps( global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'], debug=flags.debug): with cache_parameters(nmt_model): valid_loss, valid_loss_dict = loss_evaluation( model=nmt_model, critic=critic, valid_iterator=valid_iterator, rank=rank, world_size=world_size) if scheduler is not None and optimizer_configs[ "schedule_method"] == "loss": scheduler.step(metric=valid_loss) model_collections.add_to_collection("history_losses", valid_loss) min_history_loss = np.array( model_collections.get_collection("history_losses")).min() best_valid_loss = min_history_loss if summary_writer is not None: summary_writer.add_scalar("loss", valid_loss, global_step=uidx) summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx) # ================================================================================== # # BLEU Validation & Early Stop if should_trigger_by_steps( global_step=uidx, n_epoch=eidx, every_n_step=training_configs['bleu_valid_freq'], min_step=training_configs['bleu_valid_warmup'], debug=flags.debug): with cache_parameters(nmt_model): valid_bleu = bleu_evaluation( uidx=uidx, valid_iterator=valid_iterator, batch_size=training_configs["bleu_valid_batch_size"], model=nmt_model, bleu_scorer=bleu_scorer, vocab_src=vocab_src, vocab_tgt=vocab_tgt, valid_dir=flags.valid_path, max_steps=training_configs["bleu_valid_configs"] ["max_steps"], beam_size=training_configs["bleu_valid_configs"] ["beam_size"], alpha=training_configs["bleu_valid_configs"]["alpha"], world_size=world_size, rank=rank, ) model_collections.add_to_collection(key="history_bleus", value=valid_bleu) best_valid_bleu = float( np.array(model_collections.get_collection( "history_bleus")).max()) if summary_writer is not None: summary_writer.add_scalar("bleu", valid_bleu, uidx) summary_writer.add_scalar("best_bleu", best_valid_bleu, uidx) # If model get new best valid bleu score if valid_bleu >= best_valid_bleu: bad_count = 0 if is_early_stop is False: if rank == 0: # 1. save the best model torch.save(nmt_model.state_dict(), best_model_prefix + ".final") else: bad_count += 1 # At least one epoch should be traversed if bad_count >= training_configs[ 'early_stop_patience'] and eidx > 0: is_early_stop = True WARN("Early Stop!") exit(0) if rank == 0: best_k_saver.save(global_step=uidx, metric=valid_bleu, model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) # ODC if training_configs['use_odc'] is True: if valid_bleu >= best_valid_bleu: pass # choose method to generate teachers from checkpoints # - best # - ave_k_best # - ma if training_configs['teacher_choice'] == 'ma': teacher_params = ma.export_ma_params() elif training_configs['teacher_choice'] == 'best': teacher_params = nmt_model.state_dict() elif "ave_best" in training_configs['teacher_choice']: if best_k_saver.num_saved >= ave_best_k: teacher_params = average_checkpoints( best_k_saver.get_all_ckpt_path() [-ave_best_k:]) else: teacher_params = nmt_model.state_dict() else: raise ValueError( "can not support teacher choice %s" % training_configs['teacher_choice']) torch.save(teacher_params, teacher_model_path) del teacher_params teacher_patience = 0 critic.set_use_KD(False) else: teacher_patience += 1 if teacher_patience >= training_configs[ 'teacher_refresh_warmup']: teacher_params = torch.load( teacher_model_path, map_location=Constants.CURRENT_DEVICE) teacher_model.load_state_dict(teacher_params, strict=False) del teacher_params critic.reset_teacher(teacher_model) critic.set_use_KD(True) if summary_writer is not None: summary_writer.add_scalar("bad_count", bad_count, uidx) info_str = "{0} Loss: {1:.2f} BLEU: {2:.2f} lrate: {3:6f} patience: {4} ".format( uidx, valid_loss, valid_bleu, lrate, bad_count) for key, value in valid_loss_dict.items(): info_str += (key + ': {0:.2f} '.format(value)) INFO(info_str) # ================================================================================== # # Saving checkpoints if should_trigger_by_steps( uidx, eidx, every_n_step=training_configs['save_freq'], debug=flags.debug): model_collections.add_to_collection("uidx", uidx) model_collections.add_to_collection("eidx", eidx) model_collections.add_to_collection("bad_count", bad_count) model_collections.add_to_collection("teacher_patience", teacher_patience) if not is_early_stop: if rank == 0: checkpoint_saver.save(global_step=uidx, model=nmt_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) if training_progress_bar is not None: training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break
class Translate_Env(object): """ wrap translate environment for multiple agents env needs parallel data to evaluate bleu_degredation state of the env is defined as the batched src labels and current target index environment yields rewards based on discriminator and finally by sentence-level BLEU :return: translation multiple sentences and return changed bleu """ def __init__( self, attack_configs, discriminator_configs, src_vocab, trg_vocab, data_iterator, save_to, device="cpu", ): """ initiate translation environments, needs a discriminator and translator :param attack_configs: attack configures dictionary :param save_to: discriminator models :param data_iterator: use to provide data for environment initiate the directory of the src sentences :param device: (string) devices to allocate variables("cpu", "cuda:*") default as cpu """ self.data_iterator = data_iterator discriminator_model_configs = discriminator_configs[ "discriminator_model_configs"] discriminator_optim_configs = discriminator_configs[ "discriminator_optimizer_configs"] self.victim_config_path = attack_configs["victim_configs"] self.victim_model_path = attack_configs["victim_model"] # determine devices self.device = device with open(self.victim_config_path.strip()) as v_f: print("open victim configs...%s" % self.victim_config_path) victim_configs = yaml.load(v_f) self.src_vocab = src_vocab self.trg_vocab = trg_vocab self.translate_model = build_translate_model(victim_configs, self.victim_model_path, vocab_src=self.src_vocab, vocab_trg=self.trg_vocab, device=self.device) self.translate_model.eval() self.w2p, self.w2vocab = load_or_extract_near_vocab( config_path=self.victim_config_path, model_path=self.victim_model_path, init_perturb_rate=attack_configs["init_perturb_rate"], save_to=os.path.join(save_to, "near_vocab"), save_to_full=os.path.join(save_to, "full_near_vocab"), top_reserve=12, emit_as_id=True) ######################################################### # to update discriminator # discriminator_data_configs = attack_configs["discriminator_data_configs"] self.discriminator = TransDiscriminator( n_src_words=self.src_vocab.max_n_words, n_trg_words=self.trg_vocab.max_n_words, **discriminator_model_configs) self.discriminator.to(self.device) load_embedding(self.discriminator, model_path=self.victim_model_path, device=self.device) self.optim_D = Optimizer( name=discriminator_optim_configs["optimizer"], model=self.discriminator, lr=discriminator_optim_configs["learning_rate"], grad_clip=discriminator_optim_configs["grad_clip"], optim_args=discriminator_optim_configs["optimizer_params"]) self.criterion_D = nn.CrossEntropyLoss( ) # used in discriminator updates self.scheduler_D = None # default as None if discriminator_optim_configs['schedule_method'] is not None: if discriminator_optim_configs['schedule_method'] == "loss": self.scheduler_D = ReduceOnPlateauScheduler( optimizer=self.optim_D, **discriminator_optim_configs["scheduler_configs"]) elif discriminator_optim_configs['schedule_method'] == "noam": self.scheduler_D = NoamScheduler( optimizer=self.optim_D, **discriminator_optim_configs['scheduler_configs']) elif discriminator_optim_configs["schedule_method"] == "rsqrt": self.scheduler_D = RsqrtScheduler( optimizer=self.optim_D, **discriminator_optim_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.". format(discriminator_optim_configs['schedule_method'])) ############################################################ self._init_state() self.adversarial = attack_configs[ "adversarial"] # adversarial sample or reinforced samples self.r_s_weight = attack_configs["r_s_weight"] self.r_d_weight = attack_configs["r_d_weight"] def _init_state(self): """ initiate batched sentences / origin_bleu / index (start from first label, no BOS/EOS) the initial state of the environment :return: env states (the src, index) """ self.index = 1 self.origin_bleu = [] batch = next(self.data_iterator) assert len( batch ) == 3, "must be provided with line index (check for data_iterator)" # training, parallel trg is provided _, seqs_x, self.seqs_y = batch self.sent_len = [len(x) for x in seqs_x] # for terminal signals self.terminal_signal = [0] * len(seqs_x) # for terminal signals self.padded_src, self.padded_trg = self.prepare_data( seqs_x=seqs_x, seqs_y=self.seqs_y) self.origin_result = self.translate() # calculate BLEU scores for the top candidate for index, sent_t in enumerate(self.seqs_y): bleu_t = bleu.sentence_bleu(references=[sent_t], hypothesis=self.origin_result[index], emulate_multibleu=True) self.origin_bleu.append(bleu_t) return self.padded_src.cpu().numpy() def get_src_vocab(self): return self.src_vocab def reset(self): return self._init_state() def reset_data_iter( self, data_iter): # reset data iterator with provided iterator self.data_iterator = data_iter return def reset_discriminator(self): self.discriminator.reset() load_embedding(self.discriminator, model_path=self.victim_model_path, device=self.device) def prepare_D_data(self, attacker, seqs_x, seqs_y, batch_first=True): """ using global_attacker to generate training data for discriminator :param attacker: prepare the data :param seqs_x: list of sources :param seqs_y: corresponding targets :param batch_first: first dimension of seqs be batch :param device: cpu or cuda* :return: perturbed seqsx, seqsy, flags """ def _np_pad_batch_2D(samples, pad, batch_first=True): # pack seqs into tensor with pads batch_size = len(samples) sizes = [len(s) for s in samples] max_size = max(sizes) x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64') for ii in range(batch_size): x_np[ii, :sizes[ii]] = samples[ii] if batch_first is False: x_np = np.transpose(x_np, [1, 0]) x = torch.tensor(x_np).to(self.device) return x seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x)) x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first) # training mode attack: randomly choose half of the seqs to attack attacker.eval() x, flags = attacker.seq_attack(x, self.w2vocab, training_mode=True) seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y)) y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first) flags.to(self.device) # # print trace # flag_list = flags.cpu().numpy().tolist() # x_list = x.cpu().numpy().tolist() # for i in range(len(flag_list)): # if flag_list[i]==1: # print(self.src_vocab.ids2sent(seqs_x[i])) # print(self.src_vocab.ids2sent(x_list[i])) # print(self.trg_vocab.ids2sent(seqs_y[i])) return x, y, flags def prepare_data(self, seqs_x, seqs_y=None, batch_first=True): """ Args: eval ('bool'): indicator for eval/infer. Returns: padded data matrices """ def _np_pad_batch_2D(samples, pad, batch_first=True): batch_size = len(samples) sizes = [len(s) for s in samples] max_size = max(sizes) x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64') for ii in range(batch_size): x_np[ii, :sizes[ii]] = samples[ii] if batch_first is False: x_np = np.transpose(x_np, [1, 0]) x = torch.tensor(x_np).to(self.device) return x seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x)) x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first) if seqs_y is None: return x seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y)) y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first) return x, y def acc_validation(self, attacker): self.discriminator.eval() acc = 0 sample_count = 0 for i in range(5): try: batch = next(self.data_iterator) except StopIteration: batch = next(self.data_iterator) seq_nums, seqs_x, seqs_y = batch x, y, flags = self.prepare_D_data(attacker, seqs_x, seqs_y) # set components to evaluation mode self.discriminator.eval() with torch.no_grad(): preds = self.discriminator(x, y).argmax(dim=-1) acc += torch.eq(preds, flags).sum() sample_count += preds.size(0) acc = acc.float() / sample_count return acc.item() def compute_D_forward(self, seqs_x, seqs_y, gold_flags, evaluate=False): """ get loss according to criterion :param: gold_flags=1 if perturbed, otherwise 0 :return: loss value """ if not evaluate: # set components to training mode(dropout layers) self.discriminator.train() self.criterion_D.train() with torch.enable_grad(): class_probs = self.discriminator(seqs_x, seqs_y) loss = self.criterion_D(class_probs, gold_flags) torch.autograd.backward(loss) return loss.item() else: # set components to evaluation mode(dropout layers) self.discriminator.eval() self.criterion_D.eval() with torch.no_grad(): class_probs = self.discriminator(seqs_x, seqs_y) loss = self.criterion_D(class_probs, gold_flags) return loss.item() def update_discriminator(self, attacker_model, base_steps=0, min_update_steps=20, max_update_steps=300, accuracy_bound=0.8, summary_writer=None): """ update discriminator :param attacker_model: attacker to generate training data for discriminator :param base_steps: used for saving :param min_update_steps: (integer) minimum update steps, also the discriminator evaluate steps :param max_update_steps: (integer) maximum update steps :param accuracy_bound: (float) update until accuracy reaches the bound (or max_update_steps) :param summary_writer: used to log discriminator learning information :return: steps and test accuracy as trust region """ INFO("update discriminator") self.optim_D.zero_grad() attacker_model = attacker_model.to(self.device) step = 0 while True: try: batch = next(self.data_iterator) except StopIteration: batch = next(self.data_iterator) # update the discriminator step += 1 if self.scheduler_D is not None: # override learning rate in self.optim_D self.scheduler_D.step(global_step=step) _, seqs_x, seqs_y = batch # returned tensor type of the data try: x, y, flags = self.prepare_D_data(attacker_model, seqs_x, seqs_y) loss = self.compute_D_forward(seqs_x=x, seqs_y=y, gold_flags=flags) self.optim_D.step() print("discriminator loss:", loss) except RuntimeError as e: if "out of memory" in str(e): print("WARNING: out of memory, skipping batch") self.optim_D.zero_grad() else: raise e # valid for accuracy / check for break (if any) if step % min_update_steps == 0: acc = self.acc_validation(attacker_model) print("discriminator acc: %2f" % acc) summary_writer.add_scalar("discriminator", scalar_value=acc, global_step=base_steps + step) if accuracy_bound and acc > accuracy_bound: INFO("discriminator reached training acc bound, updated.") return base_steps + step, acc if step > max_update_steps: acc = self.acc_validation(attacker_model) print("discriminator acc: %2f" % acc) INFO("Reach maximum discriminator update. Finished.") return base_steps + step, acc # stop updates def translate(self, inputs=None): """ translate the self.perturbed_src :param inputs: if None, translate perturbed sequences stored in the environments :return: list of translation results """ if inputs is None: inputs = self.padded_src with torch.no_grad(): print(inputs.device) perturbed_results = beam_search( self.translate_model, beam_size=5, max_steps=150, src_seqs=inputs, alpha=-1.0, ) perturbed_results = perturbed_results.cpu().numpy().tolist() # only use the top result from the result result = [] for sent in perturbed_results: sent = [wid for wid in sent[0] if wid != PAD] result.append(sent) return result def step(self, actions): """ step update for the environment: finally update self.index this is defined as inference of the environments :param actions: whether to perturb (action distribution vector in shape [batch, 1])on current index * result of torch.argmax(actor_output_distribution, dim=-1) test: actions = actor_output_distribution.argmax(dim=-1) or train: actions = actor.output_distribution.multinomial(dim=-1) can be on cpu or cuda. :return: updated states/ rewards/ terminal signal from the environments reward (float), terminal_signal (boolean) """ with torch.no_grad(): terminal = False # default is not terminated batch_size = actions.shape[0] reward = 0 inputs = self.padded_src[:, self.index] inputs_mask = ~inputs.eq(PAD) target_of_step = [] # modification on sequences (state) for batch_index in range(batch_size): word_id = inputs[batch_index] target_word_id = self.w2vocab[word_id.item()][np.random.choice( len(self.w2vocab[word_id.item()]), 1)[0]] target_of_step += [target_word_id] if self.device != "cpu" and not actions.is_cuda: actions = actions.to(self.device) actions *= inputs_mask # PAD is neglect # override the state src with random choice from candidates self.padded_src[:, self.index] *= (1 - actions) adjustification_ = torch.tensor(target_of_step) adjustification_ = adjustification_.to(self.device) self.padded_src[:, self.index] += adjustification_ * actions # update sequences' pointer self.index += 1 """ run discriminator check for terminal signals, update local terminal list True: all sentences in the batch is defined as false by self.discriminator False: otherwise """ # get discriminator distribution on the current src state discriminate_out = self.discriminator(self.padded_src, self.padded_trg) self.terminal_signal = self.terminal_signal or discriminate_out.detach( ).argmax(dim=-1).cpu().numpy().tolist() signal = (1 - discriminate_out.argmax(dim=-1)).sum().item() if signal == 0 or self.index == self.padded_src.shape[1] - 1: terminal = True # no need to further explore or reached EOS for all src """ collect rewards on the current state """ # calculate intermediate survival rewards if not terminal: # survival rewards for survived objects distribution, discriminate_index = discriminate_out.max(dim=-1) distribution = distribution.detach().cpu().numpy() discriminate_index = (1 - discriminate_index).cpu().numpy() survival_value = distribution * discriminate_index * ( 1 - np.array(self.terminal_signal)) reward += survival_value.sum() * self.r_s_weight else: # only penalty for overall intermediate termination reward = -1 * batch_size # only check for finished relative BLEU degradation when survival on the last label if self.index == self.padded_src.shape[1] - 1: # re-tokenize ignore the original UNK for victim model inputs = self.padded_src.cpu().numpy().tolist() new_inputs = [] for indices in inputs: # remove EOS, BOS, PAD new_line = [ word_id for word_id in indices if word_id not in [EOS, BOS, PAD] ] new_line = self.src_vocab.ids2sent(new_line) if not hasattr(self.src_vocab.tokenizer, "bpe"): new_line = new_line.strip().split() else: new_token = [] for w in new_line.strip().split(): if w != self.src_vocab.id2token(UNK): new_token.append( self.src_vocab.tokenizer.bpe.segment_word( w)) else: new_token.append([w]) new_line = sum(new_token, []) new_line = [self.src_vocab.token2id(t) for t in new_line] new_inputs.append(new_line) # translate calculate padded_src perturbed_result = self.translate( self.prepare_data(seqs_x=new_inputs, )) # calculate final BLEU degredation: episodic_rewards = [] for i, sent in enumerate(self.seqs_y): # sentence is still surviving if self.index >= self.sent_len[ i] - 1 and self.terminal_signal[i] == 0: if self.origin_bleu[i] == 0: # here we want to minimize noise from original bad cases relative_degraded_value = 0 else: relative_degraded_value = ( self.origin_bleu[i] - bleu.sentence_bleu( references=[sent], hypothesis=perturbed_result[i], emulate_multibleu=True)) # print(relative_degraded_value, self.origin_bleu[i]) relative_degraded_value /= self.origin_bleu[i] if self.adversarial: episodic_rewards.append(relative_degraded_value) else: episodic_rewards.append(-relative_degraded_value) else: episodic_rewards.append(0.0) reward += sum(episodic_rewards) * self.r_d_weight reward = reward / batch_size return self.padded_src.cpu().numpy(), reward, terminal,
def _init_local_optims(self, rephraser_optimizer_configs): """ actor, critic, alpha optimizers and lr scheduler if necessary rephraser_optimizer_configs: optimizer: "adafactor" learning_rate: 0.01 grad_clip: -1.0 optimizer_params: ~ schedule_method: rsqrt scheduler_configs: d_model: *dim warmup_steps: 100 """ # initiate local optimizer if rephraser_optimizer_configs is None: self.actor_optimizer = None self.critic_optimizer = None self.log_alpha_optimizer = None # self.actor_icm_optimizer = None self.actor_scheduler = None self.critic_scheduler = None else: self.actor_optimizer = Optimizer( name=rephraser_optimizer_configs["optimizer"], model=self.actor, lr=rephraser_optimizer_configs["learning_rate"], grad_clip=rephraser_optimizer_configs["grad_clip"], optim_args=rephraser_optimizer_configs["optimizer_params"]) self.critic_optimizer = Optimizer( name=rephraser_optimizer_configs["optimizer"], model=self.critic, lr=rephraser_optimizer_configs["learning_rate"], grad_clip=rephraser_optimizer_configs["grad_clip"], optim_args=rephraser_optimizer_configs["optimizer_params"]) # hardcoded entropy weight updates and icm updates self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=1e-4, betas=(0.9, 0.999)) # self.actor_icm_optimizer = torch.optim.Adam(self.actor.icm.parameters(), lr=1e-3, ) # Build scheduler for optimizer if needed if rephraser_optimizer_configs['schedule_method'] is not None: if rephraser_optimizer_configs['schedule_method'] == "loss": self.actor_scheduler = ReduceOnPlateauScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = ReduceOnPlateauScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) elif rephraser_optimizer_configs['schedule_method'] == "noam": self.actor_scheduler = NoamScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = NoamScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) elif rephraser_optimizer_configs["schedule_method"] == "rsqrt": self.actor_scheduler = RsqrtScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = RsqrtScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) else: WARN( "Unknown scheduler name {0}. Do not use lr_scheduling." .format( rephraser_optimizer_configs['schedule_method'])) self.actor_scheduler = None self.critic_scheduler = None else: self.actor_scheduler = None self.critic_scheduler = None
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Curriculum learning affects data loader self.curriculum = self.config['hparas']['curriculum'] self.val_mode = self.config['hparas']['val_mode'].lower() self.WER = 'per' if self.val_mode == 'per' else 'wer' def fetch_data(self, data, train=False): ''' Move data to device and compute text seq. length''' # feat: B x T x D _, feat, feat_len, txt = data if self.paras.upstream is not None: # feat is raw waveform device = 'cpu' if self.paras.deterministic else self.device self.upstream.to(device) self.specaug.to(device) def to_device(feat): return [f.to(device) for f in feat] def extract_feature(feat): feat = self.upstream(to_device(feat)) if train and self.config['data']['audio'][ 'augment'] and 'aug' not in self.paras.upstream: feat = [self.specaug(f) for f in feat] return feat if HALF_BATCHSIZE_AUDIO_LEN < 3500 and train: first_len = extract_feature(feat[:1])[0].shape[0] if first_len > HALF_BATCHSIZE_AUDIO_LEN: feat = feat[::2] txt = txt[::2] if self.paras.upstream_trainable: self.upstream.train() feat = extract_feature(feat) else: with torch.no_grad(): self.upstream.eval() feat = extract_feature(feat) feat_len = torch.LongTensor([len(f) for f in feat]) feat = pad_sequence(feat, batch_first=True) txt = pad_sequence(txt, batch_first=True) feat = feat.to(self.device) feat_len = feat_len.to(self.device) txt = txt.to(self.device) txt_len = torch.sum(txt != 0, dim=-1) return feat, feat_len, txt, txt_len def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' if self.paras.upstream is not None: print(f'[Solver] - using S3PRL {self.paras.upstream}') self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \ load_wav_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum>0, **self.config['data']) self.upstream = torch.hub.load( 's3prl/s3prl', self.paras.upstream, feature_selection=self.paras.upstream_feature_selection, refresh=self.paras.upstream_refresh, ckpt=self.paras.upstream_ckpt, force_reload=True, ) self.feat_dim = self.upstream.get_output_dim() self.specaug = Augment() else: self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, self.curriculum>0, **self.config['data']) self.verbose(msg) # Dev set sames self.dv_names = [] if type(self.dv_set) is list: for ds in self.config['data']['corpus']['dev_split']: self.dv_names.append(ds[0]) else: self.dv_names = self.config['data']['corpus']['dev_split'][0] # Logger settings if type(self.dv_names) is str: self.best_wer = { 'att': { self.dv_names: 3.0 }, 'ctc': { self.dv_names: 3.0 } } else: self.best_wer = {'att': {}, 'ctc': {}} for name in self.dv_names: self.best_wer['att'][name] = 3.0 self.best_wer['ctc'][name] = 3.0 def set_model(self): ''' Setup ASR model and optimizer ''' # Model #print(self.feat_dim) #160 batch_size = self.config['data']['corpus']['batch_size'] // 2 self.model = ASR(self.feat_dim, self.vocab_size, batch_size, **self.config['model']).to(self.device) self.verbose(self.model.create_msg()) model_paras = [{'params': self.model.parameters()}] # Losses '''label smoothing''' if self.config['hparas']['label_smoothing']: self.seq_loss = LabelSmoothingLoss(31, 0.1) print('[INFO] using label smoothing. ') else: self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) self.ctc_loss = torch.nn.CTCLoss( blank=0, zero_infinity=False) # Note: zero_infinity=False is unstable? # Plug-ins self.emb_fuse = False self.emb_reg = ('emb' in self.config) and (self.config['emb']['enable']) if self.emb_reg: from src.plugin import EmbeddingRegularizer self.emb_decoder = EmbeddingRegularizer( self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) model_paras.append({'params': self.emb_decoder.parameters()}) self.emb_fuse = self.emb_decoder.apply_fuse if self.emb_fuse: self.seq_loss = torch.nn.NLLLoss(ignore_index=0) self.verbose(self.emb_decoder.create_msg()) # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.lr_scheduler = self.optimizer.lr_scheduler self.verbose(self.optimizer.create_msg()) # Enable AMP if needed self.enable_apex() # Transfer Learning if self.transfer_learning: self.verbose('Apply transfer learning: ') self.verbose(' Train encoder layers: {}'.format( self.train_enc)) self.verbose(' Train decoder: {}'.format( self.train_dec)) self.verbose(' Save name: {}'.format( self.save_name)) # Automatically load pre-trained model if self.paras.load is given self.load_ckpt() def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) if self.transfer_learning: self.model.encoder.fix_layers(self.fix_enc) if self.fix_dec and self.model.enable_att: self.model.decoder.fix_layers() if self.fix_dec and self.model.enable_ctc: self.model.fix_ctc_layer() self.n_epochs = 0 self.timer.set() '''early stopping for ctc ''' self.early_stoping = self.config['hparas']['early_stopping'] stop_epoch = 10 batch_size = self.config['data']['corpus']['batch_size'] stop_step = len(self.tr_set) * stop_epoch // batch_size while self.step < self.max_step: ctc_loss, att_loss, emb_loss = None, None, None # Renew dataloader to enable random sampling if self.curriculum > 0 and n_epochs == self.curriculum: self.verbose( 'Curriculum learning ends after {} epochs, starting random sampling.' .format(n_epochs)) self.tr_set, _, _, _, _, _ = \ load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad tf_rate = self.optimizer.pre_step(self.step) total_loss = 0 # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data, train=True) self.timer.cnt('rd') # Forward model # Note: txt should NOT start w/ <sos> ctc_output, encode_len, att_output, att_align, dec_state = \ self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate, teacher=txt, get_dec_state=self.emb_reg) # Clear not used objects del att_align # Plugins if self.emb_reg: emb_loss, fuse_output = self.emb_decoder(dec_state, att_output, label=txt) total_loss += self.emb_decoder.weight * emb_loss else: del dec_state ''' early stopping ctc''' if self.early_stoping: if self.step > stop_step: ctc_output = None self.model.ctc_weight = 0 #print(ctc_output.shape) # Compute all objectives if ctc_output is not None: if self.paras.cudnn_ctc: ctc_loss = self.ctc_loss( ctc_output.transpose(0, 1), txt.to_sparse().values().to(device='cpu', dtype=torch.int32), [ctc_output.shape[1]] * len(ctc_output), #[int(encode_len.max()) for _ in encode_len], txt_len.cpu().tolist()) else: ctc_loss = self.ctc_loss(ctc_output.transpose(0, 1), txt, encode_len, txt_len) total_loss += ctc_loss * self.model.ctc_weight del encode_len if att_output is not None: #print(att_output.shape) b, t, _ = att_output.shape att_output = fuse_output if self.emb_fuse else att_output att_loss = self.seq_loss(att_output.view(b * t, -1), txt.view(-1)) # Sum each uttr and devide by length then mean over batch # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float()) total_loss += att_loss * (1 - self.model.ctc_weight) self.timer.cnt('fw') # Backprop grad_norm = self.backward(total_loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\ .format(total_loss.cpu().item(),grad_norm,self.timer.show())) self.write_log('emb_loss', {'tr': emb_loss}) if att_output is not None: self.write_log('loss', {'tr_att': att_loss}) self.write_log(self.WER, { 'tr_att': cal_er(self.tokenizer, att_output, txt) }) self.write_log( 'cer', { 'tr_att': cal_er(self.tokenizer, att_output, txt, mode='cer') }) if ctc_output is not None: self.write_log('loss', {'tr_ctc': ctc_loss}) self.write_log( self.WER, { 'tr_ctc': cal_er( self.tokenizer, ctc_output, txt, ctc=True) }) self.write_log( 'cer', { 'tr_ctc': cal_er(self.tokenizer, ctc_output, txt, mode='cer', ctc=True) }) self.write_log( 'ctc_text_train', self.tokenizer.decode( ctc_output[0].argmax(dim=-1).tolist(), ignore_repeat=True)) # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0: # self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True)) #del total_loss if self.emb_fuse: if self.emb_decoder.fuse_learnable: self.write_log( 'fuse_lambda', {'emb': self.emb_decoder.get_weight()}) self.write_log('fuse_temp', {'temp': self.emb_decoder.get_temp()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): if type(self.dv_set) is list: for dv_id in range(len(self.dv_set)): self.validate(self.dv_set[dv_id], self.dv_names[dv_id]) else: self.validate(self.dv_set, self.dv_names) if self.step % (len(self.tr_set) // batch_size) == 0: # one epoch print('Have finished epoch: ', self.n_epochs) self.n_epochs += 1 if self.lr_scheduler == None: lr = self.optimizer.opt.param_groups[0]['lr'] if self.step == 1: print( '[INFO] using lr schedular defined by Daniel, init lr = ', lr) if self.step > 99999 and self.step % 2000 == 0: lr = lr * 0.85 for param_group in self.optimizer.opt.param_groups: param_group['lr'] = lr print('[INFO] at step:', self.step) print('[INFO] lr reduce to', lr) #self.lr_scheduler.step(total_loss) # End of step # if self.step % EMPTY_CACHE_STEP == 0: # Empty cuda cache after every fixed amount of steps torch.cuda.empty_cache( ) # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 self.timer.set() if self.step > self.max_step: break #update lr_scheduler self.log.close() print('[INFO] Finished training after', human_format(self.max_step), 'steps.') def validate(self, _dv_set, _name): # Eval mode self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() dev_wer = {'att': [], 'ctc': []} dev_cer = {'att': [], 'ctc': []} dev_er = {'att': [], 'ctc': []} for i, data in enumerate(_dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(_dv_set))) # Fetch data feat, feat_len, txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): ctc_output, encode_len, att_output, att_align, dec_state = \ self.model( feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), emb_decoder=self.emb_decoder) if att_output is not None: dev_wer['att'].append( cal_er(self.tokenizer, att_output, txt, mode='wer')) dev_cer['att'].append( cal_er(self.tokenizer, att_output, txt, mode='cer')) dev_er['att'].append( cal_er(self.tokenizer, att_output, txt, mode=self.val_mode)) if ctc_output is not None: dev_wer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='wer', ctc=True)) dev_cer['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode='cer', ctc=True)) dev_er['ctc'].append( cal_er(self.tokenizer, ctc_output, txt, mode=self.val_mode, ctc=True)) # Show some example on tensorboard if i == len(_dv_set) // 2: for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text_{}_{}'.format(_name, i), self.tokenizer.decode(txt[i].tolist())) if att_output is not None: self.write_log( 'att_align_{}_{}'.format(_name, i), feat_to_fig(att_align[i, 0, :, :].cpu().detach())) self.write_log( 'att_text_{}_{}'.format(_name, i), self.tokenizer.decode( att_output[i].argmax(dim=-1).tolist())) if ctc_output is not None: self.write_log( 'ctc_text_{}_{}'.format(_name, i), self.tokenizer.decode( ctc_output[i].argmax(dim=-1).tolist(), ignore_repeat=True)) # Ckpt if performance improves tasks = [] if len(dev_er['att']) > 0: tasks.append('att') if len(dev_er['ctc']) > 0: tasks.append('ctc') for task in tasks: dev_er[task] = sum(dev_er[task]) / len(dev_er[task]) dev_wer[task] = sum(dev_wer[task]) / len(dev_wer[task]) dev_cer[task] = sum(dev_cer[task]) / len(dev_cer[task]) if dev_er[task] < self.best_wer[task][_name]: self.best_wer[task][_name] = dev_er[task] self.save_checkpoint( 'best_{}_{}.pth'.format( task, _name + (self.save_name if self.transfer_learning else '')), self.val_mode, dev_er[task], _name) if self.step >= self.max_step: self.save_checkpoint( 'last_{}_{}.pth'.format( task, _name + (self.save_name if self.transfer_learning else '')), self.val_mode, dev_er[task], _name) self.write_log(self.WER, {'dv_' + task + '_' + _name.lower(): dev_wer[task]}) self.write_log('cer', {'dv_' + task + '_' + _name.lower(): dev_cer[task]}) # if self.transfer_learning: # print('[{}] WER {:.4f} / CER {:.4f} on {}'.format(human_format(self.step), dev_wer[task], dev_cer[task], _name)) # Resume training self.model.train() if self.transfer_learning: self.model.encoder.fix_layers(self.fix_enc) if self.fix_dec and self.model.enable_att: self.model.decoder.fix_layers() if self.fix_dec and self.model.enable_ctc: self.model.fix_ctc_layer() if self.emb_decoder is not None: self.emb_decoder.train()
class Translate_Env(object): """ wrap translate environment for multiple agents env needs parallel data to evaluate final bleu improvement stores the states as [current src embeddings, index], yields rewards at each step environment yields rewards based on scorer and finally by sentence-level BLEU :return: translation multiple sentences and return changed bleu """ def __init__(self, reinforce_configs, annunciator_configs, src_vocab, trg_vocab, data_iterator, save_to, device="cpu", ): """ initiate translation environments, needs a Scorer and translator :param reinforce_configs: attack configures dictionary :param annunciator_configs: discriminator or scorer configs(provide survive signals) :param save_to: path to save the model :param data_iterator: use to provide data for environment initiate the directory of the src sentences :param device: (string) devices to allocate variables("cpu", "cuda:*") default as cpu """ # environment devices self.device = device self.data_iterator = data_iterator scorer_model_configs = annunciator_configs["scorer_model_configs"] # discriminator_model_configs = annunciator_configs["discriminator_model_configs"] annunciator_optim_configs = annunciator_configs["annunciator_optimizer_configs"] victim_config_path = reinforce_configs["victim_configs"] victim_model_path = reinforce_configs["victim_model"] with open(victim_config_path.strip()) as v_f: INFO("env open victim configs at %s" % victim_config_path) victim_configs = yaml.load(v_f, Loader=yaml.FullLoader) # to extract the embedding as representation # *vocab and *emb will provide psudo-reinforced embedding to train annunciator self.src_vocab = src_vocab self.trg_vocab = trg_vocab # translation model for BLEU(take src_embs as inputs) and corresponding embedding layers self.src_emb, self.trg_emb, self.translate_model = build_translate_model( victim_configs, victim_model_path, vocab_src=self.src_vocab, vocab_trg=self.trg_vocab, device=self.device) self.max_roll_out_step = victim_configs["data_configs"]["max_len"][0] self.src_emb.eval() # source language embeddings self.trg_emb.eval() # target language embeddings self.translate_model.eval() # the epsilon range used for action space when perturbation _, _, self.limit_dist = load_or_extract_near_vocab( config_path=victim_config_path, model_path=victim_model_path, init_perturb_rate=reinforce_configs["init_perturb_rate"], save_to=os.path.join(save_to, "near_vocab"), save_to_full=os.path.join(save_to, "full_near_vocab"), top_reserve=12, emit_as_id=True) ######################################################### # scorer(an Annunciator object) provides intrinsic step rewards self.annunciator = TransScorer( victim_configs, victim_model_path, self.trg_emb, **scorer_model_configs) self.annunciator.to(self.device) # # discriminator(an Annunciator object) provides intrisic step rewards and terminal signal # self.discriminator = TransDiscriminator( # victim_configs, victim_model_path, # **discriminator_model_configs) # self.discriminator.to(self.device) # Annunciator update configs self.acc_bound = annunciator_configs["acc_bound"] self.mse_bound = annunciator_configs["mse_bound"] self.min_update_steps = annunciator_configs["valid_freq"] self.max_update_steps = annunciator_configs["annunciator_update_steps"] # the optimizer and schedule used for Annunciator update. self.optim_A = Optimizer( name=annunciator_optim_configs["optimizer"], model=self.annunciator, lr=annunciator_optim_configs["learning_rate"], grad_clip=annunciator_optim_configs["grad_clip"], optim_args=annunciator_optim_configs["optimizer_params"]) self.scheduler_A = None # default as None if annunciator_optim_configs['schedule_method'] is not None: if annunciator_optim_configs['schedule_method'] == "loss": self.scheduler_A = ReduceOnPlateauScheduler(optimizer=self.optim_A, **annunciator_optim_configs["scheduler_configs"]) elif annunciator_optim_configs['schedule_method'] == "noam": self.scheduler_A = NoamScheduler(optimizer=self.optim_A, **annunciator_optim_configs['scheduler_configs']) elif annunciator_optim_configs["schedule_method"] == "rsqrt": self.scheduler_A = RsqrtScheduler(optimizer=self.optim_A, **annunciator_optim_configs["scheduler_configs"]) else: WARN("Unknown scheduler name {0}. Do not use lr_scheduling.".format( annunciator_optim_configs['schedule_method'])) self.criterion_A = nn.CrossEntropyLoss() ############################################################ self.adversarial = reinforce_configs["adversarial"] # adversarial or reinforce as learning objects self.r_s_weight = reinforce_configs["r_s_weight"] self.r_i_weight = reinforce_configs["r_i_weight"] def _init_state(self, rephraser=None): """ initiate batched sentences / origin_bleu / index (start from first label, no BOS/EOS) the initial state of the environment. (applied on the env's device) :return: env states (the src, index) """ self.index = 1 # step index for perturbation self.origin_bleu = [] # saving origin BLEU batch = next(self.data_iterator) assert len(batch) == 3, "must be provided with line index (check for data_iterator)" # training, parallel trg is provided for evaluation (src grouped by similar length) _, seqs_x, self.seqs_y = batch self.sent_len = [len(x) for x in seqs_x] self.survival_signals = np.array([1] * len(seqs_x)) # the survival signals, 1 when true. # for reinforce inputs(embedding level). padded_src, padded_trg = self.prepare_data( seqs_x=seqs_x, seqs_y=self.seqs_y) self.x_emb = self.src_emb(padded_src).detach() # float self.y_emb = self.trg_emb(padded_trg).detach() self.x_pad_indicator = padded_src.detach().eq(PAD) # byte indicating PAD tokens self.y_pad_indicator = padded_trg.detach().eq(PAD) # randomly choose half of the sequence and perturbed by given agent # for self learning (rephraser can be on the cpu()) if rephraser is not None: # self.x_emb, mask_to_UNK = rephraser.random_seq_perturb( # self.x_emb, self.x_pad_indicator, # half_mode=True, rand_act=False, enable_UNK=False) # self.x_emb = self.x_emb.to(self.device) # mask_to_UNK = mask_to_UNK.to(self.device) # # print("x_emb shape:", self.x_emb.shape, "mask_to_UNK shape:", mask_to_UNK.shape) # self.x_emb = self.x_emb*(1.-mask_to_UNK.float().unsqueeze(dim=2)) + \ # self.src_emb((UNK * mask_to_UNK).long()) # self.x_emb = self.x_emb.detach() _, self.x_emb, _ = rephraser.random_seq_perturb( self.x_emb, self.x_pad_indicator, half_mode=True, rand_act=False) self.x_emb = self.x_emb.detach() # print(self.x_mask.shape, self.x_emb.shape) self.origin_result = self.translate() # calculate BLEU scores for the top candidate for index, sent_t in enumerate(self.seqs_y): bleu_t = bleu.sentence_bleu(references=[sent_t], hypothesis=self.origin_result[index], emulate_multibleu=True) self.origin_bleu.append(bleu_t) INFO("initialize env on: %s"%self.x_emb.device) return self.x_emb.cpu().numpy() def get_src_vocab(self): return self.src_vocab def reset(self, rephraser=None): """ when the steps are exhausted. :param rephraser: rephraser is default None for no self-improving learning :return: reset environments' embedding """ return self._init_state(rephraser) def reset_data_iter(self, data_iter): # reset data iterator with provided iterator self.data_iterator = data_iter return def reset_annunciator(self): # a backup, would be deprecated self.annunciator.reset() def prepare_A_data(self, agent, seqs_x, seqs_y, batch_first=True, half_mode=True, rand_act=True): """ use the current rephraser to generate data for Annunciator training perturbation will be applied to a random sequence step. (perturb all the former steps as the origin_emb, and perturb one more step as the perturbed_emb) such process will rephrase the entire batch. :param agent: prepare the data for scorer training (actor and critic) :param seqs_x: list of sources :param seqs_y: list of targets :param batch_first: first dimension of seqs be batch :param rand_act: sample the actions based on rephraser outputs :return: origin_x_emb, perturbed_x_emb, y_emb, x_mask, y_mask, flags """ def _np_pad_batch_2D(samples, pad, batch_first=True): # pack seqs into tensor with pads batch_size = len(samples) sizes = [len(s) for s in samples] max_size = max(sizes) x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64') for ii in range(batch_size): x_np[ii, :sizes[ii]] = samples[ii] if batch_first is False: x_np = np.transpose(x_np, [1, 0]) x = torch.tensor(x_np).to(self.device) return x seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x)) x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first) x_emb = self.src_emb(x).detach().to(self.device) x_pad_indicator = x.detach().eq(PAD).to(self.device) # # mere actor rollout # origin_x_emb, perturbed_x_emb, flags = rephraser.random_seq_perturb( # x_emb, x_pad_indicator, # half_mode=True, rand_act=rand_act) # actor rollout w/ critic's restriction with torch.no_grad(): agent.actor.eval() agent.critic.eval() batch_size, max_seq_len = x_pad_indicator.shape perturbed_x_emb = x_emb.detach().clone() x_mask = 1 - x_pad_indicator.int() for t in range(1, max_seq_len-1): former_emb = perturbed_x_emb input_emb = former_emb[:, t-1:t+2, :] if rand_act: actions, _ = agent.actor.sample_normal( x_emb=former_emb, x_pad_indicator=x_pad_indicator, label_emb=input_emb, reparamization=False) else: mu, _ = agent.actor.forward( x_emb=former_emb, x_pad_indicator=x_pad_indicator, label_emb=input_emb) actions = mu * agent.actor.action_range # actions shape [batch, emb_dim] critique = agent.critic( x_emb=former_emb, x_pad_indicator=x_pad_indicator, label_emb=input_emb, action=actions) # actions_masks shape [batch, 1] actions_mask = critique.gt(0).int() * x_mask[:, t].unsqueeze(dim=1) # mask unnecessary actions perturbed_x_emb[:,t,:] += actions * actions_mask flags = x_emb.new_ones(batch_size) if half_mode: flags = torch.bernoulli(0.5 * flags).to(x_emb.device) perturbed_x_emb = perturbed_x_emb * flags.unsqueeze(dim=1). unsqueeze(dim=2) \ + x_emb * (1-flags).unsqueeze(dim=1).unsqueeze(dim=2) origin_x_emb = x_emb seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y)) y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first) y_emb = self.trg_emb(y).detach().to(self.device) y_pad_indicator = y.detach().eq(PAD).to(self.device) perturbed_x_emb.detach().to(self.device) return origin_x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags.long() def prepare_data(self, seqs_x, seqs_y=None, batch_first=True): """ prepare the batched, padded data with BOS and EOS for translation. used in initialization. Returns: padded data matrices (batch_size, max_seq_len) """ def _np_pad_batch_2D(samples, pad, batch_first=True): batch_size = len(samples) sizes = [len(s) for s in samples] max_size = max(sizes) x_np = np.full((batch_size, max_size), fill_value=pad, dtype='int64') for ii in range(batch_size): x_np[ii, :sizes[ii]] = samples[ii] if batch_first is False: x_np = np.transpose(x_np, [1, 0]) x = torch.tensor(x_np).to(self.device) return x seqs_x = list(map(lambda s: [BOS] + s + [EOS], seqs_x)) x = _np_pad_batch_2D(samples=seqs_x, pad=PAD, batch_first=batch_first) if seqs_y is None: return x seqs_y = list(map(lambda s: [BOS] + s + [EOS], seqs_y)) y = _np_pad_batch_2D(seqs_y, pad=PAD, batch_first=batch_first) return x, y def ratio_validation(self, agent, overall_contrast=True): """ validate the mse of the environments scorer for the given rephraser used for checkpoints and other checks :param rephraser generates the data for validation. :return: the mse of the current scorer in environment. """ # set victim encoder and scorer to evaluation mode self.annunciator.eval() # for i in range(5): try: batch = next(self.data_iterator) except StopIteration: batch = next(self.data_iterator) seq_nums, seqs_x, seqs_y = batch origin_x_emb, perturbed_x_emb, y_emb, x_mask, y_mask, flags = self.prepare_A_data( agent, seqs_x, seqs_y, half_mode=False, rand_act=False) origin_density_score = self.annunciator.get_density_score( origin_x_emb, x_mask, seqs_y) perturbed_density_score = self.annunciator.get_density_score( perturbed_x_emb, x_mask, seqs_y) density_score = origin_density_score/(origin_density_score+perturbed_density_score) if overall_contrast: return density_score.mean().item() else: return perturbed_density_score.mean().item() def acc_validation(self, agent): """ validate the acc of the environments discriminator by a given rephraser used for checkpoints :param agent generates data for validation :return the accuracy of the discriminator to evaluation mode """ self.annunciator.eval() acc = 0 sample_count = 0 for i in range(5): try: batch = next(self.data_iterator) except StopIteration: batch = next(self.data_iterator) seq_nums, seqs_x, seqs_y = batch origin_x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags = \ self.prepare_A_data(agent, seqs_x, seqs_y, half_mode=True) with torch.no_grad(): preds = self.annunciator(perturbed_x_emb, x_pad_indicator, y_emb, y_pad_indicator).argmax(dim=-1) acc += torch.eq(preds, flags).sum() sample_count += preds.shape[0] acc = acc.float() / sample_count return acc.item() # def compute_P_forward(self, # origin_x_emb, perturbed_x_emb, x_mask, # evaluate=False): # """ # process the victim encoder embedding and get CE loss # :param origin_x_emb: float tensor, input embeddings of input tokens # :param perturbed_x_emb: float tensor, perturbed inputs embeddings # :param x_mask: byte tensor, mask of the input tokens # :return: loss value # """ # if not evaluate: # # set components to training mode(dropout layers) # self.scorer.train() # with torch.enable_grad(): # loss = self.scorer(origin_x_emb, perturbed_x_emb, x_mask).mean() # torch.autograd.backward(loss) # return loss.item() # else: # # set components to evaluation mode(dropout layers) # self.scorer.eval() # with torch.enable_grad(): # loss = self.scorer(origin_x_emb, perturbed_x_emb, x_mask).mean() # return loss.item() def compute_A_forward(self, x_emb, y_emb, x_pad_indicator, y_pad_indicator, gold_flags, evaluate=False): """get loss according to criterion :param gold_flags=1 if perturbed, otherwise 0 :param evaluate: False during training mode :return loss value """ if not evaluate: # set components to training mode(dropout layers) self.annunciator.train() self.criterion_A.train() with torch.enable_grad(): class_probs = self.annunciator( x_emb, x_pad_indicator, y_emb, y_pad_indicator) loss = self.criterion_A(class_probs, gold_flags) torch.autograd.backward(loss) return loss.item() else: # set components to evaluation mode(dropout layers) self.annunciator.eval() self.criterion_A.eval() with torch.no_grad(): class_probs = self.annunciator( x_emb, x_pad_indicator, y_emb, y_pad_indicator) loss = self.criterion_A(class_probs, gold_flags) return loss.item() def update_annunciator(self, agent, base_steps=0, min_update_steps=1, max_update_steps=300, accuracy_bound=0.8, overall_update_weight=0.5, summary_writer=None): """ update discriminator using given rephraser :param agent: AC agent to generate training data for discriminator :param base_steps: used for saving :param min_update_steps: (integer) minimum update steps, also the discriminator evaluate steps :param max_update_steps: (integer) maximum update steps :param accuracy_bound: (float) update until accuracy reaches the bound (or max_update_steps) :param summary_writer: used to log discriminator learning information :return: steps and test accuracy as trust region """ INFO("update annunciator") self.optim_A.zero_grad() agent.to(self.device) step = 0 while True: try: batch = next(self.data_iterator) except StopIteration: batch = next(self.data_iterator) # update the discriminator step += 1 if self.scheduler_A is not None: # override learning rate in self.optim_D self.scheduler_A.step(global_step=step) _, seqs_x, seqs_y = batch # returned tensor type of the data try: x_emb, perturbed_x_emb, y_emb, x_pad_indicator, y_pad_indicator, flags = \ self.prepare_A_data(agent, seqs_x, seqs_y, half_mode=False, rand_act=True) loss = self.annunciator(x_emb, perturbed_x_emb, x_pad_indicator, seqs_y, overall_update_weight) # for name, p in self.annunciator.named_parameters(): # if "weight" in name: # loss += torch.norm(p, 2) # with l2-norm against overfitting torch.autograd.backward(loss) self.optim_A.step() print("annunciator loss:", loss) except RuntimeError as e: if "out of memory" in str(e): print("WARNING: out of memory, skipping batch") self.optim_A.zero_grad() else: raise e # valid for accuracy / check for break (if any) if step % min_update_steps == 0: perturbed_density = self.ratio_validation(agent, overall_contrast=False) overall_density = self.ratio_validation(agent) if summary_writer is not None: summary_writer.add_scalar("a_contrast_ratio", scalar_value=overall_density, global_step=base_steps+step) summary_writer.add_scalar("a_ratio_src", scalar_value=perturbed_density, global_step=base_steps+step) print("overall density: %2f" % overall_density) if accuracy_bound and overall_density > accuracy_bound: INFO("annunciator reached training bound, updated") return base_steps+step, overall_density if step > max_update_steps: overall_density = self.ratio_validation(agent) perturbed_density = self.ratio_validation(agent, overall_contrast=False) print("overall density: %2f" % overall_density) INFO("Reach maximum annunciator update. Finished.") return base_steps+step, overall_density # stop updates def translate(self, x_emb=None, x_mask=None): """ translate by given embedding :param src_emb: if None, translate embeddings stored in the environments :param src_mask: input mask paired with embedding :return: list of translation results """ if x_emb is None: # original translation with original embedding x_emb = self.x_emb x_mask = self.x_pad_indicator with torch.no_grad(): perturbed_results = beam_search( self.translate_model, beam_size=5, max_steps=150, src_embs=x_emb, src_mask=x_mask, alpha=-1.0) perturbed_results = perturbed_results.cpu().numpy().tolist() # only use the top result from the result result = [] for sent in perturbed_results: sent = [wid for wid in sent[0] if wid != PAD] result.append(sent) return result def get_state(self): """ retrieve states for the learning :return: the states of the environment """ states = self.x_emb # current sen embeddings, [batch_size, len, emb_dim] masks = 1. - self.x_pad_indicator.float() # indicates valid tokens [batch, max_len] rephrase_positions = torch.tensor(np.array([self.index] * masks.shape[0])).unsqueeze(dim=-1).long() # current state positions [batch, 1] survival_signals = torch.tensor(self.survival_signals).unsqueeze(dim=-1).float() # [batch_size, 1] return states, masks, rephrase_positions, survival_signals def step(self, action): """ step update for the environment: finally update self.index this is defined as inference of the environments states are returned in np.array :param action: tensor.variable as action input(in shape [batch, dim]) on current index for step updates :return: updated states/ rewards/ terminal signal from the environments reward (list of float), terminal_signal (list of boolean) """ with torch.no_grad(): self.annunciator.eval() batch_size, _ = action.shape batched_rewards = [0.] * batch_size if self.device != "cpu" and not action.is_cuda: WARN("mismatching action for gpu_env, move actions to %s"%self.device) action = action.to(self.device) # extract the step mask for actions and rewards inputs_mask = 1. - self.x_pad_indicator.float() inputs_mask = inputs_mask[:, self.index] # slice at current step(index), mask of [batch] inputs_mask *= inputs_mask.new_tensor(self.survival_signals) # mask those terminated # update current src embedding with action origin_emb = self.x_emb.clone().detach() # update embedding; cancel modification on PAD self.x_emb[:, self.index, :] += (action * inputs_mask.unsqueeze(dim=1)) # actions on PAD is masked # update survival_signals, which later determines whether rewards are valid for return # 1. mask survival by step and sent-len step_reward_mask = [int(self.index <= i) for i in self.sent_len] # 2. get batched sentence matching for survival signals on the current src state # # as the reward process (probs on ``survival'' as rewards) d_probs = self.annunciator.get_density_score(self.x_emb, self.x_pad_indicator, self.seqs_y) # print("dprobs:",d_probs) signals = d_probs.detach().lt(0.5).long().cpu().numpy().tolist() # 1 as terminate # print("signals:", signals) if 1 in step_reward_mask: # rollout reaches the sents length # 0 as survive, 1 as terminate probs = d_probs.detach().cpu().numpy() discriminate_index = d_probs.detach().lt(0.5).long() survival_mask = (1 - discriminate_index).cpu().numpy() survival_value = probs * survival_mask terminate_punishment = probs * discriminate_index.cpu().numpy() # looping for survival signals and step rewards origin_survival_signals = self.survival_signals.copy() for i in range(batch_size): # update survivals signals self.survival_signals[i] = self.survival_signals[i] * (1-signals[i]) * step_reward_mask[i] if self.survival_signals[i]: batched_rewards[i] += survival_value[i] * self.r_s_weight elif origin_survival_signals[i]: # punish once the survival signal flips batched_rewards[i] -= terminate_punishment[i] * self.r_i_weight else: # all dead, no need to calculate other rewards, it's ok to waste some samples return self.x_emb.cpu().numpy(), np.array(batched_rewards), self.survival_signals # additional episodic reward for surviving sequences (w/ finished sentence at current step) bleu_mask = [int(self.index == i) for i in self.sent_len] bleu_mask = [bleu_mask[i]*self.survival_signals[i] for i in range(batch_size)] if 1 in bleu_mask: # check for the finished line and mask out the others perturbed_results = self.translate(self.x_emb, self.x_pad_indicator) episodic_rewards = [] for i, sent in enumerate(self.seqs_y): if bleu_mask[i] == 1: degraded_value = (self.origin_bleu[i]-bleu.sentence_bleu( references=[sent], hypothesis=perturbed_results[i], emulate_multibleu=True )) if self.adversarial: # relative degradation if self.origin_bleu[i] == 0: relative_degraded_bleu = 0 else: relative_degraded_bleu = degraded_value/self.origin_bleu[i] episodic_rewards.append(relative_degraded_bleu) else: # absolute improvement print("bleu variation:", self.origin_bleu[i],-degraded_value) episodic_rewards.append(-degraded_value) else: episodic_rewards.append(0.0) # append additional episodic rewards batched_rewards = [batched_rewards[i]+episodic_rewards[i]*self.r_i_weight for i in range(batch_size)] # update sequences' pointer for rephrasing self.index += 1 return self.x_emb.cpu().numpy(), np.array(batched_rewards), self.survival_signals
def train(FLAGS): """ FLAGS: saveto: str reload: store_true config_path: str pretrain_path: str, default="" model_name: str log_path: str """ # write log of training to file. write_log_to_file( os.path.join(FLAGS.log_path, "%s.log" % time.strftime("%Y%m%d-%H%M%S"))) GlobalNames.USE_GPU = FLAGS.use_gpu if GlobalNames.USE_GPU: CURRENT_DEVICE = "cpu" else: CURRENT_DEVICE = "cuda:0" config_path = os.path.abspath(FLAGS.config_path) with open(config_path.strip()) as f: configs = yaml.load(f) INFO(pretty_configs(configs)) # Add default configs configs = default_configs(configs) data_configs = configs['data_configs'] model_configs = configs['model_configs'] optimizer_configs = configs['optimizer_configs'] training_configs = configs['training_configs'] GlobalNames.SEED = training_configs['seed'] set_seed(GlobalNames.SEED) best_model_prefix = os.path.join( FLAGS.saveto, FLAGS.model_name + GlobalNames.MY_BEST_MODEL_SUFFIX) timer = Timer() # ================================================================================== # # Load Data INFO('Loading data...') timer.tic() # Generate target dictionary vocab_tgt = Vocabulary(**data_configs["vocabularies"][0]) train_batch_size = training_configs["batch_size"] * max( 1, training_configs["update_cycle"]) train_buffer_size = training_configs["buffer_size"] * max( 1, training_configs["update_cycle"]) train_bitext_dataset = ZipDataset(TextLineDataset( data_path=data_configs['train_data'][0], vocabulary=vocab_tgt, max_len=data_configs['max_len'][0], ), shuffle=training_configs['shuffle']) valid_bitext_dataset = ZipDataset( TextLineDataset( data_path=data_configs['valid_data'][0], vocabulary=vocab_tgt, )) training_iterator = DataIterator( dataset=train_bitext_dataset, batch_size=train_batch_size, use_bucket=training_configs['use_bucket'], buffer_size=train_buffer_size, batching_func=training_configs['batching_key']) valid_iterator = DataIterator( dataset=valid_bitext_dataset, batch_size=training_configs['valid_batch_size'], use_bucket=True, buffer_size=100000, numbering=True) INFO('Done. Elapsed time {0}'.format(timer.toc())) lrate = optimizer_configs['learning_rate'] is_early_stop = False # ================================ Begin ======================================== # # Build Model & Optimizer # We would do steps below on after another # 1. build models & criterion # 2. move models & criterion to gpu if needed # 3. load pre-trained model if needed # 4. build optimizer # 5. build learning rate scheduler if needed # 6. load checkpoints if needed # 0. Initial model_collections = Collections() checkpoint_saver = Saver( save_prefix="{0}.ckpt".format( os.path.join(FLAGS.saveto, FLAGS.model_name)), num_max_keeping=training_configs['num_kept_checkpoints']) best_model_saver = Saver( save_prefix=best_model_prefix, num_max_keeping=training_configs['num_kept_best_model']) # 1. Build Model & Criterion INFO('Building model...') timer.tic() lm_model = build_model(n_tgt_vocab=vocab_tgt.max_n_words, **model_configs) INFO(lm_model) params_total = sum([p.numel() for n, p in lm_model.named_parameters()]) params_with_embedding = sum([ p.numel() for n, p in lm_model.named_parameters() if n.find('embedding') == -1 ]) INFO('Total parameters: {}'.format(params_total)) INFO('Total parameters (excluding word embeddings): {}'.format( params_with_embedding)) critic = NMTCriterion(label_smoothing=model_configs['label_smoothing']) INFO(critic) INFO('Done. Elapsed time {0}'.format(timer.toc())) # 2. Move to GPU if GlobalNames.USE_GPU: lm_model = lm_model.cuda() critic = critic.cuda() # 3. Load pretrained model if needed lm_model.init_parameters(FLAGS.pretrain_path, device=CURRENT_DEVICE) # 4. Build optimizer INFO('Building Optimizer...') optim = Optimizer(name=optimizer_configs['optimizer'], model=lm_model, lr=lrate, grad_clip=optimizer_configs['grad_clip'], optim_args=optimizer_configs['optimizer_params']) # 5. Build scheduler for optimizer if needed if optimizer_configs['schedule_method'] is not None: if optimizer_configs['schedule_method'] == "loss": scheduler = ReduceOnPlateauScheduler( optimizer=optim, **optimizer_configs["scheduler_configs"]) elif optimizer_configs['schedule_method'] == "noam": scheduler = NoamScheduler(optimizer=optim, **optimizer_configs['scheduler_configs']) else: WARN( "Unknown scheduler name {0}. Do not use lr_scheduling.".format( optimizer_configs['schedule_method'])) scheduler = None else: scheduler = None # 6. build moving average if training_configs['moving_average_method'] is not None: ma = MovingAverage( moving_average_method=training_configs['moving_average_method'], named_params=lm_model.named_parameters(), alpha=training_configs['moving_average_alpha']) else: ma = None INFO('Done. Elapsed time {0}'.format(timer.toc())) # Reload from latest checkpoint if FLAGS.reload: checkpoint_saver.load_latest(model=lm_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) # ================================================================================== # # Prepare training eidx = model_collections.get_collection("eidx", [0])[-1] uidx = model_collections.get_collection("uidx", [0])[-1] bad_count = model_collections.get_collection("bad_count", [0])[-1] oom_count = model_collections.get_collection("oom_count", [0])[-1] summary_writer = SummaryWriter(log_dir=FLAGS.log_path) cum_samples = 0 cum_words = 0 valid_loss = best_valid_loss = float('inf') # Max Float saving_files = [] # Timer for computing speed timer_for_speed = Timer() timer_for_speed.tic() INFO('Begin training...') while True: summary_writer.add_scalar("Epoch", (eidx + 1), uidx) # Build iterator and progress bar training_iter = training_iterator.build_generator() training_progress_bar = tqdm(desc=' - (Epc {}, Upd {}) '.format( eidx, uidx), total=len(training_iterator), unit="sents") for batch in training_iter: uidx += 1 if optimizer_configs[ "schedule_method"] is not None and optimizer_configs[ "schedule_method"] != "loss": scheduler.step(global_step=uidx) seqs_y = batch n_samples_t = len(seqs_y) n_words_t = sum(len(s) for s in seqs_y) cum_samples += n_samples_t cum_words += n_words_t train_loss = 0. optim.zero_grad() try: # Prepare data for (seqs_y_t, ) in split_shard( seqs_y, split_size=training_configs['update_cycle']): y = prepare_data(seqs_y_t, cuda=GlobalNames.USE_GPU) loss = compute_forward( model=lm_model, critic=critic, # seqs_x=x, seqs_y=y, eval=False, normalization=n_samples_t, norm_by_words=training_configs["norm_by_words"]) train_loss += loss / y.size( 1) if not training_configs["norm_by_words"] else loss optim.step() except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') oom_count += 1 optim.zero_grad() else: raise e if ma is not None and eidx >= training_configs[ 'moving_average_start_epoch']: ma.step() training_progress_bar.update(n_samples_t) training_progress_bar.set_description( ' - (Epc {}, Upd {}) '.format(eidx, uidx)) training_progress_bar.set_postfix_str( 'TrainLoss: {:.2f}, ValidLoss(best): {:.2f} ({:.2f})'.format( train_loss, valid_loss, best_valid_loss)) summary_writer.add_scalar("train_loss", scalar_value=train_loss, global_step=uidx) # ================================================================================== # # Display some information if should_trigger_by_steps( uidx, eidx, every_n_step=training_configs['disp_freq']): # words per second and sents per second words_per_sec = cum_words / (timer.toc(return_seconds=True)) sents_per_sec = cum_samples / (timer.toc(return_seconds=True)) lrate = list(optim.get_lrate())[0] summary_writer.add_scalar("Speed(words/sec)", scalar_value=words_per_sec, global_step=uidx) summary_writer.add_scalar("Speed(sents/sen)", scalar_value=sents_per_sec, global_step=uidx) summary_writer.add_scalar("lrate", scalar_value=lrate, global_step=uidx) summary_writer.add_scalar("oom_count", scalar_value=oom_count, global_step=uidx) # Reset timer timer.tic() cum_words = 0 cum_samples = 0 # ================================================================================== # # Saving checkpoints if should_trigger_by_steps( uidx, eidx, every_n_step=training_configs['save_freq'], debug=FLAGS.debug): model_collections.add_to_collection("uidx", uidx) model_collections.add_to_collection("eidx", eidx) model_collections.add_to_collection("bad_count", bad_count) if not is_early_stop: checkpoint_saver.save(global_step=uidx, model=lm_model, optim=optim, lr_scheduler=scheduler, collections=model_collections, ma=ma) # ================================================================================== # # Loss Validation & Learning rate annealing if should_trigger_by_steps( global_step=uidx, n_epoch=eidx, every_n_step=training_configs['loss_valid_freq'], debug=FLAGS.debug): if ma is not None: origin_state_dict = deepcopy(lm_model.state_dict()) lm_model.load_state_dict(ma.export_ma_params(), strict=False) valid_loss = loss_validation( model=lm_model, critic=critic, valid_iterator=valid_iterator, norm_by_words=training_configs["norm_by_words"]) model_collections.add_to_collection("history_losses", valid_loss) min_history_loss = np.array( model_collections.get_collection("history_losses")).min() summary_writer.add_scalar("loss", valid_loss, global_step=uidx) summary_writer.add_scalar("best_loss", min_history_loss, global_step=uidx) if ma is not None: lm_model.load_state_dict(origin_state_dict) del origin_state_dict if optimizer_configs["schedule_method"] == "loss": scheduler.step(metric=best_valid_loss) # If model get new best valid loss if valid_loss < best_valid_loss: bad_count = 0 if is_early_stop is False: # 1. save the best model torch.save(lm_model.state_dict(), best_model_prefix + ".final") # 2. record all several best models best_model_saver.save(global_step=uidx, model=lm_model) else: bad_count += 1 # At least one epoch should be traversed if bad_count >= training_configs[ 'early_stop_patience'] and eidx > 0: is_early_stop = True WARN("Early Stop!") best_valid_loss = min_history_loss summary_writer.add_scalar("bad_count", bad_count, uidx) INFO("{0} Loss: {1:.2f} lrate: {2:6f} patience: {3}".format( uidx, valid_loss, lrate, bad_count)) training_progress_bar.close() eidx += 1 if eidx > training_configs["max_epochs"]: break
class SACAgent(object): def __init__(self, device="cpu", d_word_vec=512, d_model=256, limit_dist=0.1, dropout=0.0, reparam_noise=1e-6, **kwargs): self.device = device self.actor = Rephraser(d_word_vec=d_word_vec, d_model=d_model, limit_dist=limit_dist, dropout=dropout, reparam_noise=reparam_noise).to(device) self.critic = CriticNet(d_word_vec=d_word_vec, d_model=d_model, limit_dist=limit_dist, dropout=dropout, reparam_noise=reparam_noise).to(device) self.saver = Saver(save_prefix="{0}.ckpt".format( os.path.join(kwargs["save_to"], "ACmodel")), num_max_keeping=kwargs["num_kept_checkpoints"]) self.soft_update_lock = mp.Lock() # the entropy regularization weight for SAC learning self.learnable_temperature = kwargs["learnable_temperature"] self.target_entropy = -d_word_vec # act_dim (d_word_vec) as the expected entropy base self.log_alpha = torch.tensor(np.log(kwargs["init_temperature"])).to( self.device) self.log_alpha.requires_grad = True # initialize the training mode for the Agent self.train() self._init_local_optims(kwargs["rephraser_optimizer_configs"]) # self.load_model() # always reload model if there is any in the path def to(self, device): self.actor.to(device) self.critic.to(device) self.log_alpha.to(device) return self def share_memory(self): # global model needs to share memory with other threads self.actor.share_memory() self.critic.share_memory() @property def alpha(self): return self.log_alpha.exp() def load_model(self, load_final_path: str = None): """ load from path by self.saver :param load_final_path: final model path dir, final model doesn't have optim_params :return: training step count int """ step = 0 model_collections = Collections() if load_final_path: # self.saver.load_latest( # actor_model=self.actor, critic_model=self.critic # ) # load from the latest ckpt model state_dict = torch.load(os.path.join(load_final_path)) self.actor.load_state_dict(state_dict["actor_model"]) self.critic.load_state_dict(state_dict["critic_model"]) else: self.saver.load_latest(collections=model_collections, actor_model=self.actor, critic_model=self.critic, actor_optim=self.actor_optimizer, critic_optim=self.critic_optimizer, actor_scheduler=self.actor_scheduler, critic_scheduler=self.critic_scheduler) step = model_collections.get_collection("step", [0])[-1] return step def save_model( self, step=None, save_to_final=None): # save model parameters, optims, lr_steps model_collections = Collections() if step is not None: model_collections.add_to_collection("step", step) self.saver.save(global_step=step, collections=model_collections, actor_model=self.actor, critic_model=self.critic, actor_optim=self.actor_optimizer, critic_optim=self.critic_optimizer, actor_scheduler=self.actor_scheduler, critic_scheduler=self.critic_scheduler) else: # only save the model parameters assert save_to_final is not None, "final model saving dir must be provided" collection = dict() collection["actor_model"] = self.actor.state_dict() collection["critic_model"] = self.critic.state_dict() torch.save(collection, os.path.join(save_to_final, "ACmodel.final")) return def _init_local_optims(self, rephraser_optimizer_configs): """ actor, critic, alpha optimizers and lr scheduler if necessary rephraser_optimizer_configs: optimizer: "adafactor" learning_rate: 0.01 grad_clip: -1.0 optimizer_params: ~ schedule_method: rsqrt scheduler_configs: d_model: *dim warmup_steps: 100 """ # initiate local optimizer if rephraser_optimizer_configs is None: self.actor_optimizer = None self.critic_optimizer = None self.log_alpha_optimizer = None # self.actor_icm_optimizer = None self.actor_scheduler = None self.critic_scheduler = None else: self.actor_optimizer = Optimizer( name=rephraser_optimizer_configs["optimizer"], model=self.actor, lr=rephraser_optimizer_configs["learning_rate"], grad_clip=rephraser_optimizer_configs["grad_clip"], optim_args=rephraser_optimizer_configs["optimizer_params"]) self.critic_optimizer = Optimizer( name=rephraser_optimizer_configs["optimizer"], model=self.critic, lr=rephraser_optimizer_configs["learning_rate"], grad_clip=rephraser_optimizer_configs["grad_clip"], optim_args=rephraser_optimizer_configs["optimizer_params"]) # hardcoded entropy weight updates and icm updates self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=1e-4, betas=(0.9, 0.999)) # self.actor_icm_optimizer = torch.optim.Adam(self.actor.icm.parameters(), lr=1e-3, ) # Build scheduler for optimizer if needed if rephraser_optimizer_configs['schedule_method'] is not None: if rephraser_optimizer_configs['schedule_method'] == "loss": self.actor_scheduler = ReduceOnPlateauScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = ReduceOnPlateauScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) elif rephraser_optimizer_configs['schedule_method'] == "noam": self.actor_scheduler = NoamScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = NoamScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) elif rephraser_optimizer_configs["schedule_method"] == "rsqrt": self.actor_scheduler = RsqrtScheduler( optimizer=self.actor_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) self.critic_scheduler = RsqrtScheduler( optimizer=self.critic_optimizer, **rephraser_optimizer_configs["scheduler_configs"]) else: WARN( "Unknown scheduler name {0}. Do not use lr_scheduling." .format( rephraser_optimizer_configs['schedule_method'])) self.actor_scheduler = None self.critic_scheduler = None else: self.actor_scheduler = None self.critic_scheduler = None def sync_from(self, sac_agent): with sac_agent.soft_update_lock, self.soft_update_lock: self.actor.sync_from(sac_agent.actor) self.critic.sync_from(sac_agent.critic) def train(self, training=True): # default training is true self.training = training self.actor.train(training) self.critic.train(training) return self def update_critic(self, states, masks, rephrase_positions, actions, rewards, survive_and_no_maxs, target_critic, update_step, discount_factor, summary_writer, update_trust_region=0.8): """ update critic by using a target_critic net (usually a global critic model) and a buffer SARSA for TD learning :param states: :param masks: :param rephrase_positions: :param actions: actions :param rewards: the rewards :param survive_and_no_maxs: able to rollout next step for TD learning :param target_critic: provides target value estimation(global model is usually on cpu) :param discount_factor: for discounted rewards update :param update_step: learning steps :param update_trust_region: discount for loss updates """ label_emb = slice_by_indices( states, rephrase_positions, device=self.device) # next_rephrase_positions to label emb next_action, log_probs = self.actor.sample_normal(states, 1. - masks, label_emb, reparamization=True) log_probs = log_probs.sum(dim=-1, keepdims=True) next_states = transition(states, masks, actions, rephrase_positions) next_rephrase_positions = rephrase_positions + 1 next_label_emb = slice_by_indices(next_states, next_rephrase_positions, device=self.device) # # note that with intrinsic curiosity module, the rewards will add curiosity bonus # self.actor.icm.eval() # rephrase_feature = self.actor.preprocess(states, 1.-masks, label_emb) # next_rephrase_feature = self.actor.preprocess(next_states, 1.-masks, next_label_emb) # bonus = self.actor.icm.get_surprise_bonus(rephrase_feature, next_rephrase_feature, actions).detach() # bonus = 0.01 * bonus * survive_and_no_maxs # print("bonus:", bonus.sum()) # rewards += bonus # print("rewards:", rewards.squeeze()) # note that log_probs has the same dimension with the action. thus the log_prob of a whole action is the sum along dimensions. target_critic.eval() target_V = target_critic( next_states, 1. - masks, next_label_emb, next_action) - log_probs * self.alpha.detach() target_Q = rewards + ( survive_and_no_maxs ) * discount_factor * target_V # we have next states for TD learning rollout target_Q = target_Q.detach() # get current Q estimates current_Q = self.critic(states, 1. - masks, label_emb, actions) critic_loss = F.mse_loss(current_Q, target_Q) critic_loss *= update_trust_region print("critic_loss", critic_loss.sum()) # Optimize the critic self.critic_optimizer.zero_grad() if self.critic_scheduler is not None: self.critic_scheduler.step(global_step=update_step) critic_loss.backward() self.critic_optimizer.step() # logging: entropy/target_entropy ratio, critic_loss, summary_writer.add_scalar("critic_loss", scalar_value=critic_loss, global_step=update_step) def update_actor_and_alpha(self, states, masks, rephrase_positions, target_critic, update_step, summary_writer, update_trust_region=0.5): """ :param states: tensor states from the buffer samples :param masks: indicats the valid token positions :param rephrase_positions: induce the next states by the given states :param update_trust_region: current annnunciator trust_acc (valid). served as a trust region for RL updates; also the weight of rewind or reinforce trust_acc * rewind_loss + (1-trust_acc) * policy_loss """ self.actor.train() label_emb = slice_by_indices(states, rephrase_positions, device=self.device) actions, log_probs = self.actor.sample_normal(states, 1. - masks, label_emb, reparamization=True) log_probs = log_probs.sum(dim=-1, keepdims=True) actor_Q = self.critic(states, 1. - masks, label_emb, actions) policy_loss = (self.alpha.detach() * log_probs - actor_Q).mean() summary_writer.add_scalar('policy_loss', policy_loss, update_step) summary_writer.add_scalar('entropy_ratio', -log_probs.mean() / self.target_entropy, update_step) # the policy rewind loss, the rewind is determined by target value estimates (estimated survival + improvements) # negative means rewind needed. target_Q = target_critic(states, 1. - masks, label_emb, actions).detach() rewind_mask = target_Q.lt(0.).detach().float() # [batch, 1] next_states = transition(states, masks, actions, rephrase_positions).detach() next_label_emb = slice_by_indices(next_states, rephrase_positions, device=self.device).detach() rewind_action, _ = self.actor.forward(next_states, 1. - masks, next_label_emb) target_actions = -actions.detach() rewind_loss = F.mse_loss( rewind_action * rewind_mask * self.actor.action_range, target_actions * rewind_mask) summary_writer.add_scalar('rewind_loss', rewind_loss, update_step) # the higher trust region means less indicative the perturbations are, policy should focus more on the rewind. actor_loss = (update_trust_region) * rewind_loss + ( 1. - update_trust_region) * policy_loss ## update the intrinsic reward module: action reconstruction and feature prediction mse # self.actor.icm.train() # next_states = transition(states, masks, actions, rephrase_positions) # next_label_emb = slice_by_indices(next_states, rephrase_positions, device=self.device) # rephrase_feature = self.actor.preprocess(states, 1.0-masks, label_emb) # next_rephrase_feature = self.actor.preprocess(next_states, 1.0-masks, next_label_emb) # if update_step<3000: # # icm updates does not propagate to the policy on the early stage # icm_loss = self.actor.icm(rephrase_feature.detach(), next_rephrase_feature.detach(), actions) # else: # icm_loss = self.actor.icm(rephrase_feature, next_rephrase_feature, actions) # summary_writer.add_scalar("intrinsic_curiosity_loss", icm_loss, update_step) # # the 0.1 is the setting by Intrinsic curiosity learning # actor_loss = actor_loss + icm_loss # optimize the actor self.actor_optimizer.zero_grad() # self.actor_icm_optimizer.zero_grad() if self.actor_scheduler is not None: self.actor_scheduler.step(global_step=update_step) actor_loss.backward() self.actor_optimizer.step() # self.actor_icm_optimizer.step() if self.learnable_temperature: self.log_alpha_optimizer.zero_grad() alpha_loss = (self.alpha * (-log_probs - self.target_entropy).detach()).mean() summary_writer.add_scalar('alpha_loss', alpha_loss, update_step) summary_writer.add_scalar('alpha', self.alpha, update_step) alpha_loss.backward() self.log_alpha_optimizer.step() def update_local_net(self, local_agent_configs, replay_buffer, target_critic, update_step, discount_factor, summary_writer, update_trust_region=1.0): """ :param local_agent_configs: provides agent update freq :param replay_buffer: provides the SARSA listed below [states, masks, actions, rephrase_positions, rewards, terminal_signals] states: the embedding as states. [batch, len, emb_dim] float masks: the indicator of valid token for embedding. [batch, len] float actions: the action embedding on the position. [batch, emb_dim] float rephrase_positions: the position to rephrase. [batch, 1] long rewards: the rewards for the transition. [batch, 1] float terminal_signals: the terminal signals for the transition. [batch, 1] float :param target_critic: provides the global critic :param update_step: for lr scheduler and logging :param discount_factor: rollout-rewards discount :param summary_writer: logging :param update_trust_region: discount for loss updates """ learn_batch_size = local_agent_configs["rephraser_learning_batch"] states, masks, \ actions, rephrase_positions, \ rewards, _, survive_and_no_maxs = replay_buffer.sample(learn_batch_size, device=self.device) INFO("update local agent critics on device: %s" % self.device) self.update_critic(states, masks, rephrase_positions, actions, rewards, survive_and_no_maxs, target_critic, update_step, discount_factor, summary_writer, update_trust_region) if update_step % local_agent_configs["actor_update_freq"] == 0: INFO("update local agent policy on device: %s" % self.device) self.update_actor_and_alpha(states, masks, rephrase_positions, target_critic, update_step, summary_writer, update_trust_region) def soft_update_target_net(self, target_SACAgent, tau): # soft update the target network. first move to CPU, than move back to local # mind not to update global model while reading and synch local models. self.to(target_SACAgent.device) with target_SACAgent.soft_update_lock: for param, target_param in zip( self.critic.parameters(), target_SACAgent.critic.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) for param, target_param in zip(self.actor.parameters(), target_SACAgent.actor.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) self.to(self.device)
class Solver(BaseSolver): ''' Solver for training language models''' def __init__(self, config, paras, mode): super().__init__(config, paras, mode) # Logger settings self.best_loss = 10 def fetch_data(self, data): ''' Move data to device, insert <sos> and compute text seq. length''' txt = torch.cat((torch.zeros( (data.shape[0], 1), dtype=torch.long), data), dim=1).to(self.device) txt_len = torch.sum(data != 0, dim=-1) return txt, txt_len def load_data(self): ''' Load data for training/validation, store tokenizer and input/output shape''' self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \ load_textset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup ASR model and optimizer ''' # Model # self.model = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.model = Prediction(self.vocab_size, **self.config['model']).to(self.device) self.rnnlm = RNNLM(self.vocab_size, **self.config['model']).to(self.device) self.verbose(self.rnnlm.create_msg()) # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer self.optimizer = Optimizer( list(self.model.parameters()) + list(self.rnnlm.parameters()), **self.config['hparas']) # Enable AMP if needed self.enable_apex() # load pre-trained model if self.paras.load: self.load_ckpt() ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step)) def exec(self): ''' Training End-to-end ASR system ''' self.verbose('Total training steps {}.'.format( human_format(self.max_step))) self.timer.set() while self.step < self.max_step: for data in self.tr_set: # Pre-step : update tf_rate/lr_rate and do zero_grad self.optimizer.pre_step(self.step) # Fetch data txt, txt_len = self.fetch_data(data) self.timer.cnt('rd') # Forward model outputs, hidden = self.model(txt[:, :-1], txt_len) pred = self.rnnlm(outputs) # Compute all objectives lm_loss = self.seq_loss(pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1)) self.timer.cnt('fw') # Backprop grad_norm = self.backward(lm_loss) self.step += 1 # Logger if self.step % self.PROGRESS_STEP == 0: self.progress( 'Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'. format(lm_loss.cpu().item(), grad_norm, self.timer.show())) self.write_log('entropy', {'tr': lm_loss}) self.write_log('perplexity', {'tr': torch.exp(lm_loss).cpu().item()}) # Validation if (self.step == 1) or (self.step % self.valid_step == 0): self.validate() # End of step self.timer.set() if self.step > self.max_step: break self.log.close() def validate(self): # Eval mode self.model.eval() self.rnnlm.eval() dev_loss = [] for i, data in enumerate(self.dv_set): self.progress('Valid step - {}/{}'.format(i + 1, len(self.dv_set))) # Fetch data txt, txt_len = self.fetch_data(data) # Forward model with torch.no_grad(): outputs, hidden = self.model(txt[:, :-1], txt_len) pred = self.rnnlm(outputs) lm_loss = self.seq_loss(pred.view(-1, self.vocab_size), txt[:, 1:].reshape(-1)) dev_loss.append(lm_loss) # Ckpt if performance improves dev_loss = sum(dev_loss) / len(dev_loss) dev_ppx = torch.exp(dev_loss).cpu().item() if dev_loss < self.best_loss: self.best_loss = dev_loss self.save_checkpoint('best_ppx.pth', 'perplexity', dev_ppx) self.write_log('entropy', {'dv': dev_loss}) self.write_log('perplexity', {'dv': dev_ppx}) # Show some example of last batch on tensorboard for i in range(min(len(txt), self.DEV_N_EXAMPLE)): if self.step == 1: self.write_log('true_text{}'.format(i), self.tokenizer.decode(txt[i].tolist())) self.write_log( 'pred_text{}'.format(i), self.tokenizer.decode(pred[i].argmax(dim=-1).tolist())) # Resume training self.model.train() self.rnnlm.train()
class Solver(BaseSolver): ''' Solver for training''' def __init__(self, config, paras): super().__init__(config, paras) # Logger settings self.best_dev_er = 1.0 self.cur_epoch = 0 # Configs following self-supervised learning self.task = self.paras.task assert self.task in ['phn-clf', 'spk-clf'], 'unsupported task' self.ssl_config = yaml.load(open( self.config['model']['feat']['config'], 'r'), Loader=yaml.FullLoader) self.feature = self.ssl_config['model']['method'] if self.feature == 'npc' and 'spec' in self.config['model']['feat']: # NPC has additional option to use unmasked feature self.feat_spec = self.config['model']['feat']['spec'] else: self.feat_spec = None self.config['data']['audio'] = self.ssl_config['data']['audio'] def fetch_data(self, data, train=True): ''' Move data to device ''' file_id, audio_feat, audio_len, label = data if self.gpu: audio_feat = audio_feat.cuda() label = label.cuda() # Extract feature with torch.no_grad(): if self.feat_spec is not None: # Get unmasked feature from particular NPC layer n_layer_feat = int(self.feat_spec.split('-')[-1]) audio_feat = self.feat_extractor.get_unmasked_feat( audio_feat, n_layer_feat) elif self.feature == 'npc': # Get masked feature from NPC _, audio_feat = self.feat_extractor(audio_feat, testing=True) else: # Get feature from APC based model _, audio_feat = self.feat_extractor(audio_feat, audio_len, testing=True) # Mean pool feature for spkr classification if self.task == 'spk-clf': single_feat = [] for a_feat, a_len in zip(audio_feat, audio_len): single_feat.append(a_feat[:a_len].mean(dim=0)) audio_feat = torch.stack(single_feat, dim=0) return file_id, audio_feat, audio_len, label def load_data(self): ''' Load data for training/validation ''' self.tr_set, self.dv_set, self.tt_set, self.audio_dim, msg = \ prepare_data(self.paras.njobs,self.paras.dev_njobs,self.paras.gpu, self.paras.pin_memory, **self.config['data']) self.verbose(msg) def set_model(self): ''' Setup model and optimizer ''' # Load SSL models for feature extraction self.verbose([' Load feat. extractor ckpt from '\ +self.config['model']['feat']['ckpt']]) if self.feature in ['apc', 'vqapc']: from model.apc import APC as Net elif self.feature == 'npc': from model.npc import NPC as Net if self.feat_spec is not None: self.verbose([' Using specific feature: ' + self.feat_spec]) else: raise NotImplementedError self.feat_extractor = Net(input_size=self.audio_dim, **self.ssl_config['model']['paras']) ckpt = torch.load( self.config['model']['feat']['ckpt'], map_location=self.device if self.mode == 'train' else 'cpu') ckpt['model'] = {k.replace('module.','',1):v \ for k,v in ckpt['model'].items()} self.feat_extractor.load_state_dict(ckpt['model']) # Classifier model self.model = CLF(feat_dim=self.feat_extractor.code_dim, **self.config['model']['clf']) if self.gpu: self.feat_extractor = self.feat_extractor.cuda() self.feat_extractor.eval() self.model = self.model.cuda() model_paras = [{'params': self.model.parameters()}] # Losses ignore_idx = 0 if self.task == 'phn-clf' else -1 self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx) if self.gpu: self.loss = self.loss.cuda() # Optimizer self.optimizer = Optimizer(model_paras, **self.config['hparas']) self.verbose(self.optimizer.create_msg()) self.load_ckpt() self.model.train() def exec(self): ''' Training End-to-end ASR system ''' if self.paras.mode == 'train': self.verbose('Total training epoch {}.'.format( human_format(self.epoch))) self.timer.set() ep_len = len(self.tr_set) for ep in range(self.epoch): if ep > 0: # Lr decay if needed self.optimizer.decay() for data in self.tr_set: # Pre-step : do zero_grad self.optimizer.pre_step(self.step) # Fetch data self.timer.cnt('rd') _, audio_feat, audio_len, label = self.fetch_data(data) # Forward pred = self.model(audio_feat) if self.task == 'phn-clf': pred = pred.permute(0, 2, 1) # BxCxT for phn clf loss = self.loss(pred, label) self.timer.cnt('fw') # Backprop grad_norm = self.backward(loss) self.step += 1 # Logger if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): self.progress( ' {:2.1f} % | Loss - {:.2f} | Grad. Norm - {:.2f} | {}' .format(100 * float(self.step % ep_len) / ep_len, loss.cpu().item(), grad_norm, self.timer.show())) self.write_log(self.task + '_loss', {'tr': loss}) if self.task == 'phn-clf': tr_er = cal_per(pred, label, audio_len)[0] else: tr_er = (pred.argmax(dim=-1) != label) tr_er = tr_er.sum().detach().cpu().float() / len( label) self.write_log(self.task + '_er', {'tr': tr_er}) # End of step self.timer.set() # End of epoch self.cur_epoch += 1 self.validate() # Test at the end self.validate(test=True) self.log.close() def validate(self, test=False): # Eval mode self.model.eval() val_loss = [] split = 'dev' val_hit, val_total = 0.0, 0.0 ds = self.tt_set if test else self.dv_set # In training mode, best model is stored in RAM for test # ToDo: load ckpt if test: split = 'test' if self.paras.mode == 'train': self.model = self.best_model if self.gpu: self.model = self.model.cuda() for i, data in enumerate(ds): self.progress('Valid step - {}/{}'.format(i + 1, len(ds))) # Fetch data _, audio_feat, audio_len, label = self.fetch_data(data) # Forward model with torch.no_grad(): # Prediction pred = self.model(audio_feat) if self.task == 'phn-clf': pred = pred.permute(0, 2, 1) # BxCxT # Accumulate batch result val_loss.append(self.loss(pred, label)) if self.task == 'phn-clf': _, hit, total = cal_per(pred, label, audio_len) val_hit += hit val_total += total else: hit = (pred.argmax(dim=-1) == label).sum() val_hit += hit.detach().cpu().float() val_total += len(label) # Write testing prediction if needed if test and self.paras.write_test: if self.task == 'phn-clf': pred = pred.argmax(dim=1).detach().cpu() label = label.cpu() with open(os.path.join(self.ckpdir, self.task + '.csv'), 'a') as f: for p, l, a_len in zip(pred, label, audio_len): for x, y in zip(p[:a_len].tolist(), l[:a_len].tolist()): f.write('{}\t{}\n'.format(x, y)) # Record metric, store ckpt by dev error rate val_loss = sum(val_loss) / len(val_loss) val_er = 1.0 - val_hit / val_total self.write_log(self.task + '_loss', {split: val_loss}) self.write_log(self.task + '_er', {split: val_er}) if split == 'dev' and self.best_dev_er > val_er: self.best_dev_er = val_er self.save_checkpoint('best.pth', self.task + '_er', val_er) self.best_model = copy.deepcopy(self.model.cpu()) # Clone for test # Resume training if self.gpu: self.model = self.model.cuda() self.model.train()