예제 #1
0
 def checkGraphInvariants(self, graph):
     elements = list(graph.elements)
     for n, element in enumerate(elements):
         # Ordered comparisons on graphs behave like sets.
         self.assertLessEqual(element.graph, graph)
         # Ordered comparisons on elements correspond to the ordering within
         # a DimensionUniverse (topological, with deterministic
         # tiebreakers).
         for other in elements[:n]:
             self.assertLess(other, element)
             self.assertLessEqual(other, element)
         for other in elements[n + 1:]:
             self.assertGreater(other, element)
             self.assertGreaterEqual(other, element)
         if isinstance(element, Dimension):
             self.assertEqual(element.graph.required, element.required)
     self.assertEqual(DimensionGraph(self.universe, graph.required), graph)
     self.assertCountEqual(graph.required, [
         dimension for dimension in graph.dimensions
         if not any(dimension in other.graph.implied
                    for other in graph.elements)
     ])
     self.assertCountEqual(graph.implied, graph.dimensions - graph.required)
     self.assertCountEqual(graph.dimensions, [
         element
         for element in graph.elements if isinstance(element, Dimension)
     ])
     self.assertCountEqual(graph.dimensions,
                           itertools.chain(graph.required, graph.implied))
     # Check primary key traversal order: each element should follow any it
     # requires, and element that is implied by any other in the graph
     # follow at least one of those.
     seen = NamedValueSet()
     for element in graph.primaryKeyTraversalOrder:
         with self.subTest(required=graph.required,
                           implied=graph.implied,
                           element=element):
             seen.add(element)
             self.assertLessEqual(element.graph.required, seen)
             if element in graph.implied:
                 self.assertTrue(any(element in s.implied for s in seen))
     self.assertCountEqual(seen, graph.elements)
     # Test encoding and decoding of DimensionGraphs to bytes.
     encoded = graph.encode()
     self.assertEqual(len(encoded), self.universe.getEncodeLength())
     self.assertEqual(
         DimensionGraph.decode(encoded, universe=self.universe), graph)
예제 #2
0
    def fromPipeline(cls, pipeline, *,
                     registry: Registry) -> PipelineDatasetTypes:
        """Extract and classify the dataset types from all tasks in a
        `Pipeline`.

        Parameters
        ----------
        pipeline: `Pipeline`
            An ordered collection of tasks that can be run together.
        registry: `Registry`
            Registry used to construct normalized `DatasetType` objects and
            retrieve those that are incomplete.

        Returns
        -------
        types: `PipelineDatasetTypes`
            The dataset types used by this `Pipeline`.

        Raises
        ------
        ValueError
            Raised if Tasks are inconsistent about which datasets are marked
            prerequisite.  This indicates that the Tasks cannot be run as part
            of the same `Pipeline`.
        """
        allInputs = NamedValueSet()
        allOutputs = NamedValueSet()
        allInitInputs = NamedValueSet()
        allInitOutputs = NamedValueSet()
        prerequisites = NamedValueSet()
        byTask = dict()
        if isinstance(pipeline, Pipeline):
            pipeline = pipeline.toExpandedPipeline()
        for taskDef in pipeline:
            thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
            allInitInputs |= thisTask.initInputs
            allInitOutputs |= thisTask.initOutputs
            allInputs |= thisTask.inputs
            prerequisites |= thisTask.prerequisites
            allOutputs |= thisTask.outputs
            byTask[taskDef.label] = thisTask
        if not prerequisites.isdisjoint(allInputs):
            raise ValueError(
                "{} marked as both prerequisites and regular inputs".format(
                    {dt.name
                     for dt in allInputs & prerequisites}))
        if not prerequisites.isdisjoint(allOutputs):
            raise ValueError(
                "{} marked as both prerequisites and outputs".format(
                    {dt.name
                     for dt in allOutputs & prerequisites}))
        # Make sure that components which are marked as inputs get treated as
        # intermediates if there is an output which produces the composite
        # containing the component
        intermediateComponents = NamedValueSet()
        intermediateComposites = NamedValueSet()
        outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
        for dsType in allInputs:
            # get the name of a possible component
            name, component = dsType.nameAndComponent()
            # if there is a component name, that means this is a component
            # DatasetType, if there is an output which produces the parent of
            # this component, treat this input as an intermediate
            if component is not None:
                if name in outputNameMapping:
                    if outputNameMapping[name].dimensions != dsType.dimensions:
                        raise ValueError(
                            f"Component dataset type {dsType.name} has different "
                            f"dimensions ({dsType.dimensions}) than its parent "
                            f"({outputNameMapping[name].dimensions}).")
                    composite = DatasetType(
                        name,
                        dsType.dimensions,
                        outputNameMapping[name].storageClass,
                        universe=registry.dimensions)
                    intermediateComponents.add(dsType)
                    intermediateComposites.add(composite)

        def checkConsistency(a: NamedValueSet, b: NamedValueSet):
            common = a.names & b.names
            for name in common:
                if a[name] != b[name]:
                    raise ValueError(
                        f"Conflicting definitions for dataset type: {a[name]} != {b[name]}."
                    )

        checkConsistency(allInitInputs, allInitOutputs)
        checkConsistency(allInputs, allOutputs)
        checkConsistency(allInputs, intermediateComposites)
        checkConsistency(allOutputs, intermediateComposites)

        def frozen(s: NamedValueSet) -> NamedValueSet:
            s.freeze()
            return s

        return cls(
            initInputs=frozen(allInitInputs - allInitOutputs),
            initIntermediates=frozen(allInitInputs & allInitOutputs),
            initOutputs=frozen(allInitOutputs - allInitInputs),
            inputs=frozen(allInputs - allOutputs - intermediateComponents),
            intermediates=frozen(allInputs & allOutputs
                                 | intermediateComponents),
            outputs=frozen(allOutputs - allInputs - intermediateComposites),
            prerequisites=frozen(prerequisites),
            byTask=MappingProxyType(
                byTask
            ),  # MappingProxyType -> frozen view of dict for immutability
        )
예제 #3
0
        def makeDatasetTypesSet(connectionType, freeze=True):
            """Constructs a set of true `DatasetType` objects

            Parameters
            ----------
            connectionType : `str`
                Name of the connection type to produce a set for, corresponds
                to an attribute of type `list` on the connection class instance
            freeze : `bool`, optional
                If `True`, call `NamedValueSet.freeze` on the object returned.

            Returns
            -------
            datasetTypes : `NamedValueSet`
                A set of all datasetTypes which correspond to the input
                connection type specified in the connection class of this
                `PipelineTask`

            Notes
            -----
            This function is a closure over the variables ``registry`` and
            ``taskDef``.
            """
            datasetTypes = NamedValueSet()
            for c in iterConnections(taskDef.connections, connectionType):
                dimensions = set(getattr(c, 'dimensions', set()))
                if "skypix" in dimensions:
                    try:
                        datasetType = registry.getDatasetType(c.name)
                    except LookupError as err:
                        raise LookupError(
                            f"DatasetType '{c.name}' referenced by "
                            f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
                            f"placeholder, but does not already exist in the registry.  "
                            f"Note that reference catalog names are now used as the dataset "
                            f"type name instead of 'ref_cat'.") from err
                    rest1 = set(
                        registry.dimensions.extract(dimensions -
                                                    set(["skypix"])).names)
                    rest2 = set(dim.name for dim in datasetType.dimensions
                                if not isinstance(dim, SkyPixDimension))
                    if rest1 != rest2:
                        raise ValueError(
                            f"Non-skypix dimensions for dataset type {c.name} declared in "
                            f"connections ({rest1}) are inconsistent with those in "
                            f"registry's version of this dataset ({rest2}).")
                else:
                    # Component dataset types are not explicitly in the
                    # registry.  This complicates consistency checks with
                    # registry and requires we work out the composite storage
                    # class.
                    registryDatasetType = None
                    try:
                        registryDatasetType = registry.getDatasetType(c.name)
                    except KeyError:
                        compositeName, componentName = DatasetType.splitDatasetTypeName(
                            c.name)
                        parentStorageClass = DatasetType.PlaceholderParentStorageClass \
                            if componentName else None
                        datasetType = c.makeDatasetType(
                            registry.dimensions,
                            parentStorageClass=parentStorageClass)
                        registryDatasetType = datasetType
                    else:
                        datasetType = c.makeDatasetType(
                            registry.dimensions,
                            parentStorageClass=registryDatasetType.
                            parentStorageClass)

                    if registryDatasetType and datasetType != registryDatasetType:
                        raise ValueError(
                            f"Supplied dataset type ({datasetType}) inconsistent with "
                            f"registry definition ({registryDatasetType}) "
                            f"for {taskDef.label}.")
                datasetTypes.add(datasetType)
            if freeze:
                datasetTypes.freeze()
            return datasetTypes