def _train_epoches(self, train_set, model, n_epochs, start_epoch, start_step, dev_set=None): log = self.logger las_print_loss_total = 0 # Reset every print_every step = start_step step_elapsed = 0 prev_acc = 0.0 count_no_improve = 0 count_num_rollback = 0 ckpt = None # ******************** [loop over epochs] ******************** for epoch in range(start_epoch, n_epochs + 1): for param_group in self.optimizer.optimizer.param_groups: log.info('epoch:{} lr: {}'.format(epoch, param_group['lr'])) lr_curr = param_group['lr'] # ----------construct batches----------- log.info('--- construct train set ---') train_set.construct_batches(is_train=True) if dev_set is not None: log.info('--- construct dev set ---') dev_set.construct_batches(is_train=True) # --------print info for each epoch---------- steps_per_epoch = len(train_set.iter_loader) total_steps = steps_per_epoch * n_epochs log.info("steps_per_epoch {}".format(steps_per_epoch)) log.info("total_steps {}".format(total_steps)) log.debug(" --------- Epoch: %d, Step: %d ---------" % (epoch, step)) mem_kb, mem_mb, mem_gb = get_memory_alloc() mem_mb = round(mem_mb, 2) log.info('Memory used: {0:.2f} MB'.format(mem_mb)) self.writer.add_scalar('Memory_MB', mem_mb, global_step=step) sys.stdout.flush() # ******************** [loop over batches] ******************** model.train(True) trainiter = iter(train_set.iter_loader) for idx in range(steps_per_epoch): # load batch items batch_items = trainiter.next() # update macro count step += 1 step_elapsed += 1 # Get loss losses = self._train_batch(model, batch_items, train_set, step, total_steps) las_loss = losses['las_loss'] las_print_loss_total += las_loss if step % self.print_every == 0 and step_elapsed > self.print_every: las_print_loss_avg = las_print_loss_total / self.print_every las_print_loss_total = 0 log_msg = 'Progress: %d%%, Train las: %.4f'\ % (step / total_steps * 100, las_print_loss_avg) log.info(log_msg) self.writer.add_scalar('train_las_loss', las_print_loss_avg, global_step=step) # Checkpoint if step % self.checkpoint_every == 0 or step == total_steps: # save criteria if dev_set is not None: dev_accs, dev_losses = self._evaluate_batches( model, dev_set) las_loss = dev_losses['las_loss'] las_acc = dev_accs['las_acc'] log_msg = 'Progress: %d%%, Dev las loss: %.4f, accuracy: %.4f'\ % (step / total_steps * 100, las_loss, las_acc) log.info(log_msg) self.writer.add_scalar('dev_las_loss', las_loss, global_step=step) self.writer.add_scalar('dev_las_acc', las_acc, global_step=step) accuracy = las_acc # save if prev_acc < accuracy: # save best model ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_src) saved_path = ckpt.save(self.expt_dir) log.info('saving at {} ... '.format(saved_path)) # reset prev_acc = accuracy count_no_improve = 0 count_num_rollback = 0 else: count_no_improve += 1 # roll back if count_no_improve > self.max_count_no_improve: # resuming latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) log.info( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim.__class__( model.parameters(), **defaults) # reset count_no_improve = 0 count_num_rollback += 1 # update learning rate if count_num_rollback > self.max_count_num_rollback: # roll back latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) log.info( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim.__class__( model.parameters(), **defaults) # decrease lr for param_group in self.optimizer.optimizer.param_groups: param_group['lr'] *= 0.5 lr_curr = param_group['lr'] log.info('reducing lr ...') log.info('step:{} - lr: {}'.format( step, param_group['lr'])) # check early stop if lr_curr < 0.125 * self.learning_rate: log.info('early stop ...') break # reset count_no_improve = 0 count_num_rollback = 0 model.train(mode=True) if ckpt is None: ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_tgt) ckpt.rm_old(self.expt_dir, keep_num=self.keep_num) log.info('n_no_improve {}, num_rollback {}'.format( count_no_improve, count_num_rollback)) sys.stdout.flush() else: if dev_set is None: # save every epoch if no dev_set ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_src) saved_path = ckpt.save_epoch(self.expt_dir, epoch) log.info('saving at {} ... '.format(saved_path)) continue else: continue # break nested for loop break
def _train_epoches(self, train_sets, model, n_epochs, start_epoch, start_step, dev_sets=None): # load datasets train_set_asr = train_sets['asr'] dev_set_asr = dev_sets['asr'] train_set_mt = train_sets['mt'] dev_set_mt = dev_sets['mt'] log = self.logger print_loss_ae_total = 0 # Reset every print_every print_loss_asr_total = 0 print_loss_mt_total = 0 print_loss_kl_total = 0 print_loss_l2_total = 0 step = start_step step_elapsed = 0 prev_acc = 0.0 prev_bleu = 0.0 count_no_improve = 0 count_num_rollback = 0 ckpt = None # loop over epochs for epoch in range(start_epoch, n_epochs + 1): # update lr if self.lr_warmup_steps != 0: self.optimizer.optimizer = self.lr_scheduler( self.optimizer.optimizer, step, init_lr=self.learning_rate_init, peak_lr=self.learning_rate, warmup_steps=self.lr_warmup_steps) # print lr for param_group in self.optimizer.optimizer.param_groups: log.info('epoch:{} lr: {}'.format(epoch, param_group['lr'])) lr_curr = param_group['lr'] # construct batches - allow re-shuffling of data log.info('--- construct train set ---') train_set_asr.construct_batches(is_train=True) train_set_mt.construct_batches(is_train=True) if dev_set_asr is not None: log.info('--- construct dev set ---') dev_set_asr.construct_batches(is_train=False) dev_set_mt.construct_batches(is_train=False) # print info steps_per_epoch_asr = len(train_set_asr.iter_loader) steps_per_epoch_mt = len(train_set_mt.iter_loader) steps_per_epoch = min(steps_per_epoch_asr, steps_per_epoch_mt) total_steps = steps_per_epoch * n_epochs log.info("steps_per_epoch {}".format(steps_per_epoch)) log.info("total_steps {}".format(total_steps)) log.info(" ---------- Epoch: %d, Step: %d ----------" % (epoch, step)) mem_kb, mem_mb, mem_gb = get_memory_alloc() mem_mb = round(mem_mb, 2) log.info('Memory used: {0:.2f} MB'.format(mem_mb)) self.writer.add_scalar('Memory_MB', mem_mb, global_step=step) sys.stdout.flush() # loop over batches model.train(True) trainiter_asr = iter(train_set_asr.iter_loader) trainiter_mt = iter(train_set_mt.iter_loader) for idx in range(steps_per_epoch): # load batch items batch_items_asr = trainiter_asr.next() batch_items_mt = trainiter_mt.next() # update macro count step += 1 step_elapsed += 1 if self.lr_warmup_steps != 0: self.optimizer.optimizer = self.lr_scheduler( self.optimizer.optimizer, step, init_lr=self.learning_rate_init, peak_lr=self.learning_rate, warmup_steps=self.lr_warmup_steps) # Get loss losses = self._train_batch(model, batch_items_asr, batch_items_mt, step, total_steps) loss_ae = losses['nll_loss_ae'] loss_asr = losses['nll_loss_asr'] loss_mt = losses['nll_loss_mt'] loss_kl = losses['kl_loss'] loss_l2 = losses['l2_loss'] print_loss_ae_total += loss_ae print_loss_asr_total += loss_asr print_loss_mt_total += loss_mt print_loss_kl_total += loss_kl print_loss_l2_total += loss_l2 if step % self.print_every == 0 and step_elapsed > self.print_every: print_loss_ae_avg = print_loss_ae_total / self.print_every print_loss_ae_total = 0 print_loss_asr_avg = print_loss_asr_total / self.print_every print_loss_asr_total = 0 print_loss_mt_avg = print_loss_mt_total / self.print_every print_loss_mt_total = 0 print_loss_kl_avg = print_loss_kl_total / self.print_every print_loss_kl_total = 0 print_loss_l2_avg = print_loss_l2_total / self.print_every print_loss_l2_total = 0 log_msg = 'Progress: %d%%, Train nlll_ae: %.4f, nlll_asr: %.4f, ' % ( step / total_steps * 100, print_loss_ae_avg, print_loss_asr_avg) log_msg += 'Train nlll_mt: %.4f, l2: %.4f, kl_en: %.4f' % ( print_loss_mt_avg, print_loss_l2_avg, print_loss_kl_avg) log.info(log_msg) self.writer.add_scalar('train_loss_ae', print_loss_ae_avg, global_step=step) self.writer.add_scalar('train_loss_asr', print_loss_asr_avg, global_step=step) self.writer.add_scalar('train_loss_mt', print_loss_mt_avg, global_step=step) self.writer.add_scalar('train_loss_kl', print_loss_kl_avg, global_step=step) self.writer.add_scalar('train_loss_l2', print_loss_l2_avg, global_step=step) # Checkpoint if step % self.checkpoint_every == 0 or step == total_steps: # save criteria if dev_set_asr is not None: losses, metrics = self._evaluate_batches( model, dev_set_asr, dev_set_mt) loss_kl = losses['kl_loss'] loss_l2 = losses['l2_loss'] loss_ae = losses['nll_loss_ae'] accuracy_ae = metrics['accuracy_ae'] bleu_ae = metrics['bleu_ae'] loss_asr = losses['nll_loss_asr'] accuracy_asr = metrics['accuracy_asr'] bleu_asr = metrics['bleu_asr'] loss_mt = losses['nll_loss_mt'] accuracy_mt = metrics['accuracy_mt'] bleu_mt = metrics['bleu_mt'] log_msg = 'Progress: %d%%, Dev AE loss: %.4f, accuracy: %.4f, bleu: %.4f' % ( step / total_steps * 100, loss_ae, accuracy_ae, bleu_ae) log.info(log_msg) log_msg = 'Progress: %d%%, Dev ASR loss: %.4f, accuracy: %.4f, bleu: %.4f' % ( step / total_steps * 100, loss_asr, accuracy_asr, bleu_asr) log.info(log_msg) log_msg = 'Progress: %d%%, Dev MT loss: %.4f, accuracy: %.4f, bleu: %.4f' % ( step / total_steps * 100, loss_mt, accuracy_mt, bleu_mt) log.info(log_msg) log_msg = 'Progress: %d%%, Dev En KL loss: %.4f, L2 loss: %.4f' % ( step / total_steps * 100, loss_kl, loss_l2) log.info(log_msg) self.writer.add_scalar('dev_loss_l2', loss_l2, global_step=step) self.writer.add_scalar('dev_loss_kl', loss_kl, global_step=step) self.writer.add_scalar('dev_loss_ae', loss_ae, global_step=step) self.writer.add_scalar('dev_acc_ae', accuracy_ae, global_step=step) self.writer.add_scalar('dev_bleu_ae', bleu_ae, global_step=step) self.writer.add_scalar('dev_loss_asr', loss_asr, global_step=step) self.writer.add_scalar('dev_acc_asr', accuracy_asr, global_step=step) self.writer.add_scalar('dev_bleu_asr', bleu_asr, global_step=step) self.writer.add_scalar('dev_loss_mt', loss_mt, global_step=step) self.writer.add_scalar('dev_acc_mt', accuracy_mt, global_step=step) self.writer.add_scalar('dev_bleu_mt', bleu_mt, global_step=step) # save - use ASR res accuracy_ave = (accuracy_asr / 4.0 + accuracy_mt) / 2.0 bleu_ave = (bleu_asr / 4.0 + bleu_mt) / 2.0 if ((prev_acc < accuracy_ave) and (bleu_ave < 0.1)) or prev_bleu < bleu_ave: # save best model - using bleu as metric ckpt = Checkpoint( model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set_asr.vocab_src, output_vocab=train_set_asr.vocab_tgt) saved_path = ckpt.save(self.expt_dir) log.info('saving at {} ... '.format(saved_path)) # reset prev_acc = accuracy_ave prev_bleu = bleu_ave count_no_improve = 0 count_num_rollback = 0 else: count_no_improve += 1 # roll back if count_no_improve > self.max_count_no_improve: # break after self.max_count_no_improve epochs if self.max_count_num_rollback == 0: break # resuming latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) log.info( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim.__class__( model.parameters(), **defaults) # reset count_no_improve = 0 count_num_rollback += 1 # update learning rate if count_num_rollback > self.max_count_num_rollback: # roll back latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) log.info( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim.__class__( model.parameters(), **defaults) # decrease lr for param_group in self.optimizer.optimizer.param_groups: param_group['lr'] *= 0.5 lr_curr = param_group['lr'] log.info('reducing lr ...') log.info('step:{} - lr: {}'.format( step, param_group['lr'])) # check early stop if lr_curr <= 0.125 * self.learning_rate: log.info('early stop ...') break # reset count_no_improve = 0 count_num_rollback = 0 model.train(mode=True) if ckpt is not None: ckpt.rm_old(self.expt_dir, keep_num=self.keep_num) log.info('n_no_improve {}, num_rollback {}'.format( count_no_improve, count_num_rollback)) sys.stdout.flush() else: if dev_set_asr is None: # save every epoch if no dev_set ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set_asr.vocab_src, output_vocab=train_set_asr.vocab_tgt) saved_path = ckpt.save_epoch(self.expt_dir, epoch) log.info('saving at {} ... '.format(saved_path)) continue else: continue # break nested for loop break
def _train_epochs(self, train_set, model, n_epochs, start_epoch, start_step, dev_set=None): log = self.logger print_loss_total = 0 # Reset every print_every step = start_step step_elapsed = 0 prev_acc = 0.0 prev_bleu = 0.0 count_no_improve = 0 count_num_rollback = 0 ckpt = None # loop over epochs for epoch in range(start_epoch, n_epochs + 1): # update lr if self.lr_warmup_steps != 0: self.optimizer.optimizer = self.lr_scheduler( self.optimizer.optimizer, step, init_lr=self.learning_rate_init, peak_lr=self.learning_rate, warmup_steps=self.lr_warmup_steps) # print lr for param_group in self.optimizer.optimizer.param_groups: log.info('epoch:{} lr: {}'.format(epoch, param_group['lr'])) lr_curr = param_group['lr'] # construct batches - allow re-shuffling of data log.info('--- construct train set ---') train_set.construct_batches(is_train=True) if dev_set is not None: log.info('--- construct dev set ---') dev_set.construct_batches(is_train=False) # print info steps_per_epoch = len(train_set.iter_loader) total_steps = steps_per_epoch * n_epochs log.info("steps_per_epoch {}".format(steps_per_epoch)) log.info("total_steps {}".format(total_steps)) log.info(" ---------- Epoch: %d, Step: %d ----------" % (epoch, step)) mem_kb, mem_mb, mem_gb = get_memory_alloc() mem_mb = round(mem_mb, 2) log.info('Memory used: {0:.2f} MB'.format(mem_mb)) self.writer.add_scalar('Memory_MB', mem_mb, global_step=step) sys.stdout.flush() # loop over batches model.train(True) trainiter = iter(train_set.iter_loader) for idx in range(steps_per_epoch): # load batch items batch_items = trainiter.next() # update macro count step += 1 step_elapsed += 1 if self.lr_warmup_steps != 0: self.optimizer.optimizer = self.lr_scheduler( self.optimizer.optimizer, step, init_lr=self.learning_rate_init, peak_lr=self.learning_rate, warmup_steps=self.lr_warmup_steps) # Get loss loss = self._train_batch(model, batch_items, train_set, step, total_steps) print_loss_total += loss if step % self.print_every == 0 and step_elapsed > self.print_every: print_loss_avg = print_loss_total / self.print_every print_loss_total = 0 log_msg = 'Progress: %d%%, Train nlll: %.4f' % ( step / total_steps * 100, print_loss_avg) log.info(log_msg) self.writer.add_scalar('train_loss', print_loss_avg, global_step=step) # Checkpoint if step % self.checkpoint_every == 0 or step == total_steps: # save criteria if dev_set is not None: losses, metrics = self._evaluate_batches( model, dev_set) loss = losses['nll_loss'] accuracy = metrics['accuracy'] bleu = metrics['bleu'] log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, bleu: %.4f' % ( step / total_steps * 100, loss, accuracy, bleu) log.info(log_msg) self.writer.add_scalar('dev_loss', loss, global_step=step) self.writer.add_scalar('dev_acc', accuracy, global_step=step) self.writer.add_scalar('dev_bleu', bleu, global_step=step) # save condition cond_acc = (prev_acc <= accuracy) cond_bleu = (((prev_acc <= accuracy) and (bleu < 0.1)) or prev_bleu <= bleu) # save if self.eval_metric == 'tokacc': save_cond = cond_acc elif self.eval_metric == 'bleu': save_cond = cond_bleu if save_cond: # save best model ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_tgt) saved_path = ckpt.save(self.expt_dir) log.info('saving at {} ... '.format(saved_path)) # reset prev_acc = accuracy prev_bleu = bleu count_no_improve = 0 count_num_rollback = 0 else: count_no_improve += 1 # roll back if count_no_improve > self.max_count_no_improve: # no roll back - break after self.max_count_no_improve epochs if self.max_count_num_rollback == 0: break # resuming latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) log.info( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim.__class__( model.parameters(), **defaults) # reset count_no_improve = 0 count_num_rollback += 1 # update learning rate if count_num_rollback > self.max_count_num_rollback: # roll back latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) log.info( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim.__class__( model.parameters(), **defaults) # decrease lr for param_group in self.optimizer.optimizer.param_groups: param_group['lr'] *= 0.5 lr_curr = param_group['lr'] log.info('reducing lr ...') log.info('step:{} - lr: {}'.format( step, param_group['lr'])) # check early stop if lr_curr <= 0.125 * self.learning_rate: log.info('early stop ...') break # reset count_no_improve = 0 count_num_rollback = 0 model.train(mode=True) if ckpt is None: ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_tgt) ckpt.rm_old(self.expt_dir, keep_num=self.keep_num) log.info('n_no_improve {}, num_rollback {}'.format( count_no_improve, count_num_rollback)) sys.stdout.flush() else: if dev_set is None: # save every epoch if no dev_set ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_tgt) saved_path = ckpt.save_epoch(self.expt_dir, epoch) log.info('saving at {} ... '.format(saved_path)) continue else: continue # break nested for loop break
def _train_epoches(self, train_set, model, n_epochs, start_epoch, start_step, dev_set=None): log = self.logger print_loss_total = 0 # Reset every print_every epoch_loss_total = 0 # Reset every epoch att_print_loss_total = 0 # Reset every print_every att_epoch_loss_total = 0 # Reset every epoch attcls_print_loss_total = 0 # Reset every print_every attcls_epoch_loss_total = 0 # Reset every epoch step = start_step step_elapsed = 0 prev_acc = 0.0 count_no_improve = 0 count_num_rollback = 0 ckpt = None # ******************** [loop over epochs] ******************** for epoch in range(start_epoch, n_epochs + 1): for param_group in self.optimizer.optimizer.param_groups: print('epoch:{} lr: {}'.format(epoch, param_group['lr'])) lr_curr = param_group['lr'] # ----------construct batches----------- # allow re-shuffling of data if type(train_set.attkey_path) == type(None): print('--- construct train set ---') train_batches, vocab_size = train_set.construct_batches( is_train=True) if dev_set is not None: print('--- construct dev set ---') dev_batches, vocab_size = dev_set.construct_batches( is_train=False) else: print('--- construct train set ---') train_batches, vocab_size = train_set.construct_batches_with_ddfd_prob( is_train=True) if dev_set is not None: print('--- construct dev set ---') assert type(dev_set.attkey_path) != type( None), 'Dev set missing ddfd probabilities' dev_batches, vocab_size = dev_set.construct_batches_with_ddfd_prob( is_train=False) # --------print info for each epoch---------- steps_per_epoch = len(train_batches) total_steps = steps_per_epoch * n_epochs log.info("steps_per_epoch {}".format(steps_per_epoch)) log.info("total_steps {}".format(total_steps)) log.debug( " ----------------- Epoch: %d, Step: %d -----------------" % (epoch, step)) mem_kb, mem_mb, mem_gb = get_memory_alloc() mem_mb = round(mem_mb, 2) print('Memory used: {0:.2f} MB'.format(mem_mb)) self.writer.add_scalar('Memory_MB', mem_mb, global_step=step) sys.stdout.flush() # ******************** [loop over batches] ******************** model.train(True) for batch in train_batches: # update macro count step += 1 step_elapsed += 1 # load data src_ids = batch['src_word_ids'] src_lengths = batch['src_sentence_lengths'] tgt_ids = batch['tgt_word_ids'] tgt_lengths = batch['tgt_sentence_lengths'] src_probs = None src_labs = None if 'src_ddfd_probs' in batch and model.additional_key_size > 0: src_probs = batch['src_ddfd_probs'] src_probs = _convert_to_tensor(src_probs, self.use_gpu).unsqueeze(2) if 'src_ddfd_labs' in batch: src_labs = batch['src_ddfd_labs'] src_labs = _convert_to_tensor(src_labs, self.use_gpu).unsqueeze(2) # sanity check src-tgt pair if step == 1: print('--- Check src tgt pair ---') log_msgs = check_srctgt(src_ids, tgt_ids, train_set.src_id2word, train_set.tgt_id2word) for log_msg in log_msgs: sys.stdout.buffer.write(log_msg) # convert variable to tensor src_ids = _convert_to_tensor(src_ids, self.use_gpu) tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu) # Get loss loss, att_loss, attcls_loss = self._train_batch( src_ids, tgt_ids, model, step, total_steps, src_probs=src_probs, src_labs=src_labs) print_loss_total += loss epoch_loss_total += loss att_print_loss_total += att_loss att_epoch_loss_total += att_loss attcls_print_loss_total += attcls_loss attcls_epoch_loss_total += attcls_loss if step % self.print_every == 0 and step_elapsed > self.print_every: print_loss_avg = print_loss_total / self.print_every att_print_loss_avg = att_print_loss_total / self.print_every attcls_print_loss_avg = attcls_print_loss_total / self.print_every print_loss_total = 0 att_print_loss_total = 0 attcls_print_loss_total = 0 log_msg = 'Progress: %d%%, Train nlll: %.4f, att: %.4f, attcls: %.4f' % ( step / total_steps * 100, print_loss_avg, att_print_loss_avg, attcls_print_loss_avg) # print(log_msg) log.info(log_msg) self.writer.add_scalar('train_loss', print_loss_avg, global_step=step) self.writer.add_scalar('att_train_loss', att_print_loss_avg, global_step=step) self.writer.add_scalar('attcls_train_loss', attcls_print_loss_avg, global_step=step) # Checkpoint if step % self.checkpoint_every == 0 or step == total_steps: # save criteria if dev_set is not None: dev_loss, accuracy, dev_attlosses = \ self._evaluate_batches(model, dev_batches, dev_set) dev_attloss = dev_attlosses['att_loss'] dev_attclsloss = dev_attlosses['attcls_loss'] log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, att: %.4f, attcls: %.4f' % ( step / total_steps * 100, dev_loss, accuracy, dev_attloss, dev_attclsloss) log.info(log_msg) self.writer.add_scalar('dev_loss', dev_loss, global_step=step) self.writer.add_scalar('dev_acc', accuracy, global_step=step) self.writer.add_scalar('att_dev_loss', dev_attloss, global_step=step) self.writer.add_scalar('attcls_dev_loss', dev_attclsloss, global_step=step) # save if prev_acc < accuracy: # save best model ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_tgt) saved_path = ckpt.save(self.expt_dir) print('saving at {} ... '.format(saved_path)) # reset prev_acc = accuracy count_no_improve = 0 count_num_rollback = 0 else: count_no_improve += 1 # roll back if count_no_improve > MAX_COUNT_NO_IMPROVE: # resuming latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) print( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim\ .__class__(model.parameters(), **defaults) # start_epoch = resume_checkpoint.epoch # step = resume_checkpoint.step # reset count_no_improve = 0 count_num_rollback += 1 # update learning rate if count_num_rollback > MAX_COUNT_NUM_ROLLBACK: # roll back latest_checkpoint_path = Checkpoint.get_latest_checkpoint( self.expt_dir) if type(latest_checkpoint_path) != type(None): resume_checkpoint = Checkpoint.load( latest_checkpoint_path) print( 'epoch:{} step: {} - rolling back {} ...'. format(epoch, step, latest_checkpoint_path)) model = resume_checkpoint.model self.optimizer = resume_checkpoint.optimizer # A walk around to set optimizing parameters properly resume_optim = self.optimizer.optimizer defaults = resume_optim.param_groups[0] defaults.pop('params', None) defaults.pop('initial_lr', None) self.optimizer.optimizer = resume_optim\ .__class__(model.parameters(), **defaults) start_epoch = resume_checkpoint.epoch step = resume_checkpoint.step # decrease lr for param_group in self.optimizer.optimizer.param_groups: param_group['lr'] *= 0.5 lr_curr = param_group['lr'] print('reducing lr ...') print('step:{} - lr: {}'.format( step, param_group['lr'])) # check early stop if lr_curr < 0.000125: print('early stop ...') break # reset count_no_improve = 0 count_num_rollback = 0 model.train(mode=True) if ckpt is None: ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_tgt) saved_path = ckpt.save(self.expt_dir) ckpt.rm_old(self.expt_dir, keep_num=KEEP_NUM) print('n_no_improve {}, num_rollback {}'.format( count_no_improve, count_num_rollback)) sys.stdout.flush() else: if dev_set is None: # save every epoch if no dev_set ckpt = Checkpoint(model=model, optimizer=self.optimizer, epoch=epoch, step=step, input_vocab=train_set.vocab_src, output_vocab=train_set.vocab_tgt) # saved_path = ckpt.save(self.expt_dir) saved_path = ckpt.save_epoch(self.expt_dir, epoch) print('saving at {} ... '.format(saved_path)) continue else: continue # break nested for loop break if step_elapsed == 0: continue epoch_loss_avg = epoch_loss_total / min(steps_per_epoch, step - start_step) epoch_loss_total = 0 log_msg = "Finished epoch %d: Train %s: %.4f" % ( epoch, self.loss.name, epoch_loss_avg) log.info('\n') log.info(log_msg)