Ejemplo n.º 1
0
def extract():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument(
        'config_path',
        type=str,
        help="path to the config file",
    )
    parser.add_argument(
        '--dataset_size',
        type=str,
        help="config override",
    )

    args = parser.parse_args()

    config = Config.from_file(args.config_path)

    if args.dataset_size is not None:
        config.override(
            'prooftrace_dataset_size',
            args.dataset_size,
        )

    sys.setrecursionlimit(4096)

    kernel = ProofTraceKernel(
        os.path.expanduser(config.get('prooftrace_dataset_dir')),
        config.get('prooftrace_dataset_size'),
    )
    tokenizer = ProofTraceTokenizer()

    Log.out("Starting cross steps detection")

    traces = [ProofTrace(kernel, k) for k in kernel._names.keys()]

    Log.out("Prooftraces computed", {
        "traces_count": len(traces),
    })

    cross_steps = {}
    for tr in traces:
        for th in tr._steps.keys():
            if th not in cross_steps:
                cross_steps[th] = []
            if tr._index not in cross_steps[th]:
                cross_steps[th].append(tr._index)

    cross_step_count = 0
    for th in cross_steps:
        if len(cross_steps[th]) > 1:
            cross_step_count += 1
            kernel.add_shared(th, cross_steps[th])

    Log.out("Cross steps detection", {
        "cross_step_count": cross_step_count,
    })

    Log.out("Starting shared premises detection")

    traces = [ProofTrace(kernel, k) for k in kernel._names.keys()]

    Log.out("Prooftraces computed", {
        "traces_count": len(traces),
    })

    shared_premise_count = 0
    for tr in traces:
        for th in tr._premises.keys():
            if kernel.name_shared_premise(th):
                shared_premise_count += 1

    Log.out("Shared premises detection", {
        "shared_premise_count": shared_premise_count,
    })

    Log.out("Starting min_cut operations")

    kernel._shared = {}
    traces = [ProofTrace(kernel, k) for k in kernel._names.keys()]
    traces = [tr for tr in traces if len(tr._steps) > 0]

    excess = [
        tr for tr in traces
        if tr.len() > config.get('prooftrace_max_demo_length') * 4 / 5
    ]
    Log.out("Min-cut initialization", {
        'excess': len(excess),
    })

    while len(excess) > 0:
        orig = []
        cut = []

        for tr in excess:
            orig.append(tr._index)
            cut += tr.min_cut(
                config.get('prooftrace_max_demo_length') * 1 / 8,
                config.get('prooftrace_max_demo_length') * 1 / 2,
            )

        for idx in cut:
            kernel.name_cut_premise(idx)

        refresh = orig + cut
        traces = [ProofTrace(kernel, k) for k in refresh]
        excess = [
            tr for tr in traces
            if tr.len() > config.get('prooftrace_max_demo_length') * 4 / 5
        ]

        Log.out("Min-cut processing loop", {
            'excess': len(excess),
            'orig': len(orig),
            'cut': len(cut),
        })

    Log.out("Stitching small prooftraces")

    traces = [ProofTrace(kernel, k) for k in kernel._names.keys()]
    traces = [tr for tr in traces if len(tr._steps) > 0]

    for tr in traces:
        if tr.len() < 32:
            # Log.out("Remove small prooftrace", {
            #     'name': tr.name(),
            #     'index': tr._index,
            # })
            kernel.remove_premise(tr._index)

    Log.out("Starting final prooftraces generation")

    traces = [ProofTrace(kernel, k) for k in kernel._names.keys()]
    traces = [tr for tr in traces if len(tr._steps) > 0]
    traces = sorted(traces, key=lambda tr: tr._index)

    # Finally we localize the resulting traces.
    for tr in traces:
        tr.localize()

    Log.out("Prooftraces computed, filtered, localized and sorted", {
        "traces_count": len(traces),
    })

    for tr in traces:
        tr.tokenize(tokenizer)

    Log.out(
        "Pre-tokenized prooftraces", {
            "term_token_count": len(tokenizer._term_tokens),
            "type_token_count": len(tokenizer._type_tokens),
        })

    with gzip.open(
            os.path.join(
                os.path.expanduser(config.get('prooftrace_dataset_dir')),
                config.get('prooftrace_dataset_size'),
                'traces.tokenizer',
            ), 'wb') as f:
        pickle.dump(tokenizer, f, protocol=pickle.HIGHEST_PROTOCOL)

    Log.out(
        "Dumped tokenizer", {
            "term_token_count": len(tokenizer._term_tokens),
            "type_token_count": len(tokenizer._type_tokens),
        })

    Log.histogram(
        "ProofTraces Premises", [len(tr._premises) for tr in traces],
        buckets=[64, 128, 256, 512, 1024, 2048, 4096],
        labels=["0064", "0128", "0256", "0512", "1024", "2048", "4096"])
    Log.histogram(
        "ProofTraces Substs", [len(tr._substs) for tr in traces],
        buckets=[64, 128, 256, 512, 1024, 2048, 4096],
        labels=["0064", "0128", "0256", "0512", "1024", "2048", "4096"])
    Log.histogram(
        "ProofTraces SubstTypes", [len(tr._subst_types) for tr in traces],
        buckets=[64, 128, 256, 512, 1024, 2048, 4096],
        labels=["0064", "0128", "0256", "0512", "1024", "2048", "4096"])
    Log.histogram(
        "ProofTraces Terms", [len(tr._terms) for tr in traces],
        buckets=[64, 128, 256, 512, 1024, 2048, 4096],
        labels=["0064", "0128", "0256", "0512", "1024", "2048", "4096"])
    Log.histogram(
        "ProofTraces Steps", [len(tr._steps) for tr in traces],
        buckets=[64, 128, 256, 512, 1024, 2048, 4096],
        labels=["0064", "0128", "0256", "0512", "1024", "2048", "4096"])
    Log.histogram(
        "ProofTraces Length", [tr.len() for tr in traces],
        buckets=[64, 128, 256, 512, 1024, 2048, 4096],
        labels=["0064", "0128", "0256", "0512", "1024", "2048", "4096"])
    Log.out("Starting action generation")

    traces_path_train = os.path.join(
        os.path.expanduser(config.get('prooftrace_dataset_dir')),
        config.get('prooftrace_dataset_size'),
        "train_traces",
    )
    traces_path_test = os.path.join(
        os.path.expanduser(config.get('prooftrace_dataset_dir')),
        config.get('prooftrace_dataset_size'),
        "test_traces",
    )

    if os.path.isdir(traces_path_train):
        shutil.rmtree(traces_path_train)
    os.mkdir(traces_path_train)
    if os.path.isdir(traces_path_test):
        shutil.rmtree(traces_path_test)
    os.mkdir(traces_path_test)

    executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)

    map_args = []
    for i, tr in enumerate(traces):
        map_args.append([config, tokenizer, tr, i, len(traces)])

    trace_lengths = [
        l for l in executor.map(dump_trace, map_args, chunksize=8)
    ]

    Log.histogram(
        "ProofTraces Length",
        trace_lengths,
        buckets=[64, 128, 256, 512, 1024, 2048, 4096],
        labels=["0064", "0128", "0256", "0512", "1024", "2048", "4096"])

    Log.out(
        "Dumped all traces", {
            "traces_path_train": traces_path_train,
            "traces_path_test": traces_path_test,
            "trace_count": len(traces),
        })