예제 #1
0
 def testOperators(self):
     ab = NamedValueSet({self.a, self.b})
     bc = NamedValueSet({self.b, self.c})
     self.checkOperator(ab & bc, {self.b})
     self.checkOperator(ab | bc, {self.a, self.b, self.c})
     self.checkOperator(ab ^ bc, {self.a, self.c})
     self.checkOperator(ab - bc, {self.a})
예제 #2
0
 def testConstruction(self):
     for arg in ({self.a, self.b}, (self.a, self.b)):
         for nvs in (NamedValueSet(arg), NamedValueSet(arg).freeze()):
             self.assertEqual(len(nvs), 2)
             self.assertEqual(nvs.names, {"a", "b"})
             self.assertCountEqual(nvs, {self.a, self.b})
             self.assertCountEqual(nvs.asMapping().items(),
                                   [(self.a.name, self.a),
                                    (self.b.name, self.b)])
예제 #3
0
 def checkGraphInvariants(self, graph):
     elements = list(graph.elements)
     for n, element in enumerate(elements):
         # Ordered comparisons on graphs behave like sets.
         self.assertLessEqual(element.graph, graph)
         # Ordered comparisons on elements correspond to the ordering within
         # a DimensionUniverse (topological, with deterministic
         # tiebreakers).
         for other in elements[:n]:
             self.assertLess(other, element)
             self.assertLessEqual(other, element)
         for other in elements[n + 1:]:
             self.assertGreater(other, element)
             self.assertGreaterEqual(other, element)
         if isinstance(element, Dimension):
             self.assertEqual(element.graph.required, element.required)
     self.assertEqual(DimensionGraph(self.universe, graph.required), graph)
     self.assertCountEqual(graph.required, [
         dimension for dimension in graph.dimensions
         if not any(dimension in other.graph.implied
                    for other in graph.elements)
     ])
     self.assertCountEqual(graph.implied, graph.dimensions - graph.required)
     self.assertCountEqual(graph.dimensions, [
         element
         for element in graph.elements if isinstance(element, Dimension)
     ])
     self.assertCountEqual(graph.dimensions,
                           itertools.chain(graph.required, graph.implied))
     # Check primary key traversal order: each element should follow any it
     # requires, and element that is implied by any other in the graph
     # follow at least one of those.
     seen = NamedValueSet()
     for element in graph.primaryKeyTraversalOrder:
         with self.subTest(required=graph.required,
                           implied=graph.implied,
                           element=element):
             seen.add(element)
             self.assertLessEqual(element.graph.required, seen)
             if element in graph.implied:
                 self.assertTrue(any(element in s.implied for s in seen))
     self.assertCountEqual(seen, graph.elements)
     # Test encoding and decoding of DimensionGraphs to bytes.
     encoded = graph.encode()
     self.assertEqual(len(encoded), self.universe.getEncodeLength())
     self.assertEqual(
         DimensionGraph.decode(encoded, universe=self.universe), graph)
예제 #4
0
    def refresh(self, get_dataset_type: Callable[[int], DatasetType]) -> None:
        """Load all collection summary information from the database.

        Parameters
        ----------
        get_dataset_type : `Callable`
            Function that takes an `int` dataset_type_id value and returns a
            `DatasetType` instance.
        """
        # Set up the SQL query we'll use to fetch all of the summary
        # information at once.
        columns = [
            self._tables.datasetType.columns[self._collectionKeyName].label(
                self._collectionKeyName),
            self._tables.datasetType.columns.dataset_type_id.label(
                "dataset_type_id"),
        ]
        fromClause = self._tables.datasetType
        for dimension, table in self._tables.dimensions.items():
            columns.append(table.columns[dimension.name].label(dimension.name))
            fromClause = fromClause.join(
                table,
                onclause=(
                    self._tables.datasetType.columns[self._collectionKeyName]
                    == table.columns[self._collectionKeyName]),
                isouter=True,
            )
        sql = sqlalchemy.sql.select(columns).select_from(fromClause)
        # Run the query and construct CollectionSummary objects from the result
        # rows.  This will never include CHAINED collections or collections
        # with no datasets.
        summaries: Dict[Any, CollectionSummary] = {}
        for row in self._db.query(sql):
            # Collection key should never be None/NULL; it's what we join on.
            # Extract that and then turn it into a collection name.
            collectionKey = row[self._collectionKeyName]
            # dataset_type_id should also nver be None/NULL; it's in the first
            # table we joined.
            datasetType = get_dataset_type(row["dataset_type_id"])
            # See if we have a summary already for this collection; if not,
            # make one.
            summary = summaries.get(collectionKey)
            if summary is None:
                summary = CollectionSummary(
                    datasetTypes=NamedValueSet([datasetType]),
                    dimensions=GovernorDimensionRestriction.makeEmpty(
                        self._dimensions.universe),
                )
                summaries[collectionKey] = summary
            else:
                summary.datasetTypes.add(datasetType)
            # Update the dimensions with the values in this row that aren't
            # None/NULL (many will be in general, because these enter the query
            # via LEFT OUTER JOIN).
            for dimension in self._tables.dimensions:
                value = row[dimension.name]
                if value is not None:
                    summary.dimensions.add(dimension, value)
        self._cache = summaries
예제 #5
0
 def testGetItem(self):
     nvs = NamedValueSet({self.a, self.b, self.c})
     self.assertEqual(nvs["a"], self.a)
     self.assertEqual(nvs[self.a], self.a)
     self.assertEqual(nvs["b"], self.b)
     self.assertEqual(nvs[self.b], self.b)
     self.assertIn("a", nvs)
     self.assertIn(self.b, nvs)
예제 #6
0
def _makeTableSpecs(
        datasets: Type[DatasetRecordStorageManager]) -> _TablesTuple:
    """Construct specifications for tables used by the monolithic datastore
    bridge classes.

    Parameters
    ----------
    universe : `DimensionUniverse`
        All dimensions known to the `Registry`.
    datasets : subclass of `DatasetRecordStorageManager`
        Manager class for datasets; used only to create foreign key fields.

    Returns
    -------
    specs : `_TablesTuple`
        A named tuple containing `ddl.TableSpec` instances.
    """
    # We want the dataset_location and dataset_location_trash tables
    # to have the same definition, aside from the behavior of their link
    # to the dataset table: the trash table has no foreign key constraint.
    dataset_location_spec = ddl.TableSpec(
        doc=
        ("A table that provides information on whether a dataset is stored in "
         "one or more Datastores.  The presence or absence of a record in this "
         "table itself indicates whether the dataset is present in that "
         "Datastore. "),
        fields=NamedValueSet([
            ddl.FieldSpec(
                name="datastore_name",
                dtype=sqlalchemy.String,
                length=256,
                primaryKey=True,
                nullable=False,
                doc="Name of the Datastore this entry corresponds to.",
            ),
        ]),
    )
    dataset_location = copy.deepcopy(dataset_location_spec)
    datasets.addDatasetForeignKey(dataset_location, primaryKey=True)
    dataset_location_trash = copy.deepcopy(dataset_location_spec)
    datasets.addDatasetForeignKey(dataset_location_trash,
                                  primaryKey=True,
                                  constraint=False)
    return _TablesTuple(
        dataset_location=dataset_location,
        dataset_location_trash=dataset_location_trash,
    )
예제 #7
0
 def testEquality(self):
     s = {self.a, self.b, self.c}
     nvs = NamedValueSet(s)
     self.assertEqual(nvs, s)
     self.assertEqual(s, nvs)
예제 #8
0
 def testNoNameConstruction(self):
     with self.assertRaises(AttributeError):
         NamedValueSet([self.a, "a"])
예제 #9
0
 def frozen(s: NamedValueSet) -> NamedValueSet:
     s.freeze()
     return s
예제 #10
0
    def fromPipeline(cls, pipeline, *,
                     registry: Registry) -> PipelineDatasetTypes:
        """Extract and classify the dataset types from all tasks in a
        `Pipeline`.

        Parameters
        ----------
        pipeline: `Pipeline`
            An ordered collection of tasks that can be run together.
        registry: `Registry`
            Registry used to construct normalized `DatasetType` objects and
            retrieve those that are incomplete.

        Returns
        -------
        types: `PipelineDatasetTypes`
            The dataset types used by this `Pipeline`.

        Raises
        ------
        ValueError
            Raised if Tasks are inconsistent about which datasets are marked
            prerequisite.  This indicates that the Tasks cannot be run as part
            of the same `Pipeline`.
        """
        allInputs = NamedValueSet()
        allOutputs = NamedValueSet()
        allInitInputs = NamedValueSet()
        allInitOutputs = NamedValueSet()
        prerequisites = NamedValueSet()
        byTask = dict()
        if isinstance(pipeline, Pipeline):
            pipeline = pipeline.toExpandedPipeline()
        for taskDef in pipeline:
            thisTask = TaskDatasetTypes.fromTaskDef(taskDef, registry=registry)
            allInitInputs |= thisTask.initInputs
            allInitOutputs |= thisTask.initOutputs
            allInputs |= thisTask.inputs
            prerequisites |= thisTask.prerequisites
            allOutputs |= thisTask.outputs
            byTask[taskDef.label] = thisTask
        if not prerequisites.isdisjoint(allInputs):
            raise ValueError(
                "{} marked as both prerequisites and regular inputs".format(
                    {dt.name
                     for dt in allInputs & prerequisites}))
        if not prerequisites.isdisjoint(allOutputs):
            raise ValueError(
                "{} marked as both prerequisites and outputs".format(
                    {dt.name
                     for dt in allOutputs & prerequisites}))
        # Make sure that components which are marked as inputs get treated as
        # intermediates if there is an output which produces the composite
        # containing the component
        intermediateComponents = NamedValueSet()
        intermediateComposites = NamedValueSet()
        outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
        for dsType in allInputs:
            # get the name of a possible component
            name, component = dsType.nameAndComponent()
            # if there is a component name, that means this is a component
            # DatasetType, if there is an output which produces the parent of
            # this component, treat this input as an intermediate
            if component is not None:
                if name in outputNameMapping:
                    if outputNameMapping[name].dimensions != dsType.dimensions:
                        raise ValueError(
                            f"Component dataset type {dsType.name} has different "
                            f"dimensions ({dsType.dimensions}) than its parent "
                            f"({outputNameMapping[name].dimensions}).")
                    composite = DatasetType(
                        name,
                        dsType.dimensions,
                        outputNameMapping[name].storageClass,
                        universe=registry.dimensions)
                    intermediateComponents.add(dsType)
                    intermediateComposites.add(composite)

        def checkConsistency(a: NamedValueSet, b: NamedValueSet):
            common = a.names & b.names
            for name in common:
                if a[name] != b[name]:
                    raise ValueError(
                        f"Conflicting definitions for dataset type: {a[name]} != {b[name]}."
                    )

        checkConsistency(allInitInputs, allInitOutputs)
        checkConsistency(allInputs, allOutputs)
        checkConsistency(allInputs, intermediateComposites)
        checkConsistency(allOutputs, intermediateComposites)

        def frozen(s: NamedValueSet) -> NamedValueSet:
            s.freeze()
            return s

        return cls(
            initInputs=frozen(allInitInputs - allInitOutputs),
            initIntermediates=frozen(allInitInputs & allInitOutputs),
            initOutputs=frozen(allInitOutputs - allInitInputs),
            inputs=frozen(allInputs - allOutputs - intermediateComponents),
            intermediates=frozen(allInputs & allOutputs
                                 | intermediateComponents),
            outputs=frozen(allOutputs - allInputs - intermediateComposites),
            prerequisites=frozen(prerequisites),
            byTask=MappingProxyType(
                byTask
            ),  # MappingProxyType -> frozen view of dict for immutability
        )
예제 #11
0
        def makeDatasetTypesSet(connectionType, freeze=True):
            """Constructs a set of true `DatasetType` objects

            Parameters
            ----------
            connectionType : `str`
                Name of the connection type to produce a set for, corresponds
                to an attribute of type `list` on the connection class instance
            freeze : `bool`, optional
                If `True`, call `NamedValueSet.freeze` on the object returned.

            Returns
            -------
            datasetTypes : `NamedValueSet`
                A set of all datasetTypes which correspond to the input
                connection type specified in the connection class of this
                `PipelineTask`

            Notes
            -----
            This function is a closure over the variables ``registry`` and
            ``taskDef``.
            """
            datasetTypes = NamedValueSet()
            for c in iterConnections(taskDef.connections, connectionType):
                dimensions = set(getattr(c, 'dimensions', set()))
                if "skypix" in dimensions:
                    try:
                        datasetType = registry.getDatasetType(c.name)
                    except LookupError as err:
                        raise LookupError(
                            f"DatasetType '{c.name}' referenced by "
                            f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
                            f"placeholder, but does not already exist in the registry.  "
                            f"Note that reference catalog names are now used as the dataset "
                            f"type name instead of 'ref_cat'.") from err
                    rest1 = set(
                        registry.dimensions.extract(dimensions -
                                                    set(["skypix"])).names)
                    rest2 = set(dim.name for dim in datasetType.dimensions
                                if not isinstance(dim, SkyPixDimension))
                    if rest1 != rest2:
                        raise ValueError(
                            f"Non-skypix dimensions for dataset type {c.name} declared in "
                            f"connections ({rest1}) are inconsistent with those in "
                            f"registry's version of this dataset ({rest2}).")
                else:
                    # Component dataset types are not explicitly in the
                    # registry.  This complicates consistency checks with
                    # registry and requires we work out the composite storage
                    # class.
                    registryDatasetType = None
                    try:
                        registryDatasetType = registry.getDatasetType(c.name)
                    except KeyError:
                        compositeName, componentName = DatasetType.splitDatasetTypeName(
                            c.name)
                        parentStorageClass = DatasetType.PlaceholderParentStorageClass \
                            if componentName else None
                        datasetType = c.makeDatasetType(
                            registry.dimensions,
                            parentStorageClass=parentStorageClass)
                        registryDatasetType = datasetType
                    else:
                        datasetType = c.makeDatasetType(
                            registry.dimensions,
                            parentStorageClass=registryDatasetType.
                            parentStorageClass)

                    if registryDatasetType and datasetType != registryDatasetType:
                        raise ValueError(
                            f"Supplied dataset type ({datasetType}) inconsistent with "
                            f"registry definition ({registryDatasetType}) "
                            f"for {taskDef.label}.")
                datasetTypes.add(datasetType)
            if freeze:
                datasetTypes.freeze()
            return datasetTypes