コード例 #1
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description="Tells you what percentage of tactics, "
        "and entire proofs, pass a predicate")

    parser.add_argument("filenames", nargs="+", help="csv file names")
    parser.add_argument("--context-filter",
                        dest="context_filter",
                        type=str,
                        default="default")
    args = parser.parse_args()

    cfilter = get_context_filter(args.context_filter)

    num_tactics_pass = 0
    num_proofs_pass = 0
    num_tactics_total = 0
    num_proofs_total = 0

    for filename in args.filenames:
        # print("In file: {}".format(filename))
        options, rows = read_csvfile(filename)
        in_proof = False
        current_proof_perfect = False
        cur_lemma_name = ""
        for row, nextrow in pairwise(rows):
            if isinstance(row, TacticRow):
                if not in_proof:
                    current_proof_perfect = True
                in_proof = True
                passes = check_cfilter_row(cfilter, row, nextrow)
                num_tactics_total += 1
                if not passes:
                    current_proof_perfect = False
                    # print("{} doesn't pass.".format(row.command.strip()))
                else:
                    # print("{} passes!".format(row.command.strip()))
                    num_tactics_pass += 1
            elif ending_proof(row.command) and in_proof:
                in_proof = False
                num_proofs_total += 1
                if current_proof_perfect:
                    num_proofs_pass += 1
                    # print("Proof : {},\n in {}, passed!".format(cur_lemma_name, filename))
                else:
                    # print("Proof : {},\n in {}, didn't pass!"
                    #       .format(cur_lemma_name, filename))
                    pass
            else:
                if possibly_starting_proof(row.command):
                    cur_lemma_name = row.command

    print("Filter {}: {}/{} tactics pass, {}/{} complete proofs pass".format(
        args.context_filter, num_tactics_pass, num_tactics_total,
        num_proofs_pass, num_proofs_total))
コード例 #2
0
ファイル: data.py プロジェクト: rashchedrin/proverbot9001
def get_text_data(arg_values : Namespace) -> RawDataset:
    def _print(*args, **kwargs):
        eprint(*args, **kwargs, guard=arg_values.verbose)

    start = time.time()
    _print("Reading dataset...", end="")
    sys.stdout.flush()
    raw_data = RawDataset(list(read_text_data(arg_values.scrape_file)))
    filtered_data = RawDataset(list(itertools.islice(filter_data(raw_data, get_context_filter(arg_values.context_filter), arg_values), arg_values.max_tuples)))
    _print("{:.2f}s".format(time.time() - start))
    _print("Got {} input-output pairs ".format(len(filtered_data)))
    return filtered_data
コード例 #3
0
def get_text_data(data_path: str,
                  context_filter_name: str,
                  max_tuples: Optional[int] = None,
                  verbose: bool = False) -> RawDataset:
    def _print(*args, **kwargs):
        if verbose:
            print(*args, **kwargs)

    _print("Reading dataset...")
    raw_data = read_text_data(data_path)
    filtered_data = list(
        itertools.islice(
            filter_data(raw_data, get_context_filter(context_filter_name)),
            max_tuples))
    _print("Got {} input-output pairs ".format(len(filtered_data)))
    return filtered_data
コード例 #4
0
 def __init__(self, workerid: int, coqargs: List[str], includes: str,
              output_dir: str, prelude: str, debug: bool, num_jobs: int,
              baseline: bool, skip_nochange_tac: bool, context_filter: str,
              full_args: argparse.Namespace) -> None:
     threading.Thread.__init__(self, daemon=True)
     self.coqargs = coqargs
     self.includes = includes
     self.workerid = workerid
     self.output_dir = output_dir
     self.prelude = prelude
     self.debug = debug
     self.num_jobs = num_jobs
     self.baseline = baseline
     self.cfilter = get_context_filter(context_filter)
     self.skip_nochange_tac = skip_nochange_tac
     self.full_args = full_args
     pass
コード例 #5
0
ファイル: data.py プロジェクト: Desperatesonic/proverbot9001
def get_evaluation_data(arg_values: Namespace) -> StateEvaluationDataset:
    def _print(*args, **kwargs):
        eprint(*args, **kwargs, guard=arg_values.verbose)

    start = time.time()
    _print("Reading dataset...", end="")
    sys.stdout.flush()
    raw_data = read_all_text_data(arg_values.scrape_file)
    distanced_data = get_state_distances(raw_data)
    preprocessed_interactions = preprocess_data_eval(arg_values,
                                                     distanced_data)
    filtered_interactions = StateEvaluationDataset([
        StateScore(point.state, point.score) for point in itertools.islice(
            filter_eval_data(preprocessed_interactions,
                             get_context_filter(arg_values.context_filter),
                             arg_values), arg_values.max_tuples)
    ])
    _print("{:.2f}s".format(time.time() - start))
    _print("Got {} input-output pairs ".format(len(filtered_interactions)))
    return filtered_interactions
コード例 #6
0
def generate_evaluation_details(args: argparse.Namespace, idx: int,
                                filename: Path2,
                                evaluator: StateEvaluator) -> FileSummary:
    scrape_path = args.prelude / filename.with_suffix(".v.scrape")
    interactions = list(read_all_text_data(scrape_path))
    context_filter = get_context_filter(args.context_filter)
    json_rows: List[Dict[str, Any]] = []

    num_points = 0
    num_close = 0
    num_correct = 0
    num_proofs = 0

    doc, tag, text, line = Doc().ttl()

    def write_highlighted(vernac: str) -> None:
        nonlocal text
        nonlocal tag
        substrings = syntax_highlight(vernac)

        for substring in substrings:
            if isinstance(substring, ColoredString):
                with tag('span', style=f'color:{substring.color}'):
                    text(substring.contents)
            else:
                text(substring)

    def write_vernac(block: VernacBlock):
        nonlocal tag
        for command in block.commands:
            with tag('code', klass='plaincommand'):
                write_highlighted(command)

    def generate_proof_evaluation_details(block: ProofBlock, region_idx: int):
        nonlocal num_proofs
        nonlocal num_close
        nonlocal num_correct
        nonlocal json_rows
        num_proofs += 1

        nonlocal num_points

        distanced_tactics = label_distances(block.proof_interactions)

        proof_length = len(distanced_tactics)
        num_points += proof_length

        with tag('div', klass='region'):
            nonlocal evaluator
            for idx, (interaction,
                      distance_from_end) in enumerate(distanced_tactics, 1):
                if interaction.tactic.strip() in [
                        "Proof.", "Qed.", "Defined."
                ]:
                    with tag('code', klass='plaincommand'):
                        write_highlighted(interaction.tactic.strip("\n"))
                    doc.stag('br')
                else:
                    predicted_distance_from_end = evaluator.scoreState(
                        interaction.context_before)
                    grade = grade_prediction(distance_from_end,
                                             predicted_distance_from_end)
                    if grade == "goodcommand":
                        num_correct += 1
                    elif grade == "okaycommand":
                        num_close += 1

                    num_points += 1
                    json_rows.append({
                        "lemma": block.lemma_statement,
                        "hyps": interaction.context_before.hypotheses,
                        "goal": interaction.context_before.goal,
                        "actual-distance": distance_from_end,
                        "predicted-distance": predicted_distance_from_end,
                        "grade": grade
                    })
                    with tag('span',
                             ('data-hyps', "\n".join(interaction.context_before.hypotheses)),
                             ('data-goal', interaction.context_before.goal),
                             ('data-actual-distance', str(distance_from_end)),
                             ('data-predicted-distance', str(predicted_distance_from_end)),
                             ('data-region', region_idx),
                             ('data-index', idx),
                             klass='tactic'), \
                             tag('code', klass=grade):
                        text(interaction.tactic)
                    doc.stag('br')

    def write_lemma_button(lemma_statement: str, region_idx: int):
        nonlocal tag
        nonlocal text
        lemma_name = \
            serapi_instance.lemma_name_from_statement(lemma_statement)
        with tag('button', klass='collapsible',
                 id=f'collapsible-{region_idx}'):
            with tag('code', klass='buttontext'):
                write_highlighted(lemma_statement.strip())

    def grade_prediction(correct_number: int, predicted_number: float) -> str:
        distance = abs(correct_number - predicted_number)
        if distance < 1:
            return "goodcommand"
        elif distance < 5:
            return "okaycommand"
        else:
            return "badcommand"

    with tag('html'):
        header(tag, doc, text, details_css, details_javascript,
               "Proverbot9001 Report")
        with tag('body', onload='init()'), tag('pre'):
            for idx, block in enumerate(get_blocks(interactions)):
                if isinstance(block, VernacBlock):
                    write_vernac(block)
                else:
                    assert isinstance(block, ProofBlock)
                    write_lemma_button(block.lemma_statement, idx)
                    generate_proof_evaluation_details(block, idx)

    base = Path2(os.path.dirname(os.path.abspath(__file__)))
    for extra_filename in extra_files:
        (base.parent / "reports" / extra_filename).copyfile(args.output /
                                                            extra_filename)

    with (args.output /
          filename.with_suffix(".html").name).open(mode='w') as fout:
        fout.write(doc.getvalue())

    with (args.output /
          filename.with_suffix(".json").name).open(mode='w') as fout:
        for row in json_rows:
            fout.write(json.dumps(row))
            fout.write("\n")

    return FileSummary(filename, num_close, num_correct, num_points,
                       num_proofs)
コード例 #7
0
def report_file(args : argparse.Namespace,
                training_args : argparse.Namespace,
                context_filter_str : str,
                filename : Path2) -> Optional['ResultStats']:

    def make_predictions(num_predictions : int,
                         tactic_interactions : List[ScrapedTactic]) -> \
        Tuple[Iterable[Tuple[ScrapedTactic, List[Prediction]]], float]:
        if len(tactic_interactions) == 0:
            return [], 0
        chunk_size = args.chunk_size
        total_loss = 0.
        for tactic_interaction in tactic_interactions:
            assert isinstance(tactic_interaction.goal, str)
        inputs = [TacticContext(tactic_interaction.prev_tactics,
                                tactic_interaction.hypotheses,
                                format_goal(tactic_interaction.goal))
                  for tactic_interaction in tactic_interactions]
        corrects = [tactic_interaction.tactic
                    for tactic_interaction in tactic_interactions]
        predictions : List[List[Prediction]] = []
        for inputs_chunk, corrects_chunk in zip(chunks(inputs, chunk_size),
                                                chunks(corrects, chunk_size)):
            predictions_chunk, loss = predictor.predictKTacticsWithLoss_batch(
                inputs_chunk, args.num_predictions, corrects_chunk)
            predictions += predictions_chunk
            total_loss += loss
        del inputs
        del corrects
        return list(zip(tactic_interactions, predictions)), \
            total_loss / math.ceil(len(tactic_interactions) / chunk_size)

    def merge_indexed(lic : Sequence[Tuple[int, T1]], lib : Sequence[Tuple[int,T2]]) \
        -> Iterable[Union[T1, T2]]:
        lic = list(reversed(lic))
        lib = list(reversed(lib))
        while lic and lib:
            lst : List[Tuple[int, Any]] = (lic if lic[-1][0] < lib[-1][0] else lib) # type: ignore
            yield lst.pop()[1]
        yield from list(reversed([c for _, c in lic]))
        yield from list(reversed([b for _, b in lib]))
    def get_should_filter(data : MixedDataset) -> Iterable[Tuple[ScrapedCommand, bool]]:
        list_data : List[ScrapedCommand] = list(data)
        extended_list : List[Optional[ScrapedCommand]] = \
            cast(List[Optional[ScrapedCommand]], list_data[1:])  + [None]
        for point, nextpoint in zip(list_data, extended_list):
            if isinstance(point, ScrapedTactic) \
               and not re.match("\s*[{}]\s*", point.tactic):
                if isinstance(nextpoint, ScrapedTactic):
                    yield(point, not context_filter({"goal":format_goal(point.goal),
                                                     "hyps":point.hypotheses},
                                                    point.tactic,
                                                    {"goal":format_goal(nextpoint.goal),
                                                     "hyps":nextpoint.hypotheses},
                                                    training_args))
                else:
                    yield(point, not context_filter({"goal":format_goal(point.goal),
                                                     "hyps":point.hypotheses},
                                                    point.tactic,
                                                    {"goal":"",
                                                     "hyps":""},
                                                    training_args))
            else:
                yield (point, True)
    try:
        scrape_path = args.prelude / filename.with_suffix(".v.scrape")
        interactions = list(read_text_data_singlethreaded(scrape_path))
        print("Loaded {} interactions for file {}".format(len(interactions), filename))
    except FileNotFoundError:
        print("Couldn't find file {}, skipping...".format(scrape_path))
        return None
    context_filter = get_context_filter(context_filter_str)

    command_results : List[CommandResult] = []
    stats = ResultStats(str(filename))
    indexed_filter_aware_interactions = list(enumerate(get_should_filter(interactions)))
    for idx, (interaction, should_filter) in indexed_filter_aware_interactions:
        assert isinstance(idx, int)
        if not should_filter:
            assert isinstance(interaction, ScrapedTactic), interaction
    indexed_filter_aware_prediction_contexts, indexed_filter_aware_pass_through = \
        multipartition(indexed_filter_aware_interactions,
                       lambda indexed_filter_aware_interaction:
                       indexed_filter_aware_interaction[1][1])
    indexed_prediction_contexts: List[Tuple[int, ScrapedTactic]] = \
        [(idx, cast(ScrapedTactic, obj)) for (idx, (obj, filtered))
         in indexed_filter_aware_prediction_contexts]
    indexed_pass_through = [(idx, cast(Union[ScrapedTactic, str], obj))
                            for (idx, (obj, filtered))
                            in indexed_filter_aware_pass_through]
    for idx, prediction_context in indexed_prediction_contexts:
        assert isinstance(idx, int)
        assert isinstance(prediction_context, ScrapedTactic)
    prediction_interactions, loss = \
        make_predictions(args.num_predictions,
                         [prediction_context for idx, prediction_context
                          in indexed_prediction_contexts])
    indexed_prediction_interactions = \
        [(idx, prediction_interaction)
         for (idx, prediction_context), prediction_interaction
         in zip(indexed_prediction_contexts, prediction_interactions)]
    interactions_with_predictions = \
        list(merge_indexed(indexed_prediction_interactions, indexed_pass_through))

    for inter in interactions_with_predictions:
        if isinstance(inter, tuple) and not isinstance(inter, ScrapedTactic):
            assert len(inter) == 2, inter
            scraped, predictions_and_certainties \
                = inter #cast(Tuple[ScrapedTactic, List[Prediction]], inter)
            (prev_tactics, hyps, goal, correct_tactic) = scraped
            prediction_results = [PredictionResult(prediction,
                                                   grade_prediction(scraped,
                                                                    prediction),
                                                   certainty)
                                  for prediction, certainty in
                                  predictions_and_certainties]
            command_results.append(TacticResult(correct_tactic, hyps, goal,
                                                prediction_results))
            stats.add_tactic(prediction_results,
                             correct_tactic)
        elif isinstance(inter, ScrapedTactic):
            command_results.append(TacticResult(inter.tactic,inter.hypotheses, inter.goal, []))
        else:
            command_results.append((inter,))

    stats.set_loss(loss)

    print("Finished grading file {}".format(filename))

    write_html(args.output, filename, command_results, stats)
    write_csv(args.output, filename, args, command_results, stats)
    print("Finished output for file {}".format(filename))
    return stats
コード例 #8
0
def count_proofs(args : argparse.Namespace, filename : str) \
    -> Tuple[int, int]:
    eprint(f"Counting {filename}", guard=args.debug)
    scrapefile = args.prelude + "/" + filename + ".scrape"
    interactions = list(
        read_all_text_data(args.prelude + "/" + filename + ".scrape"))
    filter_func = get_context_filter(args.context_filter)

    count = 0
    total_count = 0
    cur_proof_counts = False
    cur_lemma_stmt = ""
    extended_interactions : List[Optional[ScrapedCommand]] = \
        cast(List[Optional[ScrapedCommand]], interactions[1:])  + [None]
    for inter, next_inter in zip(interactions, extended_interactions):
        if isinstance(inter, ScrapedTactic):
            goal_before = inter.goal
            hyps_before = inter.hypotheses
            command = inter.tactic
        else:
            goal_before = ""
            hyps_before = []
            command = inter

        if next_inter and isinstance(next_inter, ScrapedTactic):
            goal_after = next_inter.goal
            hyps_after = next_inter.hypotheses
        else:
            goal_after = ""
            hyps_after = []

        entering_proof = bool((not goal_before) and goal_after)
        exiting_proof = bool(goal_before and not goal_after)

        if entering_proof:
            cur_lemma_stmt = next_inter.prev_tactics[0]
            cur_proof_counts = False if args.some else True
            continue

        if cur_lemma_stmt:
            if filter_func(
                {
                    "goal": format_goal(goal_before),
                    "hyps": hyps_before
                }, command, {
                    "goal": format_goal(goal_after),
                    "hyps": goal_after
                }, args):
                if args.some and not cur_proof_counts:
                    cur_proof_counts = True
            else:
                if args.all and cur_proof_counts:
                    cur_lemma_name = serapi_instance.lemma_name_from_statement(
                        cur_lemma_stmt)
                    eprint(
                        f"Eliminating proof {cur_lemma_name} "
                        f"because tactic {command.strip()} doesn't match",
                        guard=args.debug)
                    cur_proof_counts = False

        if exiting_proof:
            if cur_proof_counts:
                cur_lemma_name = serapi_instance.lemma_name_from_statement(
                    cur_lemma_stmt)
                if args.print_name:
                    print(cur_lemma_name)
                if args.print_stmt:
                    print(re.sub("\n", "\\n", cur_lemma_stmt))
                eprint(f"Proof of {cur_lemma_name} counts", guard=args.debug)
                count += 1
            total_count += 1
            cur_lemma_stmt = ""
    if not args.print_name and not args.print_stmt:
        print(f"{filename}: {count}/{total_count} "
              f"({stringified_percent(count, total_count)}%)")
    return count, total_count