예제 #1
0
    def __init__(self,
                 output_path,
                 config,
                 hydra_featurizer: HydraFeaturizer,
                 model: BaseModel,
                 note=""):
        self.config = config
        self.model = model
        self.eval_history_file = os.path.join(output_path, "eval.log")
        self.bad_case_dir = os.path.join(output_path, "bad_cases")
        if "DEBUG" not in config:
            os.mkdir(self.bad_case_dir)
            with open(self.eval_history_file, "w", encoding="utf8") as f:
                f.write(note.rstrip() + "\n")

        self.eval_data = {}
        for eval_path in config["dev_data_path"].split(
                "|") + config["test_data_path"].split("|"):
            if "use_content" in config.keys(
            ) and config["use_content"] == "True":
                processed_data_path = eval_path +\
                    "_{}_{}_{}.pickle".format(
                        config["base_class"],
                        config["base_name"],
                        "filtered" if "filter_content" in config.keys(
                        ) and config["filter_content"] == "True" else "unfilt"
                    )
            else:
                processed_data_path = eval_path +\
                    "_{}_{}.pickle".format(
                        config["base_class"], config["base_name"])
            if os.path.exists(processed_data_path):
                eval_data = pickle.load(open(processed_data_path, "rb"))
                print("Loaded processed data from " + processed_data_path)
            else:
                eval_data = SQLDataset(eval_path, config, hydra_featurizer,
                                       True)
                pickle.dump(eval_data, open(processed_data_path, "wb"))
                print("Unloaded processed data to " + processed_data_path)
            self.eval_data[os.path.basename(eval_path)] = eval_data

            print("Eval Data file {0} loaded, sample num = {1}".format(
                eval_path, len(eval_data)))
예제 #2
0
    def __init__(self,
                 output_path,
                 config,
                 hydra_featurizer: HydraFeaturizer,
                 model: BaseModel,
                 note=""):
        self.config = config
        self.model = model
        self.eval_history_file = os.path.join(output_path, "eval.log")
        self.bad_case_dir = os.path.join(output_path, "bad_cases")
        if "DEBUG" not in config:
            os.mkdir(self.bad_case_dir)
            with open(self.eval_history_file, "w", encoding="utf8") as f:
                f.write(note.rstrip() + "\n")

        self.eval_data = {}
        for eval_path in config["dev_data_path"].split(
                "|") + config["test_data_path"].split("|"):
            eval_data = SQLDataset(eval_path, config, hydra_featurizer, True)
            self.eval_data[os.path.basename(eval_path)] = eval_data

            print("Eval Data file {0} loaded, sample num = {1}".format(
                eval_path, len(eval_data)))
예제 #3
0
        processed_data_path = config["train_data_path"] +\
            "_{}_{}".format(config["base_class"], config["base_name"])
    if is_meta:
        processed_data_path += "_meta_train.pickle"
    else:
        processed_data_path += ".pickle"

    if os.path.exists(processed_data_path):
        train_data = pickle.load(open(processed_data_path, "rb"))
        print("Loaded processed data from " + processed_data_path)
    else:
        if is_meta:
            train_data = featurizer.load_meta_data(config["train_data_path"],
                                                   config)
        else:
            train_data = SQLDataset(config["train_data_path"], config,
                                    featurizer, True)
        pickle.dump(train_data, open(processed_data_path, "wb"))
        print("Unloaded processed data to " + processed_data_path)

    if is_meta:
        print("Using meta training.")
        print("start training")
        epoch = 0
        while True:
            spt_sets, qry_sets = sample_column_wise_meta_data(
                train_data, config)
            cur_loss = model.meta_train_one_task(spt_sets, qry_sets,
                                                 model_path)
            if epoch % int(int(config["epochs"]) / 50) == 0 and epoch != 0:
                currentDT = datetime.datetime.now()
                print("[{2}] epoch {0}, task_loss={1:.4f}".format(
예제 #4
0
    out_file = f"output/test_out_ko_from_table_not_h_4_beam-{args.beam_size}_top-{args.topk}.jsonl"
    label_file = "WikiSQL/data/test.jsonl"
    db_file = "WikiSQL/data/test.db"
    model_out_file = f"output/test_model_out_ko_from_table_not_h_4_beam-{args.beam_size}_top-{args.topk}.pkl"

    ###================================================================================================###

    # All Best
    model_path = "output/20210505_235209"
    epoch = 4

    engine = DBEngine(db_file)
    config = utils.read_conf(os.path.join(model_path, "model.conf"))
    # config["DEBUG"] = 1
    featurizer = HydraFeaturizer(config)
    pred_data = SQLDataset(in_file, config, featurizer, False)
    print("num of samples: {0}".format(len(pred_data.input_features)))

    ##======================EG + TOP_k=============================##

    model = create_model(config, is_train=False)
    model.load(model_path, epoch)

    if "DEBUG" in config:
        model_out_file = model_out_file + ".partial"

    if os.path.exists(model_out_file):
        model_outputs = pickle.load(open(model_out_file, "rb"))
    else:
        model_outputs = model.dataset_inference(pred_data)
        pickle.dump(model_outputs, open(model_out_file, "wb"))
def execute_one_test(dataset, shot, model_moment, epoch):
    os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"
    model_path = "output/" + model_moment

    in_file = "data/wiki{}_content.jsonl".format(
        dataset) if shot == "orig" else "data/wiki{}_{}_content.jsonl".format(
            dataset, shot)
    db_file = "WikiSQL/data/{}.db".format(dataset)
    label_file = "WikiSQL/data/{}.jsonl".format(
        dataset) if shot == "orig" else "WikiSQL/data_{}/{}.jsonl".format(
            shot, dataset)
    out_path = "predictions/{}_{}_{}_{}".format(model_moment, epoch, dataset,
                                                shot)
    if not os.path.exists(out_path):
        os.mkdir(out_path)
    out_file = os.path.join(out_path, "out.jsonl")
    eg_out_file = os.path.join(out_path, "out_eg.jsonl")
    model_out_file = os.path.join(out_path, "model_out.pkl")
    test_result_file = os.path.join(out_path, "result.txt")

    engine = DBEngine(db_file)
    config = utils.read_conf(os.path.join(model_path, "model.conf"))
    # config["DEBUG"] = 1
    featurizer = HydraFeaturizer(config)
    pred_data = SQLDataset(in_file, config, featurizer, False)
    print("num of samples: {0}".format(len(pred_data.input_features)))

    model = create_model(config, is_train=False)
    model.load(model_path, epoch)

    if "DEBUG" in config:
        model_out_file = model_out_file + ".partial"
    model_outputs = model.dataset_inference(pred_data)
    pickle.dump(model_outputs, open(model_out_file, "wb"))
    # model_outputs = pickle.load(open(model_out_file, "rb"))

    print("===HydraNet===")
    pred_sqls = model.predict_SQL(pred_data, model_outputs=model_outputs)
    with open(out_file, "w") as g:
        for pred_sql in pred_sqls:
            # print(pred_sql)
            result = {"query": {}}
            result["query"]["agg"] = int(pred_sql[0])
            result["query"]["sel"] = int(pred_sql[1])
            result["query"]["conds"] = [(int(cond[0]), int(cond[1]),
                                         str(cond[2])) for cond in pred_sql[2]]
            g.write(json.dumps(result) + "\n")
    normal_res = print_metric(label_file, out_file, db_file)

    print("===HydraNet+EG===")
    pred_sqls = model.predict_SQL_with_EG(engine,
                                          pred_data,
                                          model_outputs=model_outputs)
    with open(eg_out_file, "w") as g:
        for pred_sql in pred_sqls:
            # print(pred_sql)
            result = {"query": {}}
            result["query"]["agg"] = int(pred_sql[0])
            result["query"]["sel"] = int(pred_sql[1])
            result["query"]["conds"] = [(int(cond[0]), int(cond[1]),
                                         str(cond[2])) for cond in pred_sql[2]]
            g.write(json.dumps(result) + "\n")
    eg_res = print_metric(label_file, eg_out_file, db_file)

    with open(test_result_file, "w") as g:
        g.write("normal results:\n" + normal_res + "eg results:\n" + eg_res)