예제 #1
0
def main():
    parser = \
      argparse.ArgumentParser(
          description="A module for drawing and re-drawing reinforcement "
          "learning graphs")

    parser.add_argument("predictor_weights")
    parser.add_argument("estimator_weights")
    parser.add_argument("graph_json")
    parser.add_argument("--max-term-length", default=512, type=int)

    args = parser.parse_args()

    predictor = predict_tactic.loadPredictorByFile(args.predictor_weights)
    q_estimator_name, *saved = torch.load(str(args.estimator_weights))
    if q_estimator_name == "features evaluator":
        q_estimator = features_q_estimator.FeaturesQEstimator(0, 0, 0)
    elif q_estimator_name == "polyarg evaluator":
        q_estimator = polyarg_q_estimator.PolyargQEstimator(
            0, 0, 0,
            cast(features_polyarg_predictor.FeaturesPolyargPredictor,
                 predictor))

    graph = ReinforceGraph.load(args.graph_json)
    assignApproximateQScores(graph, args.max_term_length, predictor,
                             q_estimator)
    graph.draw(Path2(args.graph_json).stem)
예제 #2
0
def main(arg_list : List[str]) -> None:
    global predictor
    parser = argparse.ArgumentParser(description=
                                     "Produce an html report from the scrape file.")
    parser.add_argument("-j", "--threads", default=16, type=int)
    parser.add_argument("--prelude", default=".", type=Path2)
    parser.add_argument("--verbose", "-v", help="verbose output",
                        action='store_const', const=True, default=False)
    parser.add_argument("--progress", "-P", help="show progress of files",
                        action='store_const', const=True, default=False)
    parser.add_argument("--debug", default=False, const=True, action='store_const')
    parser.add_argument("--output", "-o", help="output data folder name",
                        default="static-report", type=Path2)
    parser.add_argument("--message", "-m", default=None)
    parser.add_argument('--context-filter', dest="context_filter", type=str,
                        default=None)
    parser.add_argument('--chunk-size', dest="chunk_size", type=int, default=4096)
    parser.add_argument('--weightsfile', default=None)
    parser.add_argument('--predictor', choices=list(static_predictors.keys()),
                        default=None)
    parser.add_argument("--num-predictions", dest="num_predictions", type=int, default=3)
    parser.add_argument('--skip-nochange-tac', default=False, const=True, action='store_const',
                        dest='skip_nochange_tac')
    parser.add_argument('filenames', nargs="+", help="proof file name (*.v)", type=Path2)
    args = parser.parse_args(arg_list)

    cur_commit = subprocess.check_output(["git show --oneline | head -n 1"],
                                         shell=True).decode('utf-8').strip()
    cur_date = datetime.datetime.now()

    if args.weightsfile:
        predictor = loadPredictorByFile(args.weightsfile)
    elif args.predictor:
        predictor = loadPredictorByName(args.predictor)
    else:
        print("You must specify either --weightsfile or --predictor!")
        parser.print_help()
        return

    if not args.output.exists():
        args.output.makedirs()

    context_filter = args.context_filter or dict(predictor.getOptions())["context_filter"]

    with multiprocessing.pool.ThreadPool(args.threads) as pool:
        file_results = \
            list((stats for stats in
                  pool.imap_unordered(functools.partial(report_file, args,
                                                        predictor.training_args,
                                                        context_filter),
                                      args.filenames)
                  if stats))

    write_summary(args, predictor.getOptions() +
                  [("report type", "static"), ("predictor", args.predictor)],
                  cur_commit, cur_date, file_results)
예제 #3
0
def get_predictor(parser: argparse.ArgumentParser,
                  args: argparse.Namespace) -> TacticPredictor:
    predictor: TacticPredictor
    if args.weightsfile:
        predictor = loadPredictorByFile(args.weightsfile)
    elif args.predictor:
        predictor = loadPredictorByName(args.predictor)
    else:
        print("You must specify either --weightsfile or --predictor!")
        parser.print_help()
        sys.exit(1)
    return predictor
예제 #4
0
def q_report(args: argparse.Namespace) -> None:
    num_originally_correct = 0
    num_correct = 0
    num_top3 = 0
    num_total = 0
    num_possible = 0

    predictor = predict_tactic.loadPredictorByFile(args.predictor_weights)
    q_estimator_name, *saved  = \
        torch.load(args.estimator_weights)
    q_estimator = FeaturesQEstimator(0, 0, 0)
    q_estimator.load_saved_state(*saved)

    for filename in args.test_files:
        points = dataloader.scraped_tactics_from_file(
            str(filename) + ".scrape", None)
        for point in points:
            context = TacticContext(point.relevant_lemmas, point.prev_tactics,
                                    point.prev_hyps, point.prev_goal)
            predictions = [
                p.prediction for p in predictor.predictKTactics(
                    context, args.num_predictions)
            ]
            q_choices = zip(
                q_estimator([(context, prediction)
                             for prediction in predictions]), predictions)
            ordered_actions = [
                p[1]
                for p in sorted(q_choices, key=lambda q: q[0], reverse=True)
            ]

            num_total += 1
            if point.tactic.strip() in predictions:
                num_possible += 1

            if ordered_actions[0] == point.tactic.strip():
                num_correct += 1

            if point.tactic.strip() in ordered_actions[:3]:
                num_top3 += 1

            if predictions[0] == point.tactic.strip():
                num_originally_correct += 1
            pass

    print(f"num_correct: {num_correct}")
    print(f"num_originally_correct: {num_originally_correct}")
    print(f"num_top3: {num_top3}")
    print(f"num_total: {num_total}")
    print(f"num_possible: {num_possible}")
예제 #5
0
def main(arg_list: List[str]) -> None:
    global jobs
    global num_jobs
    global net
    global gresult
    parser = argparse.ArgumentParser(
        description="try to match the file by predicting a tactic")
    parser.add_argument('-j', '--threads', default=16, type=int)
    parser.add_argument('--prelude', default=".", type=Path2)
    parser.add_argument('--debug',
                        default=False,
                        const=True,
                        action='store_const')
    parser.add_argument("--verbose",
                        "-v",
                        help="verbose output",
                        action='store_const',
                        const=True,
                        default=False)
    parser.add_argument("--progress",
                        "-P",
                        help="show progress of files",
                        action='store_const',
                        const=True,
                        default=False)
    parser.add_argument('-o',
                        '--output',
                        help="output data folder name",
                        default="report",
                        type=Path2)
    parser.add_argument('-m', '--message', default=None)
    parser.add_argument(
        '--baseline',
        help="run in baseline mode, predicting {} every time".format(
            baseline_tactic),
        default=False,
        const=True,
        action='store_const')
    parser.add_argument('--context-filter',
                        dest="context_filter",
                        type=str,
                        default=None)
    parser.add_argument('--weightsfile', default=None)
    parser.add_argument('--predictor',
                        choices=list(static_predictors.keys()),
                        default=None)
    parser.add_argument('--skip-nochange-tac',
                        default=False,
                        const=True,
                        action='store_const',
                        dest='skip_nochange_tac')
    parser.add_argument('filenames',
                        nargs="+",
                        help="proof file name (*.v)",
                        type=Path2)
    args = parser.parse_args(arg_list)

    coqargs = ["sertop", "--implicit"]
    includes = subprocess.Popen(
        ['make', '-C', str(args.prelude), 'print-includes'],
        stdout=subprocess.PIPE).communicate()[0].decode('utf-8')

    # Get some metadata
    cur_commit = subprocess.check_output(["git show --oneline | head -n 1"],
                                         shell=True).decode('utf-8').strip()
    cur_date = datetime.datetime.now()

    if not args.output.exists():
        args.output.makedirs()

    jobs = queue.Queue()
    workers = []
    num_jobs = len(args.filenames)
    for infname in args.filenames:
        jobs.put(infname)

    args.threads = min(args.threads, len(args.filenames))

    if args.weightsfile:
        net = loadPredictorByFile(args.weightsfile)
    elif args.predictor:
        net = loadPredictorByName(args.predictor)
    else:
        print("You must specify either --weightsfile or --predictor!")
        parser.print_help()
        return
    gresult = GlobalResult(net.getOptions())
    context_filter = args.context_filter or dict(
        net.getOptions())["context_filter"]

    for idx in range(args.threads):
        worker = Worker(idx, coqargs, includes, args.output, args.prelude,
                        args.debug, num_jobs, args.baseline,
                        args.skip_nochange_tac, context_filter, args)
        worker.start()
        workers.append(worker)

    for idx in range(args.threads):
        finished_id = finished_queue.get()
        workers[finished_id].join()
        print("Thread {} finished ({} of {}).".format(finished_id, idx + 1,
                                                      args.threads))

    write_summary(args.output, num_jobs, cur_commit, args.message,
                  args.baseline, cur_date, gresult)
예제 #6
0
def reinforce_multithreaded(args: argparse.Namespace) -> None:

    def resume(resume_file: Path2, weights: Path2,
               q_estimator: QEstimator) -> \
      Tuple[List[LabeledTransition],
            List[Job],
            List[Tuple[str, ReinforceGraph]]]:
        eprint("Looks like there was a session in progress for these weights! "
               "Resuming")
        q_estimator_name, *saved = \
            torch.load(str(weights))
        q_estimator.load_saved_state(*saved)
        replay_memory = []
        with resume_file.open('r') as f:
            num_samples = sum(1 for _ in f)
        if num_samples > args.buffer_max_size:
            samples_to_use = random.sample(range(num_samples),
                                           args.buffer_max_size)
        else:
            samples_to_use = None
        with resume_file.open('r') as f:
            for (idx, line) in enumerate(f, start=1):
                if num_samples > args.buffer_max_size and \
                  idx not in samples_to_use:
                    continue
                try:
                    replay_memory.append(LabeledTransition.from_dict(
                        json.loads(line)))
                except json.decoder.JSONDecodeError:
                    eprint(f"Problem loading line {idx}: {line}")
                    raise
        already_done = []
        graphs_done = []
        with weights.with_suffix('.done').open('r') as f:
            for line in f:
                next_done = json.loads(line)
                already_done.append((Path2(next_done[0]), next_done[1],
                                     next_done[2]))
                graphpath = (args.graphs_dir / next_done[1])\
                    .with_suffix(".png")
                graph = ReinforceGraph.load(graphpath + ".json")
                graphs_done.append((graphpath, graph))
        return replay_memory, already_done, graphs_done

    # Load the predictor
    predictor = cast(features_polyarg_predictor.
                     FeaturesPolyargPredictor,
                     predict_tactic.loadPredictorByFile(
                         args.predictor_weights))

    q_estimator: QEstimator
    # Create an initial Q Estimator
    if args.estimator == "polyarg":
        q_estimator = PolyargQEstimator(args.learning_rate,
                                        args.batch_step,
                                        args.gamma,
                                        predictor)
    else:
        q_estimator = FeaturesQEstimator(args.learning_rate,
                                         args.batch_step,
                                         args.gamma)
    # This sets up a handler so that if the user hits Ctrl-C, we save
    # the weights as we have them and exit.
    signal.signal(
        signal.SIGINT,
        lambda signal, frame:
        progn(q_estimator.save_weights(
            args.out_weights, args),  # type: ignore

              exit()))

    resume_file = args.out_weights.with_suffix('.tmp')
    if resume_file.exists():
        replay_memory, already_done, graphs_done = resume(resume_file,
                                                          args.out_weights,
                                                          q_estimator)
    else:
        graphs_done = []
        # Load the scraped (demonstrated) samples and the proof
        # environment commands. Assigns them an estimated "original
        # predictor certainty" value for use as a feature.
        with print_time("Loading initial samples from labeled data"):
            replay_memory = assign_rewards(
                args, predictor,
                dataloader.tactic_transitions_from_file(
                    predictor.dataloader_args,
                    args.scrape_file, args.buffer_min_size))
        # Load in any starting weights
        if args.start_from:
            q_estimator_name, *saved = \
                torch.load(args.start_from)
            q_estimator.load_saved_state(*saved)
        elif args.pretrain:
            # Pre-train the q scores to zero
            with print_time("Pretraining"):
                pre_train(args, predictor, q_estimator,
                          dataloader.tactic_transitions_from_file(
                              predictor.dataloader_args,
                              args.scrape_file, args.buffer_min_size * 3))
        already_done = []
        with args.out_weights.with_suffix('.tmp').open('w') as f:
            for sample in replay_memory:
                f.write(json.dumps(sample.to_dict()))
                f.write("\n")
        with args.out_weights.with_suffix('.done').open('w'):
            pass

    q_estimator.save_weights(args.out_weights, args)
    if args.num_episodes == 0:
        args.out_weights.with_suffix('.tmp').unlink()
        args.out_weights.with_suffix('.done').unlink()
        return

    ctxt = tmp.get_context('spawn')
    jobs: Queue[Tuple[Job, Optional[Demonstration]]] = ctxt.Queue()
    done: Queue[Tuple[Job, Tuple[str, ReinforceGraph]]] = ctxt.Queue()
    samples: Queue[LabeledTransition] = ctxt.Queue()

    for sample in replay_memory:
        samples.put(sample)

    with tmp.Pool() as pool:
        jobs_in_files = list(tqdm(pool.imap(
            functools.partial(get_proofs, args),
            list(enumerate(args.environment_files))),
                                  total=len(args.environment_files),
                                  leave=False))
    unfiltered_jobs = [job for job_list in jobs_in_files for job in job_list
                       if job not in already_done]
    if args.proofs_file:
        with open(args.proofs_file, 'r') as f:
            proof_names = [line.strip() for line in f]
        all_jobs = [
            job for job in unfiltered_jobs if
            serapi_instance.lemma_name_from_statement(job[2]) in proof_names]
    elif args.proof:
        all_jobs = [
            job for job in unfiltered_jobs if
            serapi_instance.lemma_name_from_statement(job[2]) == args.proof] \
             * args.num_threads
    else:
        all_jobs = unfiltered_jobs

    all_jobs_and_dems: List[Tuple[Job, Optional[Demonstration]]]
    if args.demonstrate_from:
        all_jobs_and_dems = [(job, extract_solution(args,
                                                    args.demonstrate_from,
                                                    job))
                             for job in all_jobs]
    else:
        all_jobs_and_dems = [(job, None) for job in all_jobs]

    for job in all_jobs_and_dems:
        jobs.put(job)

    with Manager() as manager:
        manager = cast(multiprocessing.managers.SyncManager, manager)
        ns = manager.Namespace()
        ns.predictor = predictor
        ns.estimator = q_estimator
        lock = manager.Lock()

        training_worker = ctxt.Process(
            target=reinforce_training_worker,
            args=(args, len(replay_memory), lock, ns, samples))
        workers = [ctxt.Process(
            target=reinforce_worker,
            args=(widx,
                  args,
                  lock,
                  ns,
                  samples,
                  jobs,
                  done))
                   for widx in range(min(args.num_threads, len(all_jobs)))]
        training_worker.start()
        for worker in workers:
            worker.start()

        with tqdm(total=len(all_jobs) + len(already_done),
                  dynamic_ncols=True) as bar:
            bar.update(len(already_done))
            bar.refresh()
            for _ in range(len(all_jobs)):
                done_job, graph_job = done.get()
                if graph_job:
                    graphs_done.append(graph_job)
                bar.update()
                with args.out_weights.with_suffix(".done").open('a') as f:
                    f.write(json.dumps((str(done_job[0]),
                                        done_job[1],
                                        done_job[2])))
        for worker in workers:
            worker.kill()
        training_worker.kill()

        for graphpath, graph in graphs_done:
            assignApproximateQScores(graph, args.max_term_length, predictor,
                                     q_estimator)
            graph.draw(graphpath)
예제 #7
0
def supervised_q(args: argparse.Namespace) -> None:
    replay_memory = []
    with open(args.tmp_file, 'r') as f:
        for idx, line in enumerate(tqdm(f, desc="Loading data")):
            replay_memory.append(LabeledTransition.from_dict(json.loads(line)))
    if args.max_tuples is not None:
        replay_memory = replay_memory[-args.max_tuples:]

    # Load the predictor
    predictor = cast(
        features_polyarg_predictor.FeaturesPolyargPredictor,
        predict_tactic.loadPredictorByFile(args.predictor_weights))

    q_estimator: QEstimator
    # Create an initial Q Estimator
    if args.estimator == "polyarg":
        q_estimator = PolyargQEstimator(args.learning_rate, args.epoch_step,
                                        args.gamma, predictor)
    else:
        q_estimator = FeaturesQEstimator(args.learning_rate, args.epoch_step,
                                         args.gamma)
    if args.start_from:
        q_estimator_name, *saved = \
          torch.load(args.start_from)
        if args.estimator == "polyarg":
            assert q_estimator_name == "polyarg evaluator", \
                q_estimator_name
        else:
            assert q_estimator_name == "features evaluator", \
                q_estimator_name
        q_estimator.load_saved_state(*saved)

    training_start = time.time()
    training_samples = assign_scores(args,
                                     q_estimator,
                                     predictor,
                                     replay_memory,
                                     progress=True)
    input_tensors = q_estimator.get_input_tensors(training_samples)
    rescore_lr = args.learning_rate

    for epoch in range(1, args.num_epochs + 1):
        scores = torch.FloatTensor(
            [score for _, _, _, score in training_samples])
        batches: Sequence[Sequence[torch.Tensor]] = data.DataLoader(
            data.TensorDataset(*(input_tensors + [scores])),
            batch_size=args.batch_size,
            num_workers=0,
            shuffle=True,
            pin_memory=True,
            drop_last=True)

        epoch_loss = 0.
        eprint("Epoch {}: Learning rate {:.12f}".format(
            epoch, q_estimator.optimizer.param_groups[0]['lr']),
               guard=args.show_loss)
        for idx, batch in enumerate(batches, start=1):
            q_estimator.optimizer.zero_grad()
            word_features_batch, vec_features_batch, \
                expected_outputs_batch = batch
            outputs = q_estimator.model(word_features_batch,
                                        vec_features_batch)
            loss = q_estimator.criterion(outputs,
                                         maybe_cuda(expected_outputs_batch))
            loss.backward()
            q_estimator.optimizer.step()
            q_estimator.total_batches += 1
            epoch_loss += loss.item()
            if idx % args.print_every == 0:
                items_processed = idx * args.batch_size + \
                    (epoch - 1) * len(replay_memory)
                progress = items_processed / (len(replay_memory) *
                                              args.num_epochs)
                eprint("{} ({:7} {:5.2f}%) {:.4f}".format(
                    timeSince(training_start, progress), items_processed,
                    progress * 100, epoch_loss * (len(batches) / idx)),
                       guard=args.show_loss)
        q_estimator.adjuster.step()

        q_estimator.save_weights(args.out_weights, args)
        if epoch % args.score_every == 0 and epoch < args.num_epochs:
            training_samples = assign_scores(args,
                                             q_estimator,
                                             predictor,
                                             replay_memory,
                                             progress=True)
            rescore_lr *= args.rescore_gamma
            q_estimator.optimizer.param_groups[0]['lr'] = rescore_lr

        pass

    pass
예제 #8
0
def reinforce(args: argparse.Namespace) -> None:

    # Load the scraped (demonstrated) samples, the proof environment
    # commands, and the predictor
    replay_memory = assign_rewards(
        dataloader.tactic_transitions_from_file(args.scrape_file,
                                                args.buffer_size))

    predictor = predict_tactic.loadPredictorByFile(args.predictor_weights)

    q_estimator = FeaturesQEstimator(args.learning_rate,
                                     args.batch_step,
                                     args.gamma)
    signal.signal(
        signal.SIGINT,
        lambda signal, frame:
        progn(q_estimator.save_weights(
            args.out_weights, args),  # type: ignore
              exit()))
    if args.start_from:
        q_estimator_name, *saved = \
            torch.load(args.start_from)
        q_estimator.load_saved_state(*saved)
    elif args.pretrain:
        pre_train(args, q_estimator,
                  dataloader.tactic_transitions_from_file(
                      args.scrape_file, args.buffer_size * 10))

    epsilon = 0.3
    gamma = 0.9

    if args.proof is not None:
        assert len(args.environment_files) == 1, \
            "Can't use multiple env files with --proof!"
        env_commands = serapi_instance.load_commands_preserve(
            args, 0, args.prelude / args.environment_files[0])
        num_proofs = len([cmd for cmd in env_commands
                          if cmd.strip() == "Qed."
                          or cmd.strip() == "Defined."])

        with serapi_instance.SerapiContext(
                ["sertop", "--implicit"],
                serapi_instance.get_module_from_filename(
                    args.environment_files[0]),
                str(args.prelude)) as coq:
            coq.quiet = True
            coq.verbose = args.verbose
            rest_commands, run_commands = coq.run_into_next_proof(env_commands)
            lemma_statement = run_commands[-1]
            while coq.cur_lemma_name != args.proof:
                if not rest_commands:
                    eprint("Couldn't find lemma {args.proof}! Exiting...")
                    return
                rest_commands, _ = coq.finish_proof(rest_commands)
                rest_commands, run_commands = coq.run_into_next_proof(
                    rest_commands)
                lemma_statement = run_commands[-1]
            reinforce_lemma(args, predictor, q_estimator, coq,
                            lemma_statement,
                            epsilon, gamma, replay_memory)
            q_estimator.save_weights(args.out_weights, args)
    else:
        for env_file in args.environment_files:
            env_commands = serapi_instance.load_commands_preserve(
                args, 0, args.prelude / env_file)
            num_proofs = len([cmd for cmd in env_commands
                              if cmd.strip() == "Qed."
                              or cmd.strip() == "Defined."])
            rest_commands = env_commands
            all_run_commands: List[str] = []
            with tqdm(total=num_proofs, disable=(not args.progress),
                      leave=True, desc=env_file.stem) as pbar:
                while rest_commands:
                    with serapi_instance.SerapiContext(
                            ["sertop", "--implicit"],
                            serapi_instance.get_module_from_filename(
                                env_file),
                            str(args.prelude),
                            log_outgoing_messages=args.log_outgoing_messages) \
                            as coq:
                        coq.quiet = True
                        coq.verbose = args.verbose
                        for command in all_run_commands:
                            coq.run_stmt(command)

                        while rest_commands:
                            rest_commands, run_commands = \
                                coq.run_into_next_proof(rest_commands)
                            if not rest_commands:
                                break
                            all_run_commands += run_commands[:-1]
                            lemma_statement = run_commands[-1]

                            # Check if the definition is
                            # proof-relevant. If it is, then finishing
                            # subgoals doesn't necessarily mean you've
                            # solved the problem, so don't try to
                            # train on it.
                            proof_relevant = False
                            for cmd in rest_commands:
                                if serapi_instance.ending_proof(cmd):
                                    if cmd.strip() == "Defined.":
                                        proof_relevant = True
                                    break
                            proof_relevant = proof_relevant or \
                                bool(re.match(r"\s*Derive", lemma_statement))
                            for sample in replay_memory:
                                sample.graph_node = None
                            if not proof_relevant:
                                try:
                                    reinforce_lemma(args, predictor,
                                                    q_estimator,
                                                    coq,
                                                    lemma_statement,
                                                    epsilon, gamma,
                                                    replay_memory)
                                except serapi_instance.CoqAnomaly:
                                    if args.log_anomalies:
                                        with args.log_anomalies.open('a') as f:
                                            traceback.print_exc(file=f)
                                    if args.hardfail:
                                        eprint(
                                            "Hit an anomaly!"
                                            "Quitting due to --hardfail")
                                        raise
                                    eprint(
                                        "Hit an anomaly! "
                                        "Restarting coq instance")
                                    rest_commands.insert(0, lemma_statement)
                                    break
                            pbar.update(1)
                            rest_commands, run_commands = \
                                coq.finish_proof(rest_commands)
                            all_run_commands.append(lemma_statement)
                            all_run_commands += run_commands

            q_estimator.save_weights(args.out_weights, args)