def main(): parser = \ argparse.ArgumentParser( description="A module for drawing and re-drawing reinforcement " "learning graphs") parser.add_argument("predictor_weights") parser.add_argument("estimator_weights") parser.add_argument("graph_json") parser.add_argument("--max-term-length", default=512, type=int) args = parser.parse_args() predictor = predict_tactic.loadPredictorByFile(args.predictor_weights) q_estimator_name, *saved = torch.load(str(args.estimator_weights)) if q_estimator_name == "features evaluator": q_estimator = features_q_estimator.FeaturesQEstimator(0, 0, 0) elif q_estimator_name == "polyarg evaluator": q_estimator = polyarg_q_estimator.PolyargQEstimator( 0, 0, 0, cast(features_polyarg_predictor.FeaturesPolyargPredictor, predictor)) graph = ReinforceGraph.load(args.graph_json) assignApproximateQScores(graph, args.max_term_length, predictor, q_estimator) graph.draw(Path2(args.graph_json).stem)
def main(arg_list : List[str]) -> None: global predictor parser = argparse.ArgumentParser(description= "Produce an html report from the scrape file.") parser.add_argument("-j", "--threads", default=16, type=int) parser.add_argument("--prelude", default=".", type=Path2) parser.add_argument("--verbose", "-v", help="verbose output", action='store_const', const=True, default=False) parser.add_argument("--progress", "-P", help="show progress of files", action='store_const', const=True, default=False) parser.add_argument("--debug", default=False, const=True, action='store_const') parser.add_argument("--output", "-o", help="output data folder name", default="static-report", type=Path2) parser.add_argument("--message", "-m", default=None) parser.add_argument('--context-filter', dest="context_filter", type=str, default=None) parser.add_argument('--chunk-size', dest="chunk_size", type=int, default=4096) parser.add_argument('--weightsfile', default=None) parser.add_argument('--predictor', choices=list(static_predictors.keys()), default=None) parser.add_argument("--num-predictions", dest="num_predictions", type=int, default=3) parser.add_argument('--skip-nochange-tac', default=False, const=True, action='store_const', dest='skip_nochange_tac') parser.add_argument('filenames', nargs="+", help="proof file name (*.v)", type=Path2) args = parser.parse_args(arg_list) cur_commit = subprocess.check_output(["git show --oneline | head -n 1"], shell=True).decode('utf-8').strip() cur_date = datetime.datetime.now() if args.weightsfile: predictor = loadPredictorByFile(args.weightsfile) elif args.predictor: predictor = loadPredictorByName(args.predictor) else: print("You must specify either --weightsfile or --predictor!") parser.print_help() return if not args.output.exists(): args.output.makedirs() context_filter = args.context_filter or dict(predictor.getOptions())["context_filter"] with multiprocessing.pool.ThreadPool(args.threads) as pool: file_results = \ list((stats for stats in pool.imap_unordered(functools.partial(report_file, args, predictor.training_args, context_filter), args.filenames) if stats)) write_summary(args, predictor.getOptions() + [("report type", "static"), ("predictor", args.predictor)], cur_commit, cur_date, file_results)
def get_predictor(parser: argparse.ArgumentParser, args: argparse.Namespace) -> TacticPredictor: predictor: TacticPredictor if args.weightsfile: predictor = loadPredictorByFile(args.weightsfile) elif args.predictor: predictor = loadPredictorByName(args.predictor) else: print("You must specify either --weightsfile or --predictor!") parser.print_help() sys.exit(1) return predictor
def q_report(args: argparse.Namespace) -> None: num_originally_correct = 0 num_correct = 0 num_top3 = 0 num_total = 0 num_possible = 0 predictor = predict_tactic.loadPredictorByFile(args.predictor_weights) q_estimator_name, *saved = \ torch.load(args.estimator_weights) q_estimator = FeaturesQEstimator(0, 0, 0) q_estimator.load_saved_state(*saved) for filename in args.test_files: points = dataloader.scraped_tactics_from_file( str(filename) + ".scrape", None) for point in points: context = TacticContext(point.relevant_lemmas, point.prev_tactics, point.prev_hyps, point.prev_goal) predictions = [ p.prediction for p in predictor.predictKTactics( context, args.num_predictions) ] q_choices = zip( q_estimator([(context, prediction) for prediction in predictions]), predictions) ordered_actions = [ p[1] for p in sorted(q_choices, key=lambda q: q[0], reverse=True) ] num_total += 1 if point.tactic.strip() in predictions: num_possible += 1 if ordered_actions[0] == point.tactic.strip(): num_correct += 1 if point.tactic.strip() in ordered_actions[:3]: num_top3 += 1 if predictions[0] == point.tactic.strip(): num_originally_correct += 1 pass print(f"num_correct: {num_correct}") print(f"num_originally_correct: {num_originally_correct}") print(f"num_top3: {num_top3}") print(f"num_total: {num_total}") print(f"num_possible: {num_possible}")
def main(arg_list: List[str]) -> None: global jobs global num_jobs global net global gresult parser = argparse.ArgumentParser( description="try to match the file by predicting a tactic") parser.add_argument('-j', '--threads', default=16, type=int) parser.add_argument('--prelude', default=".", type=Path2) parser.add_argument('--debug', default=False, const=True, action='store_const') parser.add_argument("--verbose", "-v", help="verbose output", action='store_const', const=True, default=False) parser.add_argument("--progress", "-P", help="show progress of files", action='store_const', const=True, default=False) parser.add_argument('-o', '--output', help="output data folder name", default="report", type=Path2) parser.add_argument('-m', '--message', default=None) parser.add_argument( '--baseline', help="run in baseline mode, predicting {} every time".format( baseline_tactic), default=False, const=True, action='store_const') parser.add_argument('--context-filter', dest="context_filter", type=str, default=None) parser.add_argument('--weightsfile', default=None) parser.add_argument('--predictor', choices=list(static_predictors.keys()), default=None) parser.add_argument('--skip-nochange-tac', default=False, const=True, action='store_const', dest='skip_nochange_tac') parser.add_argument('filenames', nargs="+", help="proof file name (*.v)", type=Path2) args = parser.parse_args(arg_list) coqargs = ["sertop", "--implicit"] includes = subprocess.Popen( ['make', '-C', str(args.prelude), 'print-includes'], stdout=subprocess.PIPE).communicate()[0].decode('utf-8') # Get some metadata cur_commit = subprocess.check_output(["git show --oneline | head -n 1"], shell=True).decode('utf-8').strip() cur_date = datetime.datetime.now() if not args.output.exists(): args.output.makedirs() jobs = queue.Queue() workers = [] num_jobs = len(args.filenames) for infname in args.filenames: jobs.put(infname) args.threads = min(args.threads, len(args.filenames)) if args.weightsfile: net = loadPredictorByFile(args.weightsfile) elif args.predictor: net = loadPredictorByName(args.predictor) else: print("You must specify either --weightsfile or --predictor!") parser.print_help() return gresult = GlobalResult(net.getOptions()) context_filter = args.context_filter or dict( net.getOptions())["context_filter"] for idx in range(args.threads): worker = Worker(idx, coqargs, includes, args.output, args.prelude, args.debug, num_jobs, args.baseline, args.skip_nochange_tac, context_filter, args) worker.start() workers.append(worker) for idx in range(args.threads): finished_id = finished_queue.get() workers[finished_id].join() print("Thread {} finished ({} of {}).".format(finished_id, idx + 1, args.threads)) write_summary(args.output, num_jobs, cur_commit, args.message, args.baseline, cur_date, gresult)
def reinforce_multithreaded(args: argparse.Namespace) -> None: def resume(resume_file: Path2, weights: Path2, q_estimator: QEstimator) -> \ Tuple[List[LabeledTransition], List[Job], List[Tuple[str, ReinforceGraph]]]: eprint("Looks like there was a session in progress for these weights! " "Resuming") q_estimator_name, *saved = \ torch.load(str(weights)) q_estimator.load_saved_state(*saved) replay_memory = [] with resume_file.open('r') as f: num_samples = sum(1 for _ in f) if num_samples > args.buffer_max_size: samples_to_use = random.sample(range(num_samples), args.buffer_max_size) else: samples_to_use = None with resume_file.open('r') as f: for (idx, line) in enumerate(f, start=1): if num_samples > args.buffer_max_size and \ idx not in samples_to_use: continue try: replay_memory.append(LabeledTransition.from_dict( json.loads(line))) except json.decoder.JSONDecodeError: eprint(f"Problem loading line {idx}: {line}") raise already_done = [] graphs_done = [] with weights.with_suffix('.done').open('r') as f: for line in f: next_done = json.loads(line) already_done.append((Path2(next_done[0]), next_done[1], next_done[2])) graphpath = (args.graphs_dir / next_done[1])\ .with_suffix(".png") graph = ReinforceGraph.load(graphpath + ".json") graphs_done.append((graphpath, graph)) return replay_memory, already_done, graphs_done # Load the predictor predictor = cast(features_polyarg_predictor. FeaturesPolyargPredictor, predict_tactic.loadPredictorByFile( args.predictor_weights)) q_estimator: QEstimator # Create an initial Q Estimator if args.estimator == "polyarg": q_estimator = PolyargQEstimator(args.learning_rate, args.batch_step, args.gamma, predictor) else: q_estimator = FeaturesQEstimator(args.learning_rate, args.batch_step, args.gamma) # This sets up a handler so that if the user hits Ctrl-C, we save # the weights as we have them and exit. signal.signal( signal.SIGINT, lambda signal, frame: progn(q_estimator.save_weights( args.out_weights, args), # type: ignore exit())) resume_file = args.out_weights.with_suffix('.tmp') if resume_file.exists(): replay_memory, already_done, graphs_done = resume(resume_file, args.out_weights, q_estimator) else: graphs_done = [] # Load the scraped (demonstrated) samples and the proof # environment commands. Assigns them an estimated "original # predictor certainty" value for use as a feature. with print_time("Loading initial samples from labeled data"): replay_memory = assign_rewards( args, predictor, dataloader.tactic_transitions_from_file( predictor.dataloader_args, args.scrape_file, args.buffer_min_size)) # Load in any starting weights if args.start_from: q_estimator_name, *saved = \ torch.load(args.start_from) q_estimator.load_saved_state(*saved) elif args.pretrain: # Pre-train the q scores to zero with print_time("Pretraining"): pre_train(args, predictor, q_estimator, dataloader.tactic_transitions_from_file( predictor.dataloader_args, args.scrape_file, args.buffer_min_size * 3)) already_done = [] with args.out_weights.with_suffix('.tmp').open('w') as f: for sample in replay_memory: f.write(json.dumps(sample.to_dict())) f.write("\n") with args.out_weights.with_suffix('.done').open('w'): pass q_estimator.save_weights(args.out_weights, args) if args.num_episodes == 0: args.out_weights.with_suffix('.tmp').unlink() args.out_weights.with_suffix('.done').unlink() return ctxt = tmp.get_context('spawn') jobs: Queue[Tuple[Job, Optional[Demonstration]]] = ctxt.Queue() done: Queue[Tuple[Job, Tuple[str, ReinforceGraph]]] = ctxt.Queue() samples: Queue[LabeledTransition] = ctxt.Queue() for sample in replay_memory: samples.put(sample) with tmp.Pool() as pool: jobs_in_files = list(tqdm(pool.imap( functools.partial(get_proofs, args), list(enumerate(args.environment_files))), total=len(args.environment_files), leave=False)) unfiltered_jobs = [job for job_list in jobs_in_files for job in job_list if job not in already_done] if args.proofs_file: with open(args.proofs_file, 'r') as f: proof_names = [line.strip() for line in f] all_jobs = [ job for job in unfiltered_jobs if serapi_instance.lemma_name_from_statement(job[2]) in proof_names] elif args.proof: all_jobs = [ job for job in unfiltered_jobs if serapi_instance.lemma_name_from_statement(job[2]) == args.proof] \ * args.num_threads else: all_jobs = unfiltered_jobs all_jobs_and_dems: List[Tuple[Job, Optional[Demonstration]]] if args.demonstrate_from: all_jobs_and_dems = [(job, extract_solution(args, args.demonstrate_from, job)) for job in all_jobs] else: all_jobs_and_dems = [(job, None) for job in all_jobs] for job in all_jobs_and_dems: jobs.put(job) with Manager() as manager: manager = cast(multiprocessing.managers.SyncManager, manager) ns = manager.Namespace() ns.predictor = predictor ns.estimator = q_estimator lock = manager.Lock() training_worker = ctxt.Process( target=reinforce_training_worker, args=(args, len(replay_memory), lock, ns, samples)) workers = [ctxt.Process( target=reinforce_worker, args=(widx, args, lock, ns, samples, jobs, done)) for widx in range(min(args.num_threads, len(all_jobs)))] training_worker.start() for worker in workers: worker.start() with tqdm(total=len(all_jobs) + len(already_done), dynamic_ncols=True) as bar: bar.update(len(already_done)) bar.refresh() for _ in range(len(all_jobs)): done_job, graph_job = done.get() if graph_job: graphs_done.append(graph_job) bar.update() with args.out_weights.with_suffix(".done").open('a') as f: f.write(json.dumps((str(done_job[0]), done_job[1], done_job[2]))) for worker in workers: worker.kill() training_worker.kill() for graphpath, graph in graphs_done: assignApproximateQScores(graph, args.max_term_length, predictor, q_estimator) graph.draw(graphpath)
def supervised_q(args: argparse.Namespace) -> None: replay_memory = [] with open(args.tmp_file, 'r') as f: for idx, line in enumerate(tqdm(f, desc="Loading data")): replay_memory.append(LabeledTransition.from_dict(json.loads(line))) if args.max_tuples is not None: replay_memory = replay_memory[-args.max_tuples:] # Load the predictor predictor = cast( features_polyarg_predictor.FeaturesPolyargPredictor, predict_tactic.loadPredictorByFile(args.predictor_weights)) q_estimator: QEstimator # Create an initial Q Estimator if args.estimator == "polyarg": q_estimator = PolyargQEstimator(args.learning_rate, args.epoch_step, args.gamma, predictor) else: q_estimator = FeaturesQEstimator(args.learning_rate, args.epoch_step, args.gamma) if args.start_from: q_estimator_name, *saved = \ torch.load(args.start_from) if args.estimator == "polyarg": assert q_estimator_name == "polyarg evaluator", \ q_estimator_name else: assert q_estimator_name == "features evaluator", \ q_estimator_name q_estimator.load_saved_state(*saved) training_start = time.time() training_samples = assign_scores(args, q_estimator, predictor, replay_memory, progress=True) input_tensors = q_estimator.get_input_tensors(training_samples) rescore_lr = args.learning_rate for epoch in range(1, args.num_epochs + 1): scores = torch.FloatTensor( [score for _, _, _, score in training_samples]) batches: Sequence[Sequence[torch.Tensor]] = data.DataLoader( data.TensorDataset(*(input_tensors + [scores])), batch_size=args.batch_size, num_workers=0, shuffle=True, pin_memory=True, drop_last=True) epoch_loss = 0. eprint("Epoch {}: Learning rate {:.12f}".format( epoch, q_estimator.optimizer.param_groups[0]['lr']), guard=args.show_loss) for idx, batch in enumerate(batches, start=1): q_estimator.optimizer.zero_grad() word_features_batch, vec_features_batch, \ expected_outputs_batch = batch outputs = q_estimator.model(word_features_batch, vec_features_batch) loss = q_estimator.criterion(outputs, maybe_cuda(expected_outputs_batch)) loss.backward() q_estimator.optimizer.step() q_estimator.total_batches += 1 epoch_loss += loss.item() if idx % args.print_every == 0: items_processed = idx * args.batch_size + \ (epoch - 1) * len(replay_memory) progress = items_processed / (len(replay_memory) * args.num_epochs) eprint("{} ({:7} {:5.2f}%) {:.4f}".format( timeSince(training_start, progress), items_processed, progress * 100, epoch_loss * (len(batches) / idx)), guard=args.show_loss) q_estimator.adjuster.step() q_estimator.save_weights(args.out_weights, args) if epoch % args.score_every == 0 and epoch < args.num_epochs: training_samples = assign_scores(args, q_estimator, predictor, replay_memory, progress=True) rescore_lr *= args.rescore_gamma q_estimator.optimizer.param_groups[0]['lr'] = rescore_lr pass pass
def reinforce(args: argparse.Namespace) -> None: # Load the scraped (demonstrated) samples, the proof environment # commands, and the predictor replay_memory = assign_rewards( dataloader.tactic_transitions_from_file(args.scrape_file, args.buffer_size)) predictor = predict_tactic.loadPredictorByFile(args.predictor_weights) q_estimator = FeaturesQEstimator(args.learning_rate, args.batch_step, args.gamma) signal.signal( signal.SIGINT, lambda signal, frame: progn(q_estimator.save_weights( args.out_weights, args), # type: ignore exit())) if args.start_from: q_estimator_name, *saved = \ torch.load(args.start_from) q_estimator.load_saved_state(*saved) elif args.pretrain: pre_train(args, q_estimator, dataloader.tactic_transitions_from_file( args.scrape_file, args.buffer_size * 10)) epsilon = 0.3 gamma = 0.9 if args.proof is not None: assert len(args.environment_files) == 1, \ "Can't use multiple env files with --proof!" env_commands = serapi_instance.load_commands_preserve( args, 0, args.prelude / args.environment_files[0]) num_proofs = len([cmd for cmd in env_commands if cmd.strip() == "Qed." or cmd.strip() == "Defined."]) with serapi_instance.SerapiContext( ["sertop", "--implicit"], serapi_instance.get_module_from_filename( args.environment_files[0]), str(args.prelude)) as coq: coq.quiet = True coq.verbose = args.verbose rest_commands, run_commands = coq.run_into_next_proof(env_commands) lemma_statement = run_commands[-1] while coq.cur_lemma_name != args.proof: if not rest_commands: eprint("Couldn't find lemma {args.proof}! Exiting...") return rest_commands, _ = coq.finish_proof(rest_commands) rest_commands, run_commands = coq.run_into_next_proof( rest_commands) lemma_statement = run_commands[-1] reinforce_lemma(args, predictor, q_estimator, coq, lemma_statement, epsilon, gamma, replay_memory) q_estimator.save_weights(args.out_weights, args) else: for env_file in args.environment_files: env_commands = serapi_instance.load_commands_preserve( args, 0, args.prelude / env_file) num_proofs = len([cmd for cmd in env_commands if cmd.strip() == "Qed." or cmd.strip() == "Defined."]) rest_commands = env_commands all_run_commands: List[str] = [] with tqdm(total=num_proofs, disable=(not args.progress), leave=True, desc=env_file.stem) as pbar: while rest_commands: with serapi_instance.SerapiContext( ["sertop", "--implicit"], serapi_instance.get_module_from_filename( env_file), str(args.prelude), log_outgoing_messages=args.log_outgoing_messages) \ as coq: coq.quiet = True coq.verbose = args.verbose for command in all_run_commands: coq.run_stmt(command) while rest_commands: rest_commands, run_commands = \ coq.run_into_next_proof(rest_commands) if not rest_commands: break all_run_commands += run_commands[:-1] lemma_statement = run_commands[-1] # Check if the definition is # proof-relevant. If it is, then finishing # subgoals doesn't necessarily mean you've # solved the problem, so don't try to # train on it. proof_relevant = False for cmd in rest_commands: if serapi_instance.ending_proof(cmd): if cmd.strip() == "Defined.": proof_relevant = True break proof_relevant = proof_relevant or \ bool(re.match(r"\s*Derive", lemma_statement)) for sample in replay_memory: sample.graph_node = None if not proof_relevant: try: reinforce_lemma(args, predictor, q_estimator, coq, lemma_statement, epsilon, gamma, replay_memory) except serapi_instance.CoqAnomaly: if args.log_anomalies: with args.log_anomalies.open('a') as f: traceback.print_exc(file=f) if args.hardfail: eprint( "Hit an anomaly!" "Quitting due to --hardfail") raise eprint( "Hit an anomaly! " "Restarting coq instance") rest_commands.insert(0, lemma_statement) break pbar.update(1) rest_commands, run_commands = \ coq.finish_proof(rest_commands) all_run_commands.append(lemma_statement) all_run_commands += run_commands q_estimator.save_weights(args.out_weights, args)