Esempio n. 1
0
    def fixupQuanta(self, graph: QuantumGraph) -> QuantumGraph:
        taskDef = graph.findTaskDefByLabel(self.taskLabel)
        if taskDef is None:
            raise ValueError(f"Cannot find task with label {self.taskLabel}")
        quanta = list(graph.quantaForTask(taskDef))
        keyQuanta = defaultdict(list)
        for q in quanta:
            key = self._key(q)
            keyQuanta[key].append(q)
        keys = sorted(keyQuanta.keys(), reverse=self.reverse)
        networkGraph = graph.graph

        for prev_key, key in zip(keys, keys[1:]):
            for prev_node in keyQuanta[prev_key]:
                for node in keyQuanta[key]:
                    # remove any existing edges between the two nodes, but
                    # don't fail if there are not any. Both directions need
                    # tried because in a directed graph, order maters
                    try:
                        networkGraph.remove_edge(node, prev_node)
                        networkGraph.remove_edge(prev_node, node)
                    except nx.NetworkXException:
                        pass
                    networkGraph.add_edge(prev_node, node)
        return graph
Esempio n. 2
0
class QuantumGraphTestCase(unittest.TestCase):
    """Tests the various functions of a quantum graph"""

    def setUp(self):
        config = Config(
            {
                "version": 1,
                "namespace": "pipe_base_test",
                "skypix": {
                    "common": "htm7",
                    "htm": {
                        "class": "lsst.sphgeom.HtmPixelization",
                        "max_level": 24,
                    },
                },
                "elements": {
                    "A": {
                        "keys": [
                            {
                                "name": "id",
                                "type": "int",
                            }
                        ],
                        "storage": {
                            "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                        },
                    },
                    "B": {
                        "keys": [
                            {
                                "name": "id",
                                "type": "int",
                            }
                        ],
                        "storage": {
                            "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                        },
                    },
                },
                "packers": {},
            }
        )
        universe = DimensionUniverse(config=config)
        # need to make a mapping of TaskDef to set of quantum
        quantumMap = {}
        tasks = []
        for task, label in (
            (Dummy1PipelineTask, "R"),
            (Dummy2PipelineTask, "S"),
            (Dummy3PipelineTask, "T"),
            (Dummy4PipelineTask, "U"),
        ):
            config = task.ConfigClass()
            taskDef = TaskDef(get_full_type_name(task), config, task, label)
            tasks.append(taskDef)
            quantumSet = set()
            connections = taskDef.connections
            for a, b in ((1, 2), (3, 4)):
                if connections.initInputs:
                    initInputDSType = DatasetType(
                        connections.initInput.name,
                        tuple(),
                        storageClass=connections.initInput.storageClass,
                        universe=universe,
                    )
                    initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))]
                else:
                    initRefs = None
                inputDSType = DatasetType(
                    connections.input.name,
                    connections.input.dimensions,
                    storageClass=connections.input.storageClass,
                    universe=universe,
                )
                inputRefs = [
                    DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
                ]
                outputDSType = DatasetType(
                    connections.output.name,
                    connections.output.dimensions,
                    storageClass=connections.output.storageClass,
                    universe=universe,
                )
                outputRefs = [
                    DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe))
                ]
                quantumSet.add(
                    Quantum(
                        taskName=task.__qualname__,
                        dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe),
                        taskClass=task,
                        initInputs=initRefs,
                        inputs={inputDSType: inputRefs},
                        outputs={outputDSType: outputRefs},
                    )
                )
            quantumMap[taskDef] = quantumSet
        self.tasks = tasks
        self.quantumMap = quantumMap
        self.qGraph = QuantumGraph(quantumMap, metadata=METADATA)
        self.universe = universe

    def testTaskGraph(self):
        for taskDef in self.quantumMap.keys():
            self.assertIn(taskDef, self.qGraph.taskGraph)

    def testGraph(self):
        graphSet = {q.quantum for q in self.qGraph.graph}
        for quantum in chain.from_iterable(self.quantumMap.values()):
            self.assertIn(quantum, graphSet)

    def testGetQuantumNodeByNodeId(self):
        inputQuanta = tuple(self.qGraph.inputQuanta)
        node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
        self.assertEqual(node, inputQuanta[0])
        wrongNode = uuid.uuid4()
        with self.assertRaises(KeyError):
            self.qGraph.getQuantumNodeByNodeId(wrongNode)

    def testPickle(self):
        stringify = pickle.dumps(self.qGraph)
        restore: QuantumGraph = pickle.loads(stringify)
        self.assertEqual(self.qGraph, restore)

    def testInputQuanta(self):
        inputs = {q.quantum for q in self.qGraph.inputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs)

    def testOutputtQuanta(self):
        outputs = {q.quantum for q in self.qGraph.outputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs)

    def testLength(self):
        self.assertEqual(len(self.qGraph), 2 * len(self.tasks))

    def testGetQuantaForTask(self):
        for task in self.tasks:
            self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task])

    def testGetNodesForTask(self):
        for task in self.tasks:
            nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task)
            quanta_in_node = set(n.quantum for n in nodes)
            self.assertEqual(quanta_in_node, self.quantumMap[task])

    def testFindTasksWithInput(self):
        self.assertEqual(
            tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1]
        )

    def testFindTasksWithOutput(self):
        self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0])

    def testTaskWithDSType(self):
        self.assertEqual(
            set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2])
        )

    def testFindTaskDefByName(self):
        self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0])

    def testFindTaskDefByLabel(self):
        self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])

    def testFindQuantaWIthDSType(self):
        self.assertEqual(
            self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]]
        )

    def testAllDatasetTypes(self):
        allDatasetTypes = set(self.qGraph.allDatasetTypes)
        truth = set()
        for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections):
            for connection in conClass.allConnections.values():  # type: ignore
                if not isinstance(connection, cT.InitOutput):
                    truth.add(connection.name)
        self.assertEqual(allDatasetTypes, truth)

    def testSubset(self):
        allNodes = list(self.qGraph)
        subset = self.qGraph.subset(allNodes[0])
        self.assertEqual(len(subset), 1)
        subsetList = list(subset)
        self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
        self.assertEqual(self.qGraph._buildId, subset._buildId)

    def testSubsetToConnected(self):
        # False because there are two quantum chains for two distinct sets of
        # dimensions
        self.assertFalse(self.qGraph.isConnected)

        connectedGraphs = self.qGraph.subsetToConnected()
        self.assertEqual(len(connectedGraphs), 4)
        self.assertTrue(connectedGraphs[0].isConnected)
        self.assertTrue(connectedGraphs[1].isConnected)
        self.assertTrue(connectedGraphs[2].isConnected)
        self.assertTrue(connectedGraphs[3].isConnected)

        # Split out task[3] because it is expected to be on its own
        for cg in connectedGraphs:
            if self.tasks[3] in cg.taskGraph:
                self.assertEqual(len(cg), 1)
            else:
                self.assertEqual(len(cg), 3)

        self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])

        count = 0
        for node in self.qGraph:
            if connectedGraphs[0].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[1].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[2].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[3].checkQuantumInGraph(node.quantum):
                count += 1
        self.assertEqual(len(self.qGraph), count)

        taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs}
Esempio n. 3
0
class QuantumGraphTestCase(unittest.TestCase):
    """Tests the various functions of a quantum graph
    """
    def setUp(self):
        config = Config({
            "version": 1,
            "skypix": {
                "common": "htm7",
                "htm": {
                    "class": "lsst.sphgeom.HtmPixelization",
                    "max_level": 24,
                }
            },
            "elements": {
                "A": {
                    "keys": [{
                        "name": "id",
                        "type": "int",
                    }],
                    "storage": {
                        "cls":
                        "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                    },
                },
                "B": {
                    "keys": [{
                        "name": "id",
                        "type": "int",
                    }],
                    "storage": {
                        "cls":
                        "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage",
                    },
                }
            },
            "packers": {}
        })
        universe = DimensionUniverse(config=config)
        # need to make a mapping of TaskDef to set of quantum
        quantumMap = {}
        tasks = []
        for task, label in ((Dummy1PipelineTask, "R"),
                            (Dummy2PipelineTask, "S"), (Dummy3PipelineTask,
                                                        "T")):
            config = task.ConfigClass()
            taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task,
                              label)
            tasks.append(taskDef)
            quantumSet = set()
            connections = taskDef.connections
            for a, b in ((1, 2), (3, 4)):
                if connections.initInputs:
                    initInputDSType = DatasetType(
                        connections.initInput.name,
                        tuple(),
                        storageClass=connections.initInput.storageClass,
                        universe=universe)
                    initRefs = [
                        DatasetRef(initInputDSType,
                                   DataCoordinate.makeEmpty(universe))
                    ]
                else:
                    initRefs = None
                inputDSType = DatasetType(
                    connections.input.name,
                    connections.input.dimensions,
                    storageClass=connections.input.storageClass,
                    universe=universe,
                )
                inputRefs = [
                    DatasetRef(
                        inputDSType,
                        DataCoordinate.standardize({
                            "A": a,
                            "B": b
                        },
                                                   universe=universe))
                ]
                outputDSType = DatasetType(
                    connections.output.name,
                    connections.output.dimensions,
                    storageClass=connections.output.storageClass,
                    universe=universe,
                )
                outputRefs = [
                    DatasetRef(
                        outputDSType,
                        DataCoordinate.standardize({
                            "A": a,
                            "B": b
                        },
                                                   universe=universe))
                ]
                quantumSet.add(
                    Quantum(taskName=task.__qualname__,
                            dataId=DataCoordinate.standardize(
                                {
                                    "A": a,
                                    "B": b
                                }, universe=universe),
                            taskClass=task,
                            initInputs=initRefs,
                            inputs={inputDSType: inputRefs},
                            outputs={outputDSType: outputRefs}))
            quantumMap[taskDef] = quantumSet
        self.tasks = tasks
        self.quantumMap = quantumMap
        self.qGraph = QuantumGraph(quantumMap)
        self.universe = universe

    def _cleanGraphs(self, graph1, graph2):
        # This is a hack for the unit test since the qualified name will be
        # different as it will be __main__ here, but qualified to the
        # unittest module name when restored
        # Updates in place
        for saved, loaded in zip(graph1._quanta.keys(), graph2._quanta.keys()):
            saved.taskName = saved.taskName.split('.')[-1]
            loaded.taskName = loaded.taskName.split('.')[-1]

    def testTaskGraph(self):
        for taskDef in self.quantumMap.keys():
            self.assertIn(taskDef, self.qGraph.taskGraph)

    def testGraph(self):
        graphSet = {q.quantum for q in self.qGraph.graph}
        for quantum in chain.from_iterable(self.quantumMap.values()):
            self.assertIn(quantum, graphSet)

    def testGetQuantumNodeByNodeId(self):
        inputQuanta = tuple(self.qGraph.inputQuanta)
        node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId)
        self.assertEqual(node, inputQuanta[0])
        wrongNode = NodeId(15, BuildId("alternative build Id"))
        with self.assertRaises(IncompatibleGraphError):
            self.qGraph.getQuantumNodeByNodeId(wrongNode)

    def testPickle(self):
        stringify = pickle.dumps(self.qGraph)
        restore: QuantumGraph = pickle.loads(stringify)
        self._cleanGraphs(self.qGraph, restore)
        self.assertEqual(self.qGraph, restore)

    def testInputQuanta(self):
        inputs = {q.quantum for q in self.qGraph.inputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[0]], inputs)

    def testOutputtQuanta(self):
        outputs = {q.quantum for q in self.qGraph.outputQuanta}
        self.assertEqual(self.quantumMap[self.tasks[-1]], outputs)

    def testLength(self):
        self.assertEqual(len(self.qGraph), 6)

    def testGetQuantaForTask(self):
        for task in self.tasks:
            self.assertEqual(self.qGraph.getQuantaForTask(task),
                             self.quantumMap[task])

    def testFindTasksWithInput(self):
        self.assertEqual(
            tuple(
                self.qGraph.findTasksWithInput(
                    DatasetTypeName("Dummy1Output")))[0], self.tasks[1])

    def testFindTasksWithOutput(self):
        self.assertEqual(
            self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")),
            self.tasks[0])

    def testTaskWithDSType(self):
        self.assertEqual(
            set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))),
            set(self.tasks[:2]))

    def testFindTaskDefByName(self):
        self.assertEqual(
            self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0],
            self.tasks[0])

    def testFindTaskDefByLabel(self):
        self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0])

    def testFindQuantaWIthDSType(self):
        self.assertEqual(
            self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")),
            self.quantumMap[self.tasks[0]])

    def testAllDatasetTypes(self):
        allDatasetTypes = set(self.qGraph.allDatasetTypes)
        truth = set()
        for conClass in (Dummy1Connections, Dummy2Connections,
                         Dummy3Connections):
            for connection in conClass.allConnections.values():  # type: ignore
                truth.add(connection.name)
        self.assertEqual(allDatasetTypes, truth)

    def testSubset(self):
        allNodes = list(self.qGraph)
        subset = self.qGraph.subset(allNodes[0])
        self.assertEqual(len(subset), 1)
        subsetList = list(subset)
        self.assertEqual(allNodes[0].quantum, subsetList[0].quantum)
        self.assertEqual(self.qGraph._buildId, subset._buildId)

    def testIsConnected(self):
        # False because there are two quantum chains for two distinct sets of
        # dimensions
        self.assertFalse(self.qGraph.isConnected)
        # make a broken subset
        allNodes = list(self.qGraph)
        subset = self.qGraph.subset((allNodes[0], allNodes[1]))
        # True because we subset to only one chain of graphs
        self.assertTrue(subset.isConnected)

    def testSubsetToConnected(self):
        connectedGraphs = self.qGraph.subsetToConnected()
        self.assertEqual(len(connectedGraphs), 2)
        self.assertTrue(connectedGraphs[0].isConnected)
        self.assertTrue(connectedGraphs[1].isConnected)

        self.assertEqual(len(connectedGraphs[0]), 3)
        self.assertEqual(len(connectedGraphs[1]), 3)

        self.assertNotEqual(connectedGraphs[0], connectedGraphs[1])

        count = 0
        for node in self.qGraph:
            if connectedGraphs[0].checkQuantumInGraph(node.quantum):
                count += 1
            if connectedGraphs[1].checkQuantumInGraph(node.quantum):
                count += 1
        self.assertEqual(len(self.qGraph), count)

        self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph))
        self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph))
        allNodes = list(self.qGraph)
        node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
        self.assertEqual(set([allNodes[0]]), node)
        node = self.qGraph.determineInputsToQuantumNode(allNodes[1])
        self.assertEqual(set([allNodes[0]]), node)

    def testDetermineOutputsOfQuantumNode(self):
        allNodes = list(self.qGraph)
        node = next(
            iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1])))
        self.assertEqual(allNodes[2], node)

    def testDetermineConnectionsOfQuantum(self):
        allNodes = list(self.qGraph)
        connections = self.qGraph.determineConnectionsOfQuantumNode(
            allNodes[1])
        self.assertEqual(list(connections),
                         list(self.qGraph.subset(allNodes[:3])))

    def testDetermineAnsestorsOfQuantumNode(self):
        allNodes = list(self.qGraph)
        ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2])
        self.assertEqual(list(ansestors),
                         list(self.qGraph.subset(allNodes[:3])))

    def testFindCycle(self):
        self.assertFalse(self.qGraph.findCycle())

    def testSaveLoad(self):
        with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile:
            self.qGraph.save(tmpFile)
            tmpFile.seek(0)
            restore = QuantumGraph.load(tmpFile, self.universe)
            self._cleanGraphs(self.qGraph, restore)
            self.assertEqual(self.qGraph, restore)
            # Load in just one node
            tmpFile.seek(0)
            restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0, ))
            self.assertEqual(len(restoreSub), 1)
            self.assertEqual(
                list(restoreSub)[0],
                restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))

    def testSaveLoadUri(self):
        uri = None
        try:
            with tempfile.NamedTemporaryFile(delete=False,
                                             suffix=".qgraph") as tmpFile:
                uri = tmpFile.name
                self.qGraph.saveUri(uri)
                restore = QuantumGraph.loadUri(uri, self.universe)
                self._cleanGraphs(self.qGraph, restore)
                self.assertEqual(self.qGraph, restore)
                nodeNumber = random.randint(0, len(self.qGraph) - 1)
                restoreSub = QuantumGraph.loadUri(uri,
                                                  self.universe,
                                                  nodes=(nodeNumber, ),
                                                  graphID=self.qGraph._buildId)
                self.assertEqual(len(restoreSub), 1)
                self.assertEqual(
                    list(restoreSub)[0],
                    restore.getQuantumNodeByNodeId(
                        NodeId(nodeNumber, restore.graphID)))
                # verify that more than one node works
                nodeNumber2 = random.randint(0, len(self.qGraph) - 1)
                # ensure it is a different node number
                while nodeNumber2 == nodeNumber:
                    nodeNumber2 = random.randint(0, len(self.qGraph) - 1)
                restoreSub = QuantumGraph.loadUri(uri,
                                                  self.universe,
                                                  nodes=(nodeNumber,
                                                         nodeNumber2))
                self.assertEqual(len(restoreSub), 2)
                self.assertEqual(
                    set(restoreSub),
                    set((restore.getQuantumNodeByNodeId(
                        NodeId(nodeNumber, restore._buildId)),
                         restore.getQuantumNodeByNodeId(
                             NodeId(nodeNumber2, restore._buildId)))))
                # verify an error when requesting a non existant node number
                with self.assertRaises(ValueError):
                    QuantumGraph.loadUri(uri, self.universe, nodes=(99, ))

                # verify a graphID that does not match will be an error
                with self.assertRaises(ValueError):
                    QuantumGraph.loadUri(uri,
                                         self.universe,
                                         graphID="NOTRIGHT")

        except Exception as e:
            raise e
        finally:
            if uri is not None:
                os.remove(uri)

        with self.assertRaises(TypeError):
            self.qGraph.saveUri("test.notgraph")

    @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!")
    @mock_s3
    def testSaveLoadUriS3(self):
        # Test loading a quantum graph from an mock s3 store
        conn = boto3.resource('s3', region_name="us-east-1")
        conn.create_bucket(Bucket='testBucket')
        uri = f"s3://testBucket/qgraph.qgraph"
        self.qGraph.saveUri(uri)
        restore = QuantumGraph.loadUri(uri, self.universe)
        self._cleanGraphs(self.qGraph, restore)
        self.assertEqual(self.qGraph, restore)
        restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0, ))
        self.assertEqual(len(restoreSub), 1)
        self.assertEqual(
            list(restoreSub)[0],
            restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId)))

    def testContains(self):
        firstNode = next(iter(self.qGraph))
        self.assertIn(firstNode, self.qGraph)