示例#1
0
def _join_parameter_configs(measurement_name, left_parameters,
                            right_parameters):
    """
    Join two measurement parameter config specifications.

    Only uses by :method:`_join_measurements` when join='outer'.

    Raises:
      ~pyhf.exceptions.InvalidWorkspaceOperation: Parameter configuration specifications are incompatible.

    Args:
        measurement_name (:obj:`str`): The name of the measurement being joined (a detail for raising exceptions correctly)
        left_parameters (:obj:`list`): The left parameter configuration specification.
        right_parameters (:obj:`list`): The right parameter configuration specification.

    Returns:
        :obj:`list`: A joined list of parameter configurations. Each parameter configuration follows the :obj:`defs.json#/definitions/config` schema

    """
    joined_parameter_configs = _join_items('outer', left_parameters,
                                           right_parameters)
    counted_parameter_configs = collections.Counter(
        parameter['name'] for parameter in joined_parameter_configs)
    incompatible_parameter_configs = [
        parameter for parameter, count in counted_parameter_configs.items()
        if count > 1
    ]
    if incompatible_parameter_configs:
        raise exceptions.InvalidWorkspaceOperation(
            f"Workspaces cannot have a measurement ({measurement_name}) with incompatible parameter configs: {incompatible_parameter_configs}. You can also try a different join operation: {Workspace.valid_joins}."
        )
    return joined_parameter_configs
示例#2
0
def _join_channels(join, left_channels, right_channels, merge=False):
    """
    Join two workspace channel specifications.

    Raises:
      ~pyhf.exceptions.InvalidWorkspaceOperation: Channel specifications are incompatible.

    Args:
        join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations.
        left_channels (:obj:`list`): The left channel specification.
        right_channels (:obj:`list`): The right channel specification.
        merge (:obj:`bool`): Whether to deeply merge channels or not.

    Returns:
        :obj:`list`: A joined list of channels. Each channel follows the :obj:`defs.json#/definitions/channel` `schema <https://scikit-hep.org/pyhf/likelihood.html#channel>`__

    """

    joined_channels = _join_items(join,
                                  left_channels,
                                  right_channels,
                                  deep_merge_key='samples' if merge else None)
    if join == 'none':
        common_channels = {c['name']
                           for c in left_channels
                           }.intersection(c['name'] for c in right_channels)
        if common_channels:
            raise exceptions.InvalidWorkspaceOperation(
                f"Workspaces cannot have any channels in common with the same name: {common_channels}. You can also try a different join operation: {Workspace.valid_joins}."
            )

    elif join == 'outer':
        counted_channels = collections.Counter(channel['name']
                                               for channel in joined_channels)
        incompatible_channels = [
            channel for channel, count in counted_channels.items() if count > 1
        ]
        if incompatible_channels:
            raise exceptions.InvalidWorkspaceOperation(
                f"Workspaces cannot have channels in common with incompatible structure: {incompatible_channels}. You can also try a different join operation: {Workspace.valid_joins}."
            )
    return joined_channels
示例#3
0
def _join_observations(join, left_observations, right_observations):
    """
    Join two workspace observation specifications.

    Raises:
      ~pyhf.exceptions.InvalidWorkspaceOperation: Observation specifications are incompatible.

    Args:
        join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations.
        left_observations (:obj:`list`): The left observation specification.
        right_observations (:obj:`list`): The right observation specification.

    Returns:
        :obj:`list`: A joined list of observations. Each observation follows the :obj:`defs.json#/definitions/observation` `schema <https://scikit-hep.org/pyhf/likelihood.html#observations>`__

    """
    joined_observations = _join_items(join, left_observations,
                                      right_observations)
    if join == 'none':
        common_observations = {obs['name']
                               for obs in left_observations
                               }.intersection(obs['name']
                                              for obs in right_observations)
        if common_observations:
            raise exceptions.InvalidWorkspaceOperation(
                f"Workspaces cannot have any observations in common with the same name: {common_observations}. You can also try a different join operation: {Workspace.valid_joins}."
            )

    elif join == 'outer':
        counted_observations = collections.Counter(
            observation['name'] for observation in joined_observations)
        incompatible_observations = [
            observation for observation, count in counted_observations.items()
            if count > 1
        ]
        if incompatible_observations:
            raise exceptions.InvalidWorkspaceOperation(
                f"Workspaces cannot have observations in common with incompatible structure: {incompatible_observations}. You can also try a different join operation: {Workspace.valid_joins}."
            )
    return joined_observations
示例#4
0
def _join_versions(join, left_version, right_version):
    """
    Join two workspace versions.

    Raises:
      ~pyhf.exceptions.InvalidWorkspaceOperation: Versions are incompatible.

    Args:
        join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations.
        left_version (:obj:`str`): The left workspace version.
        right_version (:obj:`str`): The right workspace version.

    Returns:
        :obj:`str`: The workspace version.

    """
    if left_version != right_version:
        raise exceptions.InvalidWorkspaceOperation(
            f"Workspaces of different versions cannot be combined: {left_version} != {right_version}"
        )
    return left_version
示例#5
0
    def _prune_and_rename(
        self,
        prune_modifiers=None,
        prune_modifier_types=None,
        prune_samples=None,
        prune_channels=None,
        prune_measurements=None,
        rename_modifiers=None,
        rename_samples=None,
        rename_channels=None,
        rename_measurements=None,
    ):
        """
        Return a new, pruned, renamed workspace specification. This will not modify the original workspace.

        Pruning removes pieces of the workspace whose name or type matches the
        user-provided lists. The pruned, renamed workspace must also be a valid
        workspace.

        A workspace is composed of many named components, such as channels and
        samples, as well as types of systematics (e.g. `histosys`). Components
        can be removed (pruned away) filtering on name or be renamed according
        to the provided :obj:`dict` mapping. Additionally, modifiers of
        specific types can be removed (pruned away).

        This function also handles specific peculiarities, such as
        renaming/removing a channel which needs to rename/remove the
        corresponding `observation`.

        Args:
            prune_modifiers: A :obj:`list` of modifiers to prune.
            prune_modifier_types: A :obj:`list` of modifier types to prune.
            prune_samples: A :obj:`list` of samples to prune.
            prune_channels: A :obj:`list` of channels to prune.
            prune_measurements: A :obj:`list` of measurements to prune.
            rename_modifiers: A :obj:`dict` mapping old modifier name to new modifier name.
            rename_samples: A :obj:`dict` mapping old sample name to new sample name.
            rename_channels: A :obj:`dict` mapping old channel name to new channel name.
            rename_measurements: A :obj:`dict` mapping old measurement name to new measurement name.

        Returns:
            ~pyhf.workspace.Workspace: A new workspace object with the specified components removed or renamed

        Raises:
          ~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune or rename does not exist in the workspace.

        """
        # avoid mutable defaults
        prune_modifiers = [] if prune_modifiers is None else prune_modifiers
        prune_modifier_types = ([] if prune_modifier_types is None else
                                prune_modifier_types)
        prune_samples = [] if prune_samples is None else prune_samples
        prune_channels = [] if prune_channels is None else prune_channels
        prune_measurements = [] if prune_measurements is None else prune_measurements
        rename_modifiers = {} if rename_modifiers is None else rename_modifiers
        rename_samples = {} if rename_samples is None else rename_samples
        rename_channels = {} if rename_channels is None else rename_channels
        rename_measurements = {} if rename_measurements is None else rename_measurements

        for modifier_type in prune_modifier_types:
            if modifier_type not in dict(self.modifiers).values():
                raise exceptions.InvalidWorkspaceOperation(
                    f"{modifier_type} is not one of the modifier types in this workspace."
                )

        for modifier_name in (*prune_modifiers, *rename_modifiers.keys()):
            if modifier_name not in dict(self.modifiers):
                raise exceptions.InvalidWorkspaceOperation(
                    f"{modifier_name} is not one of the modifiers in this workspace."
                )

        for sample_name in (*prune_samples, *rename_samples.keys()):
            if sample_name not in self.samples:
                raise exceptions.InvalidWorkspaceOperation(
                    f"{sample_name} is not one of the samples in this workspace."
                )

        for channel_name in (*prune_channels, *rename_channels.keys()):
            if channel_name not in self.channels:
                raise exceptions.InvalidWorkspaceOperation(
                    f"{channel_name} is not one of the channels in this workspace."
                )

        for measurement_name in (*prune_measurements,
                                 *rename_measurements.keys()):
            if measurement_name not in self.measurement_names:
                raise exceptions.InvalidWorkspaceOperation(
                    f"{measurement_name} is not one of the measurements in this workspace."
                )

        newspec = {
            'channels': [{
                'name':
                rename_channels.get(channel['name'], channel['name']),
                'samples': [{
                    'name':
                    rename_samples.get(sample['name'], sample['name']),
                    'data':
                    sample['data'],
                    'modifiers': [
                        dict(
                            modifier,
                            name=rename_modifiers.get(modifier['name'],
                                                      modifier['name']),
                        ) for modifier in sample['modifiers']
                        if modifier['name'] not in prune_modifiers
                        and modifier['type'] not in prune_modifier_types
                    ],
                } for sample in channel['samples']
                            if sample['name'] not in prune_samples],
            } for channel in self['channels']
                         if channel['name'] not in prune_channels],
            'measurements': [{
                'name':
                rename_measurements.get(measurement['name'],
                                        measurement['name']),
                'config': {
                    'parameters': [
                        dict(
                            parameter,
                            name=rename_modifiers.get(parameter['name'],
                                                      parameter['name']),
                        ) for parameter in measurement['config']['parameters']
                        if parameter['name'] not in prune_modifiers
                    ],
                    'poi':
                    rename_modifiers.get(measurement['config']['poi'],
                                         measurement['config']['poi']),
                },
            } for measurement in self['measurements']
                             if measurement['name'] not in prune_measurements],
            'observations': [
                dict(
                    copy.deepcopy(observation),
                    name=rename_channels.get(observation['name'],
                                             observation['name']),
                ) for observation in self['observations']
                if observation['name'] not in prune_channels
            ],
            'version':
            self['version'],
        }
        return Workspace(newspec)
示例#6
0
def _join_measurements(join, left_measurements, right_measurements):
    """
    Join two workspace measurement specifications.

    Raises:
      ~pyhf.exceptions.InvalidWorkspaceOperation: Measurement specifications are incompatible.

    Args:
        join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations.
        left_measurements (:obj:`list`): The left measurement specification.
        right_measurements (:obj:`list`): The right measurement specification.

    Returns:
        :obj:`list`: A joined list of measurements. Each measurement follows the :obj:`defs.json#/definitions/measurement` `schema <https://scikit-hep.org/pyhf/likelihood.html#measurements>`__

    """
    joined_measurements = _join_items(join, left_measurements,
                                      right_measurements)
    if join == 'none':
        common_measurements = {meas['name']
                               for meas in left_measurements
                               }.intersection(meas['name']
                                              for meas in right_measurements)
        if common_measurements:
            raise exceptions.InvalidWorkspaceOperation(
                f"Workspaces cannot have any measurements in common with the same name: {common_measurements}. You can also try a different join operation: {Workspace.valid_joins}."
            )

    elif join == 'outer':
        # need to store a mapping of measurement name to all measurement objects with that name
        _measurement_mapping = {}
        for measurement in joined_measurements:
            _measurement_mapping.setdefault(measurement['name'],
                                            []).append(measurement)
        # first check for incompatible POI
        # then merge parameter configs
        incompatible_poi = [
            measurement_name
            for measurement_name, measurements in _measurement_mapping.items()
            if
            len({measurement['config']['poi']
                 for measurement in measurements}) > 1
        ]
        if incompatible_poi:
            raise exceptions.InvalidWorkspaceOperation(
                f"Workspaces cannot have the same measurements with incompatible POI: {incompatible_poi}."
            )

        joined_measurements = []
        for measurement_name, measurements in _measurement_mapping.items():
            if len(measurements) != 1:
                new_measurement = {
                    'name': measurement_name,
                    'config': {
                        'poi':
                        measurements[0]['config']['poi'],
                        'parameters':
                        _join_parameter_configs(
                            measurement_name,
                            *(measurement['config']['parameters']
                              for measurement in measurements),
                        ),
                    },
                }
            else:
                new_measurement = measurements[0]
            joined_measurements.append(new_measurement)
    return joined_measurements