コード例 #1
0
from deeplearning.ml4pl.testing import random_networkx_generator
from deeplearning.ml4pl.testing import testing_databases
from labm8.py import decorators
from labm8.py import test

FLAGS = test.FLAGS

###############################################################################
# Fixtures.
###############################################################################


@test.Fixture(
    scope="session",
    params=testing_databases.GetDatabaseUrls(),
    namer=testing_databases.DatabaseUrlNamer("graph_db"),
)
def empty_db(request) -> graph_tuple_database.Database:
    """A test fixture which yields an empty graph proto database."""
    yield from testing_databases.YieldDatabase(graph_tuple_database.Database,
                                               request.param)


@test.Fixture(
    scope="function",
    params=testing_databases.GetDatabaseUrls(),
    namer=testing_databases.DatabaseUrlNamer("graph_db"),
)
def db_session(request) -> graph_tuple_database.Database.SessionType:
    """A test fixture which yields an empty graph proto database session."""
    with testing_databases.DatabaseContext(graph_tuple_database.Database,
コード例 #2
0
ファイル: split_test.py プロジェクト: monperrus/ProGraML
from labm8.py import test

FLAGS = test.FLAGS


def CreateRandomString(min_length: int = 1, max_length: int = 1024) -> str:
    """Generate a random string."""
    return "".join(
        random.choice(string.ascii_lowercase)
        for _ in range(random.randint(min_length, max_length)))


@test.Fixture(
    scope="session",
    params=testing_databases.GetDatabaseUrls(),
    namer=testing_databases.DatabaseUrlNamer("ir_db"),
)
def ir_db(request) -> ir_database.Database:
    """A test fixture which yields an IR database."""
    with testing_databases.DatabaseContext(ir_database.Database,
                                           request.param) as db:
        rows = []
        for i in range(250):
            ir = ir_database.IntermediateRepresentation.CreateFromText(
                source=random.choice([
                    "pact17_opencl_devmap",
                    "poj-104:train",
                    "poj-104:val",
                    "poj-104:test",
                ]),
                relpath=str(i),
コード例 #3
0
  assert graph_tuple.graph_x_dimensionality == graph_x_dimensionality
  assert graph_tuple.graph_y_dimensionality == graph_y_dimensionality
  if with_data_flow:
    assert graph_tuple.data_flow_steps >= 1
    assert graph_tuple.data_flow_root_node >= 0
    assert graph_tuple.data_flow_positive_node_count >= 1
  else:
    assert graph_tuple.data_flow_steps is None
    assert graph_tuple.data_flow_root_node is None
    assert graph_tuple.data_flow_positive_node_count is None


@test.Fixture(
  scope="function",
  params=testing_databases.GetDatabaseUrls(),
  namer=testing_databases.DatabaseUrlNamer("db"),
)
def db(request) -> graph_tuple_database.Database:
  """A test fixture which yields an empty graph proto database."""
  yield from testing_databases.YieldDatabase(
    graph_tuple_database.Database, request.param
  )


@test.Fixture(scope="function", params=(1, 100))
def graph_count(request) -> int:
  """Test fixture to enumerate graph counts."""
  return request.param


@test.Fixture(scope="function", params=(1, 3))
コード例 #4
0
ファイル: ggnn_test.py プロジェクト: tehranixyz/ProGraML
FLAGS = test.FLAGS

# For testing models, always use --strict_graph_segmentation.
FLAGS.strict_graph_segmentation = True


###############################################################################
# Fixtures.
###############################################################################


@test.Fixture(
  scope="session",
  params=testing_databases.GetDatabaseUrls(),
  namer=testing_databases.DatabaseUrlNamer("log_db"),
)
def log_db(request) -> log_database.Database:
  """A test fixture which yields an empty log database."""
  yield from testing_databases.YieldDatabase(
    log_database.Database, request.param
  )


@test.Fixture(scope="session")
def logger(log_db: log_database.Database) -> logging.Logger:
  """A test fixture which yields a logger."""
  with logging.Logger(log_db, max_buffer_length=128) as logger:
    yield logger