def test_seq_copy(train_params: str, translate_params: str, use_prepared_data: bool, use_source_factors: bool): """ Task: copy short sequences of digits """ with tmp_digits_dataset(prefix="test_seq_copy", train_line_count=_TRAIN_LINE_COUNT, train_line_count_empty=_TRAIN_LINE_COUNT_EMPTY, train_max_length=_LINE_MAX_LENGTH, dev_line_count=_DEV_LINE_COUNT, dev_max_length=_LINE_MAX_LENGTH, test_line_count=_TEST_LINE_COUNT, test_line_count_empty=_TEST_LINE_COUNT_EMPTY, test_max_length=_TEST_MAX_LENGTH, sort_target=False, with_source_factors=use_source_factors) as data: # TODO: Here we temporarily switch off comparing translation and scoring scores, which # sometimes produces inconsistent results for --batch-size > 1 (see issue #639 on github). check_train_translate(train_params=train_params, translate_params=translate_params, data=data, use_prepared_data=use_prepared_data, max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS, compare_output=False)
def test_seq_sort(name, train_params, translate_params, use_prepared_data, n_source_factors, n_target_factors, perplexity_thresh, bleu_thresh): """Task: sort short sequences of digits""" with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _TRAIN_LINE_COUNT_EMPTY, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH, _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH, sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA, with_n_source_factors=n_source_factors, with_n_target_factors=n_target_factors) as data: data = check_train_translate(train_params=train_params, translate_params=translate_params, data=data, use_prepared_data=use_prepared_data, max_seq_len=_LINE_MAX_LENGTH, compare_output=True, seed=seed) # get best validation perplexity metrics = sockeye.utils.read_metrics_file(os.path.join(data['model'], C.METRICS_NAME)) perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics) # compute metrics hypotheses = [json['translation'] for json in data['test_outputs']] hypotheses_restricted = [json['translation'] for json in data['test_outputs_restricted']] bleu = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses, references=data['test_targets']) chrf = sockeye.evaluate.raw_corpus_chrf(hypotheses=hypotheses, references=data['test_targets']) bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses_restricted, references=data['test_targets']) logger.info("test: %s", name) logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf) assert perplexity <= perplexity_thresh assert bleu >= bleu_thresh assert bleu_restrict >= bleu_thresh
def test_seq_copy(name, train_params, translate_params, use_prepared_data, perplexity_thresh, bleu_thresh): """Task: copy short sequences of digits""" with tmp_digits_dataset(prefix="test_seq_copy", train_line_count=_TRAIN_LINE_COUNT, train_line_count_empty=_TRAIN_LINE_COUNT_EMPTY, train_max_length=_LINE_MAX_LENGTH, dev_line_count=_DEV_LINE_COUNT, dev_max_length=_LINE_MAX_LENGTH, test_line_count=_TEST_LINE_COUNT, test_line_count_empty=_TEST_LINE_COUNT_EMPTY, test_max_length=_TEST_MAX_LENGTH, sort_target=False, with_n_source_factors=0) as data: data = check_train_translate(train_params=train_params, translate_params=translate_params, data=data, use_prepared_data=use_prepared_data, max_seq_len=_LINE_MAX_LENGTH, compare_output=True, seed=seed) # get best validation perplexity metrics = sockeye.utils.read_metrics_file( os.path.join(data['model'], C.METRICS_NAME)) perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics) # compute metrics hypotheses = [json['translation'] for json in data['test_outputs']] bleu = sockeye.evaluate.raw_corpus_bleu( hypotheses=hypotheses, references=data['test_targets']) chrf = sockeye.evaluate.raw_corpus_chrf( hypotheses=hypotheses, references=data['test_targets']) if 'test_outputs_restricted' in data: hypotheses_restricted = [ json['translation'] for json in data['test_outputs_restricted'] ] bleu_restrict = sockeye.evaluate.raw_corpus_bleu( hypotheses=hypotheses_restricted, references=data['test_targets']) else: bleu_restrict = None logger.info("================") logger.info("test results: %s", name) logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf) logger.info("================\n") assert perplexity <= perplexity_thresh assert bleu >= bleu_thresh if bleu_restrict is not None: assert bleu_restrict >= bleu_thresh
def test_seq_copy(train_params: str, translate_params: str, use_prepared_data: bool, use_source_factors: bool): """ Task: copy short sequences of digits """ with tmp_digits_dataset(prefix="test_seq_copy", train_line_count=_TRAIN_LINE_COUNT, train_max_length=_LINE_MAX_LENGTH, dev_line_count=_DEV_LINE_COUNT, dev_max_length=_LINE_MAX_LENGTH, test_line_count=_TEST_LINE_COUNT, test_line_count_empty=_TEST_LINE_COUNT_EMPTY, test_max_length=_TEST_MAX_LENGTH, sort_target=False, with_source_factors=use_source_factors) as data: check_train_translate(train_params=train_params, translate_params=translate_params, data=data, use_prepared_data=use_prepared_data, max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS)