def train_and_validate(self, train_data: Dataset, valid_data: Dataset) -> None: """ Train the model and validate it from time to time on the validation set. :param train_data: training data :param valid_data: validation data """ train_iter = make_data_iter( train_data, batch_size=self.batch_size, batch_type=self.batch_type, train=True, shuffle=self.shuffle, ) epoch_no = None for epoch_no in range(self.epochs): self.logger.info("EPOCH %d", epoch_no + 1) if self.scheduler is not None and self.scheduler_step_at == "epoch": self.scheduler.step(epoch=epoch_no) self.model.train() start = time.time() total_valid_duration = 0 count = self.batch_multiplier - 1 if self.do_recognition: processed_gls_tokens = self.total_gls_tokens epoch_recognition_loss = 0 if self.do_translation: processed_txt_tokens = self.total_txt_tokens epoch_translation_loss = 0 for batch in iter(train_iter): # reactivate training # create a Batch object from torchtext batch batch = Batch( is_train=True, torch_batch=batch, txt_pad_index=self.txt_pad_index, sgn_dim=self.feature_size, use_cuda=self.use_cuda, frame_subsampling_ratio=self.frame_subsampling_ratio, random_frame_subsampling=self.random_frame_subsampling, random_frame_masking_ratio=self.random_frame_masking_ratio, ) # only update every batch_multiplier batches # see https://medium.com/@davidlmorton/ # increasing-mini-batch-size-without-increasing- # memory-6794e10db672 update = count == 0 recognition_loss, translation_loss = self._train_batch( batch, update=update) if self.do_recognition: self.tb_writer.add_scalar("train/train_recognition_loss", recognition_loss, self.steps) epoch_recognition_loss += recognition_loss.detach().cpu( ).numpy() if self.do_translation: self.tb_writer.add_scalar("train/train_translation_loss", translation_loss, self.steps) epoch_translation_loss += translation_loss.detach().cpu( ).numpy() count = self.batch_multiplier if update else count count -= 1 if (self.scheduler is not None and self.scheduler_step_at == "step" and update): self.scheduler.step() # log learning progress if self.steps % self.logging_freq == 0 and update: elapsed = time.time() - start - total_valid_duration log_out = "[Epoch: {:03d} Step: {:08d}] ".format( epoch_no + 1, self.steps, ) if self.do_recognition: elapsed_gls_tokens = (self.total_gls_tokens - processed_gls_tokens) processed_gls_tokens = self.total_gls_tokens log_out += "Batch Recognition Loss: {:10.6f} => ".format( recognition_loss) log_out += "Gls Tokens per Sec: {:8.0f} || ".format( elapsed_gls_tokens / elapsed) if self.do_translation: elapsed_txt_tokens = (self.total_txt_tokens - processed_txt_tokens) processed_txt_tokens = self.total_txt_tokens log_out += "Batch Translation Loss: {:10.6f} => ".format( translation_loss) log_out += "Txt Tokens per Sec: {:8.0f} || ".format( elapsed_txt_tokens / elapsed) log_out += "Lr: {:.6f}".format( self.optimizer.param_groups[0]["lr"]) self.logger.info(log_out) start = time.time() total_valid_duration = 0 # validate on the entire dev set if self.steps % self.validation_freq == 0 and update: valid_start_time = time.time() # TODO (Cihan): There must be a better way of passing # these recognition only and translation only parameters! # Maybe have a NamedTuple with optional fields? # Hmm... Future Cihan's problem. val_res = validate_on_data( model=self.model, data=valid_data, batch_size=self.eval_batch_size, use_cuda=self.use_cuda, batch_type=self.eval_batch_type, dataset_version=self.dataset_version, sgn_dim=self.feature_size, txt_pad_index=self.txt_pad_index, # Recognition Parameters do_recognition=self.do_recognition, recognition_loss_function=self. recognition_loss_function if self.do_recognition else None, recognition_loss_weight=self.recognition_loss_weight if self.do_recognition else None, recognition_beam_size=self.eval_recognition_beam_size if self.do_recognition else None, # Translation Parameters do_translation=self.do_translation, translation_loss_function=self. translation_loss_function if self.do_translation else None, translation_max_output_length=self. translation_max_output_length if self.do_translation else None, level=self.level if self.do_translation else None, translation_loss_weight=self.translation_loss_weight if self.do_translation else None, translation_beam_size=self.eval_translation_beam_size if self.do_translation else None, translation_beam_alpha=self.eval_translation_beam_alpha if self.do_translation else None, frame_subsampling_ratio=self.frame_subsampling_ratio, ) self.model.train() if self.do_recognition: # Log Losses and ppl self.tb_writer.add_scalar( "valid/valid_recognition_loss", val_res["valid_recognition_loss"], self.steps, ) self.tb_writer.add_scalar( "valid/wer", val_res["valid_scores"]["wer"], self.steps) self.tb_writer.add_scalars( "valid/wer_scores", val_res["valid_scores"]["wer_scores"], self.steps, ) if self.do_translation: self.tb_writer.add_scalar( "valid/valid_translation_loss", val_res["valid_translation_loss"], self.steps, ) self.tb_writer.add_scalar("valid/valid_ppl", val_res["valid_ppl"], self.steps) # Log Scores self.tb_writer.add_scalar( "valid/chrf", val_res["valid_scores"]["chrf"], self.steps) self.tb_writer.add_scalar( "valid/rouge", val_res["valid_scores"]["rouge"], self.steps) self.tb_writer.add_scalar( "valid/bleu", val_res["valid_scores"]["bleu"], self.steps) self.tb_writer.add_scalars( "valid/bleu_scores", val_res["valid_scores"]["bleu_scores"], self.steps, ) if self.early_stopping_metric == "recognition_loss": assert self.do_recognition ckpt_score = val_res["valid_recognition_loss"] elif self.early_stopping_metric == "translation_loss": assert self.do_translation ckpt_score = val_res["valid_translation_loss"] elif self.early_stopping_metric in ["ppl", "perplexity"]: assert self.do_translation ckpt_score = val_res["valid_ppl"] else: ckpt_score = val_res["valid_scores"][self.eval_metric] new_best = False if self.is_best(ckpt_score): self.best_ckpt_score = ckpt_score self.best_all_ckpt_scores = val_res["valid_scores"] self.best_ckpt_iteration = self.steps self.logger.info( "Hooray! New best validation result [%s]!", self.early_stopping_metric, ) if self.ckpt_queue.maxsize > 0: self.logger.info("Saving new checkpoint.") new_best = True self._save_checkpoint() if (self.scheduler is not None and self.scheduler_step_at == "validation"): prev_lr = self.scheduler.optimizer.param_groups[0][ "lr"] self.scheduler.step(ckpt_score) now_lr = self.scheduler.optimizer.param_groups[0]["lr"] if prev_lr != now_lr: if self.last_best_lr != prev_lr: self.stop = True # append to validation report self._add_report( valid_scores=val_res["valid_scores"], valid_recognition_loss=val_res["valid_recognition_loss"] if self.do_recognition else None, valid_translation_loss=val_res["valid_translation_loss"] if self.do_translation else None, valid_ppl=val_res["valid_ppl"] if self.do_translation else None, eval_metric=self.eval_metric, new_best=new_best, ) valid_duration = time.time() - valid_start_time total_valid_duration += valid_duration self.logger.info( "Validation result at epoch %3d, step %8d: duration: %.4fs\n\t" "Recognition Beam Size: %d\t" "Translation Beam Size: %d\t" "Translation Beam Alpha: %d\n\t" "Recognition Loss: %4.5f\t" "Translation Loss: %4.5f\t" "PPL: %4.5f\n\t" "Eval Metric: %s\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", epoch_no + 1, self.steps, valid_duration, self.eval_recognition_beam_size if self.do_recognition else -1, self.eval_translation_beam_size if self.do_translation else -1, self.eval_translation_beam_alpha if self.do_translation else -1, val_res["valid_recognition_loss"] if self.do_recognition else -1, val_res["valid_translation_loss"] if self.do_translation else -1, val_res["valid_ppl"] if self.do_translation else -1, self.eval_metric.upper(), # WER val_res["valid_scores"]["wer"] if self.do_recognition else -1, val_res["valid_scores"]["wer_scores"]["del_rate"] if self.do_recognition else -1, val_res["valid_scores"]["wer_scores"]["ins_rate"] if self.do_recognition else -1, val_res["valid_scores"]["wer_scores"]["sub_rate"] if self.do_recognition else -1, # BLEU val_res["valid_scores"]["bleu"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu1"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu2"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu3"] if self.do_translation else -1, val_res["valid_scores"]["bleu_scores"]["bleu4"] if self.do_translation else -1, # Other val_res["valid_scores"]["chrf"] if self.do_translation else -1, val_res["valid_scores"]["rouge"] if self.do_translation else -1, ) self._log_examples( sequences=[s for s in valid_data.sequence], gls_references=val_res["gls_ref"] if self.do_recognition else None, gls_hypotheses=val_res["gls_hyp"] if self.do_recognition else None, txt_references=val_res["txt_ref"] if self.do_translation else None, txt_hypotheses=val_res["txt_hyp"] if self.do_translation else None, ) valid_seq = [s for s in valid_data.sequence] # store validation set outputs and references if self.do_recognition: self._store_outputs("dev.hyp.gls", valid_seq, val_res["gls_hyp"], "gls") self._store_outputs("references.dev.gls", valid_seq, val_res["gls_ref"]) if self.do_translation: self._store_outputs("dev.hyp.txt", valid_seq, val_res["txt_hyp"], "txt") self._store_outputs("references.dev.txt", valid_seq, val_res["txt_ref"]) if self.stop: break if self.stop: if (self.scheduler is not None and self.scheduler_step_at == "validation" and self.last_best_lr != prev_lr): self.logger.info( "Training ended since there were no improvements in" "the last learning rate step: %f", prev_lr, ) else: self.logger.info( "Training ended since minimum lr %f was reached.", self.learning_rate_min, ) break self.logger.info( "Epoch %3d: Total Training Recognition Loss %.2f " " Total Training Translation Loss %.2f ", epoch_no + 1, epoch_recognition_loss if self.do_recognition else -1, epoch_translation_loss if self.do_translation else -1, ) else: self.logger.info("Training ended after %3d epochs.", epoch_no + 1) self.logger.info( "Best validation result at step %8d: %6.2f %s.", self.best_ckpt_iteration, self.best_ckpt_score, self.early_stopping_metric, ) self.tb_writer.close() # close Tensorboard writer
def validate_on_data( model: SignModel, data: Dataset, batch_size: int, use_cuda: bool, sgn_dim: int, do_recognition: bool, recognition_loss_function: torch.nn.Module, recognition_loss_weight: int, do_translation: bool, translation_loss_function: torch.nn.Module, translation_loss_weight: int, translation_max_output_length: int, level: str, txt_pad_index: int, fusion_type: str, recognition_beam_size: int = 1, translation_beam_size: int = 1, translation_beam_alpha: int = -1, batch_type: str = "sentence", dataset_version: str = "phoenix_2014_trans", frame_subsampling_ratio: int = None, ) -> ( float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array], ): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param translation_max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param translation_loss_function: translation loss function (XEntropy) :param recognition_loss_function: recognition loss function (CTC) :param recognition_loss_weight: CTC loss weight :param translation_loss_weight: Translation loss weight :param txt_pad_index: txt padding token index :param sgn_dim: Feature dimension of sgn frames :param recognition_beam_size: beam size for validation (recognition, i.e. CTC). If 0 then greedy decoding (default). :param translation_beam_size: beam size for validation (translation). If 0 then greedy decoding (default). :param translation_beam_alpha: beam search alpha for length penalty (translation), disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param do_recognition: flag for predicting glosses :param do_translation: flag for predicting text :param dataset_version: phoenix_2014 or phoenix_2014_trans :param frame_subsampling_ratio: frame subsampling ratio :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ valid_iter = make_data_iter( dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False, ) # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_gls_outputs = [] all_txt_outputs = [] all_attention_scores = [] total_recognition_loss = 0 total_translation_loss = 0 total_num_txt_tokens = 0 total_num_gls_tokens = 0 total_num_seqs = 0 for valid_batch in iter(valid_iter): batch = Batch( is_train=False, torch_batch=valid_batch, txt_pad_index=txt_pad_index, sgn_dim=sgn_dim, fusion_type=fusion_type, use_cuda=use_cuda, frame_subsampling_ratio=frame_subsampling_ratio, ) sort_reverse_index = batch.sort_by_sgn_lengths() batch_recognition_loss, batch_translation_loss = model.get_loss_for_batch( batch=batch, fusion_type=fusion_type, recognition_loss_function=recognition_loss_function if do_recognition else None, translation_loss_function=translation_loss_function if do_translation else None, recognition_loss_weight=recognition_loss_weight if do_recognition else None, translation_loss_weight=translation_loss_weight if do_translation else None, ) if do_recognition: total_recognition_loss += batch_recognition_loss total_num_gls_tokens += batch.num_gls_tokens if do_translation: total_translation_loss += batch_translation_loss total_num_txt_tokens += batch.num_txt_tokens total_num_seqs += batch.num_seqs ( batch_gls_predictions, batch_txt_predictions, batch_attention_scores, ) = model.run_batch( batch=batch, recognition_beam_size=recognition_beam_size if do_recognition else None, translation_beam_size=translation_beam_size if do_translation else None, translation_beam_alpha=translation_beam_alpha if do_translation else None, translation_max_output_length=translation_max_output_length if do_translation else None, ) # sort outputs back to original order if do_recognition: all_gls_outputs.extend( [batch_gls_predictions[sri] for sri in sort_reverse_index]) if do_translation: all_txt_outputs.extend( batch_txt_predictions[sort_reverse_index]) all_attention_scores.extend( batch_attention_scores[sort_reverse_index] if batch_attention_scores is not None else []) if do_recognition: assert len(all_gls_outputs) == len(data) if (recognition_loss_function is not None and recognition_loss_weight != 0 and total_num_gls_tokens > 0): valid_recognition_loss = total_recognition_loss else: valid_recognition_loss = -1 # decode back to symbols decoded_gls = model.gls_vocab.arrays_to_sentences( arrays=all_gls_outputs) # Gloss clean-up function if dataset_version == "phoenix_2014_trans": gls_cln_fn = clean_phoenix_2014_trans elif dataset_version == "phoenix_2014": gls_cln_fn = clean_phoenix_2014 else: raise ValueError("Unknown Dataset Version: " + dataset_version) # Construct gloss sequences for metrics gls_ref = [gls_cln_fn(" ".join(t)) for t in data.gls] gls_hyp = [gls_cln_fn(" ".join(t)) for t in decoded_gls] assert len(gls_ref) == len(gls_hyp) # GLS Metrics gls_wer_score = wer_list(hypotheses=gls_hyp, references=gls_ref) if do_translation: assert len(all_txt_outputs) == len(data) if (translation_loss_function is not None and translation_loss_weight != 0 and total_num_txt_tokens > 0): # total validation translation loss valid_translation_loss = total_translation_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_translation_loss / total_num_txt_tokens) else: valid_translation_loss = -1 valid_ppl = -1 # decode back to symbols decoded_txt = model.txt_vocab.arrays_to_sentences( arrays=all_txt_outputs) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" # Construct text sequences for metrics txt_ref = [join_char.join(t) for t in data.txt] txt_hyp = [join_char.join(t) for t in decoded_txt] # post-process if level == "bpe": txt_ref = [bpe_postprocess(v) for v in txt_ref] txt_hyp = [bpe_postprocess(v) for v in txt_hyp] assert len(txt_ref) == len(txt_hyp) # TXT Metrics txt_bleu = bleu(references=txt_ref, hypotheses=txt_hyp) txt_chrf = chrf(references=txt_ref, hypotheses=txt_hyp) txt_rouge = rouge(references=txt_ref, hypotheses=txt_hyp) valid_scores = {} if do_recognition: valid_scores["wer"] = gls_wer_score["wer"] valid_scores["wer_scores"] = gls_wer_score if do_translation: valid_scores["bleu"] = txt_bleu["bleu4"] valid_scores["bleu_scores"] = txt_bleu valid_scores["chrf"] = txt_chrf valid_scores["rouge"] = txt_rouge results = { "valid_scores": valid_scores, "all_attention_scores": all_attention_scores, } if do_recognition: results["valid_recognition_loss"] = valid_recognition_loss results["decoded_gls"] = decoded_gls results["gls_ref"] = gls_ref results["gls_hyp"] = gls_hyp if do_translation: results["valid_translation_loss"] = valid_translation_loss results["valid_ppl"] = valid_ppl results["decoded_txt"] = decoded_txt results["txt_ref"] = txt_ref results["txt_hyp"] = txt_hyp return results