コード例 #1
0
def test_transformer_ensemble_inference(test_dir):
    from src.bin import ensemble_translate
    from src.utils.common_utils import GlobalNames
    config_path = "./unittests/configs/test_odc_loss.yaml"

    saveto = os.path.join(test_dir, "save")
    model_name = test_utils.get_model_name(config_path)

    model_path = os.path.join(
        saveto, model_name + GlobalNames.MY_BEST_MODEL_SUFFIX + ".final")
    model_path = [model_path for _ in range(3)]

    source_path = "./unittests/data/dev/zh.0"
    batch_size = 3
    beam_size = 3

    ensemble_translate.run(model_name=model_name,
                           source_path=source_path,
                           batch_size=batch_size,
                           beam_size=beam_size,
                           model_path=model_path,
                           use_gpu=False,
                           config_path=config_path,
                           saveto=saveto,
                           max_steps=20)
コード例 #2
0
def test_transformer_inference(test_dir, use_gpu=False):
    from src.bin import translate
    from src.utils.common_utils import GlobalNames
    config_path = "./unittests/configs/test_transformer.yaml"

    saveto = os.path.join(test_dir, "save")
    model_name = test_utils.get_model_name(config_path)
    model_path = os.path.join(saveto, model_name + GlobalNames.MY_BEST_MODEL_SUFFIX + ".final")
    source_path = "./unittests/data/dev/zh.0"
    reference_path = "./unitests/data/dev/en"
    batch_size = 3
    beam_size = 3
    alpha = 0.6

    translate.run(model_name=model_name,
                  source_path=source_path,
                  reference_path=reference_path,
                  batch_size=batch_size,
                  beam_size=beam_size,
                  model_path=model_path,
                  use_gpu=False,
                  config_path=config_path,
                  saveto=saveto,
                  max_steps=20,
                  alpha=alpha)
コード例 #3
0
def test_transformer_train(test_dir):
    from src.bin import train_weighted_ocd

    config_path = "./unittests/configs/test_odc_loss.yaml"
    model_name = test_utils.get_model_name(config_path)

    saveto = os.path.join(test_dir, "save")
    log_path = os.path.join(test_dir, "log")
    valid_path = os.path.join(test_dir, "valid")

    train_weighted_ocd.run(model_name=model_name,
                           config_path=config_path,
                           saveto=saveto,
                           log_path=log_path,
                           valid_path=valid_path,
                           debug=True)
コード例 #4
0
def test_transformer_train(test_dir, use_gpu=False):
    from src.bin import train

    config_path = "./unittests/configs/test_transformer.yaml"
    model_name = test_utils.get_model_name(config_path)

    saveto = os.path.join(test_dir, "save")
    log_path = os.path.join(test_dir, "log")
    valid_path = os.path.join(test_dir, "valid")

    train.run(model_name=model_name,
              config_path=config_path,
              saveto=saveto,
              log_path=log_path,
              valid_path=valid_path,
              debug=True)
コード例 #5
0
def test_dl4mt_inference(test_dir):
    from src.bin import translate
    from src.utils.common_utils import GlobalNames

    config_path = "./unittests/configs/test_dl4mt.yaml"

    saveto = os.path.join(test_dir, "save")
    model_name = test_utils.get_model_name(config_path)
    model_path = os.path.join(saveto, model_name + GlobalNames.MY_BEST_MODEL_SUFFIX)
    source_path = "./unittests/data/dev/zh.0"
    batch_size = 3
    beam_size = 3

    translate.run(model_name=model_name,
                  source_path=source_path,
                  batch_size=batch_size,
                  beam_size=beam_size,
                  model_path=model_path,
                  use_gpu=False,
                  config_path=config_path,
                  saveto=saveto,
                  max_steps=20)
コード例 #6
0
ファイル: test_transformer.py プロジェクト: wangqi1996/njunmt
def test_transformer_greedy_search(test_dir, use_gpu=False):
    from src.bin import translate
    from src.utils.common_utils import Constants
    config_path = "./unittests/configs/test_transformer.yaml"

    saveto = os.path.join(test_dir, "save")
    model_name = test_utils.get_model_name(config_path)
    model_path = os.path.join(
        saveto, model_name + Constants.MY_BEST_MODEL_SUFFIX + ".final")
    source_path = "./unittests/data/dev.de"
    batch_size = 3
    beam_size = 1
    alpha = 0.6

    translate.run(model_name=model_name,
                  source_path=source_path,
                  batch_size=batch_size,
                  beam_size=beam_size,
                  model_path=model_path,
                  use_gpu=False,
                  config_path=config_path,
                  saveto=saveto,
                  max_steps=10,
                  alpha=alpha)