Esempio n. 1
0
def process_statement(args: argparse.Namespace,
                      coq: serapi_instance.SerapiInstance, command: str,
                      result_file: TextIO) -> None:
    if coq.proof_context:
        prev_tactics = coq.prev_tactics
        context = coq.proof_context
        if args.relevant_lemmas == "local":
            relevant_lemmas = [
                re.sub("\n", " ", lemma) for lemma in coq.local_lemmas[:-1]
            ]
        elif args.relevant_lemmas == "hammer":
            relevant_lemmas = coq.get_hammer_premises()
        elif args.relevant_lemmas == "searchabout":
            relevant_lemmas = coq.get_lemmas_about_head()
        else:
            assert False, args.relevant_lemmas

        result_file.write(
            json.dumps({
                "relevant_lemmas": relevant_lemmas,
                "prev_tactics": prev_tactics,
                "context": context.to_dict(),
                "tactic": command
            }))
    else:
        result_file.write(json.dumps(command))
    result_file.write("\n")

    coq.run_stmt(command, timeout=600)
Esempio n. 2
0
 def run_to_next_vernac(coq: serapi_instance.SerapiInstance, pbar: tqdm,
                        initial_full_context: FullContext,
                        lemma_statement: str) -> List[TacticInteraction]:
     nonlocal commands_run
     nonlocal commands_in
     coq.run_stmt(lemma_statement)
     original_tactics: List[TacticInteraction] = []
     lemma_name = serapi_instance.lemma_name_from_statement(lemma_statement)
     try:
         while coq.full_context != None:
             next_in_command = commands_in.pop(0)
             context_before = coq.fullContext
             original_tactics.append(
                 TacticInteraction(next_in_command, context_before))
             coq.run_stmt(next_in_command)
             pbar.update(1)
         body_tactics = [t.tactic for t in original_tactics]
         if next_in_command.strip() == "Defined.":
             append_to_solution_vfile(
                 args.output_dir, args.filename,
                 [f"Reset {lemma_name}.", lemma_statement] + body_tactics)
         commands_run.append(lemma_statement)
         commands_run += body_tactics
     except:
         commands_in = [lemma_statement] + \
             [t.tactic for t in original_tactics] \
             + commands_in
         raise
     return original_tactics
Esempio n. 3
0
def process_statement(coq: serapi_instance.SerapiInstance, command: str,
                      result_file: TextIO) -> None:
    if not re.match("\s*[{}]\s*", command):
        if coq.proof_context:
            prev_tactics = coq.prev_tactics
            prev_hyps = coq.hypotheses
            prev_goal = coq.goals
            result_file.write(
                format_context(prev_tactics, prev_hyps, prev_goal, ""))
            result_file.write(format_tactic(command))
        else:
            subbed_command = re.sub(r"\n", r"\\n", command)
            result_file.write(subbed_command + "\n-----\n")
    coq.run_stmt(command)
def generate_lifted(commands : List[str], coq : serapi_instance.SerapiInstance,
                    pbar : tqdm) \
    -> Iterator[str]:
    lemma_stack = []  # type: List[List[str]]
    for command in commands:
        if serapi_instance.possibly_starting_proof(command):
            coq.run_stmt(command)
            if coq.proof_context:
                lemma_stack.append([])
            coq.cancel_last()
        if len(lemma_stack) > 0 and not lifted_vernac(command):
            lemma_stack[-1].append(command)
        else:
            pbar.update(1)
            yield command
        if serapi_instance.ending_proof(command):
            pending_commands = lemma_stack.pop()
            pbar.update(len(pending_commands))
            yield from pending_commands
    assert len(lemma_stack) == 0, f"Stack still contains {lemma_stack}"
Esempio n. 5
0
 def run_to_next_proof(coq: serapi_instance.SerapiInstance,
                       pbar: tqdm) -> str:
     nonlocal commands_run
     nonlocal commands_in
     nonlocal blocks_out
     nonlocal module_stack
     vernacs: List[str] = []
     assert not coq.full_context
     while not coq.full_context and len(commands_in) > 0:
         next_in_command = commands_in.pop(0)
         # Longer timeout for vernac stuff (especially requires)
         coq.run_stmt(next_in_command, timeout=60)
         if not coq.full_context:
             vernacs.append(next_in_command)
         update_module_stack(next_in_command, module_stack)
         pbar.update(1)
     if len(vernacs) > 0:
         blocks_out.append(VernacBlock(vernacs))
         commands_run += vernacs
         append_to_solution_vfile(args.output_dir, args.filename, vernacs)
     return next_in_command
def run_prediction(coq: serapi_instance.SerapiInstance,
                   prediction: str) -> Tuple[str, str, Optional[Exception]]:
    prediction = prediction.lstrip("-+*")
    coq.quiet = True
    try:
        coq.run_stmt(prediction)
        context = coq.proof_context
        coq.cancel_last()
        assert isinstance(context, str)
        return (prediction, context, None)
    except (ParseError, LexError, BadResponse, CoqExn, TimeoutError) as e:
        return (prediction, "", e)
    finally:
        coq.quiet = False
Esempio n. 7
0
def tryPrediction(
        args: argparse.Namespace, coq: serapi_instance.SerapiInstance,
        g: SearchGraph,
        predictionNode: LabeledNode) -> Tuple[FullContext, int, int, int]:
    coq.quiet = True
    if coq.use_hammer:
        coq.run_stmt(predictionNode.prediction, timeout=30)
    else:
        coq.run_stmt(predictionNode.prediction, timeout=5)
    num_stmts = 1
    subgoals_closed = 0
    while coq.count_fg_goals() == 0 and not completed_proof(coq):
        g.setNodeColor(predictionNode, "blue")
        coq.run_stmt("}")
        subgoals_closed += 1
        num_stmts += 1
    if coq.count_fg_goals() > 1 or \
       (coq.count_fg_goals() > 0 and subgoals_closed > 0):
        subgoals_opened = 1
        coq.run_stmt("{")
        num_stmts += 1
    else:
        subgoals_opened = 0
    context_after = coq.fullContext
    return context_after, num_stmts, subgoals_closed, subgoals_opened
Esempio n. 8
0
def replay_solution_vfile(args : argparse.Namespace, coq : serapi_instance.SerapiInstance,
                          model_name : str, filename : str, commands_in : List[str],
                          module_stack : List[str],
                          bar_idx : int) \
                          -> Tuple[List[str], List[str], List[DocumentBlock],
                                   int, int, int, int]:
    blocks_out: List[DocumentBlock] = []
    num_proofs = 0
    num_proofs_failed = 0
    num_proofs_completed = 0
    num_original_commands_run = 0
    in_proof = False
    skip_sync_next_lemma = False
    curLemma = ""
    curProofInters: List[TacticInteraction] = []
    curVernacCmds: List[str] = []
    with open(f"{args.output_dir}/{escape_filename(filename)}.v", 'r') as f:
        f_iter = check_solution_vfile_args(args, model_name, iter(f))
        svfile_commands = serapi_instance.read_commands_preserve(
            args, bar_idx, "".join(f_iter))
        commands_in_iter = iter(commands_in)
        for saved_command in tqdm(svfile_commands,
                                  unit="cmd",
                                  file=sys.stdout,
                                  desc="Replaying",
                                  disable=(not args.progress),
                                  leave=False,
                                  position=(bar_idx * 2),
                                  dynamic_ncols=True,
                                  bar_format=mybarfmt):
            context_before = coq.fullContext if coq.full_context else FullContext(
                [])
            if not (coq.full_context != None and len(coq.fullContext.subgoals)
                    == 0 and not serapi_instance.ending_proof(saved_command)):
                coq.run_stmt(saved_command)
            if coq.full_context == None:
                if in_proof:
                    in_proof = False
                    num_proofs += 1
                    if re.match("Qed\.", saved_command):
                        search_status = SearchStatus.SUCCESS
                        num_proofs_completed += 1
                    elif re.match("Admitted \(\*FAILURE\*\)\.", saved_command):
                        search_status = SearchStatus.FAILURE
                        num_proofs_failed += 1
                    else:
                        search_status = SearchStatus.INCOMPLETE
                    coq.cancel_last()
                    try:
                        while coq.full_context != None:
                            coq.cancel_last()
                    except serapi_instance.CoqExn as e:
                        raise serapi_instance.CoqAnomaly(
                            f"While cancelling: {e}")

                    origProofInters = []
                    if not skip_sync_next_lemma:
                        proof_cmds = list(
                            serapi_instance.next_proof(commands_in_iter))
                        coq.run_stmt(proof_cmds[0])
                        num_original_commands_run += len(proof_cmds)
                        for proof_cmd in tqdm(proof_cmds[1:],
                                              unit="tac",
                                              file=sys.stdout,
                                              desc="Running original proof",
                                              disable=(not args.progress),
                                              leave=False,
                                              position=(bar_idx * 2) + 1,
                                              dynamic_ncols=True,
                                              bar_format=mybarfmt):
                            context_before_orig = coq.fullContext
                            coq.run_stmt(proof_cmd)
                            origProofInters.append(
                                TacticInteraction(proof_cmd,
                                                  context_before_orig))
                        blocks_out.append(
                            ProofBlock(curLemma, ".".join(module_stack),
                                       search_status, curProofInters,
                                       origProofInters))
                        curVernacCmds = []
                    else:
                        for proof_cmd in proof_cmds:
                            coq.run_stmt(proof_cmd)
                        skip_sync_next_lemma = False

                else:
                    if re.match("Reset .*\.", saved_command):
                        skip_sync_next_lemma = True
                        continue
                    loaded_command = next(commands_in_iter)
                    update_module_stack(saved_command, module_stack)
                    if not re.sub("\s*", " ", loaded_command.strip()) == \
                       re.sub("\s*", " ", saved_command.strip()):
                        raise SourceChangedException(
                            f"Command {loaded_command} doesn't match {saved_command}"
                        )
                    curVernacCmds.append(loaded_command)
            else:
                if not in_proof:
                    in_proof = True
                    curLemma = saved_command
                    blocks_out.append(VernacBlock(curVernacCmds))
                    curProofInters = []
                curProofInters.append(
                    TacticInteraction(saved_command, context_before))
        assert not in_proof
        if curVernacCmds:
            blocks_out.append(VernacBlock(curVernacCmds))
        return svfile_commands, list(commands_in_iter), blocks_out,\
            num_proofs, num_proofs_failed, num_proofs_completed, num_original_commands_run
Esempio n. 9
0
def reinforce_lemma(args: argparse.Namespace,
                    predictor: tactic_predictor.TacticPredictor,
                    estimator: q_estimator.QEstimator,
                    coq: serapi_instance.SerapiInstance,
                    lemma_statement: str,
                    epsilon: float,
                    gamma: float,
                    memory: List[LabeledTransition]) -> None:
    lemma_name = coq.cur_lemma_name
    graph = ReinforceGraph(lemma_name)
    for episode in trange(args.num_episodes, disable=(not args.progress),
                          leave=False):
        cur_node = graph.start_node
        proof_contexts_seen = [unwrap(coq.proof_context)]
        episode_memory = []
        for t in range(args.episode_length):
            with print_time("Getting predictions", guard=args.verbose):
                context_before = coq.tactic_context(coq.local_lemmas[:-1])
                proof_context_before = unwrap(coq.proof_context)
                predictions = predictor.predictKTactics(
                    context_before, args.num_predictions)
            if random.random() < epsilon:
                ordered_actions = [p.prediction for p in
                                   random.sample(predictions,
                                                 len(predictions))]
            else:
                with print_time("Picking actions using q_estimator",
                                guard=args.verbose):
                    q_choices = zip(estimator(
                        [(context_before, prediction.prediction)
                         for prediction in predictions]),
                                    [p.prediction for p in predictions])
                    ordered_actions = [p[1] for p in
                                       sorted(q_choices,
                                              key=lambda q: q[0],
                                              reverse=True)]

            with print_time("Running actions", guard=args.verbose):
                action = None
                for try_action in ordered_actions:
                    try:
                        coq.run_stmt(try_action)
                        proof_context_after = unwrap(coq.proof_context)
                        if any([serapi_instance.contextSurjective(
                                proof_context_after, path_context)
                                for path_context in proof_contexts_seen]):
                            coq.cancel_last()
                            transition = assign_failed_reward(
                                context_before.relevant_lemmas,
                                context_before.prev_tactics,
                                proof_context_before,
                                proof_context_after,
                                try_action,
                                -50)
                            assert transition.reward < 2000
                            memory.append(transition)
                            if args.ghosts:
                                ghost_node = graph.addGhostTransition(
                                    cur_node, try_action)
                                transition.graph_node = ghost_node
                            continue
                        action = try_action
                        break
                    except (serapi_instance.ParseError,
                            serapi_instance.CoqExn,
                            serapi_instance.TimeoutError):
                        transition = assign_failed_reward(
                            context_before.relevant_lemmas,
                            context_before.prev_tactics,
                            proof_context_before,
                            proof_context_before,
                            try_action,
                            -500)
                        assert transition.reward < 2000
                        memory.append(transition)
                        if args.ghosts:
                            ghost_node = graph.addGhostTransition(cur_node,
                                                                  try_action)
                            transition.graph_node = ghost_node
                        pass
                if action is None:
                    # We'll hit this case of we tried all of the
                    # predictions, and none worked
                    graph.setNodeColor(cur_node, "red")
                    break  # Break from episode

            transition = assign_reward(context_before.relevant_lemmas,
                                       context_before.prev_tactics,
                                       proof_context_before,
                                       proof_context_after,
                                       action)
            cur_node = graph.addTransition(cur_node, action,
                                           transition.reward)
            transition.graph_node = cur_node
            assert transition.reward < 2000
            episode_memory.append(transition)
            memory.append(transition)
            proof_contexts_seen.append(proof_context_after)

            if coq.goals == "":
                graph.mkQED(cur_node)
                memory += (episode_memory * (args.success_repetitions - 1))
                break

        with print_time("Assigning scores", guard=args.verbose):
            transition_samples = sample_batch(memory,
                                              args.batch_size)
            training_samples = assign_scores(transition_samples,
                                             estimator, predictor,
                                             args.num_predictions,
                                             gamma,
                                             # Passing this graph
                                             # in so we can
                                             # maintain a record
                                             # of the most recent
                                             # q score estimates
                                             # in the graph
                                             graph)
        with print_time("Training", guard=args.verbose):
            estimator.train(training_samples)

        # Clean up episode
        coq.run_stmt("Admitted.")
        coq.run_stmt(f"Reset {lemma_name}.")
        coq.run_stmt(lemma_statement)
    graphpath = (args.graphs_dir / lemma_name).with_suffix(".png")
    graph.draw(str(graphpath))
    pass
def linearize_commands(args: argparse.Namespace, file_idx: int,
                       commands_sequence: Iterable[str],
                       coq: serapi_instance.SerapiInstance, filename: str,
                       relative_filename: str, skip_nochange_tac: bool,
                       known_failures: List[List[str]]):
    commands_iter = iter(commands_sequence)
    command = next(commands_iter, None)
    assert command, "Got an empty sequence!"
    while command:
        # Run up to the next proof
        while coq.count_fg_goals() == 0:
            coq.run_stmt(command)
            if coq.count_fg_goals() == 0:
                yield command
                command = next(commands_iter, None)
                if not command:
                    return

        # Cancel the proof starting command so that we're right before the proof
        coq.cancel_last()

        # Pull the entire proof from the lifter into command_batch
        command_batch = []
        while command and not serapi_instance.ending_proof(command):
            command_batch.append(command)
            command = next(commands_iter, None)
        # Get the QED on there too.
        if command:
            command_batch.append(command)

        # Now command_batch contains everything through the next
        # Qed/Defined.
        theorem_statement = serapi_instance.kill_comments(command_batch.pop(0))
        theorem_name = theorem_statement.split(":")[0].strip()
        coq.run_stmt(theorem_statement)
        yield theorem_statement
        if [relative_filename, theorem_name] in known_failures:
            eprint("Skipping {}".format(theorem_name), guard=args.verbose >= 1)
            for command in command_batch:
                coq.run_stmt(command)
                yield command
            command = next(commands_iter, None)
            continue

        # This might not be super robust?
        match = re.fullmatch("\s*Proof with (.*)\.\s*", command_batch[0])
        if match and match.group(1):
            with_tactic = match.group(1)
        else:
            with_tactic = ""

        orig = command_batch[:]
        command_batch = list(prelinear_desugar_tacs(command_batch))
        try:
            try:
                batch_handled = list(handle_with(command_batch, with_tactic))
                linearized_commands = list(
                    linearize_proof(coq, theorem_name, batch_handled,
                                    args.verbose, skip_nochange_tac))
                yield from linearized_commands
            except (BadResponse, CoqExn, LinearizerCouldNotLinearize,
                    ParseError, TimeoutError, NoSuchGoalError) as e:
                if args.verbose:
                    eprint("Aborting current proof linearization!")
                    eprint("Proof of:\n{}\nin file {}".format(
                        theorem_name, filename))
                    eprint()
                if args.hardfail:
                    raise e
                coq.run_stmt("Abort.")
                coq.run_stmt(theorem_statement)
                for command in orig:
                    if command:
                        coq.run_stmt(command, timeout=360)
                        yield command
        except CoqAnomaly:
            eprint(
                f"Anomaly! Raising with {[relative_filename, theorem_name]}",
                guard=args.verbose >= 1)
            raise CoqAnomaly([relative_filename, theorem_name])

        command = next(commands_iter, None)
def linearize_proof(coq: serapi_instance.SerapiInstance,
                    theorem_name: str,
                    command_batch: List[str],
                    verbose: int = 0,
                    skip_nochange_tac: bool = False) -> Iterable[str]:
    pending_commands_stack: List[Union[str, List[str], None]] = []
    while command_batch:
        while coq.count_fg_goals() == 0:
            indentation = "  " * (len(pending_commands_stack))
            if len(pending_commands_stack) == 0:
                while command_batch:
                    command = command_batch.pop(0)
                    if "Transparent" in command or \
                       serapi_instance.ending_proof(command):
                        coq.run_stmt(command)
                        yield command
                return
            coq.run_stmt("}")
            yield indentation + "}"
            if coq.count_fg_goals() > 0:
                coq.run_stmt("{")
                yield indentation + "{"
                pending_commands = pending_commands_stack[-1]
                if isinstance(pending_commands, list):
                    next_cmd, *rest_cmd = pending_commands
                    dotdotmatch = re.match("(.*)<\.\.>",
                                           next_cmd,
                                           flags=re.DOTALL)
                    for cmd in rest_cmd:
                        dotdotmatch = re.match("(.*)<\.\.>",
                                               cmd,
                                               flags=re.DOTALL)
                        if dotdotmatch:
                            continue
                        assert serapi_instance.isValidCommand(cmd), \
                            f"\"{cmd}\" is not a valid command"
                    if (not rest_cmd) and dotdotmatch:
                        pending_commands_stack[-1] = [next_cmd]
                        assert serapi_instance.isValidCommand(dotdotmatch.group(1)), \
                            f"\"{dotdotmatch.group(1)}\" is not a valid command"
                        command_batch.insert(0, dotdotmatch.group(1))
                    else:
                        assert serapi_instance.isValidCommand(next_cmd), \
                            f"\"{next_cmd}\" is not a valid command"
                        command_batch.insert(0, next_cmd)
                    pending_commands_stack[-1] = rest_cmd if rest_cmd else None
                    pass
                elif pending_commands:
                    assert serapi_instance.isValidCommand(pending_commands), \
                        f"\"{command}\" is not a valid command"
                    command_batch.insert(0, pending_commands)
            else:
                popped = pending_commands_stack.pop()
                if isinstance(popped, list) and len(popped) > 0 and len(
                        pending_commands_stack) > 1:
                    if pending_commands_stack[-1] is None:
                        pending_commands_stack[-1] = popped
                    elif isinstance(pending_commands_stack[-1], list):
                        if isinstance(popped, list) and "<..>" in popped[-1]:
                            raise LinearizerCouldNotLinearize
                        pending_commands_stack[
                            -1] = popped + pending_commands_stack[-1]
        command = command_batch.pop(0)
        assert serapi_instance.isValidCommand(command), \
            f"command is \"{command}\", command_batch is {command_batch}"
        comment_before_command = ""
        command_proper = command
        while re.fullmatch("\s*\(\*.*", command_proper, flags=re.DOTALL):
            next_comment, command_proper = \
                split_to_next_matching("\(\*", "\*\)", command_proper)
            command_proper = command_proper[1:]
            comment_before_command += next_comment
        if comment_before_command:
            yield comment_before_command
        if re.match("\s*[*+-]+\s*|\s*[{}]\s*", command):
            continue

        command = serapi_instance.kill_comments(command_proper)
        if verbose >= 2:
            eprint(f"Linearizing command \"{command}\"")

        goal_selector_match = re.fullmatch(r"\s*(\d+)\s*:\s*(.*)\.\s*",
                                           command)
        if goal_selector_match:
            goal_num = int(goal_selector_match.group(1))
            rest = goal_selector_match.group(2)
            if goal_num < 2:
                raise LinearizerCouldNotLinearize
            if pending_commands_stack[-1] is None:
                completed_rest = rest + "."
                assert serapi_instance.isValidCommand(rest + "."),\
                    f"\"{completed_rest}\" is not a valid command in {command}"
                pending_commands_stack[-1] = ["idtac."] * (goal_num - 2) + [
                    completed_rest
                ]
            elif isinstance(pending_commands_stack[-1], str):
                pending_cmd = pending_commands_stack[-1]
                pending_commands_stack[-1] = [pending_cmd] * (goal_num - 2) + \
                    [rest + " ; " + pending_cmd] + [pending_cmd + "<..>"]
            else:
                assert isinstance(pending_commands_stack[-1], list)
                pending_cmd_lst = pending_commands_stack[-1]
                try:
                    old_selected_cmd = pending_cmd_lst[goal_num - 2]
                except IndexError:
                    raise LinearizerCouldNotLinearize
                match = re.match("(.*)\.$", old_selected_cmd, re.DOTALL)
                assert match, f"\"{old_selected_cmd}\" doesn't match!"
                cmd_before_period = unwrap(match).group(1)
                new_selected_cmd = f"{cmd_before_period} ; {rest}."
                pending_cmd_lst[goal_num - 2] = new_selected_cmd
            continue

        if split_by_char_outside_matching("\(", "\)", "\|\||&&", command):
            coq.run_stmt(command)
            yield command
            continue

        if re.match("\(", command.strip()):
            inside_parens, after_parens = split_to_next_matching(
                '\(', '\)', command)
            command = inside_parens.strip()[1:-1] + after_parens

        # Extend this to include "by \(" as an opener if you don't
        # desugar all the "by"s
        semi_match = split_by_char_outside_matching("try \(|\(|\{\|",
                                                    "\)|\|\}", "\s*;\s*",
                                                    command)
        if semi_match:
            base_command, rest = semi_match
            rest = rest.lstrip()[1:]
            coq.run_stmt(base_command + ".")
            indentation = "  " * (len(pending_commands_stack) + 1)
            yield indentation + base_command.strip() + "."

            if re.match("\(", rest) and not \
               split_by_char_outside_matching("\(", "\)", "\|\|", rest):
                inside_parens, after_parens = split_to_next_matching(
                    '\(', '\)', rest)
                rest = inside_parens[1:-1] + after_parens
            bracket_match = re.match("\[", rest.strip())
            if bracket_match:
                bracket_command, rest_after_bracket = \
                    split_to_next_matching('\[', '\]', rest)
                rest_after_bracket = rest_after_bracket.lstrip()[1:]
                clauses = multisplit_matching("\[", "\]", "(?<!\|)\|(?!\|)",
                                              bracket_command.strip()[1:-1])
                commands_list = [
                    cmd.strip() if cmd.strip().strip(".") != "" else "idtac" +
                    cmd.strip() for cmd in clauses
                ]
                assert commands_list, command
                dotdotpat = re.compile(r"(.*)\.\.($|\W)")
                ending_dotdot_match = dotdotpat.match(commands_list[-1])
                if ending_dotdot_match:
                    commands_list = commands_list[:-1] + \
                        ([ending_dotdot_match.group(1)] *
                         (coq.count_fg_goals() -
                          len(commands_list) + 1))
                else:
                    starting_dotdot_match = dotdotpat.match(commands_list[0])
                    if starting_dotdot_match:
                        starting_tac = starting_dotdot_match.group(1)
                        commands_list = [starting_tac] * (coq.count_fg_goals() -
                                                          len(commands_list) + 1)\
                                                          + commands_list[1:]
                    else:
                        for idx, command_case in enumerate(
                                commands_list[1:-1]):
                            middle_dotdot_match = dotdotpat.match(command_case)
                            if middle_dotdot_match:
                                commands_list = \
                                    commands_list[:idx] + \
                                    [command_case] * (coq.count_fg_goals() -
                                                      len(commands_list) + 1) + \
                                                      commands_list[idx+1:]
                                break
                if rest_after_bracket.strip():
                    command_remainders = [
                        cmd + ";" + rest_after_bracket for cmd in commands_list
                    ]
                else:
                    command_remainders = [cmd + "." for cmd in commands_list]
                assert serapi_instance.isValidCommand(command_remainders[0]), \
                    f"\"{command_remainders[0]}\" is not a valid command"
                command_batch.insert(0, command_remainders[0])
                if coq.count_fg_goals() > 1:
                    for command in command_remainders[1:]:
                        assert serapi_instance.isValidCommand(command), \
                            f"\"{command}\" is not a valid command"
                    pending_commands_stack.append(command_remainders[1:])
                    coq.run_stmt("{")
                    yield indentation + "{"
            else:
                if coq.count_fg_goals() > 0:
                    assert serapi_instance.isValidCommand(rest), \
                        f"\"{rest}\" is not a valid command, from {command}"
                    command_batch.insert(0, rest)
                if coq.count_fg_goals() > 1:
                    assert serapi_instance.isValidCommand(rest), \
                        f"\"{rest}\" is not a valid command, from {command}"
                    pending_commands_stack.append(rest)
                    coq.run_stmt("{")
                    yield indentation + "{"
        elif coq.count_fg_goals() > 0:
            coq.run_stmt(command)
            indentation = "  " * (len(pending_commands_stack) +
                                  1) if command.strip() != "Proof." else ""
            yield indentation + command.strip()
            if coq.count_fg_goals() > 1:
                pending_commands_stack.append(None)
                coq.run_stmt("{")
                yield indentation + "{"
    pass
Esempio n. 12
0
def linearize_proof(coq: serapi_instance.SerapiInstance,
                    theorem_name: str,
                    command_batch: List[str],
                    debug: bool = False,
                    skip_nochange_tac: bool = False) -> Iterable[str]:
    pending_commands_stack: List[Union[str, List[str], None]] = []
    while command_batch:
        while coq.count_fg_goals() == 0:
            indentation = "  " * (len(pending_commands_stack))
            if len(pending_commands_stack) == 0:
                while command_batch:
                    command = command_batch.pop(0)
                    if "Transparent" in command or \
                       serapi_instance.ending_proof(command):
                        coq.run_stmt(command)
                        yield command
                return
            coq.run_stmt("}")
            yield indentation + "}"
            if coq.count_fg_goals() > 0:
                coq.run_stmt("{")
                yield indentation + "{"
                pending_commands = pending_commands_stack[-1]
                if isinstance(pending_commands, list):
                    next_cmd, *rest_cmd = pending_commands
                    dotdotmatch = re.match("(.*)<\.\.>", next_cmd)
                    if (not rest_cmd) and dotdotmatch:
                        pending_commands_stack[-1] = [next_cmd]
                        command_batch.insert(0, dotdotmatch.group(1))
                    else:
                        command_batch.insert(0, next_cmd)
                    pending_commands_stack[-1] = rest_cmd if rest_cmd else None
                    pass
                elif pending_commands:
                    command_batch.insert(0, pending_commands)
            else:
                popped = pending_commands_stack.pop()
                if isinstance(popped, list) and len(popped) > 0 and len(
                        pending_commands_stack) > 1:
                    if pending_commands_stack[-1] is None:
                        pending_commands_stack[-1] = popped
                    elif isinstance(pending_commands_stack[-1], list):
                        pending_commands_stack[
                            -1] = popped + pending_commands_stack[-1]
        command = command_batch.pop(0)
        assert serapi_instance.isValidCommand(command), \
            f"command is \"{command}\", command_batch is {command_batch}"
        comment_before_command = ""
        command_proper = command
        while "(*" in command_proper:
            next_comment, command_proper = \
                split_to_next_matching("\(\*", "\*\)", command_proper)
            command_proper = command_proper[1:]
            comment_before_command += next_comment
        if comment_before_command:
            yield comment_before_command
        if re.match("\s*[*+-]+\s*|\s*[{}]\s*", command):
            continue

        command = command_proper
        if debug:
            eprint(f"Linearizing command \"{command}\"")

        goal_selector_match = re.match(r"\s*(\d*)\s*:\s*(.*)\.\s*", command)
        if goal_selector_match:
            goal_num = int(goal_selector_match.group(1))
            rest = goal_selector_match.group(2)
            assert goal_num >= 2
            if pending_commands_stack[-1] is None:
                pending_commands_stack[-1] = ["idtac."] * (goal_num - 2) + [
                    rest + "."
                ]
            elif isinstance(pending_commands_stack[-1], str):
                pending_cmd = pending_commands_stack[-1]
                pending_commands_stack[-1] = [pending_cmd] * (goal_num - 2) + \
                    [rest + " ; " + pending_cmd] + [pending_cmd + "<..>"]
            else:
                assert isinstance(pending_commands_stack[-1], list)
            continue

        if split_by_char_outside_matching("\(", "\)", "\|\||&&", command):
            coq.run_stmt(command)
            yield command
            continue

        if re.match("\(", command.strip()):
            inside_parens, after_parens = split_to_next_matching(
                '\(', '\)', command)
            command = inside_parens.strip()[1:-1] + after_parens

        # Extend this to include "by \(" as an opener if you don't
        # desugar all the "by"s
        semi_match = split_by_char_outside_matching("try \(|\(|\{\|",
                                                    "\)|\|\}", "\s*;\s*",
                                                    command)
        if semi_match:
            base_command, rest = semi_match
            rest = rest.lstrip()[1:]
            coq.run_stmt(base_command + ".")
            indentation = "  " * (len(pending_commands_stack) + 1)
            yield indentation + base_command.strip() + "."

            if re.match("\(", rest) and not \
               split_by_char_outside_matching("\(", "\)", "\|\|", rest):
                inside_parens, after_parens = split_to_next_matching(
                    '\(', '\)', rest)
                rest = inside_parens[1:-1] + after_parens
            bracket_match = re.match("\[", rest.strip())
            if bracket_match:
                bracket_command, rest_after_bracket = \
                    split_to_next_matching('\[', '\]', rest)
                rest_after_bracket = rest_after_bracket.lstrip()[1:]
                clauses = multisplit_matching("\[", "\]", "(?<!\|)\|(?!\|)",
                                              bracket_command.strip()[1:-1])
                commands_list = [
                    cmd.strip() if cmd.strip().strip(".") != "" else "idtac" +
                    cmd.strip() for cmd in clauses
                ]
                dotdotpat = re.compile(r"(.*)\.\.($|\W)")
                ending_dotdot_match = dotdotpat.match(commands_list[-1])
                if ending_dotdot_match:
                    commands_list = commands_list[:-1] + \
                        ([ending_dotdot_match.group(1)] *
                         (coq.count_fg_goals() -
                          len(commands_list) + 1))
                else:
                    starting_dotdot_match = dotdotpat.match(commands_list[0])
                    if starting_dotdot_match:
                        starting_tac = starting_dotdot_match.group(1)
                        commands_list = [starting_tac] * (coq.count_fg_goals() -
                                                          len(commands_list) + 1)\
                                                          + commands_list[1:]
                    else:
                        for idx, command_case in enumerate(
                                commands_list[1:-1]):
                            middle_dotdot_match = dotdotpat.match(command_case)
                            if middle_dotdot_match:
                                commands_list = \
                                    commands_list[:idx] + \
                                    [command_case] * (coq.count_fg_goals() -
                                                      len(commands_list) + 1) + \
                                                      commands_list[idx+1:]
                                break
                if rest_after_bracket.strip():
                    command_remainders = [
                        cmd + ";" + rest_after_bracket for cmd in commands_list
                    ]
                else:
                    command_remainders = [cmd + "." for cmd in commands_list]
                command_batch.insert(0, command_remainders[0])
                if coq.count_fg_goals() > 1:
                    pending_commands_stack.append(command_remainders[1:])
                    coq.run_stmt("{")
                    yield indentation + "{"
            else:
                if coq.count_fg_goals() > 0:
                    command_batch.insert(0, rest)
                if coq.count_fg_goals() > 1:
                    pending_commands_stack.append(rest)
                    coq.run_stmt("{")
                    yield indentation + "{"
        elif coq.count_fg_goals() > 0:
            coq.run_stmt(command)
            indentation = "  " * (len(pending_commands_stack) +
                                  1) if command.strip() != "Proof." else ""
            yield indentation + command.strip()
            if coq.count_fg_goals() > 1:
                pending_commands_stack.append(None)
                coq.run_stmt("{")
                yield indentation + "{"
    pass