Esempio n. 1
0
def test_seq_copy(train_params, translate_params, perplexity_thresh,
                  bleu_thresh):
    """Task: copy short sequences of digits"""
    with TemporaryDirectory(prefix="test_seq_copy.") as work_dir:
        # Simple digits files for train/dev data
        train_source_path = os.path.join(work_dir, "train.src")
        train_target_path = os.path.join(work_dir, "train.tgt")
        dev_source_path = os.path.join(work_dir, "dev.src")
        dev_target_path = os.path.join(work_dir, "dev.tgt")
        generate_digits_file(train_source_path, train_target_path,
                             _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH)
        generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT,
                             _LINE_MAX_LENGTH)
        # Test model configuration
        perplexity, bleu = run_train_translate(train_params,
                                               translate_params,
                                               train_source_path,
                                               train_target_path,
                                               dev_source_path,
                                               dev_target_path,
                                               max_seq_len=_LINE_MAX_LENGTH +
                                               1,
                                               work_dir=work_dir)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
Esempio n. 2
0
def test_seq_copy(train_params, translate_params):
    """Task: copy short sequences of digits"""
    with TemporaryDirectory(prefix="test_seq_copy") as work_dir:
        # Simple digits files for train/dev data
        train_source_path = os.path.join(work_dir, "train.src")
        train_target_path = os.path.join(work_dir, "train.tgt")
        dev_source_path = os.path.join(work_dir, "dev.src")
        dev_target_path = os.path.join(work_dir, "dev.tgt")
        generate_digits_file(train_source_path, train_target_path, _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH)
        generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT, _LINE_MAX_LENGTH)
        # Test model configuration
        # Ignore return values (perplexity and BLEU) for integration test
        run_train_translate(train_params,
                            translate_params,
                            train_source_path,
                            train_target_path,
                            dev_source_path,
                            dev_target_path,
                            max_seq_len=_LINE_MAX_LENGTH + 1,
                            work_dir=work_dir)