示例#1
0
def PopulateWithTestSet(
    db: graph_tuple_database.Database,
    graph_count: int,
    node_x_dimensionality: int = 2,
    node_y_dimensionality: int = 0,
    graph_x_dimensionality: int = 0,
    graph_y_dimensionality: int = 0,
    with_data_flow: bool = False,
    split_count: int = 0,
):
    """Populate a database with "real" programs."""
    rows = []
    graph_tuples = itertools.islice(
        itertools.cycle(
            random_graph_tuple_generator.EnumerateTestSet(n=graph_count)),
        graph_count,
    )
    for i, graph_tuple in enumerate(graph_tuples):
        # Set the graph labels.
        node_x = (np.random.randint(low=0,
                                    high=2,
                                    size=(graph_tuple.node_count,
                                          node_x_dimensionality))
                  if node_x_dimensionality else None)
        node_y = (np.random.rand(graph_tuple.node_count, node_y_dimensionality)
                  if node_y_dimensionality else None)
        graph_x = (np.random.randint(
            low=0, high=51, size=graph_x_dimensionality)
                   if graph_x_dimensionality else None)
        graph_y = (np.random.rand(graph_tuple.graph_count,
                                  graph_y_dimensionality)
                   if graph_y_dimensionality else None)
        graph_tuple = graph_tuple.SetFeaturesAndLabels(node_x=node_x,
                                                       node_y=node_y,
                                                       graph_x=graph_x,
                                                       graph_y=graph_y,
                                                       copy=False)

        mapped = graph_tuple_database.GraphTuple.CreateFromGraphTuple(
            graph_tuple,
            ir_id=i + 1,
            split=random.randint(0, split_count) if split_count else None,
        )

        if with_data_flow:
            mapped.data_flow_steps = random.randint(1, 50)
            mapped.data_flow_root_node = random.randint(
                0, mapped.node_count - 1)
            mapped.data_flow_positive_node_count = random.randint(
                1, mapped.node_count - 1)

        rows.append(mapped)

    with db.Session(commit=True) as session:
        session.add_all(rows)

    return DatabaseAndRows(db, rows)
def test_EnumerateTestSet():
  """Test the "real" protos."""
  protos = list(random_graph_tuple_generator.EnumerateTestSet())
  assert len(protos) == 100