예제 #1
0
파일: test.py 프로젝트: OpenNMT/CTranslate2
def test_opennmt_tf_variables_conversion(tmpdir):
    import opennmt

    model_path = os.path.join(
        _TEST_DATA_DIR,
        "models",
        "transliteration-aren-all",
        "opennmt_tf",
        "v2",
        "checkpoint",
    )

    src_vocab = opennmt.data.Vocab.from_file(
        os.path.join(model_path, "ar.vocab"))
    tgt_vocab = opennmt.data.Vocab.from_file(
        os.path.join(model_path, "en.vocab"))
    _, variables = opennmt_tf.load_model(model_path)
    converter = ctranslate2.converters.OpenNMTTFConverter(
        ctranslate2.specs.TransformerSpec(6, 8),
        src_vocab,
        tgt_vocab,
        variables=variables,
    )
    output_dir = str(tmpdir.join("ctranslate2_model"))
    converter.convert(output_dir)
    translator = ctranslate2.Translator(output_dir)
    output = translator.translate_batch([["آ", "ت", "ز", "م", "و", "ن"]])
    assert output[0].hypotheses[0] == ["a", "t", "z", "m", "o", "n"]
예제 #2
0
파일: test.py 프로젝트: eduamf/CTranslate2
def test_opennmt_tf_variables_conversion(tmpdir):
    model_path = os.path.join(
        _TEST_DATA_DIR, "models", "transliteration-aren-all", "opennmt_tf", "v2", "checkpoint")
    _, variables, src_vocab, tgt_vocab = opennmt_tf.load_model(
        model_path,
        src_vocab=os.path.join(model_path, "ar.vocab"),
        tgt_vocab=os.path.join(model_path, "en.vocab"))
    converter = ctranslate2.converters.OpenNMTTFConverter(
        src_vocab=src_vocab, tgt_vocab=tgt_vocab, variables=variables)
    output_dir = str(tmpdir.join("ctranslate2_model"))
    converter.convert(output_dir, ctranslate2.specs.TransformerBase())
    translator = ctranslate2.Translator(output_dir)
    output = translator.translate_batch([["آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"]])
    assert output[0][0]["tokens"] == ["a", "t", "z", "m", "o", "n"]