示例#1
0
def main(argv):
    init_app(argv)
    path = pathlib.Path(pathflag.path())
    db = _mysql.connect(host=FLAGS.host,
                        user=FLAGS.user,
                        passwd=FLAGS.pwd,
                        db=FLAGS.db)

    # First create the output directories. Fail if they already exist.
    (path / "ir").mkdir(parents=True)
    (path / "graphs").mkdir()
    (path / "train").mkdir()
    (path / "val").mkdir()
    (path / "test").mkdir()

    # Export the legacy IR database.
    export = ImportIrDatabase(path, db)
    progress.Run(export)

    # Import the classifyapp dataset.
    ImportClassifyAppDataset(pathlib.Path(FLAGS.classifyapp), path)

    # Add inst2vec encoding features to graphs.
    logging.info("Encoding graphs with inst2vec")
    progress.Run(Inst2vecEncodeGraphs(path))

    logging.info("Creating vocabularies")
    subprocess.check_call([str(CREATE_VOCAB), "--path", str(path)])

    logging.info("Creating data flow analysis labels")
    subprocess.check_call([str(CREATE_LABELS), str(path)])
示例#2
0
def main(argv):
    """Main entry point."""
    init_app(argv)
    path = pathlib.Path(FLAGS.path)

    graphs_path = path / "test"
    labels_path = path / "labels" / FLAGS.analysis

    for graph_path in graphs_path.iterdir():
        stem = graph_path.name[:-len("ProgramGraph.pb")]
        name = f"{stem}ProgramGraphFeaturesList.pb"
        features_path = labels_path / name
        # There is no guarantee that we have generated features for this
        # program graph, so we check for its existence. As a *very* defensive
        # measure, we also check for the existence of the graph file that we
        # enumerated at the start of this function. This check can be removed
        # later, it is only useful during development when you might be
        # modifying the dataset at the same time as having test jobs running.
        if not graph_path.is_file() or not features_path.is_file():
            continue

        graph = pbutil.FromFile(graph_path, program_graph_pb2.ProgramGraph())
        if not len(graph.node) or len(graph.node) > FLAGS.max_graph_node_count:
            continue

        features_list = pbutil.FromFile(
            features_path,
            program_graph_features_pb2.ProgramGraphFeaturesList())

        for j, features in enumerate(features_list.graph):
            step_count_feature = features.features.feature[
                "data_flow_step_count"].int64_list.value
            step_count = step_count_feature[0] if len(
                step_count_feature) else 0
            print(features_path.name, j, step_count)
示例#3
0
def main(argv):
    init_app(argv)
    encoder = inst2vec_encoder.Inst2vecEncoder()

    proto = ParseStdinOrDie(ProgramGraph())
    if FLAGS.ir:
        with open(FLAGS.ir) as f:
            ir = f.read()
    else:
        ir = None
    encoder.Encode(proto, ir)
    WriteStdout(proto)
示例#4
0
def main(argv):
    """Main entry point."""
    init_app(argv)
    dataflow.PatchWarnings()

    features_list_path, features_list_index = FLAGS.input.split(":")
    graph = TestOne(
        features_list_path=Path(features_list_path),
        features_list_index=int(features_list_index),
        checkpoint_path=Path(FLAGS.model),
    )
    print(graph)
示例#5
0
def main(argv):
    """Main entry point."""
    init_app(argv)

    path = pathlib.Path(FLAGS.path)

    vocab = vocabulary.LoadVocabulary(
        path,
        model_name="cdfg" if FLAGS.cdfg else "programl",
        max_items=FLAGS.max_vocab_size,
        target_cumfreq=FLAGS.target_vocab_cumfreq,
    )

    # CDFG doesn't use positional embeddings.
    if FLAGS.cdfg:
        FLAGS.use_position_embeddings = False

    if FLAGS.test_only:
        log_dir = FLAGS.restore_from
    else:
        log_dir = ggnn.TrainDataflowGGNN(
            path=path,
            analysis=FLAGS.analysis,
            vocab=vocab,
            limit_max_data_flow_steps=FLAGS.limit_max_data_flow_steps,
            train_graph_counts=[int(x) for x in FLAGS.train_graph_counts],
            val_graph_count=FLAGS.val_graph_count,
            val_seed=FLAGS.val_seed,
            batch_size=FLAGS.batch_size,
            use_cdfg=FLAGS.cdfg,
            run_id=FLAGS.run_id,
            restore_from=FLAGS.restore_from,
        )

    if FLAGS.test:
        ggnn.TestDataflowGGNN(
            path=path,
            log_dir=log_dir,
            analysis=FLAGS.analysis,
            vocab=vocab,
            limit_max_data_flow_steps=FLAGS.limit_max_data_flow_steps,
            batch_size=FLAGS.batch_size,
            use_cdfg=FLAGS.cdfg,
        )
示例#6
0
def main(argv):
    init_app(argv)

    path = Path(FLAGS.path)
    fmt = FLAGS.fmt

    with progress.Profile("loading logs"):
        df = LogsToDataFrame(path)

    if df is None:
        print("No logs found", file=sys.stderr)
        sys.exit(1)

    if fmt == "csv":
        df.to_csv(sys.stdout, header=True)
    elif fmt == "txt":
        print(tabulate(df, headers="keys", tablefmt="psql", showindex="never"))
    else:
        raise app.UsageError(f"Unknown --fmt: {fmt}")
示例#7
0
def main(argv):
    init_app(argv)
    encoder = inst2vec_encoder.Inst2vecEncoder()

    if FLAGS.dataset:
        encoder.RunOnDataset(Path(FLAGS.dataset))
        return

    if FLAGS.directory:
        encoder.RunOnDirectory(Path(FLAGS.directory))
        return

    proto = ParseStdinOrDie(program_graph_pb2.ProgramGraph())
    if FLAGS.ir:
        with open(FLAGS.ir) as f:
            ir = f.read()
    else:
        ir = None
    encoder.Encode(proto, ir)
    WriteStdout(proto)
示例#8
0
def main(argv):
    """Main entry point."""
    init_app(argv)
    path = pathlib.Path(FLAGS.path)

    with vocabulary.VocabularyZipFile.CreateFromPublishedResults() as inst2vec:
        vocab = inst2vec.dictionary

    if FLAGS.test_only:
        log_dir = FLAGS.restore_from
    else:
        log_dir = TrainDataflowLSTM(
            path=path,
            vocab=vocab,
            val_seed=FLAGS.val_seed,
            restore_from=FLAGS.restore_from,
        )

    if FLAGS.test:
        TestDataflowLSTM(
            path=path,
            vocab=vocab,
            log_dir=log_dir,
        )
示例#9
0
def main(argv):
    init_app(argv)
    proto = ParseStdinOrDie(program_graph_pb2.ProgramGraph())
    pickle.dump(nx_format.ProgramGraphToNetworkX(proto), sys.stdout.buffer)
示例#10
0
def main(argv):
    init_app(argv)
    progress.Run(TestVocab(pathlib.Path(pathflag.path())))