def preprocess_file_commands(args: argparse.Namespace, file_idx: int, commands: List[str], coqargs: List[str], includes: str, filename: str, relative_filename: str, skip_nochange_tac: bool) -> List[str]: try: with serapi_instance.SerapiContext(coqargs, includes, args.prelude) as coq: coq.debug = args.debug with tqdm(file=sys.stdout, disable=not args.progress, position=(file_idx * 2), desc="Linearizing", leave=False, total=len(commands), dynamic_ncols=True, bar_format=mybarfmt) as pbar: result = list( postlinear_desugar_tacs( linearize_commands( args, file_idx, generate_lifted(commands, coq, pbar), coq, filename, relative_filename, skip_nochange_tac))) return result except (CoqExn, BadResponse, AckError, CompletedError): eprint("In file {}".format(filename)) raise except serapi_instance.TimeoutError: eprint("Timed out while lifting commands! Skipping linearization...") return commands
def scrape_file(coqargs: List[str], args: argparse.Namespace, includes: str, file_tuple: Tuple[int, str]) -> Optional[str]: sys.setrecursionlimit(4500) file_idx, filename = file_tuple full_filename = args.prelude + "/" + filename result_file = full_filename + ".scrape" temp_file = full_filename + ".scrape.partial" if args.cont: with contextlib.suppress(FileNotFoundError): with open(result_file, 'r') as f: if args.verbose: eprint(f"Found existing scrape at {result_file}! Using it") return result_file try: if args.linearize: commands = serapi_instance.try_load_lin(args, file_idx, full_filename) if not commands: commands = linearize_semicolons.preprocess_file_commands( args, file_idx, serapi_instance.load_commands_preserve( args, 0, full_filename), coqargs, args.prelude, full_filename, filename, args.skip_nochange_tac) serapi_instance.save_lin(commands, full_filename) else: with Path2(full_filename).open(mode='r') as vf: commands = serapi_instance.read_commands_preserve( args, file_idx, vf.read()) with serapi_instance.SerapiContext( coqargs, serapi_instance.get_module_from_filename(filename), args.prelude, args.relevant_lemmas == "hammer") as coq: coq.verbose = args.verbose try: with open(temp_file, 'w') as f: for command in tqdm(commands, file=sys.stdout, disable=(not args.progress), position=file_idx * 2, desc="Scraping file", leave=False, dynamic_ncols=True, bar_format=mybarfmt): process_statement(args, coq, command, f) shutil.move(temp_file, result_file) return result_file except serapi_instance.TimeoutError: eprint("Command in {} timed out.".format(filename)) return temp_file except Exception as e: eprint("FAILED: In file {}:".format(filename)) eprint(e) if args.hardfail or len(args.inputs) == 1 or args.hardfail_scrape: raise e return None
def preprocess_file_commands(args: argparse.Namespace, file_idx: int, commands: List[str], coqargs: List[str], prelude: str, filename: str, relative_filename: str, skip_nochange_tac: bool) -> List[str]: try: failed = True failures = list(compcert_failures) while failed: with serapi_instance.SerapiContext( coqargs, serapi_instance.get_module_from_filename(filename), prelude) as coq: coq.verbose = args.verbose coq.quiet = True with tqdm(file=sys.stdout, disable=not args.progress, position=(file_idx * 2), desc="Linearizing", leave=False, total=len(commands), dynamic_ncols=True, bar_format=mybarfmt) as pbar: try: failed = False result = list( postlinear_desugar_tacs( linearize_commands( args, file_idx, generate_lifted(commands, coq, pbar), coq, filename, relative_filename, skip_nochange_tac, failures))) except CoqAnomaly as e: if isinstance(e.msg, str): raise failed = True failures.append(cast(List[str], e.msg)) return result except (CoqExn, BadResponse, AckError, CompletedError): eprint("In file {}".format(filename)) raise except serapi_instance.TimeoutError: eprint("Timed out while lifting commands! Skipping linearization...") return commands
def scrape_file(coqargs: List[str], args: argparse.Namespace, includes: str, file_tuple: Tuple[int, str]) -> str: file_idx, filename = file_tuple full_filename = args.prelude + "/" + filename result_file = full_filename + ".scrape" if args.cont: with contextlib.suppress(FileNotFoundError): with open(result_file, 'r') as f: if args.verbose: eprint(f"Found existing scrape at {result_file}! Using it") return result_file try: commands = serapi_instance.try_load_lin(args, file_idx, full_filename) if not commands: commands = linearize_semicolons.preprocess_file_commands( args, file_idx, serapi_instance.load_commands(full_filename), coqargs, includes, full_filename, filename, args.skip_nochange_tac) serapi_instance.save_lin(commands, full_filename) with serapi_instance.SerapiContext(coqargs, includes, args.prelude) as coq: coq.debug = args.debug try: with open(result_file, 'w') as f: for command in tqdm(commands, file=sys.stdout, disable=(not args.progress), position=file_idx * 2, desc="Scraping file", leave=False, dynamic_ncols=True, bar_format=mybarfmt): process_statement(coq, command, f) except serapi_instance.TimeoutError: eprint("Command in {} timed out.".format(filename)) return result_file except Exception as e: eprint("FAILED: In file {}:".format(filename)) eprint(e) if args.hardfail: raise e
def search_file(args: argparse.Namespace, coqargs: List[str], includes: str, predictor: TacticPredictor, bar_idx: int) -> None: global obligation_number obligation_number = 0 num_proofs = 0 num_proofs_failed = 0 num_proofs_completed = 0 commands_run: List[str] = [] blocks_out: List[DocumentBlock] = [] commands_caught_up = 0 lemmas_to_skip: List[str] = [] if args.resume: try: check_csv_args(args, args.filename) with tqdm(total=1, unit="cmd", file=sys.stdout, desc=args.filename.name + " (Resumed)", disable=(not args.progress), leave=True, position=(bar_idx * 2), dynamic_ncols=True, bar_format=mybarfmt) as pbar: pbar.update(1) if not args.progress: print(f"Resumed {args.filename} from existing state") return except FileNotFoundError: pass except ArgsMismatchException as e: if not args.progress: eprint(f"Arguments in csv for {args.filename} " f"didn't match current arguments! {e} " f"Overwriting (interrupt to cancel).") commands_in = linearize_semicolons.get_linearized(args, coqargs, includes, bar_idx, str(args.filename)) num_commands_total = len(commands_in) lemma_statement = "" module_stack: List[str] = [] # Run vernacular until the next proof (or end of file) def run_to_next_proof(coq: serapi_instance.SerapiInstance, pbar: tqdm) -> str: nonlocal commands_run nonlocal commands_in nonlocal blocks_out nonlocal module_stack vernacs: List[str] = [] assert not coq.full_context while not coq.full_context and len(commands_in) > 0: next_in_command = commands_in.pop(0) # Longer timeout for vernac stuff (especially requires) coq.run_stmt(next_in_command, timeout=60) if not coq.full_context: vernacs.append(next_in_command) update_module_stack(next_in_command, module_stack) pbar.update(1) if len(vernacs) > 0: blocks_out.append(VernacBlock(vernacs)) commands_run += vernacs append_to_solution_vfile(args.output_dir, args.filename, vernacs) return next_in_command def run_to_next_vernac(coq: serapi_instance.SerapiInstance, pbar: tqdm, initial_full_context: FullContext, lemma_statement: str) -> List[TacticInteraction]: nonlocal commands_run nonlocal commands_in coq.run_stmt(lemma_statement) original_tactics: List[TacticInteraction] = [] lemma_name = serapi_instance.lemma_name_from_statement(lemma_statement) try: while coq.full_context != None: next_in_command = commands_in.pop(0) context_before = coq.fullContext original_tactics.append( TacticInteraction(next_in_command, context_before)) coq.run_stmt(next_in_command) pbar.update(1) body_tactics = [t.tactic for t in original_tactics] if next_in_command.strip() == "Defined.": append_to_solution_vfile( args.output_dir, args.filename, [f"Reset {lemma_name}.", lemma_statement] + body_tactics) commands_run.append(lemma_statement) commands_run += body_tactics except: commands_in = [lemma_statement] + \ [t.tactic for t in original_tactics] \ + commands_in raise return original_tactics def add_proof_block(status: SearchStatus, solution: Optional[List[TacticInteraction]], initial_full_context: FullContext, original_tactics: List[TacticInteraction]) -> None: nonlocal num_proofs_failed nonlocal num_proofs_completed nonlocal blocks_out empty_context = FullContext([]) # Append the proof data if solution: num_proofs_completed += 1 blocks_out.append( ProofBlock( lemma_statement, ".".join(module_stack), status, [TacticInteraction("Proof.", initial_full_context)] + solution + [TacticInteraction("Qed.", empty_context)], original_tactics)) else: blocks_out.append( ProofBlock(lemma_statement, ".".join(module_stack), status, [ TacticInteraction("Proof.", initial_full_context), TacticInteraction("Admitted.", initial_full_context) ], original_tactics)) if not args.progress: print("Loaded {} commands for file {}".format(len(commands_in), args.filename)) with tqdm(total=num_commands_total, unit="cmd", file=sys.stdout, desc=args.filename.name, disable=(not args.progress), leave=True, position=(bar_idx * 2), dynamic_ncols=True, bar_format=mybarfmt) as pbar: while len(commands_in) > 0: try: # print("Starting a coq instance...") with serapi_instance.SerapiContext( coqargs, includes, args.prelude, use_hammer=args.use_hammer) as coq: if args.progress: pbar.reset() for command in commands_run: pbar.update(1) coq.run_stmt(command) coq.debug = args.debug if args.resume and len(commands_run) == 0: model_name = dict(predictor.getOptions())["predictor"] try: commands_run, commands_in, blocks_out, \ num_proofs, num_proofs_failed, num_proofs_completed, \ num_original_commands_run = \ replay_solution_vfile(args, coq, model_name, args.filename, commands_in, module_stack, bar_idx) pbar.update(num_original_commands_run) except FileNotFoundError: make_new_solution_vfile(args, model_name, args.filename) pass except (ArgsMismatchException, SourceChangedException) as e: eprint( f"Arguments in solution vfile for {args.filename} " f"didn't match current arguments, or sources mismatch! " f"{e}") if args.overwrite_mismatch: eprint("Overwriting.") make_new_solution_vfile( args, model_name, args.filename) raise serapi_instance.CoqAnomaly("Replaying") else: raise SourceChangedException if len(commands_run) > 0 and (args.verbose or args.debug): eprint("Caught up with commands:\n{}\n...\n{}".format( commands_run[0].strip(), commands_run[-1].strip())) while len(commands_in) > 0: lemma_statement = run_to_next_proof(coq, pbar) if len(commands_in) == 0: break # Get beginning of next proof num_proofs += 1 initial_context = coq.fullContext # Try to search if lemma_statement in lemmas_to_skip: search_status = SearchStatus.FAILURE tactic_solution: Optional[ List[TacticInteraction]] = [] else: search_status, tactic_solution = \ attempt_search(args, lemma_statement, ".".join(module_stack), coq, bar_idx) # assert False # Cancel until before the proof try: while coq.full_context != None: coq.cancel_last() except serapi_instance.CoqExn as e: raise serapi_instance.CoqAnomaly( f"While cancelling: {e}") if tactic_solution: append_to_solution_vfile( args.output_dir, args.filename, [lemma_statement, "Proof."] + [tac.tactic for tac in tactic_solution] + ["Qed."]) else: if search_status == SearchStatus.FAILURE: num_proofs_failed += 1 admitted = "Admitted (*FAILURE*)." else: admitted = "Admitted (*INCOMPLETE*)." append_to_solution_vfile( args.output_dir, args.filename, [lemma_statement, "Proof.\n", admitted]) # Run the original proof original_tactics = run_to_next_vernac( coq, pbar, initial_context, lemma_statement) add_proof_block(search_status, tactic_solution, initial_context, original_tactics) except serapi_instance.CoqAnomaly as e: if lemma_statement: commands_in.insert(0, lemma_statement) if commands_caught_up == len(commands_run): eprint(f"Hit the same anomaly twice!") if lemma_statement in lemmas_to_skip: raise e else: lemmas_to_skip.append(lemma_statement) commands_caught_up = len(commands_run) if args.hardfail: raise e if args.verbose or args.debug: eprint( f"Hit a coq anomaly {e.msg}! Restarting coq instance.") except Exception as e: eprint(f"FAILED: in file {args.filename}, {repr(e)}") raise write_html(args, args.output_dir, args.filename, blocks_out) write_csv(args, args.filename, blocks_out)
def process_file(self, args : argparse.Namespace, file_idx : int, filename : str) \ -> None: global gresult fresult = FileResult(filename) if self.debug: print("Preprocessing...") commands = self.get_commands(args, file_idx, filename) command_results: List[CommandResult] = [] with serapi_instance.SerapiContext(self.coqargs, self.includes, self.prelude) as coq: coq.debug = self.debug nb_commands = len(commands) for i in range(nb_commands): command = commands[i] # print("Processing command {}/{}".format(str(i+1), str(nb_commands))) in_proof = (coq.proof_context and not re.match(".*Proof.*", command.strip())) if re.match("[{}]", command): coq.run_stmt(command) continue if in_proof: prev_tactics = coq.prev_tactics initial_context = coq.proof_context assert initial_context hyps = coq.hypotheses goals = coq.goals if self.baseline: predictions_and_certanties = [baseline_tactic + ".", 1] \ * num_predictions else: predictions_and_certainties, loss = net.predictKTacticsWithLoss( TacticContext(prev_tactics, hyps, goals), num_predictions, command) prediction_runs = [ run_prediction(coq, prediction) for prediction, certainty in predictions_and_certainties ] try: coq.run_stmt(command) actual_result_context = coq.proof_context actual_result_goal = coq.goals actual_result_hypothesis = coq.hypotheses assert isinstance(actual_result_context, str) except (AckError, CompletedError, CoqExn, BadResponse, ParseError, LexError, TimeoutError): print("In file {}:".format(filename)) raise prediction_results = [ (prediction, evaluate_prediction(fresult, initial_context, command, actual_result_context, prediction_run), certainty) for prediction_run, (prediction, certainty) in zip( prediction_runs, predictions_and_certainties) ] assert net.training_args if self.cfilter( { "goal": format_goal(goals), "hyps": format_hypothesis(hyps) }, command, { "goal": format_goal(actual_result_goal), "hyps": format_hypothesis(actual_result_hypothesis) }, net.training_args): fresult.add_command_result([ pred for pred, ctxt, ex in prediction_runs ], [ grade for pred, grade, certainty in prediction_results ], command, loss) command_results.append( (command, hyps, goals, prediction_results)) else: command_results.append((command, )) else: try: coq.run_stmt(command) except (AckError, CompletedError, CoqExn, BadResponse, ParseError, LexError, TimeoutError): print("In file {}:".format(filename)) raise command_results.append((command, )) write_csv(fresult.details_filename(), self.output_dir, gresult.options, command_results) doc, tag, text, line = Doc().ttl() with tag('html'): details_header(tag, doc, text, filename) with tag('div', id='overlay', onclick='event.stopPropagation();'): with tag('div', id='predicted'): pass with tag('div', id='context'): pass with tag('div', id='stats'): pass pass with tag('body', onclick='deselectTactic()', onload='setSelectedIdx()'), tag('pre'): for idx, command_result in enumerate(command_results): if len(command_result) == 1: with tag('code', klass='plaincommand'): text(command_result[0]) else: command, hyps, goal, prediction_results = \ cast(TacticResult, command_result) predictions, grades, certainties = zip( *prediction_results) search_index = 0 for pidx, prediction_result in enumerate( prediction_results): prediction, grade, certainty = prediction_result if (grade != "failedcommand" and grade != "superfailedcommand"): search_index = pidx break with tag( 'span', ('data-hyps', "\n".join(hyps)), ('data-goal', shorten_whitespace(goal)), ('data-num-total', str(fresult.num_tactics)), ('data-predictions', to_list_string(cast(List[str], predictions))), ('data-num-predicteds', to_list_string([ fresult.predicted_tactic_frequency.get( get_stem(prediction), 0) for prediction in cast( List[str], predictions) ])), ('data-num-corrects', to_list_string([ fresult.correctly_predicted_frequency.get( get_stem(prediction), 0) for prediction in cast( List[str], predictions) ])), ('data-certainties', to_list_string(cast(List[float], certainties))), ('data-num-actual-corrects', fresult.correctly_predicted_frequency.get( get_stem(command), 0)), ('data-num-actual-in-file', fresult.actual_tactic_frequency.get( get_stem(command))), ('data-actual-tactic', strip_comments(command)), ('data-grades', to_list_string(cast(List[str], grades))), ('data-search-idx', search_index), id='command-' + str(idx), onmouseover='hoverTactic({})'.format(idx), onmouseout='unhoverTactic()', onclick= 'selectTactic({}); event.stopPropagation();'. format(idx)): doc.stag("br") for idx, prediction_result in enumerate( prediction_results): prediction, grade, certainty = prediction_result if search_index == idx: with tag('code', klass=grade): text(" " + command.strip()) else: with tag('span', klass=grade): doc.asis(" ⬤") with open( "{}/{}.html".format(self.output_dir, fresult.details_filename()), "w") as fout: fout.write(doc.getvalue()) gresult.add_file_result(fresult) rows.put(fresult)
def reinforce(args: argparse.Namespace) -> None: # Load the scraped (demonstrated) samples, the proof environment # commands, and the predictor replay_memory = assign_rewards( dataloader.tactic_transitions_from_file(args.scrape_file, args.buffer_size)) predictor = predict_tactic.loadPredictorByFile(args.predictor_weights) q_estimator = FeaturesQEstimator(args.learning_rate, args.batch_step, args.gamma) signal.signal( signal.SIGINT, lambda signal, frame: progn(q_estimator.save_weights( args.out_weights, args), # type: ignore exit())) if args.start_from: q_estimator_name, *saved = \ torch.load(args.start_from) q_estimator.load_saved_state(*saved) elif args.pretrain: pre_train(args, q_estimator, dataloader.tactic_transitions_from_file( args.scrape_file, args.buffer_size * 10)) epsilon = 0.3 gamma = 0.9 if args.proof is not None: assert len(args.environment_files) == 1, \ "Can't use multiple env files with --proof!" env_commands = serapi_instance.load_commands_preserve( args, 0, args.prelude / args.environment_files[0]) num_proofs = len([cmd for cmd in env_commands if cmd.strip() == "Qed." or cmd.strip() == "Defined."]) with serapi_instance.SerapiContext( ["sertop", "--implicit"], serapi_instance.get_module_from_filename( args.environment_files[0]), str(args.prelude)) as coq: coq.quiet = True coq.verbose = args.verbose rest_commands, run_commands = coq.run_into_next_proof(env_commands) lemma_statement = run_commands[-1] while coq.cur_lemma_name != args.proof: if not rest_commands: eprint("Couldn't find lemma {args.proof}! Exiting...") return rest_commands, _ = coq.finish_proof(rest_commands) rest_commands, run_commands = coq.run_into_next_proof( rest_commands) lemma_statement = run_commands[-1] reinforce_lemma(args, predictor, q_estimator, coq, lemma_statement, epsilon, gamma, replay_memory) q_estimator.save_weights(args.out_weights, args) else: for env_file in args.environment_files: env_commands = serapi_instance.load_commands_preserve( args, 0, args.prelude / env_file) num_proofs = len([cmd for cmd in env_commands if cmd.strip() == "Qed." or cmd.strip() == "Defined."]) rest_commands = env_commands all_run_commands: List[str] = [] with tqdm(total=num_proofs, disable=(not args.progress), leave=True, desc=env_file.stem) as pbar: while rest_commands: with serapi_instance.SerapiContext( ["sertop", "--implicit"], serapi_instance.get_module_from_filename( env_file), str(args.prelude), log_outgoing_messages=args.log_outgoing_messages) \ as coq: coq.quiet = True coq.verbose = args.verbose for command in all_run_commands: coq.run_stmt(command) while rest_commands: rest_commands, run_commands = \ coq.run_into_next_proof(rest_commands) if not rest_commands: break all_run_commands += run_commands[:-1] lemma_statement = run_commands[-1] # Check if the definition is # proof-relevant. If it is, then finishing # subgoals doesn't necessarily mean you've # solved the problem, so don't try to # train on it. proof_relevant = False for cmd in rest_commands: if serapi_instance.ending_proof(cmd): if cmd.strip() == "Defined.": proof_relevant = True break proof_relevant = proof_relevant or \ bool(re.match(r"\s*Derive", lemma_statement)) for sample in replay_memory: sample.graph_node = None if not proof_relevant: try: reinforce_lemma(args, predictor, q_estimator, coq, lemma_statement, epsilon, gamma, replay_memory) except serapi_instance.CoqAnomaly: if args.log_anomalies: with args.log_anomalies.open('a') as f: traceback.print_exc(file=f) if args.hardfail: eprint( "Hit an anomaly!" "Quitting due to --hardfail") raise eprint( "Hit an anomaly! " "Restarting coq instance") rest_commands.insert(0, lemma_statement) break pbar.update(1) rest_commands, run_commands = \ coq.finish_proof(rest_commands) all_run_commands.append(lemma_statement) all_run_commands += run_commands q_estimator.save_weights(args.out_weights, args)