Exemple #1
0
def graph_db(request) -> graph_tuple_database.Database:
    """A test fixture which returns a graph database with random graphs."""
    graph_y_dimensionality, node_y_dimensionality = request.param
    db = graph_tuple_database.Database(testing_databases.GetDatabaseUrls()[0])
    random_graph_tuple_database_generator.PopulateDatabaseWithRandomGraphTuples(
        db,
        graph_count=100,
        graph_y_dimensionality=graph_y_dimensionality,
        node_y_dimensionality=node_y_dimensionality,
    )
    return db
Exemple #2
0
    def graph_db(self) -> graph_tuple_database.Database:
        """Return the graph database for a run. This is reconstructed from the
    --graph_db flag value recorded for the run."""
        if self._graph_db:
            return self._graph_db

        with self.log_db.Session() as session:
            graph_param: log_database.Parameter = session.query(
                log_database.Parameter).filter(
                    log_database.Parameter.type_num ==
                    log_database.ParameterType.FLAG.value,
                    log_database.Parameter.run_id == str(self.run_ids[0]),
                    log_database.Parameter.name == "graph_db",
                ).scalar()
            if not graph_param:
                raise ValueError("Unable to determine graph_db flag")
            graph_db_url = str(graph_param.value)

        return graph_tuple_database.Database(graph_db_url, must_exist=True)
Exemple #3
0
def test_empty_graph_database(logger: logging.Logger, tempdir: pathlib.Path):
    """Test that an empty graph database raises an error."""
    with test.Raises(ValueError):
        MockModel(
            logger,
            graph_tuple_database.Database(f"sqlite:///{tempdir}/empty.db"))
Exemple #4
0
def Main():
  """Main entry point."""
  db_stem = FLAGS.db_stem
  models = FLAGS.model
  tag_suffix = FLAGS.tag_suffix
  datasets = FLAGS.dataset

  # Set model and dataset-invariant flags.
  FLAGS.log_db = flags_parsers.DatabaseFlag(
    log_database.Database,
    f"{db_stem}_dataflow_logs",
    must_exist=False,  # , must_exist=True
  )
  FLAGS.ir_db = flags_parsers.DatabaseFlag(
    ir_database.Database, f"{db_stem}_ir", must_exist=True
  )
  FLAGS.test_on = "improvement_and_last"
  FLAGS.max_train_per_epoch = 5000
  FLAGS.max_val_per_epoch = 1000

  for dataset in datasets:
    graph_db = graph_tuple_database.Database(
      f"{db_stem}_{dataset}", must_exist=True
    )
    FLAGS.graph_db = flags_parsers.DatabaseFlag(
      graph_tuple_database.Database, graph_db.url, must_exist=True,
    )

    # Use binary prec/rec/f1 scores for binary node classification tasks.
    if graph_db.node_y_dimensionality == 3:
      # alias_sets uses 3-D node labels:
      FLAGS.batch_scores_averaging_method = "weighted"
    elif graph_db.node_y_dimensionality == 2:
      # Binary node classification.
      FLAGS.batch_scores_averaging_method = "binary"
    else:
      raise ValueError(
        f"Unknown node dimensionality: {graph_db.node_y_dimensionality}"
      )

    # Liveness is identifier-based, all others are statement-based.
    if dataset == "liveness":
      FLAGS.nodes = flags_parsers.EnumFlag(
        lstm.NodeEncoder, lstm.NodeEncoder.IDENTIFIER
      )
    else:
      FLAGS.nodes = flags_parsers.EnumFlag(
        lstm.NodeEncoder, lstm.NodeEncoder.STATEMENT
      )

    for model in models:
      FLAGS.tag = f"{dataset}_{model}_{tag_suffix}"

      if model == "zero_r":
        FLAGS.epoch_count = 1
        FLAGS.graph_reader_order = "in_order"
        run.Run(zero_r.ZeroR)
      elif model == "lstm_ir":
        FLAGS.epoch_count = 50
        FLAGS.ir2seq = flags_parsers.EnumFlag(
          lstm.Ir2SeqType, lstm.Ir2SeqType.LLVM
        )
        FLAGS.graph_reader_order = "batch_random"
        FLAGS.padded_sequence_length = 15000
        FLAGS.batch_size = 64
        run.Run(lstm.GraphLstm)
      elif model == "lstm_inst2vec":
        FLAGS.epoch_count = 50
        FLAGS.ir2seq = flags_parsers.EnumFlag(
          lstm.Ir2SeqType, lstm.Ir2SeqType.INST2VEC
        )
        FLAGS.graph_reader_order = "batch_random"
        FLAGS.padded_sequence_length = 15000
        FLAGS.batch_size = 64
        run.Run(lstm.GraphLstm)
      elif model == "ggnn":
        FLAGS.layer_timesteps = ["30"]
        FLAGS.graph_batch_node_count = 15000
        FLAGS.graph_reader_order = "global_random"
        FLAGS.epoch_count = 300
        run.Run(ggnn.Ggnn)
      else:
        raise app.UsageError(f"Unknown model: {model}")