class QuantumGraphTestCase(unittest.TestCase): """Tests the various functions of a quantum graph""" def setUp(self): config = Config( { "version": 1, "namespace": "pipe_base_test", "skypix": { "common": "htm7", "htm": { "class": "lsst.sphgeom.HtmPixelization", "max_level": 24, }, }, "elements": { "A": { "keys": [ { "name": "id", "type": "int", } ], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, "B": { "keys": [ { "name": "id", "type": "int", } ], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, }, "packers": {}, } ) universe = DimensionUniverse(config=config) # need to make a mapping of TaskDef to set of quantum quantumMap = {} tasks = [] for task, label in ( (Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T"), (Dummy4PipelineTask, "U"), ): config = task.ConfigClass() taskDef = TaskDef(get_full_type_name(task), config, task, label) tasks.append(taskDef) quantumSet = set() connections = taskDef.connections for a, b in ((1, 2), (3, 4)): if connections.initInputs: initInputDSType = DatasetType( connections.initInput.name, tuple(), storageClass=connections.initInput.storageClass, universe=universe, ) initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] else: initRefs = None inputDSType = DatasetType( connections.input.name, connections.input.dimensions, storageClass=connections.input.storageClass, universe=universe, ) inputRefs = [ DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) ] outputDSType = DatasetType( connections.output.name, connections.output.dimensions, storageClass=connections.output.storageClass, universe=universe, ) outputRefs = [ DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) ] quantumSet.add( Quantum( taskName=task.__qualname__, dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), taskClass=task, initInputs=initRefs, inputs={inputDSType: inputRefs}, outputs={outputDSType: outputRefs}, ) ) quantumMap[taskDef] = quantumSet self.tasks = tasks self.quantumMap = quantumMap self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) self.universe = universe def testTaskGraph(self): for taskDef in self.quantumMap.keys(): self.assertIn(taskDef, self.qGraph.taskGraph) def testGraph(self): graphSet = {q.quantum for q in self.qGraph.graph} for quantum in chain.from_iterable(self.quantumMap.values()): self.assertIn(quantum, graphSet) def testGetQuantumNodeByNodeId(self): inputQuanta = tuple(self.qGraph.inputQuanta) node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) self.assertEqual(node, inputQuanta[0]) wrongNode = uuid.uuid4() with self.assertRaises(KeyError): self.qGraph.getQuantumNodeByNodeId(wrongNode) def testPickle(self): stringify = pickle.dumps(self.qGraph) restore: QuantumGraph = pickle.loads(stringify) self.assertEqual(self.qGraph, restore) def testInputQuanta(self): inputs = {q.quantum for q in self.qGraph.inputQuanta} self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) def testOutputtQuanta(self): outputs = {q.quantum for q in self.qGraph.outputQuanta} self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) def testLength(self): self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) def testGetQuantaForTask(self): for task in self.tasks: self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) def testGetNodesForTask(self): for task in self.tasks: nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) quanta_in_node = set(n.quantum for n in nodes) self.assertEqual(quanta_in_node, self.quantumMap[task]) def testFindTasksWithInput(self): self.assertEqual( tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] ) def testFindTasksWithOutput(self): self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) def testTaskWithDSType(self): self.assertEqual( set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) ) def testFindTaskDefByName(self): self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) def testFindTaskDefByLabel(self): self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) def testFindQuantaWIthDSType(self): self.assertEqual( self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] ) def testAllDatasetTypes(self): allDatasetTypes = set(self.qGraph.allDatasetTypes) truth = set() for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): for connection in conClass.allConnections.values(): # type: ignore if not isinstance(connection, cT.InitOutput): truth.add(connection.name) self.assertEqual(allDatasetTypes, truth) def testSubset(self): allNodes = list(self.qGraph) subset = self.qGraph.subset(allNodes[0]) self.assertEqual(len(subset), 1) subsetList = list(subset) self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) self.assertEqual(self.qGraph._buildId, subset._buildId) def testSubsetToConnected(self): # False because there are two quantum chains for two distinct sets of # dimensions self.assertFalse(self.qGraph.isConnected) connectedGraphs = self.qGraph.subsetToConnected() self.assertEqual(len(connectedGraphs), 4) self.assertTrue(connectedGraphs[0].isConnected) self.assertTrue(connectedGraphs[1].isConnected) self.assertTrue(connectedGraphs[2].isConnected) self.assertTrue(connectedGraphs[3].isConnected) # Split out task[3] because it is expected to be on its own for cg in connectedGraphs: if self.tasks[3] in cg.taskGraph: self.assertEqual(len(cg), 1) else: self.assertEqual(len(cg), 3) self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) count = 0 for node in self.qGraph: if connectedGraphs[0].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[1].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[2].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[3].checkQuantumInGraph(node.quantum): count += 1 self.assertEqual(len(self.qGraph), count) taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
class QuantumGraphTestCase(unittest.TestCase): """Tests the various functions of a quantum graph """ def setUp(self): config = Config({ "version": 1, "skypix": { "common": "htm7", "htm": { "class": "lsst.sphgeom.HtmPixelization", "max_level": 24, } }, "elements": { "A": { "keys": [{ "name": "id", "type": "int", }], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, "B": { "keys": [{ "name": "id", "type": "int", }], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, } }, "packers": {} }) universe = DimensionUniverse(config=config) # need to make a mapping of TaskDef to set of quantum quantumMap = {} tasks = [] for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): config = task.ConfigClass() taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) tasks.append(taskDef) quantumSet = set() connections = taskDef.connections for a, b in ((1, 2), (3, 4)): if connections.initInputs: initInputDSType = DatasetType( connections.initInput.name, tuple(), storageClass=connections.initInput.storageClass, universe=universe) initRefs = [ DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe)) ] else: initRefs = None inputDSType = DatasetType( connections.input.name, connections.input.dimensions, storageClass=connections.input.storageClass, universe=universe, ) inputRefs = [ DatasetRef( inputDSType, DataCoordinate.standardize({ "A": a, "B": b }, universe=universe)) ] outputDSType = DatasetType( connections.output.name, connections.output.dimensions, storageClass=connections.output.storageClass, universe=universe, ) outputRefs = [ DatasetRef( outputDSType, DataCoordinate.standardize({ "A": a, "B": b }, universe=universe)) ] quantumSet.add( Quantum(taskName=task.__qualname__, dataId=DataCoordinate.standardize( { "A": a, "B": b }, universe=universe), taskClass=task, initInputs=initRefs, inputs={inputDSType: inputRefs}, outputs={outputDSType: outputRefs})) quantumMap[taskDef] = quantumSet self.tasks = tasks self.quantumMap = quantumMap self.qGraph = QuantumGraph(quantumMap) self.universe = universe def _cleanGraphs(self, graph1, graph2): # This is a hack for the unit test since the qualified name will be # different as it will be __main__ here, but qualified to the # unittest module name when restored # Updates in place for saved, loaded in zip(graph1._quanta.keys(), graph2._quanta.keys()): saved.taskName = saved.taskName.split('.')[-1] loaded.taskName = loaded.taskName.split('.')[-1] def testTaskGraph(self): for taskDef in self.quantumMap.keys(): self.assertIn(taskDef, self.qGraph.taskGraph) def testGraph(self): graphSet = {q.quantum for q in self.qGraph.graph} for quantum in chain.from_iterable(self.quantumMap.values()): self.assertIn(quantum, graphSet) def testGetQuantumNodeByNodeId(self): inputQuanta = tuple(self.qGraph.inputQuanta) node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) self.assertEqual(node, inputQuanta[0]) wrongNode = NodeId(15, BuildId("alternative build Id")) with self.assertRaises(IncompatibleGraphError): self.qGraph.getQuantumNodeByNodeId(wrongNode) def testPickle(self): stringify = pickle.dumps(self.qGraph) restore: QuantumGraph = pickle.loads(stringify) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) def testInputQuanta(self): inputs = {q.quantum for q in self.qGraph.inputQuanta} self.assertEqual(self.quantumMap[self.tasks[0]], inputs) def testOutputtQuanta(self): outputs = {q.quantum for q in self.qGraph.outputQuanta} self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) def testLength(self): self.assertEqual(len(self.qGraph), 6) def testGetQuantaForTask(self): for task in self.tasks: self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) def testFindTasksWithInput(self): self.assertEqual( tuple( self.qGraph.findTasksWithInput( DatasetTypeName("Dummy1Output")))[0], self.tasks[1]) def testFindTasksWithOutput(self): self.assertEqual( self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) def testTaskWithDSType(self): self.assertEqual( set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])) def testFindTaskDefByName(self): self.assertEqual( self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) def testFindTaskDefByLabel(self): self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) def testFindQuantaWIthDSType(self): self.assertEqual( self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]) def testAllDatasetTypes(self): allDatasetTypes = set(self.qGraph.allDatasetTypes) truth = set() for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): for connection in conClass.allConnections.values(): # type: ignore truth.add(connection.name) self.assertEqual(allDatasetTypes, truth) def testSubset(self): allNodes = list(self.qGraph) subset = self.qGraph.subset(allNodes[0]) self.assertEqual(len(subset), 1) subsetList = list(subset) self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) self.assertEqual(self.qGraph._buildId, subset._buildId) def testIsConnected(self): # False because there are two quantum chains for two distinct sets of # dimensions self.assertFalse(self.qGraph.isConnected) # make a broken subset allNodes = list(self.qGraph) subset = self.qGraph.subset((allNodes[0], allNodes[1])) # True because we subset to only one chain of graphs self.assertTrue(subset.isConnected) def testSubsetToConnected(self): connectedGraphs = self.qGraph.subsetToConnected() self.assertEqual(len(connectedGraphs), 2) self.assertTrue(connectedGraphs[0].isConnected) self.assertTrue(connectedGraphs[1].isConnected) self.assertEqual(len(connectedGraphs[0]), 3) self.assertEqual(len(connectedGraphs[1]), 3) self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) count = 0 for node in self.qGraph: if connectedGraphs[0].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[1].checkQuantumInGraph(node.quantum): count += 1 self.assertEqual(len(self.qGraph), count) self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) allNodes = list(self.qGraph) node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) self.assertEqual(set([allNodes[0]]), node) node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) self.assertEqual(set([allNodes[0]]), node) def testDetermineOutputsOfQuantumNode(self): allNodes = list(self.qGraph) node = next( iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) self.assertEqual(allNodes[2], node) def testDetermineConnectionsOfQuantum(self): allNodes = list(self.qGraph) connections = self.qGraph.determineConnectionsOfQuantumNode( allNodes[1]) self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) def testDetermineAnsestorsOfQuantumNode(self): allNodes = list(self.qGraph) ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) def testFindCycle(self): self.assertFalse(self.qGraph.findCycle()) def testSaveLoad(self): with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: self.qGraph.save(tmpFile) tmpFile.seek(0) restore = QuantumGraph.load(tmpFile, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) # Load in just one node tmpFile.seek(0) restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) def testSaveLoadUri(self): uri = None try: with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: uri = tmpFile.name self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) nodeNumber = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, ), graphID=self.qGraph._buildId) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore.graphID))) # verify that more than one node works nodeNumber2 = random.randint(0, len(self.qGraph) - 1) # ensure it is a different node number while nodeNumber2 == nodeNumber: nodeNumber2 = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) self.assertEqual(len(restoreSub), 2) self.assertEqual( set(restoreSub), set((restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore._buildId)), restore.getQuantumNodeByNodeId( NodeId(nodeNumber2, restore._buildId))))) # verify an error when requesting a non existant node number with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, nodes=(99, )) # verify a graphID that does not match will be an error with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") except Exception as e: raise e finally: if uri is not None: os.remove(uri) with self.assertRaises(TypeError): self.qGraph.saveUri("test.notgraph") @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") @mock_s3 def testSaveLoadUriS3(self): # Test loading a quantum graph from an mock s3 store conn = boto3.resource('s3', region_name="us-east-1") conn.create_bucket(Bucket='testBucket') uri = f"s3://testBucket/qgraph.qgraph" self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) def testContains(self): firstNode = next(iter(self.qGraph)) self.assertIn(firstNode, self.qGraph)