Beispiel #1
0
    def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
        taskDef = graph.findTaskDefByLabel(self.taskLabel)
        if taskDef is None:
            raise ValueError(f"Cannot find task with label {self.taskLabel}")
        quanta = list(graph.quantaForTask(taskDef))
        keyQuanta = defaultdict(list)
        for q in quanta:
            key = self._key(q)
            keyQuanta[key].append(q)
        keys = sorted(keyQuanta.keys(), reverse=self.reverse)
        networkGraph = graph.graph

        for prev_key, key in zip(keys, keys[1:]):
            for prev_node in keyQuanta[prev_key]:
                for node in keyQuanta[key]:
                    # remove any existing edges between the two nodes, but
                    # don't fail if there are not any. Both directions need
                    # tried because in a directed graph, order maters
                    try:
                        networkGraph.remove_edge(node, prev_node)
                        networkGraph.remove_edge(prev_node, node)
                    except nx.NetworkXException:
                        pass
                    networkGraph.add_edge(prev_node, node)
        return graph
Beispiel #2
0
 def testSaveLoad(self):
     with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile:
         self.qGraph.save(tmpFile)
         tmpFile.seek(0)
         restore = QuantumGraph.load(tmpFile, self.universe)
         self.assertEqual(self.qGraph, restore)
         # Load in just one node
         tmpFile.seek(0)
         nodeId = [n.nodeId for n in self.qGraph][0]
         restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,))
         self.assertEqual(len(restoreSub), 1)
         self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
Beispiel #3
0
 def testSaveLoadUriS3(self):
     # Test loading a quantum graph from an mock s3 store
     conn = boto3.resource("s3", region_name="us-east-1")
     conn.create_bucket(Bucket="testBucket")
     uri = "s3://testBucket/qgraph.qgraph"
     self.qGraph.saveUri(uri)
     restore = QuantumGraph.loadUri(uri, self.universe)
     self.assertEqual(self.qGraph, restore)
     nodeId = list(self.qGraph)[0].nodeId
     restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeId,))
     self.assertEqual(len(restoreSub), 1)
     self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId))
Beispiel #4
0
 def testSaveLoadUriS3(self):
     # Test loading a quantum graph from an mock s3 store
     conn = boto3.resource('s3', region_name="us-east-1")
     conn.create_bucket(Bucket='testBucket')
     uri = f"s3://testBucket/qgraph.qgraph"
     self.qGraph.saveUri(uri)
     restore = QuantumGraph.loadUri(uri, self.universe)
     self._cleanGraphs(self.qGraph, restore)
     self.assertEqual(self.qGraph, restore)
     restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0, ))
     self.assertEqual(len(restoreSub), 1)
     self.assertEqual(
         list(restoreSub)[0],
         restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
Beispiel #5
0
 def testSaveLoad(self):
     with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
         self.qGraph.save(tmpFile)
         tmpFile.seek(0)
         restore = QuantumGraph.load(tmpFile, self.universe)
         self._cleanGraphs(self.qGraph, restore)
         self.assertEqual(self.qGraph, restore)
         # Load in just one node
         tmpFile.seek(0)
         restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0, ))
         self.assertEqual(len(restoreSub), 1)
         self.assertEqual(
             list(restoreSub)[0],
             restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))
    def testMakeGraphFromPickle(self):
        """Tests for CmdLineFwk.makeGraph method.

        Only most trivial case is tested that does not do actual graph
        building.
        """
        fwk = CmdLineFwk()
        taskFactory = TaskFactoryMock()

        with makeTmpFile() as tmpname:

            # make empty graph and store it in a file
            qgraph = QuantumGraph()
            with open(tmpname, "wb") as pickleFile:
                pickle.dump(qgraph, pickleFile)
            args = _makeArgs(qgraph=tmpname)
            with self.assertWarnsRegex(UserWarning, "QuantumGraph is empty"):
                # this also tests that warning is generated for empty graph
                qgraph = fwk.makeGraph(None, taskFactory, args)
            self.assertIsInstance(qgraph, QuantumGraph)
            self.assertEqual(len(qgraph), 0)

            # pickle with wrong object type
            with open(tmpname, "wb") as pickleFile:
                pickle.dump({}, pickleFile)
            args = _makeArgs(qgraph=tmpname)
            with self.assertRaises(TypeError):
                fwk.makeGraph(None, taskFactory, args)
Beispiel #7
0
    def makeGraph(self, pipeline, args):
        """Build a graph from command line arguments.

        Parameters
        ----------
        pipeline : `~lsst.pipe.base.Pipeline`
            Pipeline, can be empty or ``None`` if graph is read from a file.
        args : `argparse.Namespace`
            Parsed command line

        Returns
        -------
        graph : `~lsst.pipe.base.QuantumGraph` or `None`
            If resulting graph is empty then `None` is returned.
        """

        registry, collections, run = _ButlerFactory.makeRegistryAndCollections(args)

        if args.qgraph:
            # click passes empty tuple as default value for qgraph_node_id
            nodes = args.qgraph_node_id or None
            qgraph = QuantumGraph.loadUri(args.qgraph, registry.dimensions,
                                          nodes=nodes, graphID=args.qgraph_id)

            # pipeline can not be provided in this case
            if pipeline:
                raise ValueError("Pipeline must not be given when quantum graph is read from file.")

        else:

            # make execution plan (a.k.a. DAG) for pipeline
            graphBuilder = GraphBuilder(registry,
                                        skipExisting=args.skip_existing)
            qgraph = graphBuilder.makeGraph(pipeline, collections, run, args.data_query)

        # count quanta in graph and give a warning if it's empty and return None
        nQuanta = len(qgraph)
        if nQuanta == 0:
            warnings.warn("QuantumGraph is empty", stacklevel=2)
            return None
        else:
            _LOG.info("QuantumGraph contains %d quanta for %d tasks, graph ID: %r",
                      nQuanta, len(qgraph.taskGraph), qgraph.graphID)

        if args.save_qgraph:
            qgraph.saveUri(args.save_qgraph)

        if args.save_single_quanta:
            for quantumNode in qgraph:
                sqgraph = qgraph.subset(quantumNode)
                uri = args.save_single_quanta.format(quantumNode.nodeId.number)
                sqgraph.saveUri(uri)

        if args.qgraph_dot:
            graph2dot(qgraph, args.qgraph_dot)

        return qgraph
Beispiel #8
0
    def testMakeGraphFromSave(self):
        """Tests for CmdLineFwk.makeGraph method.

        Only most trivial case is tested that does not do actual graph
        building.
        """
        fwk = CmdLineFwk()

        with makeTmpFile(suffix=".qgraph") as tmpname, makeSQLiteRegistry(
        ) as registryConfig:

            # make non-empty graph and store it in a file
            qgraph = _makeQGraph()
            with open(tmpname, "wb") as saveFile:
                qgraph.save(saveFile)
            args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig)
            qgraph = fwk.makeGraph(None, args)
            self.assertIsInstance(qgraph, QuantumGraph)
            self.assertEqual(len(qgraph), 1)

            # will fail if graph id does not match
            args = _makeArgs(qgraph=tmpname,
                             qgraph_id="R2-D2 is that you?",
                             registryConfig=registryConfig)
            with self.assertRaisesRegex(ValueError, "graphID does not match"):
                fwk.makeGraph(None, args)

            # save with wrong object type
            with open(tmpname, "wb") as saveFile:
                pickle.dump({}, saveFile)
            args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig)
            with self.assertRaises(ValueError):
                fwk.makeGraph(None, args)

            # reading empty graph from pickle should work but makeGraph()
            # will return None and make a warning
            qgraph = QuantumGraph(dict())
            with open(tmpname, "wb") as saveFile:
                qgraph.save(saveFile)
            args = _makeArgs(qgraph=tmpname, registryConfig=registryConfig)
            with self.assertWarnsRegex(UserWarning, "QuantumGraph is empty"):
                # this also tests that warning is generated for empty graph
                qgraph = fwk.makeGraph(None, args)
            self.assertIs(qgraph, None)
Beispiel #9
0
    def testSaveLoadUri(self):
        uri = None
        try:
            with tempfile.NamedTemporaryFile(delete=False,
                                             suffix=".qgraph") as tmpFile:
                uri = tmpFile.name
                self.qGraph.saveUri(uri)
                restore = QuantumGraph.loadUri(uri, self.universe)
                self._cleanGraphs(self.qGraph, restore)
                self.assertEqual(self.qGraph, restore)
                nodeNumber = random.randint(0, len(self.qGraph) - 1)
                restoreSub = QuantumGraph.loadUri(uri,
                                                  self.universe,
                                                  nodes=(nodeNumber, ),
                                                  graphID=self.qGraph._buildId)
                self.assertEqual(len(restoreSub), 1)
                self.assertEqual(
                    list(restoreSub)[0],
                    restore.getQuantumNodeByNodeId(
                        NodeId(nodeNumber, restore.graphID)))
                # verify that more than one node works
                nodeNumber2 = random.randint(0, len(self.qGraph) - 1)
                # ensure it is a different node number
                while nodeNumber2 == nodeNumber:
                    nodeNumber2 = random.randint(0, len(self.qGraph) - 1)
                restoreSub = QuantumGraph.loadUri(uri,
                                                  self.universe,
                                                  nodes=(nodeNumber,
                                                         nodeNumber2))
                self.assertEqual(len(restoreSub), 2)
                self.assertEqual(
                    set(restoreSub),
                    set((restore.getQuantumNodeByNodeId(
                        NodeId(nodeNumber, restore._buildId)),
                         restore.getQuantumNodeByNodeId(
                             NodeId(nodeNumber2, restore._buildId)))))
                # verify an error when requesting a non existant node number
                with self.assertRaises(ValueError):
                    QuantumGraph.loadUri(uri, self.universe, nodes=(99, ))

                # verify a graphID that does not match will be an error
                with self.assertRaises(ValueError):
                    QuantumGraph.loadUri(uri,
                                         self.universe,
                                         graphID="NOTRIGHT")

        except Exception as e:
            raise e
        finally:
            if uri is not None:
                os.remove(uri)

        with self.assertRaises(TypeError):
            self.qGraph.saveUri("test.notgraph")
Beispiel #10
0
def _makeQGraph():
    """Make a trivial QuantumGraph with one quantum.

    The only thing that we need to do with this quantum graph is to pickle
    it, the quanta in this graph are not usable for anything else.

    Returns
    -------
    qgraph : `~lsst.pipe.base.QuantumGraph`
    """

    # The task name in TaskDef needs to be a real importable name, use one that is sure to exist
    taskDef = TaskDef(taskName="lsst.pipe.base.Struct", config=SimpleConfig())
    quanta = [
        Quantum(taskName="lsst.pipe.base.Struct",
                inputs={FakeTaskDef("A"): FakeDSRef("A", (1, 2))})
    ]  # type: ignore
    qgraph = QuantumGraph({taskDef: set(quanta)})
    return qgraph
Beispiel #11
0
    def test_datastore_records(self):
        """Test for generating datastore records."""
        with temporaryDirectory() as root:
            # need FileDatastore for this tests
            butler, qgraph1 = simpleQGraph.makeSimpleQGraph(
                root=root, inMemory=False, makeDatastoreRecords=True)

            # save and reload
            buffer = io.BytesIO()
            qgraph1.save(buffer)
            buffer.seek(0)
            qgraph2 = QuantumGraph.load(buffer, universe=butler.dimensions)
            del buffer

            for qgraph in (qgraph1, qgraph2):
                self.assertEqual(len(qgraph), 5)
                for i, qnode in enumerate(qgraph):
                    quantum = qnode.quantum
                    self.assertIsNotNone(quantum.datastore_records)
                    # only the first quantum has a pre-existing input
                    if i == 0:
                        datastore_name = "FileDatastore@<butlerRoot>"
                        self.assertEqual(set(quantum.datastore_records.keys()),
                                         {datastore_name})
                        records_data = quantum.datastore_records[
                            datastore_name]
                        records = dict(records_data.records)
                        self.assertEqual(len(records), 1)
                        _, records = records.popitem()
                        records = records["file_datastore_records"]
                        self.assertEqual(
                            [record.path for record in records],
                            [
                                "test/add_dataset0/add_dataset0_INSTR_det0_test.pickle"
                            ],
                        )
                    else:
                        self.assertEqual(quantum.datastore_records, {})
Beispiel #12
0
 def setUp(self):
     config = Config(
         {
             "version": 1,
             "namespace": "pipe_base_test",
             "skypix": {
                 "common": "htm7",
                 "htm": {
                     "class": "lsst.sphgeom.HtmPixelization",
                     "max_level": 24,
                 },
             },
             "elements": {
                 "A": {
                     "keys": [
                         {
                             "name": "id",
                             "type": "int",
                         }
                     ],
                     "storage": {
                         "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                     },
                 },
                 "B": {
                     "keys": [
                         {
                             "name": "id",
                             "type": "int",
                         }
                     ],
                     "storage": {
                         "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                     },
                 },
             },
             "packers": {},
         }
     )
     universe = DimensionUniverse(config=config)
     # need to make a mapping of TaskDef to set of quantum
     quantumMap = {}
     tasks = []
     for task, label in (
         (Dummy1PipelineTask, "R"),
         (Dummy2PipelineTask, "S"),
         (Dummy3PipelineTask, "T"),
         (Dummy4PipelineTask, "U"),
     ):
         config = task.ConfigClass()
         taskDef = TaskDef(get_full_type_name(task), config, task, label)
         tasks.append(taskDef)
         quantumSet = set()
         connections = taskDef.connections
         for a, b in ((1, 2), (3, 4)):
             if connections.initInputs:
                 initInputDSType = DatasetType(
                     connections.initInput.name,
                     tuple(),
                     storageClass=connections.initInput.storageClass,
                     universe=universe,
                 )
                 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
             else:
                 initRefs = None
             inputDSType = DatasetType(
                 connections.input.name,
                 connections.input.dimensions,
                 storageClass=connections.input.storageClass,
                 universe=universe,
             )
             inputRefs = [
                 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
             ]
             outputDSType = DatasetType(
                 connections.output.name,
                 connections.output.dimensions,
                 storageClass=connections.output.storageClass,
                 universe=universe,
             )
             outputRefs = [
                 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
             ]
             quantumSet.add(
                 Quantum(
                     taskName=task.__qualname__,
                     dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
                     taskClass=task,
                     initInputs=initRefs,
                     inputs={inputDSType: inputRefs},
                     outputs={outputDSType: outputRefs},
                 )
             )
         quantumMap[taskDef] = quantumSet
     self.tasks = tasks
     self.quantumMap = quantumMap
     self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
     self.universe = universe
Beispiel #13
0
class QuantumGraphTestCase(unittest.TestCase):
    """Tests the various functions of a quantum graph"""

    def setUp(self):
        config = Config(
            {
                "version": 1,
                "namespace": "pipe_base_test",
                "skypix": {
                    "common": "htm7",
                    "htm": {
                        "class": "lsst.sphgeom.HtmPixelization",
                        "max_level": 24,
                    },
                },
                "elements": {
                    "A": {
                        "keys": [
                            {
                                "name": "id",
                                "type": "int",
                            }
                        ],
                        "storage": {
                            "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                        },
                    },
                    "B": {
                        "keys": [
                            {
                                "name": "id",
                                "type": "int",
                            }
                        ],
                        "storage": {
                            "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                        },
                    },
                },
                "packers": {},
            }
        )
        universe = DimensionUniverse(config=config)
        # need to make a mapping of TaskDef to set of quantum
        quantumMap = {}
        tasks = []
        for task, label in (
            (Dummy1PipelineTask, "R"),
            (Dummy2PipelineTask, "S"),
            (Dummy3PipelineTask, "T"),
            (Dummy4PipelineTask, "U"),
        ):
            config = task.ConfigClass()
            taskDef = TaskDef(get_full_type_name(task), config, task, label)
            tasks.append(taskDef)
            quantumSet = set()
            connections = taskDef.connections
            for a, b in ((1, 2), (3, 4)):
                if connections.initInputs:
                    initInputDSType = DatasetType(
                        connections.initInput.name,
                        tuple(),
                        storageClass=connections.initInput.storageClass,
                        universe=universe,
                    )
                    initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
                else:
                    initRefs = None
                inputDSType = DatasetType(
                    connections.input.name,
                    connections.input.dimensions,
                    storageClass=connections.input.storageClass,
                    universe=universe,
                )
                inputRefs = [
                    DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
                ]
                outputDSType = DatasetType(
                    connections.output.name,
                    connections.output.dimensions,
                    storageClass=connections.output.storageClass,
                    universe=universe,
                )
                outputRefs = [
                    DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
                ]
                quantumSet.add(
                    Quantum(
                        taskName=task.__qualname__,
                        dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
                        taskClass=task,
                        initInputs=initRefs,
                        inputs={inputDSType: inputRefs},
                        outputs={outputDSType: outputRefs},
                    )
                )
            quantumMap[taskDef] = quantumSet
        self.tasks = tasks
        self.quantumMap = quantumMap
        self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
        self.universe = universe

    def testTaskGraph(self):
        for taskDef in self.quantumMap.keys():
            self.assertIn(taskDef, self.qGraph.taskGraph)

    def testGraph(self):
        graphSet = {q.quantum for q in self.qGraph.graph}
        for quantum in chain.from_iterable(self.quantumMap.values()):
            self.assertIn(quantum, graphSet)

    def testGetQuantumNodeByNodeId(self):
        inputQuanta = tuple(self.qGraph.inputQuanta)
        node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
        self.assertEqual(node, inputQuanta[0])
        wrongNode = uuid.uuid4()
        with self.assertRaises(KeyError):
            self.qGraph.getQuantumNodeByNodeId(wrongNode)

    def testPickle(self):
        stringify = pickle.dumps(self.qGraph)
        restore: QuantumGraph = pickle.loads(stringify)
        self.assertEqual(self.qGraph, restore)

    def testInputQuanta(self):
        inputs = {q.quantum for q in self.qGraph.inputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)

    def testOutputtQuanta(self):
        outputs = {q.quantum for q in self.qGraph.outputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)

    def testLength(self):
        self.assertEqual(len(self.qGraph), 2 * len(self.tasks))

    def testGetQuantaForTask(self):
        for task in self.tasks:
            self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])

    def testGetNodesForTask(self):
        for task in self.tasks:
            nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
            quanta_in_node = set(n.quantum for n in nodes)
            self.assertEqual(quanta_in_node, self.quantumMap[task])

    def testFindTasksWithInput(self):
        self.assertEqual(
            tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
        )

    def testFindTasksWithOutput(self):
        self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])

    def testTaskWithDSType(self):
        self.assertEqual(
            set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
        )

    def testFindTaskDefByName(self):
        self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])

    def testFindTaskDefByLabel(self):
        self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])

    def testFindQuantaWIthDSType(self):
        self.assertEqual(
            self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
        )

    def testAllDatasetTypes(self):
        allDatasetTypes = set(self.qGraph.allDatasetTypes)
        truth = set()
        for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
            for connection in conClass.allConnections.values():  # type: ignore
                if not isinstance(connection, cT.InitOutput):
                    truth.add(connection.name)
        self.assertEqual(allDatasetTypes, truth)

    def testSubset(self):
        allNodes = list(self.qGraph)
        subset = self.qGraph.subset(allNodes[0])
        self.assertEqual(len(subset), 1)
        subsetList = list(subset)
        self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
        self.assertEqual(self.qGraph._buildId, subset._buildId)

    def testSubsetToConnected(self):
        # False because there are two quantum chains for two distinct sets of
        # dimensions
        self.assertFalse(self.qGraph.isConnected)

        connectedGraphs = self.qGraph.subsetToConnected()
        self.assertEqual(len(connectedGraphs), 4)
        self.assertTrue(connectedGraphs[0].isConnected)
        self.assertTrue(connectedGraphs[1].isConnected)
        self.assertTrue(connectedGraphs[2].isConnected)
        self.assertTrue(connectedGraphs[3].isConnected)

        # Split out task[3] because it is expected to be on its own
        for cg in connectedGraphs:
            if self.tasks[3] in cg.taskGraph:
                self.assertEqual(len(cg), 1)
            else:
                self.assertEqual(len(cg), 3)

        self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])

        count = 0
        for node in self.qGraph:
            if connectedGraphs[0].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[1].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[2].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[3].checkQuantumInGraph(node.quantum):
                count += 1
        self.assertEqual(len(self.qGraph), count)

        taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
#!/usr/bin/env python

from lsst.pipe.base import QuantumGraph
from lsst.daf.butler import DimensionUniverse, Butler
butler = Butler('/repo/main')
du = DimensionUniverse()
qgraph = QuantumGraph.loadUri('/home/krughoff/public_html/data/two_ccd_processccd.qgraph', du)
exports = set()
def runner(nodes, exports, visited=None):
    if not visited:
        visited = set()
    for node in nodes:
        if node in visited:
            continue
        exports.update([ref for thing in node.quantum.inputs.values() for ref in thing])
        exports.update([ref for thing in node.quantum.outputs.values() for ref in thing])
        exports.update([ref for ref in node.quantum.initInputs.values()])
        before = qgraph.determineAncestorsOfQuantumNode(node)
        visited.add(node)
        if before:
            runner(before, exports, visited)
runner([node for node in qgraph.getNodesForTask(qgraph.findTaskDefByLabel('calibrate'))], exports)
resolved_refs = [butler.registry.findDataset(datasetType=ex.datasetType, dataId=ex.dataId,
                 collections=butler.registry.queryCollections()) for ex in exports]

with butler.export(filename='export.yaml', directory='rsp_data_export', transfer='copy') as export:
    export.saveDatasets(resolved_refs)
    export.saveCollection("HSC/calib")
    export.saveCollection("HSC/calib/DM-28636")
    export.saveCollection("HSC/calib/gen2/20180117")
    export.saveCollection("HSC/calib/gen2/20180117/unbounded")
Beispiel #15
0
class QuantumGraphTestCase(unittest.TestCase):
    """Tests the various functions of a quantum graph
    """
    def setUp(self):
        config = Config({
            "version": 1,
            "skypix": {
                "common": "htm7",
                "htm": {
                    "class": "lsst.sphgeom.HtmPixelization",
                    "max_level": 24,
                }
            },
            "elements": {
                "A": {
                    "keys": [{
                        "name": "id",
                        "type": "int",
                    }],
                    "storage": {
                        "cls":
                        "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                    },
                },
                "B": {
                    "keys": [{
                        "name": "id",
                        "type": "int",
                    }],
                    "storage": {
                        "cls":
                        "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                    },
                }
            },
            "packers": {}
        })
        universe = DimensionUniverse(config=config)
        # need to make a mapping of TaskDef to set of quantum
        quantumMap = {}
        tasks = []
        for task, label in ((Dummy1PipelineTask, "R"),
                            (Dummy2PipelineTask, "S"), (Dummy3PipelineTask,
                                                        "T")):
            config = task.ConfigClass()
            taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task,
                              label)
            tasks.append(taskDef)
            quantumSet = set()
            connections = taskDef.connections
            for a, b in ((1, 2), (3, 4)):
                if connections.initInputs:
                    initInputDSType = DatasetType(
                        connections.initInput.name,
                        tuple(),
                        storageClass=connections.initInput.storageClass,
                        universe=universe)
                    initRefs = [
                        DatasetRef(initInputDSType,
                                   DataCoordinate.makeEmpty(universe))
                    ]
                else:
                    initRefs = None
                inputDSType = DatasetType(
                    connections.input.name,
                    connections.input.dimensions,
                    storageClass=connections.input.storageClass,
                    universe=universe,
                )
                inputRefs = [
                    DatasetRef(
                        inputDSType,
                        DataCoordinate.standardize({
                            "A": a,
                            "B": b
                        },
                                                   universe=universe))
                ]
                outputDSType = DatasetType(
                    connections.output.name,
                    connections.output.dimensions,
                    storageClass=connections.output.storageClass,
                    universe=universe,
                )
                outputRefs = [
                    DatasetRef(
                        outputDSType,
                        DataCoordinate.standardize({
                            "A": a,
                            "B": b
                        },
                                                   universe=universe))
                ]
                quantumSet.add(
                    Quantum(taskName=task.__qualname__,
                            dataId=DataCoordinate.standardize(
                                {
                                    "A": a,
                                    "B": b
                                }, universe=universe),
                            taskClass=task,
                            initInputs=initRefs,
                            inputs={inputDSType: inputRefs},
                            outputs={outputDSType: outputRefs}))
            quantumMap[taskDef] = quantumSet
        self.tasks = tasks
        self.quantumMap = quantumMap
        self.qGraph = QuantumGraph(quantumMap)
        self.universe = universe

    def _cleanGraphs(self, graph1, graph2):
        # This is a hack for the unit test since the qualified name will be
        # different as it will be __main__ here, but qualified to the
        # unittest module name when restored
        # Updates in place
        for saved, loaded in zip(graph1._quanta.keys(), graph2._quanta.keys()):
            saved.taskName = saved.taskName.split('.')[-1]
            loaded.taskName = loaded.taskName.split('.')[-1]

    def testTaskGraph(self):
        for taskDef in self.quantumMap.keys():
            self.assertIn(taskDef, self.qGraph.taskGraph)

    def testGraph(self):
        graphSet = {q.quantum for q in self.qGraph.graph}
        for quantum in chain.from_iterable(self.quantumMap.values()):
            self.assertIn(quantum, graphSet)

    def testGetQuantumNodeByNodeId(self):
        inputQuanta = tuple(self.qGraph.inputQuanta)
        node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
        self.assertEqual(node, inputQuanta[0])
        wrongNode = NodeId(15, BuildId("alternative build Id"))
        with self.assertRaises(IncompatibleGraphError):
            self.qGraph.getQuantumNodeByNodeId(wrongNode)

    def testPickle(self):
        stringify = pickle.dumps(self.qGraph)
        restore: QuantumGraph = pickle.loads(stringify)
        self._cleanGraphs(self.qGraph, restore)
        self.assertEqual(self.qGraph, restore)

    def testInputQuanta(self):
        inputs = {q.quantum for q in self.qGraph.inputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[0]], inputs)

    def testOutputtQuanta(self):
        outputs = {q.quantum for q in self.qGraph.outputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)

    def testLength(self):
        self.assertEqual(len(self.qGraph), 6)

    def testGetQuantaForTask(self):
        for task in self.tasks:
            self.assertEqual(self.qGraph.getQuantaForTask(task),
                             self.quantumMap[task])

    def testFindTasksWithInput(self):
        self.assertEqual(
            tuple(
                self.qGraph.findTasksWithInput(
                    DatasetTypeName("Dummy1Output")))[0], self.tasks[1])

    def testFindTasksWithOutput(self):
        self.assertEqual(
            self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")),
            self.tasks[0])

    def testTaskWithDSType(self):
        self.assertEqual(
            set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
            set(self.tasks[:2]))

    def testFindTaskDefByName(self):
        self.assertEqual(
            self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
            self.tasks[0])

    def testFindTaskDefByLabel(self):
        self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])

    def testFindQuantaWIthDSType(self):
        self.assertEqual(
            self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
            self.quantumMap[self.tasks[0]])

    def testAllDatasetTypes(self):
        allDatasetTypes = set(self.qGraph.allDatasetTypes)
        truth = set()
        for conClass in (Dummy1Connections, Dummy2Connections,
                         Dummy3Connections):
            for connection in conClass.allConnections.values():  # type: ignore
                truth.add(connection.name)
        self.assertEqual(allDatasetTypes, truth)

    def testSubset(self):
        allNodes = list(self.qGraph)
        subset = self.qGraph.subset(allNodes[0])
        self.assertEqual(len(subset), 1)
        subsetList = list(subset)
        self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
        self.assertEqual(self.qGraph._buildId, subset._buildId)

    def testIsConnected(self):
        # False because there are two quantum chains for two distinct sets of
        # dimensions
        self.assertFalse(self.qGraph.isConnected)
        # make a broken subset
        allNodes = list(self.qGraph)
        subset = self.qGraph.subset((allNodes[0], allNodes[1]))
        # True because we subset to only one chain of graphs
        self.assertTrue(subset.isConnected)

    def testSubsetToConnected(self):
        connectedGraphs = self.qGraph.subsetToConnected()
        self.assertEqual(len(connectedGraphs), 2)
        self.assertTrue(connectedGraphs[0].isConnected)
        self.assertTrue(connectedGraphs[1].isConnected)

        self.assertEqual(len(connectedGraphs[0]), 3)
        self.assertEqual(len(connectedGraphs[1]), 3)

        self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])

        count = 0
        for node in self.qGraph:
            if connectedGraphs[0].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[1].checkQuantumInGraph(node.quantum):
                count += 1
        self.assertEqual(len(self.qGraph), count)

        self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
        self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
        allNodes = list(self.qGraph)
        node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
        self.assertEqual(set([allNodes[0]]), node)
        node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
        self.assertEqual(set([allNodes[0]]), node)

    def testDetermineOutputsOfQuantumNode(self):
        allNodes = list(self.qGraph)
        node = next(
            iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
        self.assertEqual(allNodes[2], node)

    def testDetermineConnectionsOfQuantum(self):
        allNodes = list(self.qGraph)
        connections = self.qGraph.determineConnectionsOfQuantumNode(
            allNodes[1])
        self.assertEqual(list(connections),
                         list(self.qGraph.subset(allNodes[:3])))

    def testDetermineAnsestorsOfQuantumNode(self):
        allNodes = list(self.qGraph)
        ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
        self.assertEqual(list(ansestors),
                         list(self.qGraph.subset(allNodes[:3])))

    def testFindCycle(self):
        self.assertFalse(self.qGraph.findCycle())

    def testSaveLoad(self):
        with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
            self.qGraph.save(tmpFile)
            tmpFile.seek(0)
            restore = QuantumGraph.load(tmpFile, self.universe)
            self._cleanGraphs(self.qGraph, restore)
            self.assertEqual(self.qGraph, restore)
            # Load in just one node
            tmpFile.seek(0)
            restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0, ))
            self.assertEqual(len(restoreSub), 1)
            self.assertEqual(
                list(restoreSub)[0],
                restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))

    def testSaveLoadUri(self):
        uri = None
        try:
            with tempfile.NamedTemporaryFile(delete=False,
                                             suffix=".qgraph") as tmpFile:
                uri = tmpFile.name
                self.qGraph.saveUri(uri)
                restore = QuantumGraph.loadUri(uri, self.universe)
                self._cleanGraphs(self.qGraph, restore)
                self.assertEqual(self.qGraph, restore)
                nodeNumber = random.randint(0, len(self.qGraph) - 1)
                restoreSub = QuantumGraph.loadUri(uri,
                                                  self.universe,
                                                  nodes=(nodeNumber, ),
                                                  graphID=self.qGraph._buildId)
                self.assertEqual(len(restoreSub), 1)
                self.assertEqual(
                    list(restoreSub)[0],
                    restore.getQuantumNodeByNodeId(
                        NodeId(nodeNumber, restore.graphID)))
                # verify that more than one node works
                nodeNumber2 = random.randint(0, len(self.qGraph) - 1)
                # ensure it is a different node number
                while nodeNumber2 == nodeNumber:
                    nodeNumber2 = random.randint(0, len(self.qGraph) - 1)
                restoreSub = QuantumGraph.loadUri(uri,
                                                  self.universe,
                                                  nodes=(nodeNumber,
                                                         nodeNumber2))
                self.assertEqual(len(restoreSub), 2)
                self.assertEqual(
                    set(restoreSub),
                    set((restore.getQuantumNodeByNodeId(
                        NodeId(nodeNumber, restore._buildId)),
                         restore.getQuantumNodeByNodeId(
                             NodeId(nodeNumber2, restore._buildId)))))
                # verify an error when requesting a non existant node number
                with self.assertRaises(ValueError):
                    QuantumGraph.loadUri(uri, self.universe, nodes=(99, ))

                # verify a graphID that does not match will be an error
                with self.assertRaises(ValueError):
                    QuantumGraph.loadUri(uri,
                                         self.universe,
                                         graphID="NOTRIGHT")

        except Exception as e:
            raise e
        finally:
            if uri is not None:
                os.remove(uri)

        with self.assertRaises(TypeError):
            self.qGraph.saveUri("test.notgraph")

    @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
    @mock_s3
    def testSaveLoadUriS3(self):
        # Test loading a quantum graph from an mock s3 store
        conn = boto3.resource('s3', region_name="us-east-1")
        conn.create_bucket(Bucket='testBucket')
        uri = f"s3://testBucket/qgraph.qgraph"
        self.qGraph.saveUri(uri)
        restore = QuantumGraph.loadUri(uri, self.universe)
        self._cleanGraphs(self.qGraph, restore)
        self.assertEqual(self.qGraph, restore)
        restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0, ))
        self.assertEqual(len(restoreSub), 1)
        self.assertEqual(
            list(restoreSub)[0],
            restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))

    def testContains(self):
        firstNode = next(iter(self.qGraph))
        self.assertIn(firstNode, self.qGraph)