def __init__(self, model: SignModel, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = make_model_dir(train_config["model_dir"], overwrite=train_config.get( "overwrite", False)) self.logger = make_logger(model_dir=self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = "{}/validations.txt".format(self.model_dir) self.tb_writer = SummaryWriter(log_dir=self.model_dir + "/tensorboard/") # input self.feature_size = (sum(config["data"]["feature_size"]) if isinstance( config["data"]["feature_size"], list) else config["data"]["feature_size"]) self.dataset_version = config["data"].get("version", "phoenix_2014_trans") # model self.model = model self.txt_pad_index = self.model.txt_pad_index self.txt_bos_index = self.model.txt_bos_index self._log_parameters_list() # Check if we are doing only recognition or only translation or both self.do_recognition = (config["training"].get( "recognition_loss_weight", 1.0) > 0.0) self.do_translation = (config["training"].get( "translation_loss_weight", 1.0) > 0.0) # Get Recognition and Translation specific parameters if self.do_recognition: self._get_recognition_params(train_config=train_config) if self.do_translation: self._get_translation_params(train_config=train_config) # optimization self.last_best_lr = train_config.get("learning_rate", -1) self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer(config=train_config, parameters=model.parameters()) self.batch_multiplier = train_config.get("batch_multiplier", 1) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 100) self.num_valid_log = train_config.get("num_valid_log", 5) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) self.eval_metric = train_config.get("eval_metric", "bleu") if self.eval_metric not in ["bleu", "chrf", "wer", "rouge"]: raise ValueError("Invalid setting for 'eval_metric': {}".format( self.eval_metric)) self.early_stopping_metric = train_config.get("early_stopping_metric", "eval_metric") # if we schedule after BLEU/chrf, we want to maximize it, else minimize # early_stopping_metric decides on how to find the early stopping point: # ckpts are written when there's a new high/low score for this metric if self.early_stopping_metric in [ "ppl", "translation_loss", "recognition_loss", ]: self.minimize_metric = True elif self.early_stopping_metric == "eval_metric": if self.eval_metric in ["bleu", "chrf", "rouge"]: assert self.do_translation self.minimize_metric = False else: # eval metric that has to get minimized (not yet implemented) self.minimize_metric = True else: raise ValueError( "Invalid setting for 'early_stopping_metric': {}".format( self.early_stopping_metric)) # data_augmentation parameters self.frame_subsampling_ratio = config["data"].get( "frame_subsampling_ratio", None) self.random_frame_subsampling = config["data"].get( "random_frame_subsampling", None) self.random_frame_masking_ratio = config["data"].get( "random_frame_masking_ratio", None) # learning rate scheduling self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=config["model"]["encoder"]["hidden_size"], ) # data & batch handling self.level = config["data"]["level"] if self.level not in ["word", "bpe", "char"]: raise ValueError("Invalid segmentation level': {}".format( self.level)) self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() if self.do_translation: self.translation_loss_function.cuda() if self.do_recognition: self.recognition_loss_function.cuda() # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_txt_tokens = 0 self.total_gls_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf self.best_all_ckpt_scores = {} # comparison function for scores self.is_best = ( lambda score: score < self.best_ckpt_score if self.minimize_metric else score > self.best_ckpt_score) # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] self.logger.info("Loading model from %s", model_load_path) reset_best_ckpt = train_config.get("reset_best_ckpt", False) reset_scheduler = train_config.get("reset_scheduler", False) reset_optimizer = train_config.get("reset_optimizer", False) self.init_from_checkpoint( model_load_path, reset_best_ckpt=reset_best_ckpt, reset_scheduler=reset_scheduler, reset_optimizer=reset_optimizer, )
def validate_on_data( model: SignModel, data: Dataset, batch_size: int, use_cuda: bool, sgn_dim: int, do_recognition: bool, recognition_loss_function: torch.nn.Module, recognition_loss_weight: int, do_translation: bool, translation_loss_function: torch.nn.Module, translation_loss_weight: int, translation_max_output_length: int, level: str, txt_pad_index: int, fusion_type: str, recognition_beam_size: int = 1, translation_beam_size: int = 1, translation_beam_alpha: int = -1, batch_type: str = "sentence", dataset_version: str = "phoenix_2014_trans", frame_subsampling_ratio: int = None, ) -> ( float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array], ): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param translation_max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param translation_loss_function: translation loss function (XEntropy) :param recognition_loss_function: recognition loss function (CTC) :param recognition_loss_weight: CTC loss weight :param translation_loss_weight: Translation loss weight :param txt_pad_index: txt padding token index :param sgn_dim: Feature dimension of sgn frames :param recognition_beam_size: beam size for validation (recognition, i.e. CTC). If 0 then greedy decoding (default). :param translation_beam_size: beam size for validation (translation). If 0 then greedy decoding (default). :param translation_beam_alpha: beam search alpha for length penalty (translation), disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param do_recognition: flag for predicting glosses :param do_translation: flag for predicting text :param dataset_version: phoenix_2014 or phoenix_2014_trans :param frame_subsampling_ratio: frame subsampling ratio :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ valid_iter = make_data_iter( dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False, ) # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_gls_outputs = [] all_txt_outputs = [] all_attention_scores = [] total_recognition_loss = 0 total_translation_loss = 0 total_num_txt_tokens = 0 total_num_gls_tokens = 0 total_num_seqs = 0 for valid_batch in iter(valid_iter): batch = Batch( is_train=False, torch_batch=valid_batch, txt_pad_index=txt_pad_index, sgn_dim=sgn_dim, fusion_type=fusion_type, use_cuda=use_cuda, frame_subsampling_ratio=frame_subsampling_ratio, ) sort_reverse_index = batch.sort_by_sgn_lengths() batch_recognition_loss, batch_translation_loss = model.get_loss_for_batch( batch=batch, fusion_type=fusion_type, recognition_loss_function=recognition_loss_function if do_recognition else None, translation_loss_function=translation_loss_function if do_translation else None, recognition_loss_weight=recognition_loss_weight if do_recognition else None, translation_loss_weight=translation_loss_weight if do_translation else None, ) if do_recognition: total_recognition_loss += batch_recognition_loss total_num_gls_tokens += batch.num_gls_tokens if do_translation: total_translation_loss += batch_translation_loss total_num_txt_tokens += batch.num_txt_tokens total_num_seqs += batch.num_seqs ( batch_gls_predictions, batch_txt_predictions, batch_attention_scores, ) = model.run_batch( batch=batch, recognition_beam_size=recognition_beam_size if do_recognition else None, translation_beam_size=translation_beam_size if do_translation else None, translation_beam_alpha=translation_beam_alpha if do_translation else None, translation_max_output_length=translation_max_output_length if do_translation else None, ) # sort outputs back to original order if do_recognition: all_gls_outputs.extend( [batch_gls_predictions[sri] for sri in sort_reverse_index]) if do_translation: all_txt_outputs.extend( batch_txt_predictions[sort_reverse_index]) all_attention_scores.extend( batch_attention_scores[sort_reverse_index] if batch_attention_scores is not None else []) if do_recognition: assert len(all_gls_outputs) == len(data) if (recognition_loss_function is not None and recognition_loss_weight != 0 and total_num_gls_tokens > 0): valid_recognition_loss = total_recognition_loss else: valid_recognition_loss = -1 # decode back to symbols decoded_gls = model.gls_vocab.arrays_to_sentences( arrays=all_gls_outputs) # Gloss clean-up function if dataset_version == "phoenix_2014_trans": gls_cln_fn = clean_phoenix_2014_trans elif dataset_version == "phoenix_2014": gls_cln_fn = clean_phoenix_2014 else: raise ValueError("Unknown Dataset Version: " + dataset_version) # Construct gloss sequences for metrics gls_ref = [gls_cln_fn(" ".join(t)) for t in data.gls] gls_hyp = [gls_cln_fn(" ".join(t)) for t in decoded_gls] assert len(gls_ref) == len(gls_hyp) # GLS Metrics gls_wer_score = wer_list(hypotheses=gls_hyp, references=gls_ref) if do_translation: assert len(all_txt_outputs) == len(data) if (translation_loss_function is not None and translation_loss_weight != 0 and total_num_txt_tokens > 0): # total validation translation loss valid_translation_loss = total_translation_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_translation_loss / total_num_txt_tokens) else: valid_translation_loss = -1 valid_ppl = -1 # decode back to symbols decoded_txt = model.txt_vocab.arrays_to_sentences( arrays=all_txt_outputs) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" # Construct text sequences for metrics txt_ref = [join_char.join(t) for t in data.txt] txt_hyp = [join_char.join(t) for t in decoded_txt] # post-process if level == "bpe": txt_ref = [bpe_postprocess(v) for v in txt_ref] txt_hyp = [bpe_postprocess(v) for v in txt_hyp] assert len(txt_ref) == len(txt_hyp) # TXT Metrics txt_bleu = bleu(references=txt_ref, hypotheses=txt_hyp) txt_chrf = chrf(references=txt_ref, hypotheses=txt_hyp) txt_rouge = rouge(references=txt_ref, hypotheses=txt_hyp) valid_scores = {} if do_recognition: valid_scores["wer"] = gls_wer_score["wer"] valid_scores["wer_scores"] = gls_wer_score if do_translation: valid_scores["bleu"] = txt_bleu["bleu4"] valid_scores["bleu_scores"] = txt_bleu valid_scores["chrf"] = txt_chrf valid_scores["rouge"] = txt_rouge results = { "valid_scores": valid_scores, "all_attention_scores": all_attention_scores, } if do_recognition: results["valid_recognition_loss"] = valid_recognition_loss results["decoded_gls"] = decoded_gls results["gls_ref"] = gls_ref results["gls_hyp"] = gls_hyp if do_translation: results["valid_translation_loss"] = valid_translation_loss results["valid_ppl"] = valid_ppl results["decoded_txt"] = decoded_txt results["txt_ref"] = txt_ref results["txt_hyp"] = txt_hyp return results