Ejemplo n.º 1
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description="Tells you what percentage of tactics, "
        "and entire proofs, pass a predicate")

    parser.add_argument("filenames", nargs="+", help="csv file names")
    parser.add_argument("--context-filter",
                        dest="context_filter",
                        type=str,
                        default="default")
    args = parser.parse_args()

    cfilter = get_context_filter(args.context_filter)

    num_tactics_pass = 0
    num_proofs_pass = 0
    num_tactics_total = 0
    num_proofs_total = 0

    for filename in args.filenames:
        # print("In file: {}".format(filename))
        options, rows = read_csvfile(filename)
        in_proof = False
        current_proof_perfect = False
        cur_lemma_name = ""
        for row, nextrow in pairwise(rows):
            if isinstance(row, TacticRow):
                if not in_proof:
                    current_proof_perfect = True
                in_proof = True
                passes = check_cfilter_row(cfilter, row, nextrow)
                num_tactics_total += 1
                if not passes:
                    current_proof_perfect = False
                    # print("{} doesn't pass.".format(row.command.strip()))
                else:
                    # print("{} passes!".format(row.command.strip()))
                    num_tactics_pass += 1
            elif ending_proof(row.command) and in_proof:
                in_proof = False
                num_proofs_total += 1
                if current_proof_perfect:
                    num_proofs_pass += 1
                    # print("Proof : {},\n in {}, passed!".format(cur_lemma_name, filename))
                else:
                    # print("Proof : {},\n in {}, didn't pass!"
                    #       .format(cur_lemma_name, filename))
                    pass
            else:
                if possibly_starting_proof(row.command):
                    cur_lemma_name = row.command

    print("Filter {}: {}/{} tactics pass, {}/{} complete proofs pass".format(
        args.context_filter, num_tactics_pass, num_tactics_total,
        num_proofs_pass, num_proofs_total))
Ejemplo n.º 2
0
def count_lengths(args: argparse.Namespace, filename: str):
    print(f"Counting {filename}")
    full_filename = args.prelude + "/" + filename
    scraped_commands = list(
        read_all_text_data(Path2(full_filename + ".scrape")))
    scraped_iter = iter(scraped_commands)
    if args.post_linear:
        original_commands = serapi_instance.load_commands_preserve(
            args, 0, full_filename + ".lin")
    else:
        original_commands = serapi_instance.load_commands_preserve(
            args, 0, full_filename)

    with open(full_filename + ".csv", 'w') as fout:
        rowwriter = csv.writer(fout)
        lemma_statement = ""
        in_proof = False
        cur_len = 0
        for cmd in original_commands:
            if not serapi_instance.possibly_starting_proof(
                    cmd) and not in_proof:
                continue
            if serapi_instance.possibly_starting_proof(cmd) and not in_proof:
                normalized_command = norm(cmd)
                cur_scraped = norm(next(scraped_iter))
                while cur_scraped != normalized_command:
                    cur_scraped = norm(next(scraped_iter))
                try:
                    next_after_start = next(scraped_iter)
                except StopIteration:
                    next_after_start = ""
                if isinstance(next_after_start, ScrapedTactic):
                    lemma_statement = norm(cmd)
                    in_proof = True
                    cur_len = 0
                else:
                    scraped_iter = itertools.chain([next_after_start],
                                                   scraped_iter)
            elif serapi_instance.ending_proof(cmd):
                assert in_proof
                rowwriter.writerow([lemma_statement.strip(), cur_len])
                cur_len = -1
                in_proof = False
            elif in_proof:
                assert cur_len >= 0
                if re.match("[{}]|[*-+]*$", norm(cmd)):
                    continue
                if re.match("Proof\.", norm(cmd)):
                    continue
                cur_len += 1
                if args.add_semis or args.post_linear:
                    cur_len += count_outside_matching("\{\|", "\|\}", ";",
                                                      norm(cmd))
    return full_filename + ".csv"
def generate_lifted(commands : List[str], coq : serapi_instance.SerapiInstance,
                    pbar : tqdm) \
    -> Iterator[str]:
    lemma_stack = []  # type: List[List[str]]
    for command in commands:
        if serapi_instance.possibly_starting_proof(command):
            coq.run_stmt(command)
            if coq.proof_context:
                lemma_stack.append([])
            coq.cancel_last()
        if len(lemma_stack) > 0 and not lifted_vernac(command):
            lemma_stack[-1].append(command)
        else:
            pbar.update(1)
            yield command
        if serapi_instance.ending_proof(command):
            pending_commands = lemma_stack.pop()
            pbar.update(len(pending_commands))
            yield from pending_commands
    assert len(lemma_stack) == 0, f"Stack still contains {lemma_stack}"
Ejemplo n.º 4
0
def replay_solution_vfile(args : argparse.Namespace, coq : serapi_instance.SerapiInstance,
                          model_name : str, filename : str, commands_in : List[str],
                          module_stack : List[str],
                          bar_idx : int) \
                          -> Tuple[List[str], List[str], List[DocumentBlock],
                                   int, int, int, int]:
    blocks_out: List[DocumentBlock] = []
    num_proofs = 0
    num_proofs_failed = 0
    num_proofs_completed = 0
    num_original_commands_run = 0
    in_proof = False
    skip_sync_next_lemma = False
    curLemma = ""
    curProofInters: List[TacticInteraction] = []
    curVernacCmds: List[str] = []
    with open(f"{args.output_dir}/{escape_filename(filename)}.v", 'r') as f:
        f_iter = check_solution_vfile_args(args, model_name, iter(f))
        svfile_commands = serapi_instance.read_commands_preserve(
            args, bar_idx, "".join(f_iter))
        commands_in_iter = iter(commands_in)
        for saved_command in tqdm(svfile_commands,
                                  unit="cmd",
                                  file=sys.stdout,
                                  desc="Replaying",
                                  disable=(not args.progress),
                                  leave=False,
                                  position=(bar_idx * 2),
                                  dynamic_ncols=True,
                                  bar_format=mybarfmt):
            context_before = coq.fullContext if coq.full_context else FullContext(
                [])
            if not (coq.full_context != None and len(coq.fullContext.subgoals)
                    == 0 and not serapi_instance.ending_proof(saved_command)):
                coq.run_stmt(saved_command)
            if coq.full_context == None:
                if in_proof:
                    in_proof = False
                    num_proofs += 1
                    if re.match("Qed\.", saved_command):
                        search_status = SearchStatus.SUCCESS
                        num_proofs_completed += 1
                    elif re.match("Admitted \(\*FAILURE\*\)\.", saved_command):
                        search_status = SearchStatus.FAILURE
                        num_proofs_failed += 1
                    else:
                        search_status = SearchStatus.INCOMPLETE
                    coq.cancel_last()
                    try:
                        while coq.full_context != None:
                            coq.cancel_last()
                    except serapi_instance.CoqExn as e:
                        raise serapi_instance.CoqAnomaly(
                            f"While cancelling: {e}")

                    origProofInters = []
                    if not skip_sync_next_lemma:
                        proof_cmds = list(
                            serapi_instance.next_proof(commands_in_iter))
                        coq.run_stmt(proof_cmds[0])
                        num_original_commands_run += len(proof_cmds)
                        for proof_cmd in tqdm(proof_cmds[1:],
                                              unit="tac",
                                              file=sys.stdout,
                                              desc="Running original proof",
                                              disable=(not args.progress),
                                              leave=False,
                                              position=(bar_idx * 2) + 1,
                                              dynamic_ncols=True,
                                              bar_format=mybarfmt):
                            context_before_orig = coq.fullContext
                            coq.run_stmt(proof_cmd)
                            origProofInters.append(
                                TacticInteraction(proof_cmd,
                                                  context_before_orig))
                        blocks_out.append(
                            ProofBlock(curLemma, ".".join(module_stack),
                                       search_status, curProofInters,
                                       origProofInters))
                        curVernacCmds = []
                    else:
                        for proof_cmd in proof_cmds:
                            coq.run_stmt(proof_cmd)
                        skip_sync_next_lemma = False

                else:
                    if re.match("Reset .*\.", saved_command):
                        skip_sync_next_lemma = True
                        continue
                    loaded_command = next(commands_in_iter)
                    update_module_stack(saved_command, module_stack)
                    if not re.sub("\s*", " ", loaded_command.strip()) == \
                       re.sub("\s*", " ", saved_command.strip()):
                        raise SourceChangedException(
                            f"Command {loaded_command} doesn't match {saved_command}"
                        )
                    curVernacCmds.append(loaded_command)
            else:
                if not in_proof:
                    in_proof = True
                    curLemma = saved_command
                    blocks_out.append(VernacBlock(curVernacCmds))
                    curProofInters = []
                curProofInters.append(
                    TacticInteraction(saved_command, context_before))
        assert not in_proof
        if curVernacCmds:
            blocks_out.append(VernacBlock(curVernacCmds))
        return svfile_commands, list(commands_in_iter), blocks_out,\
            num_proofs, num_proofs_failed, num_proofs_completed, num_original_commands_run
Ejemplo n.º 5
0
def reinforce(args: argparse.Namespace) -> None:

    # Load the scraped (demonstrated) samples, the proof environment
    # commands, and the predictor
    replay_memory = assign_rewards(
        dataloader.tactic_transitions_from_file(args.scrape_file,
                                                args.buffer_size))

    predictor = predict_tactic.loadPredictorByFile(args.predictor_weights)

    q_estimator = FeaturesQEstimator(args.learning_rate,
                                     args.batch_step,
                                     args.gamma)
    signal.signal(
        signal.SIGINT,
        lambda signal, frame:
        progn(q_estimator.save_weights(
            args.out_weights, args),  # type: ignore
              exit()))
    if args.start_from:
        q_estimator_name, *saved = \
            torch.load(args.start_from)
        q_estimator.load_saved_state(*saved)
    elif args.pretrain:
        pre_train(args, q_estimator,
                  dataloader.tactic_transitions_from_file(
                      args.scrape_file, args.buffer_size * 10))

    epsilon = 0.3
    gamma = 0.9

    if args.proof is not None:
        assert len(args.environment_files) == 1, \
            "Can't use multiple env files with --proof!"
        env_commands = serapi_instance.load_commands_preserve(
            args, 0, args.prelude / args.environment_files[0])
        num_proofs = len([cmd for cmd in env_commands
                          if cmd.strip() == "Qed."
                          or cmd.strip() == "Defined."])

        with serapi_instance.SerapiContext(
                ["sertop", "--implicit"],
                serapi_instance.get_module_from_filename(
                    args.environment_files[0]),
                str(args.prelude)) as coq:
            coq.quiet = True
            coq.verbose = args.verbose
            rest_commands, run_commands = coq.run_into_next_proof(env_commands)
            lemma_statement = run_commands[-1]
            while coq.cur_lemma_name != args.proof:
                if not rest_commands:
                    eprint("Couldn't find lemma {args.proof}! Exiting...")
                    return
                rest_commands, _ = coq.finish_proof(rest_commands)
                rest_commands, run_commands = coq.run_into_next_proof(
                    rest_commands)
                lemma_statement = run_commands[-1]
            reinforce_lemma(args, predictor, q_estimator, coq,
                            lemma_statement,
                            epsilon, gamma, replay_memory)
            q_estimator.save_weights(args.out_weights, args)
    else:
        for env_file in args.environment_files:
            env_commands = serapi_instance.load_commands_preserve(
                args, 0, args.prelude / env_file)
            num_proofs = len([cmd for cmd in env_commands
                              if cmd.strip() == "Qed."
                              or cmd.strip() == "Defined."])
            rest_commands = env_commands
            all_run_commands: List[str] = []
            with tqdm(total=num_proofs, disable=(not args.progress),
                      leave=True, desc=env_file.stem) as pbar:
                while rest_commands:
                    with serapi_instance.SerapiContext(
                            ["sertop", "--implicit"],
                            serapi_instance.get_module_from_filename(
                                env_file),
                            str(args.prelude),
                            log_outgoing_messages=args.log_outgoing_messages) \
                            as coq:
                        coq.quiet = True
                        coq.verbose = args.verbose
                        for command in all_run_commands:
                            coq.run_stmt(command)

                        while rest_commands:
                            rest_commands, run_commands = \
                                coq.run_into_next_proof(rest_commands)
                            if not rest_commands:
                                break
                            all_run_commands += run_commands[:-1]
                            lemma_statement = run_commands[-1]

                            # Check if the definition is
                            # proof-relevant. If it is, then finishing
                            # subgoals doesn't necessarily mean you've
                            # solved the problem, so don't try to
                            # train on it.
                            proof_relevant = False
                            for cmd in rest_commands:
                                if serapi_instance.ending_proof(cmd):
                                    if cmd.strip() == "Defined.":
                                        proof_relevant = True
                                    break
                            proof_relevant = proof_relevant or \
                                bool(re.match(r"\s*Derive", lemma_statement))
                            for sample in replay_memory:
                                sample.graph_node = None
                            if not proof_relevant:
                                try:
                                    reinforce_lemma(args, predictor,
                                                    q_estimator,
                                                    coq,
                                                    lemma_statement,
                                                    epsilon, gamma,
                                                    replay_memory)
                                except serapi_instance.CoqAnomaly:
                                    if args.log_anomalies:
                                        with args.log_anomalies.open('a') as f:
                                            traceback.print_exc(file=f)
                                    if args.hardfail:
                                        eprint(
                                            "Hit an anomaly!"
                                            "Quitting due to --hardfail")
                                        raise
                                    eprint(
                                        "Hit an anomaly! "
                                        "Restarting coq instance")
                                    rest_commands.insert(0, lemma_statement)
                                    break
                            pbar.update(1)
                            rest_commands, run_commands = \
                                coq.finish_proof(rest_commands)
                            all_run_commands.append(lemma_statement)
                            all_run_commands += run_commands

            q_estimator.save_weights(args.out_weights, args)
def linearize_commands(args: argparse.Namespace, file_idx: int,
                       commands_sequence: Iterable[str],
                       coq: serapi_instance.SerapiInstance, filename: str,
                       relative_filename: str, skip_nochange_tac: bool,
                       known_failures: List[List[str]]):
    commands_iter = iter(commands_sequence)
    command = next(commands_iter, None)
    assert command, "Got an empty sequence!"
    while command:
        # Run up to the next proof
        while coq.count_fg_goals() == 0:
            coq.run_stmt(command)
            if coq.count_fg_goals() == 0:
                yield command
                command = next(commands_iter, None)
                if not command:
                    return

        # Cancel the proof starting command so that we're right before the proof
        coq.cancel_last()

        # Pull the entire proof from the lifter into command_batch
        command_batch = []
        while command and not serapi_instance.ending_proof(command):
            command_batch.append(command)
            command = next(commands_iter, None)
        # Get the QED on there too.
        if command:
            command_batch.append(command)

        # Now command_batch contains everything through the next
        # Qed/Defined.
        theorem_statement = serapi_instance.kill_comments(command_batch.pop(0))
        theorem_name = theorem_statement.split(":")[0].strip()
        coq.run_stmt(theorem_statement)
        yield theorem_statement
        if [relative_filename, theorem_name] in known_failures:
            eprint("Skipping {}".format(theorem_name), guard=args.verbose >= 1)
            for command in command_batch:
                coq.run_stmt(command)
                yield command
            command = next(commands_iter, None)
            continue

        # This might not be super robust?
        match = re.fullmatch("\s*Proof with (.*)\.\s*", command_batch[0])
        if match and match.group(1):
            with_tactic = match.group(1)
        else:
            with_tactic = ""

        orig = command_batch[:]
        command_batch = list(prelinear_desugar_tacs(command_batch))
        try:
            try:
                batch_handled = list(handle_with(command_batch, with_tactic))
                linearized_commands = list(
                    linearize_proof(coq, theorem_name, batch_handled,
                                    args.verbose, skip_nochange_tac))
                yield from linearized_commands
            except (BadResponse, CoqExn, LinearizerCouldNotLinearize,
                    ParseError, TimeoutError, NoSuchGoalError) as e:
                if args.verbose:
                    eprint("Aborting current proof linearization!")
                    eprint("Proof of:\n{}\nin file {}".format(
                        theorem_name, filename))
                    eprint()
                if args.hardfail:
                    raise e
                coq.run_stmt("Abort.")
                coq.run_stmt(theorem_statement)
                for command in orig:
                    if command:
                        coq.run_stmt(command, timeout=360)
                        yield command
        except CoqAnomaly:
            eprint(
                f"Anomaly! Raising with {[relative_filename, theorem_name]}",
                guard=args.verbose >= 1)
            raise CoqAnomaly([relative_filename, theorem_name])

        command = next(commands_iter, None)
def linearize_proof(coq: serapi_instance.SerapiInstance,
                    theorem_name: str,
                    command_batch: List[str],
                    verbose: int = 0,
                    skip_nochange_tac: bool = False) -> Iterable[str]:
    pending_commands_stack: List[Union[str, List[str], None]] = []
    while command_batch:
        while coq.count_fg_goals() == 0:
            indentation = "  " * (len(pending_commands_stack))
            if len(pending_commands_stack) == 0:
                while command_batch:
                    command = command_batch.pop(0)
                    if "Transparent" in command or \
                       serapi_instance.ending_proof(command):
                        coq.run_stmt(command)
                        yield command
                return
            coq.run_stmt("}")
            yield indentation + "}"
            if coq.count_fg_goals() > 0:
                coq.run_stmt("{")
                yield indentation + "{"
                pending_commands = pending_commands_stack[-1]
                if isinstance(pending_commands, list):
                    next_cmd, *rest_cmd = pending_commands
                    dotdotmatch = re.match("(.*)<\.\.>",
                                           next_cmd,
                                           flags=re.DOTALL)
                    for cmd in rest_cmd:
                        dotdotmatch = re.match("(.*)<\.\.>",
                                               cmd,
                                               flags=re.DOTALL)
                        if dotdotmatch:
                            continue
                        assert serapi_instance.isValidCommand(cmd), \
                            f"\"{cmd}\" is not a valid command"
                    if (not rest_cmd) and dotdotmatch:
                        pending_commands_stack[-1] = [next_cmd]
                        assert serapi_instance.isValidCommand(dotdotmatch.group(1)), \
                            f"\"{dotdotmatch.group(1)}\" is not a valid command"
                        command_batch.insert(0, dotdotmatch.group(1))
                    else:
                        assert serapi_instance.isValidCommand(next_cmd), \
                            f"\"{next_cmd}\" is not a valid command"
                        command_batch.insert(0, next_cmd)
                    pending_commands_stack[-1] = rest_cmd if rest_cmd else None
                    pass
                elif pending_commands:
                    assert serapi_instance.isValidCommand(pending_commands), \
                        f"\"{command}\" is not a valid command"
                    command_batch.insert(0, pending_commands)
            else:
                popped = pending_commands_stack.pop()
                if isinstance(popped, list) and len(popped) > 0 and len(
                        pending_commands_stack) > 1:
                    if pending_commands_stack[-1] is None:
                        pending_commands_stack[-1] = popped
                    elif isinstance(pending_commands_stack[-1], list):
                        if isinstance(popped, list) and "<..>" in popped[-1]:
                            raise LinearizerCouldNotLinearize
                        pending_commands_stack[
                            -1] = popped + pending_commands_stack[-1]
        command = command_batch.pop(0)
        assert serapi_instance.isValidCommand(command), \
            f"command is \"{command}\", command_batch is {command_batch}"
        comment_before_command = ""
        command_proper = command
        while re.fullmatch("\s*\(\*.*", command_proper, flags=re.DOTALL):
            next_comment, command_proper = \
                split_to_next_matching("\(\*", "\*\)", command_proper)
            command_proper = command_proper[1:]
            comment_before_command += next_comment
        if comment_before_command:
            yield comment_before_command
        if re.match("\s*[*+-]+\s*|\s*[{}]\s*", command):
            continue

        command = serapi_instance.kill_comments(command_proper)
        if verbose >= 2:
            eprint(f"Linearizing command \"{command}\"")

        goal_selector_match = re.fullmatch(r"\s*(\d+)\s*:\s*(.*)\.\s*",
                                           command)
        if goal_selector_match:
            goal_num = int(goal_selector_match.group(1))
            rest = goal_selector_match.group(2)
            if goal_num < 2:
                raise LinearizerCouldNotLinearize
            if pending_commands_stack[-1] is None:
                completed_rest = rest + "."
                assert serapi_instance.isValidCommand(rest + "."),\
                    f"\"{completed_rest}\" is not a valid command in {command}"
                pending_commands_stack[-1] = ["idtac."] * (goal_num - 2) + [
                    completed_rest
                ]
            elif isinstance(pending_commands_stack[-1], str):
                pending_cmd = pending_commands_stack[-1]
                pending_commands_stack[-1] = [pending_cmd] * (goal_num - 2) + \
                    [rest + " ; " + pending_cmd] + [pending_cmd + "<..>"]
            else:
                assert isinstance(pending_commands_stack[-1], list)
                pending_cmd_lst = pending_commands_stack[-1]
                try:
                    old_selected_cmd = pending_cmd_lst[goal_num - 2]
                except IndexError:
                    raise LinearizerCouldNotLinearize
                match = re.match("(.*)\.$", old_selected_cmd, re.DOTALL)
                assert match, f"\"{old_selected_cmd}\" doesn't match!"
                cmd_before_period = unwrap(match).group(1)
                new_selected_cmd = f"{cmd_before_period} ; {rest}."
                pending_cmd_lst[goal_num - 2] = new_selected_cmd
            continue

        if split_by_char_outside_matching("\(", "\)", "\|\||&&", command):
            coq.run_stmt(command)
            yield command
            continue

        if re.match("\(", command.strip()):
            inside_parens, after_parens = split_to_next_matching(
                '\(', '\)', command)
            command = inside_parens.strip()[1:-1] + after_parens

        # Extend this to include "by \(" as an opener if you don't
        # desugar all the "by"s
        semi_match = split_by_char_outside_matching("try \(|\(|\{\|",
                                                    "\)|\|\}", "\s*;\s*",
                                                    command)
        if semi_match:
            base_command, rest = semi_match
            rest = rest.lstrip()[1:]
            coq.run_stmt(base_command + ".")
            indentation = "  " * (len(pending_commands_stack) + 1)
            yield indentation + base_command.strip() + "."

            if re.match("\(", rest) and not \
               split_by_char_outside_matching("\(", "\)", "\|\|", rest):
                inside_parens, after_parens = split_to_next_matching(
                    '\(', '\)', rest)
                rest = inside_parens[1:-1] + after_parens
            bracket_match = re.match("\[", rest.strip())
            if bracket_match:
                bracket_command, rest_after_bracket = \
                    split_to_next_matching('\[', '\]', rest)
                rest_after_bracket = rest_after_bracket.lstrip()[1:]
                clauses = multisplit_matching("\[", "\]", "(?<!\|)\|(?!\|)",
                                              bracket_command.strip()[1:-1])
                commands_list = [
                    cmd.strip() if cmd.strip().strip(".") != "" else "idtac" +
                    cmd.strip() for cmd in clauses
                ]
                assert commands_list, command
                dotdotpat = re.compile(r"(.*)\.\.($|\W)")
                ending_dotdot_match = dotdotpat.match(commands_list[-1])
                if ending_dotdot_match:
                    commands_list = commands_list[:-1] + \
                        ([ending_dotdot_match.group(1)] *
                         (coq.count_fg_goals() -
                          len(commands_list) + 1))
                else:
                    starting_dotdot_match = dotdotpat.match(commands_list[0])
                    if starting_dotdot_match:
                        starting_tac = starting_dotdot_match.group(1)
                        commands_list = [starting_tac] * (coq.count_fg_goals() -
                                                          len(commands_list) + 1)\
                                                          + commands_list[1:]
                    else:
                        for idx, command_case in enumerate(
                                commands_list[1:-1]):
                            middle_dotdot_match = dotdotpat.match(command_case)
                            if middle_dotdot_match:
                                commands_list = \
                                    commands_list[:idx] + \
                                    [command_case] * (coq.count_fg_goals() -
                                                      len(commands_list) + 1) + \
                                                      commands_list[idx+1:]
                                break
                if rest_after_bracket.strip():
                    command_remainders = [
                        cmd + ";" + rest_after_bracket for cmd in commands_list
                    ]
                else:
                    command_remainders = [cmd + "." for cmd in commands_list]
                assert serapi_instance.isValidCommand(command_remainders[0]), \
                    f"\"{command_remainders[0]}\" is not a valid command"
                command_batch.insert(0, command_remainders[0])
                if coq.count_fg_goals() > 1:
                    for command in command_remainders[1:]:
                        assert serapi_instance.isValidCommand(command), \
                            f"\"{command}\" is not a valid command"
                    pending_commands_stack.append(command_remainders[1:])
                    coq.run_stmt("{")
                    yield indentation + "{"
            else:
                if coq.count_fg_goals() > 0:
                    assert serapi_instance.isValidCommand(rest), \
                        f"\"{rest}\" is not a valid command, from {command}"
                    command_batch.insert(0, rest)
                if coq.count_fg_goals() > 1:
                    assert serapi_instance.isValidCommand(rest), \
                        f"\"{rest}\" is not a valid command, from {command}"
                    pending_commands_stack.append(rest)
                    coq.run_stmt("{")
                    yield indentation + "{"
        elif coq.count_fg_goals() > 0:
            coq.run_stmt(command)
            indentation = "  " * (len(pending_commands_stack) +
                                  1) if command.strip() != "Proof." else ""
            yield indentation + command.strip()
            if coq.count_fg_goals() > 1:
                pending_commands_stack.append(None)
                coq.run_stmt("{")
                yield indentation + "{"
    pass
Ejemplo n.º 8
0
def linearize_proof(coq: serapi_instance.SerapiInstance,
                    theorem_name: str,
                    command_batch: List[str],
                    debug: bool = False,
                    skip_nochange_tac: bool = False) -> Iterable[str]:
    pending_commands_stack: List[Union[str, List[str], None]] = []
    while command_batch:
        while coq.count_fg_goals() == 0:
            indentation = "  " * (len(pending_commands_stack))
            if len(pending_commands_stack) == 0:
                while command_batch:
                    command = command_batch.pop(0)
                    if "Transparent" in command or \
                       serapi_instance.ending_proof(command):
                        coq.run_stmt(command)
                        yield command
                return
            coq.run_stmt("}")
            yield indentation + "}"
            if coq.count_fg_goals() > 0:
                coq.run_stmt("{")
                yield indentation + "{"
                pending_commands = pending_commands_stack[-1]
                if isinstance(pending_commands, list):
                    next_cmd, *rest_cmd = pending_commands
                    dotdotmatch = re.match("(.*)<\.\.>", next_cmd)
                    if (not rest_cmd) and dotdotmatch:
                        pending_commands_stack[-1] = [next_cmd]
                        command_batch.insert(0, dotdotmatch.group(1))
                    else:
                        command_batch.insert(0, next_cmd)
                    pending_commands_stack[-1] = rest_cmd if rest_cmd else None
                    pass
                elif pending_commands:
                    command_batch.insert(0, pending_commands)
            else:
                popped = pending_commands_stack.pop()
                if isinstance(popped, list) and len(popped) > 0 and len(
                        pending_commands_stack) > 1:
                    if pending_commands_stack[-1] is None:
                        pending_commands_stack[-1] = popped
                    elif isinstance(pending_commands_stack[-1], list):
                        pending_commands_stack[
                            -1] = popped + pending_commands_stack[-1]
        command = command_batch.pop(0)
        assert serapi_instance.isValidCommand(command), \
            f"command is \"{command}\", command_batch is {command_batch}"
        comment_before_command = ""
        command_proper = command
        while "(*" in command_proper:
            next_comment, command_proper = \
                split_to_next_matching("\(\*", "\*\)", command_proper)
            command_proper = command_proper[1:]
            comment_before_command += next_comment
        if comment_before_command:
            yield comment_before_command
        if re.match("\s*[*+-]+\s*|\s*[{}]\s*", command):
            continue

        command = command_proper
        if debug:
            eprint(f"Linearizing command \"{command}\"")

        goal_selector_match = re.match(r"\s*(\d*)\s*:\s*(.*)\.\s*", command)
        if goal_selector_match:
            goal_num = int(goal_selector_match.group(1))
            rest = goal_selector_match.group(2)
            assert goal_num >= 2
            if pending_commands_stack[-1] is None:
                pending_commands_stack[-1] = ["idtac."] * (goal_num - 2) + [
                    rest + "."
                ]
            elif isinstance(pending_commands_stack[-1], str):
                pending_cmd = pending_commands_stack[-1]
                pending_commands_stack[-1] = [pending_cmd] * (goal_num - 2) + \
                    [rest + " ; " + pending_cmd] + [pending_cmd + "<..>"]
            else:
                assert isinstance(pending_commands_stack[-1], list)
            continue

        if split_by_char_outside_matching("\(", "\)", "\|\||&&", command):
            coq.run_stmt(command)
            yield command
            continue

        if re.match("\(", command.strip()):
            inside_parens, after_parens = split_to_next_matching(
                '\(', '\)', command)
            command = inside_parens.strip()[1:-1] + after_parens

        # Extend this to include "by \(" as an opener if you don't
        # desugar all the "by"s
        semi_match = split_by_char_outside_matching("try \(|\(|\{\|",
                                                    "\)|\|\}", "\s*;\s*",
                                                    command)
        if semi_match:
            base_command, rest = semi_match
            rest = rest.lstrip()[1:]
            coq.run_stmt(base_command + ".")
            indentation = "  " * (len(pending_commands_stack) + 1)
            yield indentation + base_command.strip() + "."

            if re.match("\(", rest) and not \
               split_by_char_outside_matching("\(", "\)", "\|\|", rest):
                inside_parens, after_parens = split_to_next_matching(
                    '\(', '\)', rest)
                rest = inside_parens[1:-1] + after_parens
            bracket_match = re.match("\[", rest.strip())
            if bracket_match:
                bracket_command, rest_after_bracket = \
                    split_to_next_matching('\[', '\]', rest)
                rest_after_bracket = rest_after_bracket.lstrip()[1:]
                clauses = multisplit_matching("\[", "\]", "(?<!\|)\|(?!\|)",
                                              bracket_command.strip()[1:-1])
                commands_list = [
                    cmd.strip() if cmd.strip().strip(".") != "" else "idtac" +
                    cmd.strip() for cmd in clauses
                ]
                dotdotpat = re.compile(r"(.*)\.\.($|\W)")
                ending_dotdot_match = dotdotpat.match(commands_list[-1])
                if ending_dotdot_match:
                    commands_list = commands_list[:-1] + \
                        ([ending_dotdot_match.group(1)] *
                         (coq.count_fg_goals() -
                          len(commands_list) + 1))
                else:
                    starting_dotdot_match = dotdotpat.match(commands_list[0])
                    if starting_dotdot_match:
                        starting_tac = starting_dotdot_match.group(1)
                        commands_list = [starting_tac] * (coq.count_fg_goals() -
                                                          len(commands_list) + 1)\
                                                          + commands_list[1:]
                    else:
                        for idx, command_case in enumerate(
                                commands_list[1:-1]):
                            middle_dotdot_match = dotdotpat.match(command_case)
                            if middle_dotdot_match:
                                commands_list = \
                                    commands_list[:idx] + \
                                    [command_case] * (coq.count_fg_goals() -
                                                      len(commands_list) + 1) + \
                                                      commands_list[idx+1:]
                                break
                if rest_after_bracket.strip():
                    command_remainders = [
                        cmd + ";" + rest_after_bracket for cmd in commands_list
                    ]
                else:
                    command_remainders = [cmd + "." for cmd in commands_list]
                command_batch.insert(0, command_remainders[0])
                if coq.count_fg_goals() > 1:
                    pending_commands_stack.append(command_remainders[1:])
                    coq.run_stmt("{")
                    yield indentation + "{"
            else:
                if coq.count_fg_goals() > 0:
                    command_batch.insert(0, rest)
                if coq.count_fg_goals() > 1:
                    pending_commands_stack.append(rest)
                    coq.run_stmt("{")
                    yield indentation + "{"
        elif coq.count_fg_goals() > 0:
            coq.run_stmt(command)
            indentation = "  " * (len(pending_commands_stack) +
                                  1) if command.strip() != "Proof." else ""
            yield indentation + command.strip()
            if coq.count_fg_goals() > 1:
                pending_commands_stack.append(None)
                coq.run_stmt("{")
                yield indentation + "{"
    pass