예제 #1
0
def test_train_base_config_argparse():
    train_config = args.config(argv=["--base_config", "test/bidaf"],
                               mode=Mode.TRAIN)

    config = NestedNamespace()
    with open("base_config/test/bidaf.json", "r") as f:
        defined_config = json.load(f)
    config.load_from_json(defined_config)
    args.set_gpu_env(config)

    assert train_config == config
예제 #2
0
파일: machine.py 프로젝트: mercileesb/claf
# -*- coding: utf-8 -*-

import json

from claf.config import args
from claf.config.registry import Registry
from claf.learn.mode import Mode
from claf import utils as common_utils

if __name__ == "__main__":
    registry = Registry()

    machine_config = args.config(mode=Mode.MACHINE)
    machine_name = machine_config.name
    config = getattr(machine_config, machine_name, {})

    claf_machine = registry.get(f"machine:{machine_name}")(config)

    while True:
        question = common_utils.get_user_input(
            f"{getattr(machine_config, 'user_input', 'Question')}")
        answer = claf_machine.get_answer(question)
        answer = json.dumps(answer, indent=4, ensure_ascii=False)
        print(
            f"{getattr(machine_config, 'system_response', 'Answer')}: {answer}"
        )
예제 #3
0
def test_train_argparse():
    train_config = args.config(argv=["--seed_num", "4"], mode=Mode.TRAIN)

    assert train_config.seed_num == 4
예제 #4
0
def test_machine_argparse():
    machine_config = args.config(argv=["--machine_config", "ko_wiki"],
                                 mode=Mode.MACHINE)
    print(machine_config)
예제 #5
0
def test_predict_argparse():
    predict_config = args.config(argv=["checkpoint_path"], mode=Mode.PREDICT)
    print(predict_config)
예제 #6
0
def test_eval_argparse():
    eval_config = args.config(argv=["data_path", "checkpoint_path"],
                              mode=Mode.EVAL)
    print(eval_config)
예제 #7
0
# -*- coding: utf-8 -*-

from claf.config import args
from claf.learn.experiment import Experiment
from claf.learn.mode import Mode


if __name__ == "__main__":
    experiment = Experiment(Mode.TRAIN, args.config(mode=Mode.TRAIN))
    experiment()
예제 #8
0
파일: eval.py 프로젝트: zzozzolev/claf
# -*- coding: utf-8 -*-


from claf.config import args
from claf.learn.experiment import Experiment
from claf.learn.mode import Mode


if __name__ == "__main__":
    config = args.config(mode=Mode.EVAL)

    mode = Mode.EVAL
    if config.inference_latency: # evaluate inference_latency
        mode = Mode.INFER_EVAL

    experiment = Experiment(mode, config)
    experiment()
예제 #9
0
파일: predict.py 프로젝트: zzozzolev/claf
# -*- coding: utf-8 -*-


from claf.config import args
from claf.learn.experiment import Experiment
from claf.learn.mode import Mode


if __name__ == "__main__":
    experiment = Experiment(Mode.PREDICT, args.config(mode=Mode.PREDICT))
    result = experiment()

    print(f"Predict: {result}")