Ejemplo n.º 1
0
    def MakeBatch(
        self,
        epoch_type: epoch.Type,
        graphs: Iterable[graph_tuple_database.GraphTuple],
        ctx: progress.ProgressContext = progress.NullContext,
    ) -> batches.Data:
        """Create a mini-batch of data from an iterator of graphs.

    Returns:
      A single batch of data for feeding into RunBatch(). A batch consists of a
      list of graph IDs and a model-defined blob of data. If the list of graph
      IDs is empty, the batch is discarded and not fed into RunBatch().
    """

        # TODO(github.com/ChrisCummins/ProGraML/issues/24): The new graph batcher
        # implementation is not well suited for reading the graph IDs, hence this
        # somewhat clumsy iterator wrapper. A neater approach would be to create
        # a graph batcher which returns a list of graphs in the batch.
        class GraphIterator(object):
            """A wrapper around a graph iterator which records graph IDs."""
            def __init__(self,
                         graphs: Iterable[graph_tuple_database.GraphTuple]):
                self.input_graphs = graphs
                self.graphs_read: List[graph_tuple_database.GraphTuple] = []

            def __iter__(self):
                return self

            def __next__(self):
                graph: graph_tuple_database.GraphTuple = next(
                    self.input_graphs)
                self.graphs_read.append(graph)
                return graph.tuple

        graph_iterator = GraphIterator(graphs)

        # Create a disjoint graph out of one or more input graphs.
        batcher = graph_batcher.GraphBatcher.CreateFromFlags(graph_iterator,
                                                             ctx=ctx)

        try:
            disjoint_graph = next(batcher)
        except StopIteration:
            # We have run out of graphs, return an empty batch.
            return batches.Data(graph_ids=[], data=None)

        # Workaround for the fact that graph batcher may read one more graph than
        # actually gets included in the batch.
        if batcher.last_graph:
            graphs = graph_iterator.graphs_read[:-1]
        else:
            graphs = graph_iterator.graphs_read

        return batches.Data(
            graph_ids=[graph.id for graph in graphs],
            data=GgnnBatchData(disjoint_graph=disjoint_graph, graphs=graphs),
        )
Ejemplo n.º 2
0
    def MakeBatch(
        self,
        epoch_type: epoch.Type,
        graphs: Iterable[graph_tuple_database.GraphTuple],
        ctx: progress.ProgressContext = progress.NullContext,
    ) -> batchs.Data:
        del epoch_type  # Unused.
        del ctx  # Unused.

        batch_size = 0
        graph_ids = []
        targets = []

        # Limit batch size to 10 million elements.
        while batch_size < FLAGS.zero_r_batch_size:
            # Read the next graph.
            try:
                graph = next(graphs)
            except StopIteration:
                # We have run out of graphs.
                if len(graph_ids) == 0:
                    return batchs.EndOfBatches()
                break

            # Add the graph data to the batch.
            graph_ids.append(graph.id)
            if self.graph_db.node_y_dimensionality:
                batch_size += graph.tuple.node_y.size
                targets.append(graph.tuple.node_y)
            else:
                batch_size += graph.tuple.graph_y.size
                targets.append(graph.tuple.graph_y)

        return batchs.Data(graph_ids=graph_ids,
                           data=np.vstack(targets) if targets else None)
Ejemplo n.º 3
0
def test_RollingResults_iteration_count(weight: float):
    """Test aggreation of model iteration count and convergence."""
    rolling_results = batches.RollingResults()

    data = batches.Data(graph_ids=[1], data=None)
    results = batches.Results.Create(
        np.random.rand(1, 10),
        np.random.rand(1, 10),
        iteration_count=1,
        model_converged=True,
    )

    for _ in range(10):
        rolling_results.Update(data, results, weight=weight)

    assert rolling_results.iteration_count == 1
    assert rolling_results.model_converged == 1
    def _CreateBatchDataAndResults(
            self) -> Tuple[batches.Data, batches.Results]:
        """Create a random batch data and results instance."""
        graph_ids = [
            random.choice(self.graph_ids)
            for _ in range(random.randint(1, 200))
        ]

        if self.node_y_dimensionality:
            # Generate per-node predictions/targets.
            if self.graph_db:
                with self.graph_db.Session() as session:
                    id_to_node_count = {
                        row.id: row.node_count
                        for row in session.query(
                            graph_tuple_database.GraphTuple.id,
                            graph_tuple_database.GraphTuple.node_count,
                        ).filter(
                            graph_tuple_database.GraphTuple.id.in_(graph_ids))
                    }
                node_counts = [id_to_node_count[id] for id in graph_ids]
            else:
                node_counts = [
                    random.randint(5, 50) for _ in range(len(graph_ids))
                ]
            target_count = sum(node_counts)
            y_dimensionality = self.node_y_dimensionality
        else:
            # Generate graph predictions/targets.
            target_count = len(graph_ids)
            y_dimensionality = self.graph_y_dimensionality

        data = batches.Data(
            graph_ids=graph_ids,
            data=list(range(random.randint(10000, 100000))),
        )
        results = batches.Results.Create(
            targets=np.random.rand(target_count, y_dimensionality),
            predictions=np.random.rand(target_count, y_dimensionality),
            iteration_count=random.randint(1, 3),
            model_converged=random.choice([False, True]),
            learning_rate=random.random(),
            loss=random.random(),
        )
        return data, results
Ejemplo n.º 5
0
    def MakeBatch(
        self,
        epoch_type: epoch.Type,
        graphs: Iterable[graph_tuple_database.GraphTuple],
        ctx: progress.ProgressContext = progress.NullContext,
    ) -> batches.Data:
        """Generate a fake batch of data."""
        del epoch_type  # Unused.
        del ctx  # Unused.

        graph_ids = []
        while len(graph_ids) < 100:
            try:
                graph_ids.append(next(graphs).id)
            except StopIteration:
                break
        self.make_batch_count += 1
        return batches.Data(graph_ids=graph_ids, data=123)
Ejemplo n.º 6
0
    def MakeBatch(
        self,
        epoch_type: epoch.Type,
        graph_iterator: Iterable[graph_tuple_database.GraphTuple],
        ctx: progress.ProgressContext = progress.NullContext,
    ) -> batches.Data:
        """Create a mini-batch of LSTM data."""
        del epoch_type  # Unused.

        graphs = self.GetBatchOfGraphs(graph_iterator)
        if not graphs:
            return batches.EndOfBatches()

        # Encode the graphs in the batch.
        encoded_sequences: List[np.array] = self.encoder.Encode(graphs,
                                                                ctx=ctx)
        graph_x: List[np.array] = []
        graph_y: List[np.array] = []
        for graph in graphs:
            graph_x.append(graph.tuple.graph_x)
            graph_y.append(graph.tuple.graph_y)

        # Pad and truncate encoded sequences.
        encoded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
            encoded_sequences,
            maxlen=self.padded_sequence_length,
            dtype="int32",
            padding="pre",
            truncating="post",
            value=self.padding_element,
        )

        return batches.Data(
            graph_ids=[graph.id for graph in graphs],
            data=GraphLstmBatch(
                encoded_sequences=np.vstack(encoded_sequences),
                graph_x=np.vstack(graph_x),
                graph_y=np.vstack(graph_y),
            ),
        )
Ejemplo n.º 7
0
  def MakeBatch(
    self,
    epoch_type: epoch.Type,
    graph_iterator: Iterable[graph_tuple_database.GraphTuple],
    ctx: progress.ProgressContext = progress.NullContext,
  ) -> batches.Data:
    """Create a mini-batch of LSTM data."""
    del epoch_type  # Unused.

    graphs = self.GetBatchOfGraphs(graph_iterator)
    if not graphs:
<<<<<<< HEAD
      return batches.EndOfBatches()
=======
      return batches.Data(graph_ids=[], data=None)
>>>>>>> 33703e4de... Split the LSTM into multiple modules.

    # Encode the graphs in the batch.
    encoded_sequences: List[np.array] = self.encoder.Encode(graphs, ctx=ctx)
    graph_x: List[np.array] = []
    graph_y: List[np.array] = []
    for graph in graphs:
      graph_x.append(graph.tuple.graph_x)
      graph_y.append(graph.tuple.graph_y)

    # Pad and truncate encoded sequences.
    encoded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
      encoded_sequences,
      maxlen=self.padded_sequence_length,
      dtype="int32",
Ejemplo n.º 8
0
  def MakeBatch(
    self,
    epoch_type: epoch.Type,
    graphs: Iterable[graph_tuple_database.GraphTuple],
    ctx: progress.ProgressContext = progress.NullContext,
  ) -> batches.Data:
    """Create a mini-batch of LSTM data."""
    del epoch_type  # Unused.

    graphs = self.GetBatchOfGraphs(graphs)
    # For node classification we require all batches to be of batch_size.
    # In the future we could work around this by padding an incomplete
    # batch with arrays of zeros.
    if not graphs or len(graphs) != self.batch_size:
      return batches.EndOfBatches()

    # Encode the graphs in the batch.

    # A list of arrays of shape (node_count, 1)
    encoded_sequences: List[np.array] = []
    # A list of arrays of shape (node_count, 1)
    segment_ids: List[np.array] = []
    # A list of arrays of shape (node_mask_count, 2)
    selector_vectors: List[np.array] = []
    # A list of arrays of shape (node_mask_count, node_y_dimensionality)
    node_y: List[np.array] = []
    all_node_indices: List[np.array] = []
    targets: List[np.array] = []

    try:
      encoded_graphs = self.encoder.Encode(graphs, ctx=ctx)
    except ValueError as e:
      ctx.Error("%s", e)
      # TODO(github.com/ChrisCummins/ProGraML/issues/38): to debug a possible
      # error in the LSTM I have temporarily made the batch construction
      # resilient to encoder errors by returning an empty batch.
      # Graph encoding failed, so return an empty batch. This is probably
      # not a good idea to keep, as it means the LSTM will silently skip
      # data.
      return batches.Data(graph_ids=[], data=None)

    node_offset = 0
    # Convert ProgramGraphSeq protos to arrays of numeric values.
    for graph, seq in zip(graphs, encoded_graphs):
      # Skip empty graphs.
      if not seq.encoded:
        continue

      encoded_sequences.append(np.array(seq.encoded, dtype=np.int32))

      # Construct a list of segment IDs using the encoded node lengths,
      # e.g. for encoded node lengths [2, 3, 1], produce segment IDs:
      # [0, 0, 1, 1, 1, 2].
      out_of_range_segment = self.padded_node_sequence_length - 1
      segment_ids.append(
        np.concatenate(
          [
            np.ones(encoded_node_length, dtype=np.int32)
            * min(segment_id, out_of_range_segment)
            for segment_id, encoded_node_length in enumerate(
              seq.encoded_node_length
            )
          ]
        )
      )

      # Get the list of graph node indices that produced the serialized encoded
      # graph representation. We use this to construct predictions for the
      # "full" graph through padding.
      node_indices = np.array(seq.node, dtype=np.int32)

      # Offset the node index list and concatenate.
      all_node_indices.append(node_indices + node_offset)

      # Sanity check that the node indices are in-range for this graph.
      assert len(graph.tuple.node_x) >= max(node_indices)

      # Use only the 'binary selector' feature and convert to an array of
      # 1 hot binary vectors.
      node_selectors = graph.tuple.node_x[:, 1][node_indices]
      node_selector_vectors = np.zeros((node_selectors.size, 2), dtype=np.int32)
      node_selector_vectors[np.arange(node_selectors.size), node_selectors] = 1
      selector_vectors.append(node_selector_vectors)

      # Select the node targets for only the active nodes.
      node_y.append(graph.tuple.node_y[node_indices])
      targets.append(graph.tuple.node_y)

      # Increment out node offset for concatenating the list of node indices.
      node_offset += graph.node_count

    # Pad and truncate encoded sequences.
    encoded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
      encoded_sequences,
      maxlen=self.padded_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=self.padding_element,
    )

    # Determine an out-of-range segment ID to pad the segment IDs to.
    segment_id_padding_element = (
      max(max(s) if s.size else 0 for s in segment_ids) + 1
    )

    segment_ids = tf.keras.preprocessing.sequence.pad_sequences(
      segment_ids,
      maxlen=self.padded_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=segment_id_padding_element,
    )

    padded_node_sequence_length = min(
      self.padded_node_sequence_length, max(len(s) for s in selector_vectors)
    )

    # Pad the selector vectors to the same shape as the segment IDs.)
    selector_vectors = tf.keras.preprocessing.sequence.pad_sequences(
      selector_vectors,
      maxlen=padded_node_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=np.array((0, 0), dtype=np.int32),
    )

    node_y = tf.keras.preprocessing.sequence.pad_sequences(
      node_y,
      maxlen=padded_node_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=np.zeros(self.graph_db.node_y_dimensionality, dtype=np.int64),
    )

    all_node_indices = tf.keras.preprocessing.sequence.pad_sequences(
      all_node_indices,
      maxlen=padded_node_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=-1,
    )

    return batches.Data(
      graph_ids=[graph.id for graph in graphs],
      data=NodeLstmBatch(
        encoded_sequences=encoded_sequences,
        segment_ids=segment_ids,
        selector_vectors=selector_vectors,
        node_y=node_y,
        node_indices=np.concatenate(all_node_indices),
        targets=np.vstack(targets),
      ),
    )