def init_from_checkpoint( self, path: str, reset_best_ckpt: bool = False, reset_scheduler: bool = False, reset_optimizer: bool = False, ) -> None: """ Initialize the trainer from a given checkpoint file. This checkpoint file contains not only model parameters, but also scheduler and optimizer states, see `self._save_checkpoint`. :param path: path to checkpoint :param reset_best_ckpt: reset tracking of the best checkpoint, use for domain adaptation with a new dev set or when using a new metric for fine-tuning. :param reset_scheduler: reset the learning rate scheduler, and do not use the one stored in the checkpoint. :param reset_optimizer: reset the optimizer, and do not use the one stored in the checkpoint. """ model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda) # restore model and optimizer parameters self.model.load_state_dict(model_checkpoint["model_state"]) if not reset_optimizer: self.optimizer.load_state_dict(model_checkpoint["optimizer_state"]) else: self.logger.info("Reset optimizer.") if not reset_scheduler: if (model_checkpoint["scheduler_state"] is not None and self.scheduler is not None): self.scheduler.load_state_dict( model_checkpoint["scheduler_state"]) else: self.logger.info("Reset scheduler.") # restore counts self.steps = model_checkpoint["steps"] self.total_txt_tokens = model_checkpoint["total_txt_tokens"] self.total_gls_tokens = model_checkpoint["total_gls_tokens"] if not reset_best_ckpt: self.best_ckpt_score = model_checkpoint["best_ckpt_score"] self.best_all_ckpt_scores = model_checkpoint[ "best_all_ckpt_scores"] self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"] else: self.logger.info("Reset tracking of the best checkpoint.") # move parameters to cuda if self.use_cuda: self.model.cuda()
def feat_test(cfg_file, ckpt: str, output_path: str = None, logger: logging.Logger = None) -> None: """ Main test function. Handles loading a model from checkpoint, generating translations and storing them and attention plots. :param cfg_file: path to configuration file :param ckpt: path to checkpoint to load :param output_path: path to output :param logger: log output to this logger (creates new logger if not set) """ if logger is None: logger = logging.getLogger(__name__) if not logger.handlers: FORMAT = "%(asctime)-15s - %(message)s" logging.basicConfig(format=FORMAT) logger.setLevel(level=logging.DEBUG) cfg = load_config(cfg_file) if "test" not in cfg["data"].keys(): raise ValueError("Test data must be specified in config.") # when checkpoint is not specified, take latest (best) from model dir if ckpt is None: model_dir = cfg["training"]["model_dir"] ckpt = get_latest_checkpoint(model_dir) if ckpt is None: raise FileNotFoundError( "No checkpoint found in directory {}.".format(model_dir)) batch_size = cfg["training"]["batch_size"] batch_type = cfg["training"].get("batch_type", "sentence") use_cuda = cfg["training"].get("use_cuda", False) level = cfg["data"]["level"] dataset_version = cfg["data"].get("version", "phoenix_2014_trans") translation_max_output_length = cfg["training"].get( "translation_max_output_length", None) # load dev data do_anchoring = cfg["training"].get("anchoring_loss_weight", 1.0) > 0.0 _, dev_data, _, gls_vocab, txt_vocab = load_feat_data( data_cfg=cfg["data"], sets=['dev'], dev_size=1, do_anchoring=do_anchoring) # load model state from disk model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda) # build model and load parameters into it do_recognition = cfg["training"].get("recognition_loss_weight", 1.0) > 0.0 do_translation = cfg["training"].get("translation_loss_weight", 1.0) > 0.0 model = build_feat_model(cfg=cfg["model"], gls_vocab=gls_vocab, txt_vocab=txt_vocab, sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance(cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], do_recognition=do_recognition, do_translation=do_translation, do_anchoring=do_anchoring) model.load_state_dict(model_checkpoint["model_state"]) if use_cuda: model.cuda() # Data Augmentation Parameters frame_subsampling_ratio = cfg["data"].get("frame_subsampling_ratio", None) # Note (Cihan): we are not using 'random_frame_subsampling' and # 'random_frame_masking_ratio' in testing as they are just for training. # whether to use beam search for decoding, 0: greedy decoding if "testing" in cfg.keys(): recognition_beam_sizes = cfg["testing"].get("recognition_beam_sizes", [1]) translation_beam_sizes = cfg["testing"].get("translation_beam_sizes", [1]) translation_beam_alphas = cfg["testing"].get("translation_beam_alphas", [-1]) else: recognition_beam_sizes = [1] translation_beam_sizes = [1] translation_beam_alphas = [-1] if "testing" in cfg.keys(): max_recognition_beam_size = cfg["testing"].get( "max_recognition_beam_size", None) if max_recognition_beam_size is not None: recognition_beam_sizes = list( range(1, max_recognition_beam_size + 1)) if do_recognition: recognition_loss_function = torch.nn.CTCLoss( blank=model.gls_vocab.stoi[SIL_TOKEN], zero_infinity=True) if use_cuda: recognition_loss_function.cuda() if do_translation: translation_loss_function = XentLoss( pad_index=txt_vocab.stoi[PAD_TOKEN], smoothing=0.0) if use_cuda: translation_loss_function.cuda() if do_anchoring: anchoring_loss_function = AnchoringLoss() if use_cuda: anchoring_loss_function.cuda() # NOTE (Cihan): Currently Hardcoded to be 0 for TensorFlow decoding assert model.gls_vocab.stoi[SIL_TOKEN] == 0 if do_recognition: # Dev Recognition CTC Beam Search Results dev_recognition_results = {} dev_best_wer_score = float("inf") dev_best_recognition_beam_size = 1 for rbw in recognition_beam_sizes: logger.info("-" * 60) valid_start_time = time.time() logger.info("[DEV] partition [RECOGNITION] experiment [BW]: %d", rbw) dev_recognition_results[rbw] = validate_on_feat_data( model=model, data=dev_data, batch_size=batch_size, use_cuda=use_cuda, batch_type=batch_type, dataset_version=dataset_version, sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], txt_pad_index=txt_vocab.stoi[PAD_TOKEN], # Recognition Parameters do_recognition=do_recognition, recognition_loss_function=recognition_loss_function, recognition_loss_weight=1, recognition_beam_size=rbw, # Translation Parameters do_translation=do_translation, translation_loss_function=translation_loss_function if do_translation else None, translation_loss_weight=1 if do_translation else None, translation_max_output_length=translation_max_output_length if do_translation else None, level=level if do_translation else None, translation_beam_size=1 if do_translation else None, translation_beam_alpha=-1 if do_translation else None, frame_subsampling_ratio=frame_subsampling_ratio, ) logger.info("finished in %.4fs ", time.time() - valid_start_time) if dev_recognition_results[rbw]["valid_scores"][ "wer"] < dev_best_wer_score: dev_best_wer_score = dev_recognition_results[rbw][ "valid_scores"]["wer"] dev_best_recognition_beam_size = rbw dev_best_recognition_result = dev_recognition_results[rbw] logger.info("*" * 60) logger.info( "[DEV] partition [RECOGNITION] results:\n\t" "New Best CTC Decode Beam Size: %d\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)", dev_best_recognition_beam_size, dev_best_recognition_result["valid_scores"]["wer"], dev_best_recognition_result["valid_scores"]["wer_scores"] ["del_rate"], dev_best_recognition_result["valid_scores"]["wer_scores"] ["ins_rate"], dev_best_recognition_result["valid_scores"]["wer_scores"] ["sub_rate"], ) logger.info("*" * 60) if do_translation: logger.info("=" * 60) dev_translation_results = {} dev_best_bleu_score = float("-inf") dev_best_translation_beam_size = 1 dev_best_translation_alpha = 1 for tbw in translation_beam_sizes: dev_translation_results[tbw] = {} for ta in translation_beam_alphas: dev_translation_results[tbw][ta] = validate_on_feat_data( model=model, data=dev_data, batch_size=batch_size, use_cuda=use_cuda, level=level, sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], batch_type=batch_type, dataset_version=dataset_version, do_recognition=do_recognition, recognition_loss_function=recognition_loss_function if do_recognition else None, recognition_loss_weight=1 if do_recognition else None, recognition_beam_size=1 if do_recognition else None, do_translation=do_translation, translation_loss_function=translation_loss_function, translation_loss_weight=1, translation_max_output_length=translation_max_output_length, txt_pad_index=txt_vocab.stoi[PAD_TOKEN], translation_beam_size=tbw, translation_beam_alpha=ta, frame_subsampling_ratio=frame_subsampling_ratio, do_anchoring=do_anchoring, anchoring_loss_function=anchoring_loss_function if do_anchoring else None, anchoring_loss_weight=1 if do_anchoring else None, ) if (dev_translation_results[tbw][ta]["valid_scores"]["bleu"] > dev_best_bleu_score): dev_best_bleu_score = dev_translation_results[tbw][ta][ "valid_scores"]["bleu"] dev_best_translation_beam_size = tbw dev_best_translation_alpha = ta dev_best_translation_result = dev_translation_results[tbw][ ta] logger.info( "[DEV] partition [Translation] results:\n\t" "New Best Translation Beam Size: %d and Alpha: %d\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", dev_best_translation_beam_size, dev_best_translation_alpha, dev_best_translation_result["valid_scores"]["bleu"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu1"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu2"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu3"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu4"], dev_best_translation_result["valid_scores"]["chrf"], dev_best_translation_result["valid_scores"]["rouge"], ) logger.info("-" * 60) logger.info("*" * 60) logger.info( "[DEV] partition [Recognition & Translation] results:\n\t" "Best CTC Decode Beam Size: %d\n\t" "Best Translation Beam Size: %d and Alpha: %d\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", dev_best_recognition_beam_size if do_recognition else -1, dev_best_translation_beam_size if do_translation else -1, dev_best_translation_alpha if do_translation else -1, dev_best_recognition_result["valid_scores"]["wer"] if do_recognition else -1, dev_best_recognition_result["valid_scores"]["wer_scores"]["del_rate"] if do_recognition else -1, dev_best_recognition_result["valid_scores"]["wer_scores"]["ins_rate"] if do_recognition else -1, dev_best_recognition_result["valid_scores"]["wer_scores"]["sub_rate"] if do_recognition else -1, dev_best_translation_result["valid_scores"]["bleu"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu1"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu2"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu3"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu4"] if do_translation else -1, dev_best_translation_result["valid_scores"]["chrf"] if do_translation else -1, dev_best_translation_result["valid_scores"]["rouge"] if do_translation else -1, ) logger.info("*" * 60) def _write_to_file(file_path: str, sequence_ids: List[str], hypotheses: List[str]): with open(file_path, mode="w", encoding="utf-8") as out_file: for seq, hyp in zip(sequence_ids, hypotheses): out_file.write(seq + "|" + hyp + "\n") if output_path is not None: if do_recognition: dev_gls_output_path_set = "{}.BW_{:03d}.{}.gls".format( output_path, dev_best_recognition_beam_size, "dev") _write_to_file( dev_gls_output_path_set, [s for s in dev_data.sequence], dev_best_recognition_result["gls_hyp"], ) if do_translation: if dev_best_translation_beam_size > -1: dev_txt_output_path_set = "{}.BW_{:02d}.A_{:1d}.{}.txt".format( output_path, dev_best_translation_beam_size, dev_best_translation_alpha, "dev", ) else: dev_txt_output_path_set = "{}.BW_{:02d}.{}.txt".format( output_path, dev_best_translation_beam_size, "dev") _write_to_file( dev_txt_output_path_set, [s for s in dev_data.sequence], dev_best_translation_result["txt_hyp"], ) del dev_data # load dev data _, _, test_data, gls_vocab, txt_vocab = load_feat_data( data_cfg=cfg["data"], sets=['test'], dev_size=1, do_anchoring=do_anchoring) test_best_result = validate_on_feat_data( model=model, data=test_data, batch_size=batch_size, use_cuda=use_cuda, batch_type=batch_type, dataset_version=dataset_version, sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], txt_pad_index=txt_vocab.stoi[PAD_TOKEN], do_recognition=do_recognition, recognition_loss_function=recognition_loss_function if do_recognition else None, recognition_loss_weight=1 if do_recognition else None, recognition_beam_size=dev_best_recognition_beam_size if do_recognition else None, do_translation=do_translation, translation_loss_function=translation_loss_function if do_translation else None, translation_loss_weight=1 if do_translation else None, translation_max_output_length=translation_max_output_length if do_translation else None, level=level if do_translation else None, translation_beam_size=dev_best_translation_beam_size if do_translation else None, translation_beam_alpha=dev_best_translation_alpha if do_translation else None, frame_subsampling_ratio=frame_subsampling_ratio, do_anchoring=do_anchoring, anchoring_loss_function=anchoring_loss_function if do_anchoring else None, anchoring_loss_weight=1 if do_anchoring else None, ) logger.info( "[TEST] partition [Recognition & Translation] results:\n\t" "Best CTC Decode Beam Size: %d\n\t" "Best Translation Beam Size: %d and Alpha: %d\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", dev_best_recognition_beam_size if do_recognition else -1, dev_best_translation_beam_size if do_translation else -1, dev_best_translation_alpha if do_translation else -1, test_best_result["valid_scores"]["wer"] if do_recognition else -1, test_best_result["valid_scores"]["wer_scores"]["del_rate"] if do_recognition else -1, test_best_result["valid_scores"]["wer_scores"]["ins_rate"] if do_recognition else -1, test_best_result["valid_scores"]["wer_scores"]["sub_rate"] if do_recognition else -1, test_best_result["valid_scores"]["bleu"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu1"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu2"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu3"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu4"] if do_translation else -1, test_best_result["valid_scores"]["chrf"] if do_translation else -1, test_best_result["valid_scores"]["rouge"] if do_translation else -1, ) logger.info("*" * 60) if output_path is not None: if do_recognition: test_gls_output_path_set = "{}.BW_{:03d}.{}.gls".format( output_path, dev_best_recognition_beam_size, "test") _write_to_file( test_gls_output_path_set, [s for s in test_data.sequence], test_best_result["gls_hyp"], ) if do_translation: if dev_best_translation_beam_size > -1: test_txt_output_path_set = "{}.BW_{:02d}.A_{:1d}.{}.txt".format( output_path, dev_best_translation_beam_size, dev_best_translation_alpha, "test", ) else: test_txt_output_path_set = "{}.BW_{:02d}.{}.txt".format( output_path, dev_best_translation_beam_size, "test") _write_to_file( test_txt_output_path_set, [s for s in test_data.sequence], test_best_result["txt_hyp"], ) with open(output_path + ".dev_results.pkl", "wb") as out: pickle.dump( { "recognition_results": dev_recognition_results if do_recognition else None, "translation_results": dev_translation_results if do_translation else None, }, out, ) with open(output_path + ".test_results.pkl", "wb") as out: pickle.dump(test_best_result, out)