Esempio n. 1
0
  def Split(self, db: graph_tuple_database.Database) -> List[np.array]:
    """Apply K-fold split on a graph database stratified over graph_y labels."""
    assert db.graph_y_dimensionality

    with prof.Profile(f"Loaded labels from {db.graph_count} graphs"):
      # Load all of the graph IDs and their labels.
      reader = graph_database_reader.BufferedGraphReader(db)

      graph_ids: List[int] = []
      graph_y: List[int] = []
      for graph in reader:
        graph_ids.append(graph.id)
        graph_y.append(np.argmax(graph.tuple.graph_y))
      graph_ids = np.array(graph_ids, dtype=np.int32)
      graph_y = np.array(graph_y, dtype=np.int64)

    splitter = model_selection.StratifiedKFold(
      n_splits=self.k, shuffle=True, random_state=FLAGS.seed
    )
    dataset_splits = splitter.split(graph_ids, graph_y)

    return [
      np.array(graph_ids[test], dtype=np.int32)
      for i, (train, test) in enumerate(dataset_splits)
    ]
Esempio n. 2
0
def test_empty_batches_is_not_error(
    graph_db: graph_tuple_database.Database,
    logger: logging.Logger,
):
    """Test that empty batches are ignored.

  Regression test for <github.com/ChrisCummins/ProGraML/issues/43>. Empty batch
  generation was determined to be the cause of flaky model crashes.
  """
    class FlakyBatchModel(MockModel):
        """A mock model which returns ~50% empty batches."""
        def __init__(self, *args, **kwargs):
            super(FlakyBatchModel, self).__init__(*args, **kwargs)
            self._batch_count = 0

        def MakeBatch(
            self,
            epoch_type: epoch.Type,
            graphs: Iterable[graph_tuple_database.GraphTuple],
            ctx: progress.ProgressContext = progress.NullContext,
        ) -> batches.Data:
            self._batch_count += 1

            if self._batch_count == 1:
                # Always return an empty first batch.
                return batches.EmptyBatch()
            elif self._batch_count == 2:
                # Always return a real second batch (otherwise the epoch may end up with
                # nothing but empty batches).
                return super(FlakyBatchModel,
                             self).MakeBatch(epoch_type, graphs, ctx)

            # Return subsequent batches with 50% success rate.
            if random.random() < 0.5:
                return super(FlakyBatchModel,
                             self).MakeBatch(epoch_type, graphs, ctx)
            else:
                return batches.EmptyBatch()

    run_id = run_id_lib.RunId.GenerateUnique(
        f"mock{random.randint(0, int(1e6)):06}")

    model = FlakyBatchModel(
        logger=logger,
        graph_db=graph_db,
        run_id=run_id,
    )

    batch_iterator = batches.BatchIterator(
        batches=model.BatchIterator(
            epoch.Type.TRAIN,
            graph_database_reader.BufferedGraphReader(graph_db)),
        graph_count=graph_db.graph_count,
    )

    model.Initialize()
    results = model(epoch_type=epoch.Type.TRAIN,
                    batch_iterator=batch_iterator,
                    logger=logger)
    assert results.batch_count >= 1
Esempio n. 3
0
def test_BufferedGraphReader_values_in_order(
  db_10000: graph_tuple_database.Database, buffer_size_mb: int
):
  """Test that the expected number of graphs are returned"""
  graphs = list(
    reader.BufferedGraphReader(db_10000, buffer_size_mb=buffer_size_mb)
  )
  assert len(graphs) == 10000
  # When read in order, the ir_ids should be equal to their position.
  assert all([g.ir_id == i for i, g in enumerate(graphs)])
Esempio n. 4
0
def batch_iterator(
    model: MockModel,
    graph_db: graph_tuple_database.Database,
) -> batches.BatchIterator:
    return batches.BatchIterator(
        batches=model.BatchIterator(
            epoch.Type.TRAIN,
            graph_database_reader.BufferedGraphReader(graph_db)),
        graph_count=graph_db.graph_count,
    )
Esempio n. 5
0
def test_benchmark_BufferedGraphReader_global_random(
  benchmark, db_10000: graph_tuple_database.Database, buffer_size_mb: int,
):
  """Benchmark global random database reads."""
  benchmark(
    list,
    reader.BufferedGraphReader(
      db_10000,
      buffer_size_mb=buffer_size_mb,
      order=reader.BufferedGraphReaderOrder.GLOBAL_RANDOM,
    ),
  )
Esempio n. 6
0
def test_benchmark_BufferedGraphReader_in_order(
  benchmark, db_10000: graph_tuple_database.Database, buffer_size_mb: int,
):
  """Benchmark in-order database reads."""
  benchmark(
    list,
    reader.BufferedGraphReader(
      db_10000,
      buffer_size_mb=buffer_size_mb,
      order=reader.BufferedGraphReaderOrder.IN_ORDER,
    ),
  )
Esempio n. 7
0
def test_benchmark_BufferedGraphReader_data_flow_steps(
  benchmark, db_10000: graph_tuple_database.Database, buffer_size_mb: int,
):
  """Benchmark ordered database reads."""
  benchmark(
    list,
    reader.BufferedGraphReader(
      db_10000,
      buffer_size_mb=buffer_size_mb,
      order=reader.BufferedGraphReaderOrder.DATA_FLOW_STEPS,
    ),
  )
Esempio n. 8
0
def test_BufferedGraphReader_next(
  db_10000: graph_tuple_database.Database,
  buffer_size_mb: int,
  order: reader.BufferedGraphReaderOrder,
):
  """Test using next() to read from BufferedGraphReader()."""
  db_reader = reader.BufferedGraphReader(
    db_10000, buffer_size_mb=buffer_size_mb, order=order
  )
  for _ in range(10000):
    next(db_reader)
  with test.Raises(StopIteration):
    next(db_reader)
Esempio n. 9
0
def test_BufferedGraphReader_limit(
  db_10000: graph_tuple_database.Database,
  buffer_size_mb: int,
  order: reader.BufferedGraphReaderOrder,
  limit: int,
):
  """Test using `limit` arg to limit number of returned rows."""
  graphs = list(
    reader.BufferedGraphReader(
      db_10000, limit=limit, buffer_size_mb=buffer_size_mb, order=order
    )
  )
  assert len(graphs) == min(limit, 10000)
Esempio n. 10
0
def test_BufferedGraphReader_random_orders(
  db_10000: graph_tuple_database.Database,
  buffer_size_mb: int,
  order: reader.BufferedGraphReaderOrder,
):
  """Test that random order return rows in randomized order."""
  graphs = list(
    reader.BufferedGraphReader(
      db_10000, buffer_size_mb=buffer_size_mb, order=order
    )
  )
  ir_ids = [g.ir_id for g in graphs]
  assert ir_ids != sorted(ir_ids)
Esempio n. 11
0
def test_BufferedGraphReader_filters(
  db_10000: graph_tuple_database.Database,
  buffer_size_mb: int,
  order: reader.BufferedGraphReaderOrder,
):
  """Test when using filters to limit results."""
  filters = [
    lambda: graph_tuple_database.GraphTuple.ir_id < 3000,
    lambda: graph_tuple_database.GraphTuple.data_flow_steps < 2000,
  ]
  graphs = list(
    reader.BufferedGraphReader(
      db_10000, filters=filters, buffer_size_mb=buffer_size_mb, order=order
    )
  )
  assert len(graphs) == 2000
Esempio n. 12
0
  def GetGraphsForBatch(
    self, batch: log_database.Batch
  ) -> Iterable[graph_tuple_database.GraphTuple]:
    """Reconstruct the graphs for a batch.

    Returns:
      A iterable sequence of the unique graphs from a batch. Note that order may
      not be the same as the order they appeared in the batch, and that
      duplicate graphs in the batch will only be returned once.
    """
    if not batch.details:
      raise OSError("Cannot re-create batch without detailed logs")

    filters = [lambda: graph_tuple_database.GraphTuple.id.in_(batch.graph_ids)]
    return graph_database_reader.BufferedGraphReader(
      self.graph_db, filters=filters
    )
Esempio n. 13
0
def test_BufferedGraphReader_data_flow_steps_order(
  db_10000: graph_tuple_database.Database, buffer_size_mb: int
):
  """Test that data flow max steps increases monotonically."""
  db_reader = reader.BufferedGraphReader(
    db_10000,
    buffer_size_mb=buffer_size_mb,
    order=reader.BufferedGraphReaderOrder.DATA_FLOW_STEPS,
  )
  current_steps = -1
  i = 0
  for i, graph in enumerate(db_reader):
    # Sanity check that test fixture set expected values for data flow steps.
    assert graph.data_flow_steps == graph.ir_id
    # Assert that data flow steps is monotonically increasing.
    assert graph.data_flow_steps >= current_steps
    current_steps = graph.data_flow_steps
  # Sanity check that the correct number of graphs were returned.
  assert i + 1 == 10000
Esempio n. 14
0
    def FinalizeKerasModel(self) -> None:
        """Finalize a newly instantiated keras model.

    To enable thread-safe use of the Keras model we must ensure that the
    computation graph is fully instantiated from the master thread before the
    first call to RunBatch(). Keras lazily instantiates parts of the graph
    which we can force by performing the necessary ops now:
      * training ops: make sure those are created by running the training loop
        on a small batch of data.
      * save/restore ops: make sure those are created by running save_model()
        and throwing away the generated file.

    Once we have performed those actions, we can freeze the computation graph
    to make explicit the fact that later operations are not permitted to modify
    the graph.
    """
        with self.graph.as_default():
            tf.compat.v1.keras.backend.set_session(self.session)
            # To enable thread-safe use of the Keras model we must ensure that
            # the computation graph is fully instantiated before the first call
            # to RunBatch(). Keras lazily instantiates parts of the graph (such as
            # training ops), so make sure those are created by running the training
            # loop now on a single graph.
            reader = graph_database_reader.BufferedGraphReader(
                self.graph_db, limit=self.warm_up_batch_size)
            batch = self.MakeBatch(epoch.Type.TRAIN, reader)
            assert batch.graph_count == self.warm_up_batch_size
            self.RunBatch(epoch.Type.TRAIN, batch)

            # Run private model methods that instantiate graph components.
            # See: https://stackoverflow.com/a/46801607
            self.model._make_predict_function()
            self.model._make_test_function()
            self.model._make_train_function()

            # Saving the graph also creates new ops, so run it now.
            with tempfile.TemporaryDirectory(prefix="ml4pl_lstm_") as d:
                self.model.save(pathlib.Path(d) / "delete_md.h5")

        # Finally we have instantiated the graph, so freeze it to mane any
        # implicit modification raise an error.
        self.graph.finalize()
Esempio n. 15
0
    self.session = utils.SetAllowedGrowthOnKerasSession()
    self.graph = tf.compat.v1.get_default_graph()

    # To enable thread-safe use of a Keras model we must make sure to fix the
    # graph and session whenever we are going to use self.model.
    with self.graph.as_default():
      tf.compat.v1.keras.backend.set_session(self.session)
      self.model = self.CreateKerasModel()

      # To enable thread-safe use of the Keras model we must ensure that
      # the computation graph is fully instantiated before the first call
      # to RunBatch(). Keras lazily instantiates parts of the graph (such as
      # training ops), so make sure those are created by running the training
      # loop now on a single graph.
      reader = graph_database_reader.BufferedGraphReader(
        self.graph_db, limit=self.warm_up_batch_size
      )
      batch = self.MakeBatch(epoch.Type.TRAIN, reader)
      assert batch.graph_count == self.warm_up_batch_size
      self.RunBatch(epoch.Type.TRAIN, batch)

      # Run private model methods that instantiate graph components.
      # See: https://stackoverflow.com/a/46801607
      self.model._make_predict_function()
      self.model._make_test_function()
      self.model._make_train_function()

      # Saving the graph also creates new ops, so run it now.
      with tempfile.TemporaryDirectory(prefix="ml4pl_lstm_") as d:
        self.model.save(pathlib.Path(d) / "delete_md.h5")
Esempio n. 16
0
    def __init__(
        self,
        *args,
        padded_sequence_length: Optional[int] = None,
        graph2seq_encoder: Optional[graph2seq.EncoderBase] = None,
        batch_size: Optional[int] = None,
        **kwargs,
    ):
        super(LstmBase, self).__init__(*args, **kwargs)

        self.batch_size = batch_size or FLAGS.batch_size

        # Determine the size of padded sequences. Use the requested
        # padded_sequence_length, or the maximum encoded length if it is shorter.
        self.padded_sequence_length = (padded_sequence_length
                                       or FLAGS.padded_sequence_length)

        self.encoder = graph2seq_encoder or self.GetEncoder()

        # After instantiating the encoder, see if we can reduce the padded sequence
        # length.
        self.padded_sequence_length = min(self.padded_sequence_length,
                                          self.encoder.max_encoded_length)

        # Reset any previous Tensorflow session. This is required when running
        # consecutive LSTM models in the same process.
        tf.keras.backend.clear_session()

        # Create the Tensorflow session and graph for the model.
        self.session = utils.SetAllowedGrowthOnKerasSession()
        self.graph = tf.compat.v1.get_default_graph()

        # To enable thread-safe use of a Keras model we must make sure to fix the
        # graph and session whenever we are going to use self.model.
        with self.graph.as_default():
            tf.compat.v1.keras.backend.set_session(self.session)
            self.model = self.CreateKerasModel()

            # To enable thread-safe use of the Keras model we must ensure that
            # the computation graph is fully instantiated before the first call
            # to RunBatch(). Keras lazily instantiates parts of the graph (such as
            # training ops), so make sure those are created by running the training
            # loop now on a single graph.
            reader = graph_database_reader.BufferedGraphReader(
                self.graph_db, limit=self.warm_up_batch_size)
            batch = self.MakeBatch(epoch.Type.TRAIN, reader)
            assert batch.graph_count == self.warm_up_batch_size
            self.RunBatch(epoch.Type.TRAIN, batch)

            # Run private model methods that instantiate graph components.
            # See: https://stackoverflow.com/a/46801607
            self.model._make_predict_function()
            self.model._make_test_function()
            self.model._make_train_function()

            # Saving the graph also creates new ops, so run it now.
            with tempfile.TemporaryDirectory(prefix="ml4pl_lstm_") as d:
                self.model.save(pathlib.Path(d) / "delete_md.h5")

        # Finally we have instantiated the graph, so freeze it to mane any
        # implicit modification raise an error.
        self.graph.finalize()