def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph: taskDef = graph.findTaskDefByLabel(self.taskLabel) if taskDef is None: raise ValueError(f"Cannot find task with label {self.taskLabel}") quanta = list(graph.quantaForTask(taskDef)) keyQuanta = defaultdict(list) for q in quanta: key = self._key(q) keyQuanta[key].append(q) keys = sorted(keyQuanta.keys(), reverse=self.reverse) networkGraph = graph.graph for prev_key, key in zip(keys, keys[1:]): for prev_node in keyQuanta[prev_key]: for node in keyQuanta[key]: # remove any existing edges between the two nodes, but # don't fail if there are not any. Both directions need # tried because in a directed graph, order maters try: networkGraph.remove_edge(node, prev_node) networkGraph.remove_edge(prev_node, node) except nx.NetworkXException: pass networkGraph.add_edge(prev_node, node) return graph
def testSaveLoad(self): with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: self.qGraph.save(tmpFile) tmpFile.seek(0) restore = QuantumGraph.load(tmpFile, self.universe) self.assertEqual(self.qGraph, restore) # Load in just one node tmpFile.seek(0) nodeId = [n.nodeId for n in self.qGraph][0] restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) self.assertEqual(len(restoreSub), 1) self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
def testSaveLoadUriS3(self): # Test loading a quantum graph from an mock s3 store conn = boto3.resource("s3", region_name="us-east-1") conn.create_bucket(Bucket="testBucket") uri = "s3://testBucket/qgraph.qgraph" self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self.assertEqual(self.qGraph, restore) nodeId = list(self.qGraph)[0].nodeId restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,)) self.assertEqual(len(restoreSub), 1) self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
def testSaveLoadUriS3(self): # Test loading a quantum graph from an mock s3 store conn = boto3.resource('s3', region_name="us-east-1") conn.create_bucket(Bucket='testBucket') uri = f"s3://testBucket/qgraph.qgraph" self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
def testSaveLoad(self): with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: self.qGraph.save(tmpFile) tmpFile.seek(0) restore = QuantumGraph.load(tmpFile, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) # Load in just one node tmpFile.seek(0) restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
def testMakeGraphFromPickle(self): """Tests for CmdLineFwk.makeGraph method. Only most trivial case is tested that does not do actual graph building. """ fwk = CmdLineFwk() taskFactory = TaskFactoryMock() with makeTmpFile() as tmpname: # make empty graph and store it in a file qgraph = QuantumGraph() with open(tmpname, "wb") as pickleFile: pickle.dump(qgraph, pickleFile) args = _makeArgs(qgraph=tmpname) with self.assertWarnsRegex(UserWarning, "QuantumGraph is empty"): # this also tests that warning is generated for empty graph qgraph = fwk.makeGraph(None, taskFactory, args) self.assertIsInstance(qgraph, QuantumGraph) self.assertEqual(len(qgraph), 0) # pickle with wrong object type with open(tmpname, "wb") as pickleFile: pickle.dump({}, pickleFile) args = _makeArgs(qgraph=tmpname) with self.assertRaises(TypeError): fwk.makeGraph(None, taskFactory, args)
def makeGraph(self, pipeline, args): """Build a graph from command line arguments. Parameters ---------- pipeline : `~lsst.pipe.base.Pipeline` Pipeline, can be empty or ``None`` if graph is read from a file. args : `argparse.Namespace` Parsed command line Returns ------- graph : `~lsst.pipe.base.QuantumGraph` or `None` If resulting graph is empty then `None` is returned. """ registry, collections, run = _ButlerFactory.makeRegistryAndCollections(args) if args.qgraph: # click passes empty tuple as default value for qgraph_node_id nodes = args.qgraph_node_id or None qgraph = QuantumGraph.loadUri(args.qgraph, registry.dimensions, nodes=nodes, graphID=args.qgraph_id) # pipeline can not be provided in this case if pipeline: raise ValueError("Pipeline must not be given when quantum graph is read from file.") else: # make execution plan (a.k.a. DAG) for pipeline graphBuilder = GraphBuilder(registry, skipExisting=args.skip_existing) qgraph = graphBuilder.makeGraph(pipeline, collections, run, args.data_query) # count quanta in graph and give a warning if it's empty and return None nQuanta = len(qgraph) if nQuanta == 0: warnings.warn("QuantumGraph is empty", stacklevel=2) return None else: _LOG.info("QuantumGraph contains %d quanta for %d tasks, graph ID: %r", nQuanta, len(qgraph.taskGraph), qgraph.graphID) if args.save_qgraph: qgraph.saveUri(args.save_qgraph) if args.save_single_quanta: for quantumNode in qgraph: sqgraph = qgraph.subset(quantumNode) uri = args.save_single_quanta.format(quantumNode.nodeId.number) sqgraph.saveUri(uri) if args.qgraph_dot: graph2dot(qgraph, args.qgraph_dot) return qgraph
def testMakeGraphFromSave(self): """Tests for CmdLineFwk.makeGraph method. Only most trivial case is tested that does not do actual graph building. """ fwk = CmdLineFwk() with makeTmpFile(suffix=".qgraph") as tmpname, makeSQLiteRegistry( ) as registryConfig: # make non-empty graph and store it in a file qgraph = _makeQGraph() with open(tmpname, "wb") as saveFile: qgraph.save(saveFile) args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig) qgraph = fwk.makeGraph(None, args) self.assertIsInstance(qgraph, QuantumGraph) self.assertEqual(len(qgraph), 1) # will fail if graph id does not match args = _makeArgs(qgraph=tmpname, qgraph_id="R2-D2 is that you?", registryConfig=registryConfig) with self.assertRaisesRegex(ValueError, "graphID does not match"): fwk.makeGraph(None, args) # save with wrong object type with open(tmpname, "wb") as saveFile: pickle.dump({}, saveFile) args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig) with self.assertRaises(ValueError): fwk.makeGraph(None, args) # reading empty graph from pickle should work but makeGraph() # will return None and make a warning qgraph = QuantumGraph(dict()) with open(tmpname, "wb") as saveFile: qgraph.save(saveFile) args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig) with self.assertWarnsRegex(UserWarning, "QuantumGraph is empty"): # this also tests that warning is generated for empty graph qgraph = fwk.makeGraph(None, args) self.assertIs(qgraph, None)
def testSaveLoadUri(self): uri = None try: with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: uri = tmpFile.name self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) nodeNumber = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, ), graphID=self.qGraph._buildId) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore.graphID))) # verify that more than one node works nodeNumber2 = random.randint(0, len(self.qGraph) - 1) # ensure it is a different node number while nodeNumber2 == nodeNumber: nodeNumber2 = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) self.assertEqual(len(restoreSub), 2) self.assertEqual( set(restoreSub), set((restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore._buildId)), restore.getQuantumNodeByNodeId( NodeId(nodeNumber2, restore._buildId))))) # verify an error when requesting a non existant node number with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, nodes=(99, )) # verify a graphID that does not match will be an error with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") except Exception as e: raise e finally: if uri is not None: os.remove(uri) with self.assertRaises(TypeError): self.qGraph.saveUri("test.notgraph")
def _makeQGraph(): """Make a trivial QuantumGraph with one quantum. The only thing that we need to do with this quantum graph is to pickle it, the quanta in this graph are not usable for anything else. Returns ------- qgraph : `~lsst.pipe.base.QuantumGraph` """ # The task name in TaskDef needs to be a real importable name, use one that is sure to exist taskDef = TaskDef(taskName="lsst.pipe.base.Struct", config=SimpleConfig()) quanta = [ Quantum(taskName="lsst.pipe.base.Struct", inputs={FakeTaskDef("A"): FakeDSRef("A", (1, 2))}) ] # type: ignore qgraph = QuantumGraph({taskDef: set(quanta)}) return qgraph
def test_datastore_records(self): """Test for generating datastore records.""" with temporaryDirectory() as root: # need FileDatastore for this tests butler, qgraph1 = simpleQGraph.makeSimpleQGraph( root=root, inMemory=False, makeDatastoreRecords=True) # save and reload buffer = io.BytesIO() qgraph1.save(buffer) buffer.seek(0) qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions) del buffer for qgraph in (qgraph1, qgraph2): self.assertEqual(len(qgraph), 5) for i, qnode in enumerate(qgraph): quantum = qnode.quantum self.assertIsNotNone(quantum.datastore_records) # only the first quantum has a pre-existing input if i == 0: datastore_name = "FileDatastore@<butlerRoot>" self.assertEqual(set(quantum.datastore_records.keys()), {datastore_name}) records_data = quantum.datastore_records[ datastore_name] records = dict(records_data.records) self.assertEqual(len(records), 1) _, records = records.popitem() records = records["file_datastore_records"] self.assertEqual( [record.path for record in records], [ "test/add_dataset0/add_dataset0_INSTR_det0_test.pickle" ], ) else: self.assertEqual(quantum.datastore_records, {})
def setUp(self): config = Config( { "version": 1, "namespace": "pipe_base_test", "skypix": { "common": "htm7", "htm": { "class": "lsst.sphgeom.HtmPixelization", "max_level": 24, }, }, "elements": { "A": { "keys": [ { "name": "id", "type": "int", } ], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, "B": { "keys": [ { "name": "id", "type": "int", } ], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, }, "packers": {}, } ) universe = DimensionUniverse(config=config) # need to make a mapping of TaskDef to set of quantum quantumMap = {} tasks = [] for task, label in ( (Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T"), (Dummy4PipelineTask, "U"), ): config = task.ConfigClass() taskDef = TaskDef(get_full_type_name(task), config, task, label) tasks.append(taskDef) quantumSet = set() connections = taskDef.connections for a, b in ((1, 2), (3, 4)): if connections.initInputs: initInputDSType = DatasetType( connections.initInput.name, tuple(), storageClass=connections.initInput.storageClass, universe=universe, ) initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] else: initRefs = None inputDSType = DatasetType( connections.input.name, connections.input.dimensions, storageClass=connections.input.storageClass, universe=universe, ) inputRefs = [ DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) ] outputDSType = DatasetType( connections.output.name, connections.output.dimensions, storageClass=connections.output.storageClass, universe=universe, ) outputRefs = [ DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) ] quantumSet.add( Quantum( taskName=task.__qualname__, dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), taskClass=task, initInputs=initRefs, inputs={inputDSType: inputRefs}, outputs={outputDSType: outputRefs}, ) ) quantumMap[taskDef] = quantumSet self.tasks = tasks self.quantumMap = quantumMap self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) self.universe = universe
class QuantumGraphTestCase(unittest.TestCase): """Tests the various functions of a quantum graph""" def setUp(self): config = Config( { "version": 1, "namespace": "pipe_base_test", "skypix": { "common": "htm7", "htm": { "class": "lsst.sphgeom.HtmPixelization", "max_level": 24, }, }, "elements": { "A": { "keys": [ { "name": "id", "type": "int", } ], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, "B": { "keys": [ { "name": "id", "type": "int", } ], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, }, "packers": {}, } ) universe = DimensionUniverse(config=config) # need to make a mapping of TaskDef to set of quantum quantumMap = {} tasks = [] for task, label in ( (Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T"), (Dummy4PipelineTask, "U"), ): config = task.ConfigClass() taskDef = TaskDef(get_full_type_name(task), config, task, label) tasks.append(taskDef) quantumSet = set() connections = taskDef.connections for a, b in ((1, 2), (3, 4)): if connections.initInputs: initInputDSType = DatasetType( connections.initInput.name, tuple(), storageClass=connections.initInput.storageClass, universe=universe, ) initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] else: initRefs = None inputDSType = DatasetType( connections.input.name, connections.input.dimensions, storageClass=connections.input.storageClass, universe=universe, ) inputRefs = [ DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) ] outputDSType = DatasetType( connections.output.name, connections.output.dimensions, storageClass=connections.output.storageClass, universe=universe, ) outputRefs = [ DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) ] quantumSet.add( Quantum( taskName=task.__qualname__, dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), taskClass=task, initInputs=initRefs, inputs={inputDSType: inputRefs}, outputs={outputDSType: outputRefs}, ) ) quantumMap[taskDef] = quantumSet self.tasks = tasks self.quantumMap = quantumMap self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) self.universe = universe def testTaskGraph(self): for taskDef in self.quantumMap.keys(): self.assertIn(taskDef, self.qGraph.taskGraph) def testGraph(self): graphSet = {q.quantum for q in self.qGraph.graph} for quantum in chain.from_iterable(self.quantumMap.values()): self.assertIn(quantum, graphSet) def testGetQuantumNodeByNodeId(self): inputQuanta = tuple(self.qGraph.inputQuanta) node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) self.assertEqual(node, inputQuanta[0]) wrongNode = uuid.uuid4() with self.assertRaises(KeyError): self.qGraph.getQuantumNodeByNodeId(wrongNode) def testPickle(self): stringify = pickle.dumps(self.qGraph) restore: QuantumGraph = pickle.loads(stringify) self.assertEqual(self.qGraph, restore) def testInputQuanta(self): inputs = {q.quantum for q in self.qGraph.inputQuanta} self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) def testOutputtQuanta(self): outputs = {q.quantum for q in self.qGraph.outputQuanta} self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) def testLength(self): self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) def testGetQuantaForTask(self): for task in self.tasks: self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) def testGetNodesForTask(self): for task in self.tasks: nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) quanta_in_node = set(n.quantum for n in nodes) self.assertEqual(quanta_in_node, self.quantumMap[task]) def testFindTasksWithInput(self): self.assertEqual( tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] ) def testFindTasksWithOutput(self): self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) def testTaskWithDSType(self): self.assertEqual( set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) ) def testFindTaskDefByName(self): self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) def testFindTaskDefByLabel(self): self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) def testFindQuantaWIthDSType(self): self.assertEqual( self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] ) def testAllDatasetTypes(self): allDatasetTypes = set(self.qGraph.allDatasetTypes) truth = set() for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): for connection in conClass.allConnections.values(): # type: ignore if not isinstance(connection, cT.InitOutput): truth.add(connection.name) self.assertEqual(allDatasetTypes, truth) def testSubset(self): allNodes = list(self.qGraph) subset = self.qGraph.subset(allNodes[0]) self.assertEqual(len(subset), 1) subsetList = list(subset) self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) self.assertEqual(self.qGraph._buildId, subset._buildId) def testSubsetToConnected(self): # False because there are two quantum chains for two distinct sets of # dimensions self.assertFalse(self.qGraph.isConnected) connectedGraphs = self.qGraph.subsetToConnected() self.assertEqual(len(connectedGraphs), 4) self.assertTrue(connectedGraphs[0].isConnected) self.assertTrue(connectedGraphs[1].isConnected) self.assertTrue(connectedGraphs[2].isConnected) self.assertTrue(connectedGraphs[3].isConnected) # Split out task[3] because it is expected to be on its own for cg in connectedGraphs: if self.tasks[3] in cg.taskGraph: self.assertEqual(len(cg), 1) else: self.assertEqual(len(cg), 3) self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) count = 0 for node in self.qGraph: if connectedGraphs[0].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[1].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[2].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[3].checkQuantumInGraph(node.quantum): count += 1 self.assertEqual(len(self.qGraph), count) taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
#!/usr/bin/env python from lsst.pipe.base import QuantumGraph from lsst.daf.butler import DimensionUniverse, Butler butler = Butler('/repo/main') du = DimensionUniverse() qgraph = QuantumGraph.loadUri('/home/krughoff/public_html/data/two_ccd_processccd.qgraph', du) exports = set() def runner(nodes, exports, visited=None): if not visited: visited = set() for node in nodes: if node in visited: continue exports.update([ref for thing in node.quantum.inputs.values() for ref in thing]) exports.update([ref for thing in node.quantum.outputs.values() for ref in thing]) exports.update([ref for ref in node.quantum.initInputs.values()]) before = qgraph.determineAncestorsOfQuantumNode(node) visited.add(node) if before: runner(before, exports, visited) runner([node for node in qgraph.getNodesForTask(qgraph.findTaskDefByLabel('calibrate'))], exports) resolved_refs = [butler.registry.findDataset(datasetType=ex.datasetType, dataId=ex.dataId, collections=butler.registry.queryCollections()) for ex in exports] with butler.export(filename='export.yaml', directory='rsp_data_export', transfer='copy') as export: export.saveDatasets(resolved_refs) export.saveCollection("HSC/calib") export.saveCollection("HSC/calib/DM-28636") export.saveCollection("HSC/calib/gen2/20180117") export.saveCollection("HSC/calib/gen2/20180117/unbounded")
class QuantumGraphTestCase(unittest.TestCase): """Tests the various functions of a quantum graph """ def setUp(self): config = Config({ "version": 1, "skypix": { "common": "htm7", "htm": { "class": "lsst.sphgeom.HtmPixelization", "max_level": 24, } }, "elements": { "A": { "keys": [{ "name": "id", "type": "int", }], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, }, "B": { "keys": [{ "name": "id", "type": "int", }], "storage": { "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", }, } }, "packers": {} }) universe = DimensionUniverse(config=config) # need to make a mapping of TaskDef to set of quantum quantumMap = {} tasks = [] for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): config = task.ConfigClass() taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) tasks.append(taskDef) quantumSet = set() connections = taskDef.connections for a, b in ((1, 2), (3, 4)): if connections.initInputs: initInputDSType = DatasetType( connections.initInput.name, tuple(), storageClass=connections.initInput.storageClass, universe=universe) initRefs = [ DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe)) ] else: initRefs = None inputDSType = DatasetType( connections.input.name, connections.input.dimensions, storageClass=connections.input.storageClass, universe=universe, ) inputRefs = [ DatasetRef( inputDSType, DataCoordinate.standardize({ "A": a, "B": b }, universe=universe)) ] outputDSType = DatasetType( connections.output.name, connections.output.dimensions, storageClass=connections.output.storageClass, universe=universe, ) outputRefs = [ DatasetRef( outputDSType, DataCoordinate.standardize({ "A": a, "B": b }, universe=universe)) ] quantumSet.add( Quantum(taskName=task.__qualname__, dataId=DataCoordinate.standardize( { "A": a, "B": b }, universe=universe), taskClass=task, initInputs=initRefs, inputs={inputDSType: inputRefs}, outputs={outputDSType: outputRefs})) quantumMap[taskDef] = quantumSet self.tasks = tasks self.quantumMap = quantumMap self.qGraph = QuantumGraph(quantumMap) self.universe = universe def _cleanGraphs(self, graph1, graph2): # This is a hack for the unit test since the qualified name will be # different as it will be __main__ here, but qualified to the # unittest module name when restored # Updates in place for saved, loaded in zip(graph1._quanta.keys(), graph2._quanta.keys()): saved.taskName = saved.taskName.split('.')[-1] loaded.taskName = loaded.taskName.split('.')[-1] def testTaskGraph(self): for taskDef in self.quantumMap.keys(): self.assertIn(taskDef, self.qGraph.taskGraph) def testGraph(self): graphSet = {q.quantum for q in self.qGraph.graph} for quantum in chain.from_iterable(self.quantumMap.values()): self.assertIn(quantum, graphSet) def testGetQuantumNodeByNodeId(self): inputQuanta = tuple(self.qGraph.inputQuanta) node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) self.assertEqual(node, inputQuanta[0]) wrongNode = NodeId(15, BuildId("alternative build Id")) with self.assertRaises(IncompatibleGraphError): self.qGraph.getQuantumNodeByNodeId(wrongNode) def testPickle(self): stringify = pickle.dumps(self.qGraph) restore: QuantumGraph = pickle.loads(stringify) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) def testInputQuanta(self): inputs = {q.quantum for q in self.qGraph.inputQuanta} self.assertEqual(self.quantumMap[self.tasks[0]], inputs) def testOutputtQuanta(self): outputs = {q.quantum for q in self.qGraph.outputQuanta} self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) def testLength(self): self.assertEqual(len(self.qGraph), 6) def testGetQuantaForTask(self): for task in self.tasks: self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) def testFindTasksWithInput(self): self.assertEqual( tuple( self.qGraph.findTasksWithInput( DatasetTypeName("Dummy1Output")))[0], self.tasks[1]) def testFindTasksWithOutput(self): self.assertEqual( self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) def testTaskWithDSType(self): self.assertEqual( set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])) def testFindTaskDefByName(self): self.assertEqual( self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) def testFindTaskDefByLabel(self): self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) def testFindQuantaWIthDSType(self): self.assertEqual( self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]) def testAllDatasetTypes(self): allDatasetTypes = set(self.qGraph.allDatasetTypes) truth = set() for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): for connection in conClass.allConnections.values(): # type: ignore truth.add(connection.name) self.assertEqual(allDatasetTypes, truth) def testSubset(self): allNodes = list(self.qGraph) subset = self.qGraph.subset(allNodes[0]) self.assertEqual(len(subset), 1) subsetList = list(subset) self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) self.assertEqual(self.qGraph._buildId, subset._buildId) def testIsConnected(self): # False because there are two quantum chains for two distinct sets of # dimensions self.assertFalse(self.qGraph.isConnected) # make a broken subset allNodes = list(self.qGraph) subset = self.qGraph.subset((allNodes[0], allNodes[1])) # True because we subset to only one chain of graphs self.assertTrue(subset.isConnected) def testSubsetToConnected(self): connectedGraphs = self.qGraph.subsetToConnected() self.assertEqual(len(connectedGraphs), 2) self.assertTrue(connectedGraphs[0].isConnected) self.assertTrue(connectedGraphs[1].isConnected) self.assertEqual(len(connectedGraphs[0]), 3) self.assertEqual(len(connectedGraphs[1]), 3) self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) count = 0 for node in self.qGraph: if connectedGraphs[0].checkQuantumInGraph(node.quantum): count += 1 if connectedGraphs[1].checkQuantumInGraph(node.quantum): count += 1 self.assertEqual(len(self.qGraph), count) self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) allNodes = list(self.qGraph) node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) self.assertEqual(set([allNodes[0]]), node) node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) self.assertEqual(set([allNodes[0]]), node) def testDetermineOutputsOfQuantumNode(self): allNodes = list(self.qGraph) node = next( iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) self.assertEqual(allNodes[2], node) def testDetermineConnectionsOfQuantum(self): allNodes = list(self.qGraph) connections = self.qGraph.determineConnectionsOfQuantumNode( allNodes[1]) self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) def testDetermineAnsestorsOfQuantumNode(self): allNodes = list(self.qGraph) ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) def testFindCycle(self): self.assertFalse(self.qGraph.findCycle()) def testSaveLoad(self): with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: self.qGraph.save(tmpFile) tmpFile.seek(0) restore = QuantumGraph.load(tmpFile, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) # Load in just one node tmpFile.seek(0) restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) def testSaveLoadUri(self): uri = None try: with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: uri = tmpFile.name self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) nodeNumber = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, ), graphID=self.qGraph._buildId) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore.graphID))) # verify that more than one node works nodeNumber2 = random.randint(0, len(self.qGraph) - 1) # ensure it is a different node number while nodeNumber2 == nodeNumber: nodeNumber2 = random.randint(0, len(self.qGraph) - 1) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) self.assertEqual(len(restoreSub), 2) self.assertEqual( set(restoreSub), set((restore.getQuantumNodeByNodeId( NodeId(nodeNumber, restore._buildId)), restore.getQuantumNodeByNodeId( NodeId(nodeNumber2, restore._buildId))))) # verify an error when requesting a non existant node number with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, nodes=(99, )) # verify a graphID that does not match will be an error with self.assertRaises(ValueError): QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") except Exception as e: raise e finally: if uri is not None: os.remove(uri) with self.assertRaises(TypeError): self.qGraph.saveUri("test.notgraph") @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") @mock_s3 def testSaveLoadUriS3(self): # Test loading a quantum graph from an mock s3 store conn = boto3.resource('s3', region_name="us-east-1") conn.create_bucket(Bucket='testBucket') uri = f"s3://testBucket/qgraph.qgraph" self.qGraph.saveUri(uri) restore = QuantumGraph.loadUri(uri, self.universe) self._cleanGraphs(self.qGraph, restore) self.assertEqual(self.qGraph, restore) restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0, )) self.assertEqual(len(restoreSub), 1) self.assertEqual( list(restoreSub)[0], restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) def testContains(self): firstNode = next(iter(self.qGraph)) self.assertIn(firstNode, self.qGraph)