def train(cfg_file: str) -> None: """ Main training function. After training, also test on test data if given. :param cfg_file: path to configuration yaml file """ cfg = load_config(cfg_file) # set the random seed set_seed(seed=cfg["training"].get("random_seed", 42)) train_data, dev_data, test_data, gls_vocab, txt_vocab = load_data( data_cfg=cfg["data"]) # build model and load parameters into it do_recognition = cfg["training"].get("recognition_loss_weight", 1.0) > 0.0 do_translation = cfg["training"].get("translation_loss_weight", 1.0) > 0.0 model = build_model( cfg=cfg["model"], gls_vocab=gls_vocab, txt_vocab=txt_vocab, sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], do_recognition=do_recognition, do_translation=do_translation, ) # for training management, e.g. early stopping and model selection trainer = TrainManager(model=model, config=cfg) # store copy of original training config in model dir shutil.copy2(cfg_file, trainer.model_dir + "/config.yaml") # log all entries of config log_cfg(cfg, trainer.logger) log_data_info( train_data=train_data, valid_data=dev_data, test_data=test_data, gls_vocab=gls_vocab, txt_vocab=txt_vocab, logging_function=trainer.logger.info, ) trainer.logger.info(str(model)) # store the vocabs gls_vocab_file = "{}/gls.vocab".format(cfg["training"]["model_dir"]) gls_vocab.to_file(gls_vocab_file) txt_vocab_file = "{}/txt.vocab".format(cfg["training"]["model_dir"]) txt_vocab.to_file(txt_vocab_file) # train the model trainer.train_and_validate(train_data=train_data, valid_data=dev_data) # Delete to speed things up as we don't need training data anymore del train_data, dev_data, test_data # predict with the best model on validation and test # (if test data is available) ckpt = "{}/{}.ckpt".format(trainer.model_dir, trainer.best_ckpt_iteration) output_name = "best.IT_{:08d}".format(trainer.best_ckpt_iteration) output_path = os.path.join(trainer.model_dir, output_name) logger = trainer.logger del trainer test(cfg_file, ckpt=ckpt, output_path=output_path, logger=logger)
def test(cfg_file, ckpt: str, output_path: str = None, logger: logging.Logger = None) -> None: """ Main test function. Handles loading a model from checkpoint, generating translations and storing them and attention plots. :param cfg_file: path to configuration file :param ckpt: path to checkpoint to load :param output_path: path to output :param logger: log output to this logger (creates new logger if not set) """ if logger is None: logger = logging.getLogger(__name__) if not logger.handlers: FORMAT = "%(asctime)-15s - %(message)s" logging.basicConfig(format=FORMAT) logger.setLevel(level=logging.DEBUG) cfg = load_config(cfg_file) if "test" not in cfg["data"].keys(): raise ValueError("Test data must be specified in config.") # when checkpoint is not specified, take latest (best) from model dir if ckpt is None: model_dir = cfg["training"]["model_dir"] ckpt = get_latest_checkpoint(model_dir) if ckpt is None: raise FileNotFoundError( "No checkpoint found in directory {}.".format(model_dir)) batch_size = cfg["training"]["batch_size"] batch_type = cfg["training"].get("batch_type", "sentence") use_cuda = cfg["training"].get("use_cuda", False) level = cfg["data"]["level"] dataset_version = cfg["data"].get("version", "phoenix_2014_trans") translation_max_output_length = cfg["training"].get( "translation_max_output_length", None) # load the data _, dev_data, test_data, gls_vocab, txt_vocab = load_data( data_cfg=cfg["data"]) # load model state from disk model_checkpoint = load_checkpoint(ckpt, use_cuda=use_cuda) # build model and load parameters into it do_recognition = cfg["training"].get("recognition_loss_weight", 1.0) > 0.0 do_translation = cfg["training"].get("translation_loss_weight", 1.0) > 0.0 if cfg["fusion_type"] == 'early_fusion': add_dim = 2 * 84 + 2 * 21 + 2 * 13 else: add_dim = 0 model = build_model( cfg=cfg["model"], gls_vocab=gls_vocab, txt_vocab=txt_vocab, sgn_dim=sum(cfg["data"]["feature_size"]) + add_dim if isinstance(cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"] + add_dim, do_recognition=do_recognition, do_translation=do_translation, ) model.load_state_dict(model_checkpoint["model_state"]) if use_cuda: model.cuda() # Data Augmentation Parameters frame_subsampling_ratio = cfg["data"].get("frame_subsampling_ratio", None) # Note (Cihan): we are not using 'random_frame_subsampling' and # 'random_frame_masking_ratio' in testing as they are just for training. # whether to use beam search for decoding, 0: greedy decoding if "testing" in cfg.keys(): recognition_beam_sizes = cfg["testing"].get("recognition_beam_sizes", [1]) translation_beam_sizes = cfg["testing"].get("translation_beam_sizes", [1]) translation_beam_alphas = cfg["testing"].get("translation_beam_alphas", [-1]) else: recognition_beam_sizes = [1] translation_beam_sizes = [1] translation_beam_alphas = [-1] if "testing" in cfg.keys(): max_recognition_beam_size = cfg["testing"].get( "max_recognition_beam_size", None) if max_recognition_beam_size is not None: recognition_beam_sizes = list( range(1, max_recognition_beam_size + 1)) if do_recognition: recognition_loss_function = torch.nn.CTCLoss( blank=model.gls_vocab.stoi[SIL_TOKEN], zero_infinity=True) if use_cuda: recognition_loss_function.cuda() if do_translation: translation_loss_function = XentLoss( pad_index=txt_vocab.stoi[PAD_TOKEN], smoothing=0.0) if use_cuda: translation_loss_function.cuda() # NOTE (Cihan): Currently Hardcoded to be 0 for TensorFlow decoding assert model.gls_vocab.stoi[SIL_TOKEN] == 0 if do_recognition: # Dev Recognition CTC Beam Search Results dev_recognition_results = {} dev_best_wer_score = float("inf") dev_best_recognition_beam_size = 1 for rbw in recognition_beam_sizes: logger.info("-" * 60) valid_start_time = time.time() logger.info("[DEV] partition [RECOGNITION] experiment [BW]: %d", rbw) dev_recognition_results[rbw] = validate_on_data( model=model, data=dev_data, batch_size=batch_size, use_cuda=use_cuda, batch_type=batch_type, dataset_version=dataset_version, sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], txt_pad_index=txt_vocab.stoi[PAD_TOKEN], fusion_type=cfg["model"]["fusion_type"], # Recognition Parameters do_recognition=do_recognition, recognition_loss_function=recognition_loss_function, recognition_loss_weight=1, recognition_beam_size=rbw, # Translation Parameters do_translation=do_translation, translation_loss_function=translation_loss_function if do_translation else None, translation_loss_weight=1 if do_translation else None, translation_max_output_length=translation_max_output_length if do_translation else None, level=level if do_translation else None, translation_beam_size=1 if do_translation else None, translation_beam_alpha=-1 if do_translation else None, frame_subsampling_ratio=frame_subsampling_ratio, ) logger.info("finished in %.4fs ", time.time() - valid_start_time) if dev_recognition_results[rbw]["valid_scores"][ "wer"] < dev_best_wer_score: dev_best_wer_score = dev_recognition_results[rbw][ "valid_scores"]["wer"] dev_best_recognition_beam_size = rbw dev_best_recognition_result = dev_recognition_results[rbw] logger.info("*" * 60) logger.info( "[DEV] partition [RECOGNITION] results:\n\t" "New Best CTC Decode Beam Size: %d\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)", dev_best_recognition_beam_size, dev_best_recognition_result["valid_scores"]["wer"], dev_best_recognition_result["valid_scores"]["wer_scores"] ["del_rate"], dev_best_recognition_result["valid_scores"]["wer_scores"] ["ins_rate"], dev_best_recognition_result["valid_scores"]["wer_scores"] ["sub_rate"], ) logger.info("*" * 60) if do_translation: logger.info("=" * 60) dev_translation_results = {} dev_best_bleu_score = float("-inf") dev_best_translation_beam_size = 1 dev_best_translation_alpha = 1 for tbw in translation_beam_sizes: dev_translation_results[tbw] = {} for ta in translation_beam_alphas: dev_translation_results[tbw][ta] = validate_on_data( model=model, data=dev_data, batch_size=batch_size, use_cuda=use_cuda, level=level, fusion_type=cfg["model"]["fusion_type"], sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], batch_type=batch_type, dataset_version=dataset_version, do_recognition=do_recognition, recognition_loss_function=recognition_loss_function if do_recognition else None, recognition_loss_weight=1 if do_recognition else None, recognition_beam_size=1 if do_recognition else None, do_translation=do_translation, translation_loss_function=translation_loss_function, translation_loss_weight=1, translation_max_output_length=translation_max_output_length, txt_pad_index=txt_vocab.stoi[PAD_TOKEN], translation_beam_size=tbw, translation_beam_alpha=ta, frame_subsampling_ratio=frame_subsampling_ratio, ) if (dev_translation_results[tbw][ta]["valid_scores"]["bleu"] > dev_best_bleu_score): dev_best_bleu_score = dev_translation_results[tbw][ta][ "valid_scores"]["bleu"] dev_best_translation_beam_size = tbw dev_best_translation_alpha = ta dev_best_translation_result = dev_translation_results[tbw][ ta] logger.info( "[DEV] partition [Translation] results:\n\t" "New Best Translation Beam Size: %d and Alpha: %d\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", dev_best_translation_beam_size, dev_best_translation_alpha, dev_best_translation_result["valid_scores"]["bleu"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu1"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu2"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu3"], dev_best_translation_result["valid_scores"] ["bleu_scores"]["bleu4"], dev_best_translation_result["valid_scores"]["chrf"], dev_best_translation_result["valid_scores"]["rouge"], ) logger.info("-" * 60) logger.info("*" * 60) logger.info( "[DEV] partition [Recognition & Translation] results:\n\t" "Best CTC Decode Beam Size: %d\n\t" "Best Translation Beam Size: %d and Alpha: %d\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", dev_best_recognition_beam_size if do_recognition else -1, dev_best_translation_beam_size if do_translation else -1, dev_best_translation_alpha if do_translation else -1, dev_best_recognition_result["valid_scores"]["wer"] if do_recognition else -1, dev_best_recognition_result["valid_scores"]["wer_scores"]["del_rate"] if do_recognition else -1, dev_best_recognition_result["valid_scores"]["wer_scores"]["ins_rate"] if do_recognition else -1, dev_best_recognition_result["valid_scores"]["wer_scores"]["sub_rate"] if do_recognition else -1, dev_best_translation_result["valid_scores"]["bleu"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu1"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu2"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu3"] if do_translation else -1, dev_best_translation_result["valid_scores"]["bleu_scores"]["bleu4"] if do_translation else -1, dev_best_translation_result["valid_scores"]["chrf"] if do_translation else -1, dev_best_translation_result["valid_scores"]["rouge"] if do_translation else -1, ) logger.info("*" * 60) test_best_result = validate_on_data( model=model, data=test_data, batch_size=batch_size, use_cuda=use_cuda, batch_type=batch_type, dataset_version=dataset_version, fusion_type=cfg["model"]["fusion_type"], sgn_dim=sum(cfg["data"]["feature_size"]) if isinstance( cfg["data"]["feature_size"], list) else cfg["data"]["feature_size"], txt_pad_index=txt_vocab.stoi[PAD_TOKEN], do_recognition=do_recognition, recognition_loss_function=recognition_loss_function if do_recognition else None, recognition_loss_weight=1 if do_recognition else None, recognition_beam_size=dev_best_recognition_beam_size if do_recognition else None, do_translation=do_translation, translation_loss_function=translation_loss_function if do_translation else None, translation_loss_weight=1 if do_translation else None, translation_max_output_length=translation_max_output_length if do_translation else None, level=level if do_translation else None, translation_beam_size=dev_best_translation_beam_size if do_translation else None, translation_beam_alpha=dev_best_translation_alpha if do_translation else None, frame_subsampling_ratio=frame_subsampling_ratio, ) logger.info( "[TEST] partition [Recognition & Translation] results:\n\t" "Best CTC Decode Beam Size: %d\n\t" "Best Translation Beam Size: %d and Alpha: %d\n\t" "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t" "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t" "CHRF %.2f\t" "ROUGE %.2f", dev_best_recognition_beam_size if do_recognition else -1, dev_best_translation_beam_size if do_translation else -1, dev_best_translation_alpha if do_translation else -1, test_best_result["valid_scores"]["wer"] if do_recognition else -1, test_best_result["valid_scores"]["wer_scores"]["del_rate"] if do_recognition else -1, test_best_result["valid_scores"]["wer_scores"]["ins_rate"] if do_recognition else -1, test_best_result["valid_scores"]["wer_scores"]["sub_rate"] if do_recognition else -1, test_best_result["valid_scores"]["bleu"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu1"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu2"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu3"] if do_translation else -1, test_best_result["valid_scores"]["bleu_scores"]["bleu4"] if do_translation else -1, test_best_result["valid_scores"]["chrf"] if do_translation else -1, test_best_result["valid_scores"]["rouge"] if do_translation else -1, ) logger.info("*" * 60) def _write_to_file(file_path: str, sequence_ids: List[str], hypotheses: List[str]): with open(file_path, mode="w", encoding="utf-8") as out_file: for seq, hyp in zip(sequence_ids, hypotheses): out_file.write(seq + "|" + hyp + "\n") if output_path is not None: if do_recognition: dev_gls_output_path_set = "{}.BW_{:03d}.{}.gls".format( output_path, dev_best_recognition_beam_size, "dev") _write_to_file( dev_gls_output_path_set, [s for s in dev_data.sequence], dev_best_recognition_result["gls_hyp"], ) test_gls_output_path_set = "{}.BW_{:03d}.{}.gls".format( output_path, dev_best_recognition_beam_size, "test") _write_to_file( test_gls_output_path_set, [s for s in test_data.sequence], test_best_result["gls_hyp"], ) if do_translation: if dev_best_translation_beam_size > -1: dev_txt_output_path_set = "{}.BW_{:02d}.A_{:1d}.{}.txt".format( output_path, dev_best_translation_beam_size, dev_best_translation_alpha, "dev", ) test_txt_output_path_set = "{}.BW_{:02d}.A_{:1d}.{}.txt".format( output_path, dev_best_translation_beam_size, dev_best_translation_alpha, "test", ) else: dev_txt_output_path_set = "{}.BW_{:02d}.{}.txt".format( output_path, dev_best_translation_beam_size, "dev") test_txt_output_path_set = "{}.BW_{:02d}.{}.txt".format( output_path, dev_best_translation_beam_size, "test") _write_to_file( dev_txt_output_path_set, [s for s in dev_data.sequence], dev_best_translation_result["txt_hyp"], ) _write_to_file( test_txt_output_path_set, [s for s in test_data.sequence], test_best_result["txt_hyp"], ) with open(output_path + ".dev_results.pkl", "wb") as out: pickle.dump( { "recognition_results": dev_recognition_results if do_recognition else None, "translation_results": dev_translation_results if do_translation else None, }, out, ) with open(output_path + ".test_results.pkl", "wb") as out: pickle.dump(test_best_result, out)