Example #1
0
    def MakeBatch(
        self,
        epoch_type: epoch.Type,
        graphs: Iterable[graph_tuple_database.GraphTuple],
        ctx: progress.ProgressContext = progress.NullContext,
    ) -> batchs.Data:
        del epoch_type  # Unused.
        del ctx  # Unused.

        batch_size = 0
        graph_ids = []
        targets = []

        # Limit batch size to 10 million elements.
        while batch_size < FLAGS.zero_r_batch_size:
            # Read the next graph.
            try:
                graph = next(graphs)
            except StopIteration:
                # We have run out of graphs.
                if len(graph_ids) == 0:
                    return batchs.EndOfBatches()
                break

            # Add the graph data to the batch.
            graph_ids.append(graph.id)
            if self.graph_db.node_y_dimensionality:
                batch_size += graph.tuple.node_y.size
                targets.append(graph.tuple.node_y)
            else:
                batch_size += graph.tuple.graph_y.size
                targets.append(graph.tuple.graph_y)

        return batchs.Data(graph_ids=graph_ids,
                           data=np.vstack(targets) if targets else None)
Example #2
0
  def MakeBatch(
    self,
    epoch_type: epoch.Type,
    graphs: Iterable[graph_tuple_database.GraphTuple],
    ctx: progress.ProgressContext = progress.NullContext,
  ) -> batches.Data:
    """Generate a fake batch of data."""
    del epoch_type  # Unused.
    del ctx  # Unused.

    graph_ids = []
    while len(graph_ids) < 100:
      try:
        graph_ids.append(next(graphs).id)
      except StopIteration:
        if not graph_ids:
          return batches.EndOfBatches()
        break
    self.make_batch_count += 1
    return batches.Data(graph_ids=graph_ids, data=123)
Example #3
0
    def MakeBatch(
        self,
        epoch_type: epoch.Type,
        graph_iterator: Iterable[graph_tuple_database.GraphTuple],
        ctx: progress.ProgressContext = progress.NullContext,
    ) -> batches.Data:
        """Create a mini-batch of LSTM data."""
        del epoch_type  # Unused.

        graphs = self.GetBatchOfGraphs(graph_iterator)
        if not graphs:
            return batches.EndOfBatches()

        # Encode the graphs in the batch.
        encoded_sequences: List[np.array] = self.encoder.Encode(graphs,
                                                                ctx=ctx)
        graph_x: List[np.array] = []
        graph_y: List[np.array] = []
        for graph in graphs:
            graph_x.append(graph.tuple.graph_x)
            graph_y.append(graph.tuple.graph_y)

        # Pad and truncate encoded sequences.
        encoded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
            encoded_sequences,
            maxlen=self.padded_sequence_length,
            dtype="int32",
            padding="pre",
            truncating="post",
            value=self.padding_element,
        )

        return batches.Data(
            graph_ids=[graph.id for graph in graphs],
            data=GraphLstmBatch(
                encoded_sequences=np.vstack(encoded_sequences),
                graph_x=np.vstack(graph_x),
                graph_y=np.vstack(graph_y),
            ),
        )
Example #4
0
  def MakeBatch(
    self,
    epoch_type: epoch.Type,
    graphs: Iterable[graph_tuple_database.GraphTuple],
    ctx: progress.ProgressContext = progress.NullContext,
  ) -> batches.Data:
    """Create a mini-batch of data from an iterator of graphs.

    Returns:
      A single batch of data for feeding into RunBatch(). A batch consists of a
      list of graph IDs and a model-defined blob of data. If the list of graph
      IDs is empty, the batch is discarded and not fed into RunBatch().
    """
    # TODO(github.com/ChrisCummins/ProGraML/issues/24): The new graph batcher
    # implementation is not well suited for reading the graph IDs, hence this
    # somewhat clumsy iterator wrapper. A neater approach would be to create
    # a graph batcher which returns a list of graphs in the batch.
    class GraphIterator(object):
      """A wrapper around a graph iterator which records graph IDs."""

      def __init__(self, graphs: Iterable[graph_tuple_database.GraphTuple]):
        self.input_graphs = graphs
        self.graphs_read: List[graph_tuple_database.GraphTuple] = []

      def __iter__(self):
        return self

      def __next__(self):
        graph: graph_tuple_database.GraphTuple = next(self.input_graphs)
        self.graphs_read.append(graph)
        return graph.tuple

    graph_iterator = GraphIterator(graphs)

    # Create a disjoint graph out of one or more input graphs.
    batcher = graph_batcher.GraphBatcher.CreateFromFlags(
      graph_iterator, ctx=ctx
    )

    try:
      disjoint_graph = next(batcher)
    except StopIteration:
      # We have run out of graphs.
      return batches.EndOfBatches()

    # Workaround for the fact that graph batcher may read one more graph than
    # actually gets included in the batch.
    if batcher.last_graph:
      graphs = graph_iterator.graphs_read[:-1]
    else:
      graphs = graph_iterator.graphs_read

    # Discard single-graph batches during training when there are graph
    # features. This is because we use batch normalization on incoming features,
    # and batch normalization requires > 1 items to normalize.
    if (
      len(graphs) <= 1
      and epoch_type == epoch.Type.TRAIN
      and disjoint_graph.graph_x_dimensionality
    ):
      return batches.EmptyBatch()

    return batches.Data(
      graph_ids=[graph.id for graph in graphs],
      data=GgnnBatchData(disjoint_graph=disjoint_graph, graphs=graphs),
    )
Example #5
0
      graph_db=self.graph_db, ir2seq_encoder=self._ir2seq_encoder
    )

  def MakeBatch(
    self,
    epoch_type: epoch.Type,
    graph_iterator: Iterable[graph_tuple_database.GraphTuple],
    ctx: progress.ProgressContext = progress.NullContext,
  ) -> batches.Data:
    """Create a mini-batch of LSTM data."""
    del epoch_type  # Unused.

    graphs = self.GetBatchOfGraphs(graph_iterator)
    if not graphs:
<<<<<<< HEAD
      return batches.EndOfBatches()
=======
      return batches.Data(graph_ids=[], data=None)
>>>>>>> 33703e4de... Split the LSTM into multiple modules.

    # Encode the graphs in the batch.
    encoded_sequences: List[np.array] = self.encoder.Encode(graphs, ctx=ctx)
    graph_x: List[np.array] = []
    graph_y: List[np.array] = []
    for graph in graphs:
      graph_x.append(graph.tuple.graph_x)
      graph_y.append(graph.tuple.graph_y)

    # Pad and truncate encoded sequences.
    encoded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
      encoded_sequences,
Example #6
0
  def MakeBatch(
    self,
    epoch_type: epoch.Type,
    graphs: Iterable[graph_tuple_database.GraphTuple],
    ctx: progress.ProgressContext = progress.NullContext,
  ) -> batches.Data:
    """Create a mini-batch of LSTM data."""
    del epoch_type  # Unused.

    graphs = self.GetBatchOfGraphs(graphs)
    # For node classification we require all batches to be of batch_size.
    # In the future we could work around this by padding an incomplete
    # batch with arrays of zeros.
    if not graphs or len(graphs) != self.batch_size:
      return batches.EndOfBatches()

    # Encode the graphs in the batch.

    # A list of arrays of shape (node_count, 1)
    encoded_sequences: List[np.array] = []
    # A list of arrays of shape (node_count, 1)
    segment_ids: List[np.array] = []
    # A list of arrays of shape (node_mask_count, 2)
    selector_vectors: List[np.array] = []
    # A list of arrays of shape (node_mask_count, node_y_dimensionality)
    node_y: List[np.array] = []
    all_node_indices: List[np.array] = []
    targets: List[np.array] = []

    try:
      encoded_graphs = self.encoder.Encode(graphs, ctx=ctx)
    except ValueError as e:
      ctx.Error("%s", e)
      # TODO(github.com/ChrisCummins/ProGraML/issues/38): to debug a possible
      # error in the LSTM I have temporarily made the batch construction
      # resilient to encoder errors by returning an empty batch.
      # Graph encoding failed, so return an empty batch. This is probably
      # not a good idea to keep, as it means the LSTM will silently skip
      # data.
      return batches.Data(graph_ids=[], data=None)

    node_offset = 0
    # Convert ProgramGraphSeq protos to arrays of numeric values.
    for graph, seq in zip(graphs, encoded_graphs):
      # Skip empty graphs.
      if not seq.encoded:
        continue

      encoded_sequences.append(np.array(seq.encoded, dtype=np.int32))

      # Construct a list of segment IDs using the encoded node lengths,
      # e.g. for encoded node lengths [2, 3, 1], produce segment IDs:
      # [0, 0, 1, 1, 1, 2].
      out_of_range_segment = self.padded_node_sequence_length - 1
      segment_ids.append(
        np.concatenate(
          [
            np.ones(encoded_node_length, dtype=np.int32)
            * min(segment_id, out_of_range_segment)
            for segment_id, encoded_node_length in enumerate(
              seq.encoded_node_length
            )
          ]
        )
      )

      # Get the list of graph node indices that produced the serialized encoded
      # graph representation. We use this to construct predictions for the
      # "full" graph through padding.
      node_indices = np.array(seq.node, dtype=np.int32)

      # Offset the node index list and concatenate.
      all_node_indices.append(node_indices + node_offset)

      # Sanity check that the node indices are in-range for this graph.
      assert len(graph.tuple.node_x) >= max(node_indices)

      # Use only the 'binary selector' feature and convert to an array of
      # 1 hot binary vectors.
      node_selectors = graph.tuple.node_x[:, 1][node_indices]
      node_selector_vectors = np.zeros((node_selectors.size, 2), dtype=np.int32)
      node_selector_vectors[np.arange(node_selectors.size), node_selectors] = 1
      selector_vectors.append(node_selector_vectors)

      # Select the node targets for only the active nodes.
      node_y.append(graph.tuple.node_y[node_indices])
      targets.append(graph.tuple.node_y)

      # Increment out node offset for concatenating the list of node indices.
      node_offset += graph.node_count

    # Pad and truncate encoded sequences.
    encoded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
      encoded_sequences,
      maxlen=self.padded_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=self.padding_element,
    )

    # Determine an out-of-range segment ID to pad the segment IDs to.
    segment_id_padding_element = (
      max(max(s) if s.size else 0 for s in segment_ids) + 1
    )

    segment_ids = tf.keras.preprocessing.sequence.pad_sequences(
      segment_ids,
      maxlen=self.padded_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=segment_id_padding_element,
    )

    padded_node_sequence_length = min(
      self.padded_node_sequence_length, max(len(s) for s in selector_vectors)
    )

    # Pad the selector vectors to the same shape as the segment IDs.)
    selector_vectors = tf.keras.preprocessing.sequence.pad_sequences(
      selector_vectors,
      maxlen=padded_node_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=np.array((0, 0), dtype=np.int32),
    )

    node_y = tf.keras.preprocessing.sequence.pad_sequences(
      node_y,
      maxlen=padded_node_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=np.zeros(self.graph_db.node_y_dimensionality, dtype=np.int64),
    )

    all_node_indices = tf.keras.preprocessing.sequence.pad_sequences(
      all_node_indices,
      maxlen=padded_node_sequence_length,
      dtype="int32",
      padding="pre",
      truncating="post",
      value=-1,
    )

    return batches.Data(
      graph_ids=[graph.id for graph in graphs],
      data=NodeLstmBatch(
        encoded_sequences=encoded_sequences,
        segment_ids=segment_ids,
        selector_vectors=selector_vectors,
        node_y=node_y,
        node_indices=np.concatenate(all_node_indices),
        targets=np.vstack(targets),
      ),
    )