def test_fuzz_BuildConfusionMatrix(): """Fuzz confusion matrix construction.""" num_instances = random.randint(1, 100) y_dimensionality = random.randint(2, 5) targets = np.random.rand(num_instances, y_dimensionality) predictions = np.random.rand(num_instances, y_dimensionality) confusion_matrix = log_analysis.BuildConfusionMatrix(targets, predictions) assert confusion_matrix.shape == (y_dimensionality, y_dimensionality) assert confusion_matrix.sum() == num_instances
def test_BuildConfusionMatrix(): """Test confusion matrix with known values.""" confusion_matrix = log_analysis.BuildConfusionMatrix( targets=np.array([ np.array([1, 0, 0], dtype=np.int32), np.array([0, 0, 1], dtype=np.int32), np.array([0, 0, 1], dtype=np.int32), ]), predictions=np.array([ np.array([0.1, 0.5, 0], dtype=np.float32), np.array([0, -0.5, -0.3], dtype=np.float32), np.array([0, 0, 0.8], dtype=np.float32), ]), ) assert confusion_matrix.shape == (3, 3) assert confusion_matrix.sum() == 3 assert np.array_equal(confusion_matrix, np.array([ [0, 1, 0], [0, 0, 0], [1, 0, 1], ]))
def Run(self): """Export the batch graphs.""" # Get the full list of batches to export. with self.log_db.Session() as session: batch_ids = self.FilterBatchesQuery( session.query(log_database.Batch.id) ).all() # A confusion matrix for the entire set of batches. cm = np.zeros((2, 2), dtype=np.int64) name_prefix = f"{self.checkpoint.run_id}.{self.checkpoint.epoch_num}.{self.epoch_type.name.lower()}" # Split the batches into chunks. batch_id_chunks = labtypes.Chunkify( batch_ids, self.export_batches_per_query ) # Read the batches in a background thread. batches = ppar.ThreadedIterator( map(self.LoadBatch, batch_id_chunks), max_queue_size=5 ) # Process the batches in a background thread. graphs_batches = ppar.ThreadedIterator( map(self.BuildGraphsFromBatches, batches), max_queue_size=5 ) for batch in graphs_batches: for batch_id, graphs in batch: for graph_id, ingraph, outgraph in graphs: self.ctx.i += 1 name = f"{name_prefix}.{batch_id:04d}.{graph_id:04d}" statspath = self.outdir / "graph_stats" / f"{name}.json" ingraph_path = self.outdir / "in_graphs" / f"{name}.GraphTuple.pickle" outgraph_path = ( self.outdir / "out_graphs" / f"{name}.GraphTuple.pickle" ) # Write the input and output graph tuples. ingraph.ToFile(ingraph_path) outgraph.ToFile(outgraph_path) # Write the graph-level stats. graph_cm = log_analysis.BuildConfusionMatrix( ingraph.node_y, outgraph.node_y ) fs.Write( statspath, jsonutil.format_json( { "accuracy": (graph_cm[0][0] + graph_cm[1][1]) / graph_cm.sum(), "node_count": ingraph.node_count, "edge_count": ingraph.edge_count, "confusion_matrix": graph_cm.tolist(), } ).encode("utf-8"), ) cm = np.add(cm, graph_cm) # Write the epoch-level stats. fs.Write( self.outdir / f"{name_prefix}.json", jsonutil.format_json( {"graph_count": self.ctx.i, "confusion_matrix": cm.tolist(),} ).encode("utf-8"), ) self.ctx.i = self.ctx.n