예제 #1
0
def test(args, model, eval_examples, chunk_size=1000):
    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)
    retr_res_path = os.path.join(args.output_dir, "raw_result.csv")
    cache_file = "cached_single_test.dat"
    if args.overwrite or not os.path.isfile(cache_file):
        chunked_retrivial_examples = eval_examples.get_chunked_retrivial_task_examples(
            chunk_query_num=args.chunk_query_num, chunk_size=chunk_size)
        torch.save(chunked_retrivial_examples, cache_file)
    else:
        chunked_retrivial_examples = torch.load(cache_file)
    retrival_dataloader = DataLoader(chunked_retrivial_examples,
                                     batch_size=args.per_gpu_eval_batch_size)

    res = []
    for batch in tqdm(retrival_dataloader, desc="retrival evaluation"):
        nl_ids = batch[0]
        pl_ids = batch[1]
        labels = batch[2]
        with torch.no_grad():
            model.eval()
            inputs = format_batch_input_for_single_bert(
                batch, eval_examples, model)
            sim_score = model.get_sim_score(**inputs)
            for n, p, prd, lb in zip(nl_ids.tolist(), pl_ids.tolist(),
                                     sim_score, labels.tolist()):
                res.append((n, p, prd, lb))

    df = results_to_df(res)
    df.to_csv(retr_res_path)
    m = metrics(df, output_dir=args.output_dir)
    return m
예제 #2
0
def test(args, model, eval_examples, cache_file, batch_size=1000):
    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)
    retr_res_path = os.path.join(args.output_dir, "raw_result.csv")

    if args.overwrite or not os.path.isfile(cache_file):
        chunked_retrivial_examples = eval_examples.get_chunked_retrivial_task_examples(
            chunk_query_num=args.chunk_query_num, chunk_size=batch_size)
        torch.save(chunked_retrivial_examples, cache_file)
    else:
        chunked_retrivial_examples = torch.load(cache_file)
    retrival_dataloader = DataLoader(chunked_retrivial_examples,
                                     batch_size=args.per_gpu_eval_batch_size)
    res = []
    for batch in tqdm(retrival_dataloader, desc="retrival evaluation"):
        nl_ids = batch[0]
        pl_ids = batch[1]
        labels = batch[2]
        nl_embd, pl_embd = eval_examples.id_pair_to_embd_pair(nl_ids, pl_ids)

        with torch.no_grad():
            model.eval()
            nl_embd = nl_embd.to(model.device)
            pl_embd = pl_embd.to(model.device)
            sim_score = model.get_sim_score(text_hidden=nl_embd,
                                            code_hidden=pl_embd)
            for n, p, prd, lb in zip(nl_ids.tolist(), pl_ids.tolist(),
                                     sim_score, labels.tolist()):
                res.append((n, p, prd, lb))

    df = results_to_df(res)
    df.to_csv(retr_res_path)
    m = metrics(df, output_dir=args.output_dir)
    return m
예제 #3
0
def evalute_retrivial_for_single_bert(model, eval_examples, batch_size,
                                      res_dir):
    if not os.path.isdir(res_dir):
        os.makedirs(res_dir)
    retr_res_path = os.path.join(res_dir, "raw_result.csv")
    summary_path = os.path.join(res_dir, "summary.txt")
    retrival_dataloader = eval_examples.get_retrivial_task_dataloader(
        batch_size)
    res = []
    for batch in tqdm(retrival_dataloader, desc="retrival evaluation"):
        nl_ids = batch[0]
        pl_ids = batch[1]
        labels = batch[2]
        inputs = format_batch_input_for_single_bert(batch, eval_examples,
                                                    model)
        sim_score = model.get_sim_score(**inputs)
        for n, p, prd, lb in zip(nl_ids.tolist(), pl_ids.tolist(), sim_score,
                                 labels.tolist()):
            res.append((n, p, prd, lb))
    df = results_to_df(res)
    df.to_csv(retr_res_path)
    m = metrics(df, output_dir=res_dir)

    pk = m.precision_at_K(3)
    best_f1, best_f2, details, _ = m.precision_recall_curve("pr_curve.png")
    map = m.MAP_at_K(3)

    summary = "\nprecision@3={}, best_f1 = {}, best_f2={}, MAP={}\n".format(
        pk, best_f1, best_f2, map)
    with open(summary_path, 'w') as fout:
        fout.write(summary)
        fout.write(str(details))
    return pk, best_f1, map
예제 #4
0
def evaluate_retrival(model, eval_examples, batch_size, res_dir):
    if not os.path.isdir(res_dir):
        os.makedirs(res_dir)
    retr_res_path = os.path.join(res_dir, "raw_result.csv")
    summary_path = os.path.join(res_dir, "summary.txt")
    retrival_dataloader = eval_examples.get_retrivial_task_dataloader(
        batch_size)
    res = []
    for batch in tqdm(retrival_dataloader, desc="retrival evaluation"):
        nl_ids = batch[0]
        pl_ids = batch[1]
        labels = batch[2]
        nl_embd, pl_embd = eval_examples.id_pair_to_embd_pair(nl_ids, pl_ids)

        with torch.no_grad():
            model.eval()
            nl_embd = nl_embd.to(model.device)
            pl_embd = pl_embd.to(model.device)
            sim_score = model.get_sim_score(text_hidden=nl_embd,
                                            code_hidden=pl_embd)
            for n, p, prd, lb in zip(nl_ids.tolist(), pl_ids.tolist(),
                                     sim_score, labels.tolist()):
                res.append((n, p, prd, lb))

    df = results_to_df(res)
    df.to_csv(retr_res_path)
    m = metrics(df, output_dir=res_dir)

    pk = m.precision_at_K(3)
    best_f1, best_f2, details, _ = m.precision_recall_curve("pr_curve.png")
    map = m.MAP_at_K(3)

    summary = "\nprecision@3={}, best_f1 = {}, best_f2={}, MAP={}\n".format(
        pk, best_f1, best_f2, map)
    with open(summary_path, 'w') as fout:
        fout.write(summary)
        fout.write(str(details))
    return pk, best_f1, map
예제 #5
0
def evaluate_rnn_retrival(model: RNNTracer, eval_examples, batch_size,
                          res_dir):
    if not os.path.isdir(res_dir):
        os.makedirs(res_dir)
    retr_res_path = os.path.join(res_dir, "raw_result.csv")
    summary_path = os.path.join(res_dir, "summary.txt")
    # chunk size is 1000 as default
    retrival_dataloader = eval_examples.get_retrivial_task_dataloader(
        batch_size)
    res = []
    for batch in tqdm(retrival_dataloader, desc="retrival evaluation"):
        nl_ids = batch[0]
        pl_ids = batch[1]
        labels = batch[2]
        nl_embd, pl_embd = _id_to_embd(nl_ids,
                                       eval_examples.NL_index), _id_to_embd(
                                           pl_ids, eval_examples.PL_index)

        with torch.no_grad():
            model.eval()
            nl_embd = nl_embd.to(model.device)
            pl_embd = pl_embd.to(model.device)
            sim_score = model.get_sim_score(text_hidden=nl_embd,
                                            code_hidden=pl_embd)
            for n, p, prd, lb in zip(nl_ids.tolist(), pl_ids.tolist(),
                                     sim_score, labels.tolist()):
                res.append((n, p, prd, lb))
    df = results_to_df(res)
    df.to_csv(retr_res_path)
    m = metrics(df, output_dir=res_dir)
    m.write_summary(0)
    pk = m.precision_at_K(3)
    best_f1, best_f2, details, threshold = m.precision_recall_curve(
        "pr_curve.png")
    map = m.MAP_at_K(3)
    return pk, best_f1, map
예제 #6
0
    elif args.model_path == "LSI":
        model = LSI()
        model.build_model(doc_tokens, num_topics=200)
    else:
        raise Exception("Model not found...")

    cache_file = os.path.join(out_root, "cached_test.dat")
    if args.overwrite or not os.path.isfile(cache_file):
        chunked_retrivial_examples = test_examples.get_chunked_retrivial_task_examples(chunk_size=1000)
        torch.save(chunked_retrivial_examples, cache_file)
    else:
        chunked_retrivial_examples = torch.load(cache_file)
    retrival_dataloader = DataLoader(chunked_retrivial_examples, batch_size=1000)
    start_time = time.time()
    res = []
    for batch in tqdm(retrival_dataloader, desc="retrival evaluation"):
        for s_id, t_id, label in zip(batch[0].tolist(), batch[1].tolist(), batch[2].tolist()):
            pred = model.get_link_scores(test_examples.NL_index[s_id], test_examples.PL_index[t_id])
            res.append((s_id, t_id, pred, label))
    df = pd.DataFrame()
    df['s_id'] = [x[0] for x in res]
    df['t_id'] = [x[1] for x in res]
    df['pred'] = [x[2] for x in res]
    df['label'] = [x[3] for x in res]
    exe_time = time.time() - start_time

    raw_res_file = os.path.join(args.output_dir, "raw_res.csv")
    df.to_csv(raw_res_file)
    m = metrics(df, output_dir=args.output_dir)
    m.write_summary(exe_time)