def testMakeGraphFromSave(self): """Tests for CmdLineFwk.makeGraph method. Only most trivial case is tested that does not do actual graph building. """ fwk = CmdLineFwk() with makeTmpFile(suffix=".qgraph") as tmpname, makeSQLiteRegistry( ) as registryConfig: # make non-empty graph and store it in a file qgraph = _makeQGraph() with open(tmpname, "wb") as saveFile: qgraph.save(saveFile) args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig) qgraph = fwk.makeGraph(None, args) self.assertIsInstance(qgraph, QuantumGraph) self.assertEqual(len(qgraph), 1) # will fail if graph id does not match args = _makeArgs(qgraph=tmpname, qgraph_id="R2-D2 is that you?", registryConfig=registryConfig) with self.assertRaisesRegex(ValueError, "graphID does not match"): fwk.makeGraph(None, args) # save with wrong object type with open(tmpname, "wb") as saveFile: pickle.dump({}, saveFile) args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig) with self.assertRaises(ValueError): fwk.makeGraph(None, args) # reading empty graph from pickle should work but makeGraph() # will return None and make a warning qgraph = QuantumGraph(dict()) with open(tmpname, "wb") as saveFile: qgraph.save(saveFile) args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig) with self.assertWarnsRegex(UserWarning, "QuantumGraph is empty"): # this also tests that warning is generated for empty graph qgraph = fwk.makeGraph(None, args) self.assertIs(qgraph, None)
class QuantumGraphTestCase(unittest.TestCase): """Tests the various functions of a quantum graph """ def setUp(self): config = Config({ "version": 1, "skypix": { "common": "htm7", "htm": { "class": "lsst.sphgeom.HtmPixelization", "max_level": 24, } }, "elements": { "A": { "keys": [{ "name": "id", "type": "int", }], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, "B": { "keys": [{ "name": "id", "type": "int", }], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, } }, "packers": {} }) universe = DimensionUniverse(config=config) # need to make a mapping of TaskDef to set of quantum quantumMap = {} tasks = [] for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): config = task.ConfigClass() taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) tasks.append(taskDef) quantumSet = set() connections = taskDef.connections for a, b in ((1, 2), (3, 4)): if connections.initInputs: initInputDSType = DatasetType( connections.initInput.name, tuple(), storageClass=connections.initInput.storageClass, universe=universe) initRefs = [ DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe)) ] else: initRefs = None inputDSType = DatasetType( connections.input.name, connections.input.dimensions, storageClass=connections.input.storageClass, universe=universe, ) inputRefs = [ DatasetRef( inputDSType, DataCoordinate.standardize({ "A": a, "B": b }, universe=universe)) ] outputDSType = DatasetType( connections.output.name, connections.output.dimensions, storageClass=connections.output.storageClass, universe=universe, ) outputRefs = [ DatasetRef( outputDSType, DataCoordinate.standardize({ "A": a, "B": b }, universe=universe)) ] quantumSet.add( Quantum(taskName=task.__qualname__, dataId=DataCoordinate.standardize( { "A": a, "B": b }, universe=universe), taskClass=task, initInputs=initRefs, inputs={inputDSType: inputRefs}, outputs={outputDSType: outputRefs})) quantumMap[taskDef] = quantumSet self.tasks = tasks self.quantumMap = quantumMap self.qGraph = QuantumGraph(quantumMap) self.universe = universe def _cleanGraphs(self, graph1, graph2): # This is a hack for the unit test since the qualified name will be # different as it will be __main__ here, but qualified to the # unittest module name when restored # Updates in place for saved, loaded in zip(graph1._quanta.keys(), graph2._quanta.keys()): saved.taskName = saved.taskName.split('.')[-1] loaded.taskName = loaded.taskName.split('.')[-1] def testTaskGraph(self): for taskDef in self.quantumMap.keys(): self.assertIn(taskDef, self.qGraph.taskGraph) def testGraph(self): graphSet = {q.quantum for q in self.qGraph.graph} for quantum in chain.from_iterable(self.quantumMap.values()): self.assertIn(quantum, graphSet) def testGetQuantumNodeByNodeId(self): inputQuanta = tuple(self.qGraph.inputQuanta) node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) self.assertEqual(node, inputQuanta[0]) wrongNode = NodeId(15, BuildId("alternative build Id")) with self.assertRaises(IncompatibleGraphError): self.qGraph.getQuantumNodeByNodeId(wrongNode) def testPickle(self): stringify = pickle.dumps(self.qGraph) restore: QuantumGraph = pickle.loads(stringify) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) def testInputQuanta(self): inputs = {q.quantum for q in self.qGraph.inputQuanta} self.assertEqual(self.quantumMap[self.tasks[0]], inputs) def testOutputtQuanta(self): outputs = {q.quantum for q in self.qGraph.outputQuanta} self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) def testLength(self): self.assertEqual(len(self.qGraph), 6) def testGetQuantaForTask(self): for task in self.tasks: self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) def testFindTasksWithInput(self): self.assertEqual( tuple( self.qGraph.findTasksWithInput( DatasetTypeName("Dummy1Output")))[0], self.tasks[1]) def testFindTasksWithOutput(self): self.assertEqual( self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) def testTaskWithDSType(self): self.assertEqual( set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])) def testFindTaskDefByName(self): self.assertEqual( self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) def testFindTaskDefByLabel(self): self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) def testFindQuantaWIthDSType(self): self.assertEqual( self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]) def testAllDatasetTypes(self): allDatasetTypes = set(self.qGraph.allDatasetTypes) truth = set() for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): for connection in conClass.allConnections.values(): # type: ignore truth.add(connection.name) self.assertEqual(allDatasetTypes, truth) def testSubset(self): allNodes = list(self.qGraph) subset = self.qGraph.subset(allNodes[0]) self.assertEqual(len(subset), 1) subsetList = list(subset) self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) self.assertEqual(self.qGraph._buildId, subset._buildId) def testIsConnected(self): # False because there are two quantum chains for two distinct sets of # dimensions self.assertFalse(self.qGraph.isConnected) # make a broken subset allNodes = list(self.qGraph) subset = self.qGraph.subset((allNodes[0], allNodes[1])) # True because we subset to only one chain of graphs self.assertTrue(subset.isConnected) def testSubsetToConnected(self): connectedGraphs = self.qGraph.subsetToConnected() self.assertEqual(len(connectedGraphs), 2) self.assertTrue(connectedGraphs[0].isConnected) self.assertTrue(connectedGraphs[1].isConnected) self.assertEqual(len(connectedGraphs[0]), 3) self.assertEqual(len(connectedGraphs[1]), 3) self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) count = 0 for node in self.qGraph: if connectedGraphs[0].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[1].checkQuantumInGraph(node.quantum): count += 1 self.assertEqual(len(self.qGraph), count) self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) allNodes = list(self.qGraph) node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) self.assertEqual(set([allNodes[0]]), node) node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) self.assertEqual(set([allNodes[0]]), node) def testDetermineOutputsOfQuantumNode(self): allNodes = list(self.qGraph) node = next( iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) self.assertEqual(allNodes[2], node) def testDetermineConnectionsOfQuantum(self): allNodes = list(self.qGraph) connections = self.qGraph.determineConnectionsOfQuantumNode( allNodes[1]) self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) def testDetermineAnsestorsOfQuantumNode(self): allNodes = list(self.qGraph) ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) def testFindCycle(self): self.assertFalse(self.qGraph.findCycle()) def testSaveLoad(self): with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: self.qGraph.save(tmpFile) tmpFile.seek(0) restore = QuantumGraph.load(tmpFile, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) # Load in just one node tmpFile.seek(0) restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) def testSaveLoadUri(self): uri = None try: with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: uri = tmpFile.name self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) nodeNumber = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, ), graphID=self.qGraph._buildId) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore.graphID))) # verify that more than one node works nodeNumber2 = random.randint(0, len(self.qGraph) - 1) # ensure it is a different node number while nodeNumber2 == nodeNumber: nodeNumber2 = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) self.assertEqual(len(restoreSub), 2) self.assertEqual( set(restoreSub), set((restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore._buildId)), restore.getQuantumNodeByNodeId( NodeId(nodeNumber2, restore._buildId))))) # verify an error when requesting a non existant node number with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, nodes=(99, )) # verify a graphID that does not match will be an error with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") except Exception as e: raise e finally: if uri is not None: os.remove(uri) with self.assertRaises(TypeError): self.qGraph.saveUri("test.notgraph") @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") @mock_s3 def testSaveLoadUriS3(self): # Test loading a quantum graph from an mock s3 store conn = boto3.resource('s3', region_name="us-east-1") conn.create_bucket(Bucket='testBucket') uri = f"s3://testBucket/qgraph.qgraph" self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) def testContains(self): firstNode = next(iter(self.qGraph)) self.assertIn(firstNode, self.qGraph)