예제 #1
0
def evaluate(args):
    config = Config().from_file(args["config"])

    from data.interface_lol import prepare_data_loader_4_lol
    (train_loader, dev_loader, test_loader,
     test_4d_loader) = prepare_data_loader_4_lol(args,
                                                 args["data_path"],
                                                 training=False)

    from models.interface_lol import build_model_fn_4_lol
    model_fn = build_model_fn_4_lol(args, config)
    # for dev set
    scf = loll.LoLLearner().build(
        args, config,
        os.path.join(ENV_PATH, "lol_{}_trained".format(args["dataset"])),
        args["run_id"])

    if args["run_dev_testing"]:
        print("Development Set ...")
        eval_spec = model_fn(ModeKeys.EVAL, dev_loader)
        scf.evaluate(eval_spec)

    if args['except_domain'] != "" and args["run_except_4d"]:
        print("Test Set on 4 domains...")
        eval_spec = model_fn(ModeKeys.EVAL, test_4d_loader)
        scf.evaluate(eval_spec)

    # for test set
    print("Test Set ...")
    eval_spec = model_fn(ModeKeys.TEST, test_loader)
    scf.evaluate(eval_spec)
예제 #2
0
def evaluate(args):
    config = Config().from_file(args["config"])

    if args["dataset"] == "mwoz20":
        from data.mwoz20.interface_meta_rl_lb import prepare_data_loader_4_meta_rl
        (train_loader, dev_loader, test_loader,
         test_4d_loader) = prepare_data_loader_4_meta_rl(args,
                                                         args["data_path"],
                                                         training=False)
        from models.mwoz20.interface_meta_rl2 import build_model_fn_4_meta_rl
        model_fn = build_model_fn_4_meta_rl(args, config)
        # for dev set
        scf = meta_rl2.MetaRLLearner().build(
            args, config, os.path.join(ENV_PATH, "meta_rl2_trained"),
            args["run_id"])

        if args["run_dev_testing"]:
            print("Development Set ...")
            eval_spec = model_fn(ModeKeys.EVAL, dev_loader)
            scf.evaluate(eval_spec)

        if args['except_domain'] != "" and args["run_except_4d"]:
            print("Test Set on 4 domains...")
            eval_spec = model_fn(ModeKeys.EVAL, test_4d_loader)
            scf.evaluate(eval_spec)

        # for test set
        print("Test Set ...")
        eval_spec = model_fn(ModeKeys.TEST, test_loader)
        scf.evaluate(eval_spec)
    else:
        raise Exception("Unimplemented")
예제 #3
0
def train(args):
    config = Config().from_file(args["config"])

    from data.interface_lol import prepare_data_loader_4_lol
    (train_loader, dev_loader, test_loader,
        test_4d_loader) = prepare_data_loader_4_lol(
            args, args["data_path"], training=True)
    
    from models.interface_lol_c2f import build_model_fn_4_lol
    model_fn = build_model_fn_4_lol(args, config)
    train_spec = model_fn(ModeKeys.TRAIN, train_loader)
    eval_spec = model_fn(ModeKeys.EVAL, dev_loader)
    scf = loll.LoLLearner().build(
        args, config, os.path.join(
            ENV_PATH, "lol_{}_trained".format(args["dataset"])), args["run_id"])
    scf.train(train_spec, eval_spec)
예제 #4
0
def train(args):
    config = Config().from_file(args["config"])

    if args["dataset"] == "mwoz20":
        from data.mwoz20.interface_meta_rl_lb import prepare_data_loader_4_meta_rl
        (train_loader, dev_loader, test_loader,
         test_4d_loader) = prepare_data_loader_4_meta_rl(args,
                                                         args["data_path"],
                                                         training=True)
        from models.mwoz20.interface_meta_rl2 import build_model_fn_4_meta_rl
        model_fn = build_model_fn_4_meta_rl(args, config)
        train_spec = model_fn(ModeKeys.TRAIN, train_loader)
        eval_spec = model_fn(ModeKeys.EVAL, dev_loader)
        scf = meta_rl2.MetaRLLearner().build(
            args, config, os.path.join(ENV_PATH, "meta_rl2_trained"),
            args["run_id"])
        scf.train(train_spec, eval_spec)
    else:
        raise Exception("Unimplemented")