Esempio n. 1
0
def test_fuzz_BuildConfusionMatrix():
    """Fuzz confusion matrix construction."""
    num_instances = random.randint(1, 100)
    y_dimensionality = random.randint(2, 5)

    targets = np.random.rand(num_instances, y_dimensionality)
    predictions = np.random.rand(num_instances, y_dimensionality)

    confusion_matrix = log_analysis.BuildConfusionMatrix(targets, predictions)

    assert confusion_matrix.shape == (y_dimensionality, y_dimensionality)
    assert confusion_matrix.sum() == num_instances
Esempio n. 2
0
def test_BuildConfusionMatrix():
    """Test confusion matrix with known values."""
    confusion_matrix = log_analysis.BuildConfusionMatrix(
        targets=np.array([
            np.array([1, 0, 0], dtype=np.int32),
            np.array([0, 0, 1], dtype=np.int32),
            np.array([0, 0, 1], dtype=np.int32),
        ]),
        predictions=np.array([
            np.array([0.1, 0.5, 0], dtype=np.float32),
            np.array([0, -0.5, -0.3], dtype=np.float32),
            np.array([0, 0, 0.8], dtype=np.float32),
        ]),
    )

    assert confusion_matrix.shape == (3, 3)
    assert confusion_matrix.sum() == 3
    assert np.array_equal(confusion_matrix,
                          np.array([
                              [0, 1, 0],
                              [0, 0, 0],
                              [1, 0, 1],
                          ]))
Esempio n. 3
0
  def Run(self):
    """Export the batch graphs."""
    # Get the full list of batches to export.
    with self.log_db.Session() as session:
      batch_ids = self.FilterBatchesQuery(
        session.query(log_database.Batch.id)
      ).all()

    # A confusion matrix for the entire set of batches.
    cm = np.zeros((2, 2), dtype=np.int64)

    name_prefix = f"{self.checkpoint.run_id}.{self.checkpoint.epoch_num}.{self.epoch_type.name.lower()}"

    # Split the batches into chunks.
    batch_id_chunks = labtypes.Chunkify(
      batch_ids, self.export_batches_per_query
    )
    # Read the batches in a background thread.
    batches = ppar.ThreadedIterator(
      map(self.LoadBatch, batch_id_chunks), max_queue_size=5
    )
    # Process the batches in a background thread.
    graphs_batches = ppar.ThreadedIterator(
      map(self.BuildGraphsFromBatches, batches), max_queue_size=5
    )

    for batch in graphs_batches:
      for batch_id, graphs in batch:
        for graph_id, ingraph, outgraph in graphs:
          self.ctx.i += 1

          name = f"{name_prefix}.{batch_id:04d}.{graph_id:04d}"

          statspath = self.outdir / "graph_stats" / f"{name}.json"
          ingraph_path = self.outdir / "in_graphs" / f"{name}.GraphTuple.pickle"
          outgraph_path = (
            self.outdir / "out_graphs" / f"{name}.GraphTuple.pickle"
          )

          # Write the input and output graph tuples.
          ingraph.ToFile(ingraph_path)
          outgraph.ToFile(outgraph_path)

          # Write the graph-level stats.
          graph_cm = log_analysis.BuildConfusionMatrix(
            ingraph.node_y, outgraph.node_y
          )
          fs.Write(
            statspath,
            jsonutil.format_json(
              {
                "accuracy": (graph_cm[0][0] + graph_cm[1][1]) / graph_cm.sum(),
                "node_count": ingraph.node_count,
                "edge_count": ingraph.edge_count,
                "confusion_matrix": graph_cm.tolist(),
              }
            ).encode("utf-8"),
          )

          cm = np.add(cm, graph_cm)

    # Write the epoch-level stats.
    fs.Write(
      self.outdir / f"{name_prefix}.json",
      jsonutil.format_json(
        {"graph_count": self.ctx.i, "confusion_matrix": cm.tolist(),}
      ).encode("utf-8"),
    )

    self.ctx.i = self.ctx.n