def _join_parameter_configs(measurement_name, left_parameters, right_parameters): """ Join two measurement parameter config specifications. Only uses by :method:`_join_measurements` when join='outer'. Raises: ~pyhf.exceptions.InvalidWorkspaceOperation: Parameter configuration specifications are incompatible. Args: measurement_name (:obj:`str`): The name of the measurement being joined (a detail for raising exceptions correctly) left_parameters (:obj:`list`): The left parameter configuration specification. right_parameters (:obj:`list`): The right parameter configuration specification. Returns: :obj:`list`: A joined list of parameter configurations. Each parameter configuration follows the :obj:`defs.json#/definitions/config` schema """ joined_parameter_configs = _join_items('outer', left_parameters, right_parameters) counted_parameter_configs = collections.Counter( parameter['name'] for parameter in joined_parameter_configs) incompatible_parameter_configs = [ parameter for parameter, count in counted_parameter_configs.items() if count > 1 ] if incompatible_parameter_configs: raise exceptions.InvalidWorkspaceOperation( f"Workspaces cannot have a measurement ({measurement_name}) with incompatible parameter configs: {incompatible_parameter_configs}. You can also try a different join operation: {Workspace.valid_joins}." ) return joined_parameter_configs
def _join_channels(join, left_channels, right_channels, merge=False): """ Join two workspace channel specifications. Raises: ~pyhf.exceptions.InvalidWorkspaceOperation: Channel specifications are incompatible. Args: join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations. left_channels (:obj:`list`): The left channel specification. right_channels (:obj:`list`): The right channel specification. merge (:obj:`bool`): Whether to deeply merge channels or not. Returns: :obj:`list`: A joined list of channels. Each channel follows the :obj:`defs.json#/definitions/channel` `schema <https://scikit-hep.org/pyhf/likelihood.html#channel>`__ """ joined_channels = _join_items(join, left_channels, right_channels, deep_merge_key='samples' if merge else None) if join == 'none': common_channels = {c['name'] for c in left_channels }.intersection(c['name'] for c in right_channels) if common_channels: raise exceptions.InvalidWorkspaceOperation( f"Workspaces cannot have any channels in common with the same name: {common_channels}. You can also try a different join operation: {Workspace.valid_joins}." ) elif join == 'outer': counted_channels = collections.Counter(channel['name'] for channel in joined_channels) incompatible_channels = [ channel for channel, count in counted_channels.items() if count > 1 ] if incompatible_channels: raise exceptions.InvalidWorkspaceOperation( f"Workspaces cannot have channels in common with incompatible structure: {incompatible_channels}. You can also try a different join operation: {Workspace.valid_joins}." ) return joined_channels
def _join_observations(join, left_observations, right_observations): """ Join two workspace observation specifications. Raises: ~pyhf.exceptions.InvalidWorkspaceOperation: Observation specifications are incompatible. Args: join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations. left_observations (:obj:`list`): The left observation specification. right_observations (:obj:`list`): The right observation specification. Returns: :obj:`list`: A joined list of observations. Each observation follows the :obj:`defs.json#/definitions/observation` `schema <https://scikit-hep.org/pyhf/likelihood.html#observations>`__ """ joined_observations = _join_items(join, left_observations, right_observations) if join == 'none': common_observations = {obs['name'] for obs in left_observations }.intersection(obs['name'] for obs in right_observations) if common_observations: raise exceptions.InvalidWorkspaceOperation( f"Workspaces cannot have any observations in common with the same name: {common_observations}. You can also try a different join operation: {Workspace.valid_joins}." ) elif join == 'outer': counted_observations = collections.Counter( observation['name'] for observation in joined_observations) incompatible_observations = [ observation for observation, count in counted_observations.items() if count > 1 ] if incompatible_observations: raise exceptions.InvalidWorkspaceOperation( f"Workspaces cannot have observations in common with incompatible structure: {incompatible_observations}. You can also try a different join operation: {Workspace.valid_joins}." ) return joined_observations
def _join_versions(join, left_version, right_version): """ Join two workspace versions. Raises: ~pyhf.exceptions.InvalidWorkspaceOperation: Versions are incompatible. Args: join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations. left_version (:obj:`str`): The left workspace version. right_version (:obj:`str`): The right workspace version. Returns: :obj:`str`: The workspace version. """ if left_version != right_version: raise exceptions.InvalidWorkspaceOperation( f"Workspaces of different versions cannot be combined: {left_version} != {right_version}" ) return left_version
def _prune_and_rename( self, prune_modifiers=None, prune_modifier_types=None, prune_samples=None, prune_channels=None, prune_measurements=None, rename_modifiers=None, rename_samples=None, rename_channels=None, rename_measurements=None, ): """ Return a new, pruned, renamed workspace specification. This will not modify the original workspace. Pruning removes pieces of the workspace whose name or type matches the user-provided lists. The pruned, renamed workspace must also be a valid workspace. A workspace is composed of many named components, such as channels and samples, as well as types of systematics (e.g. `histosys`). Components can be removed (pruned away) filtering on name or be renamed according to the provided :obj:`dict` mapping. Additionally, modifiers of specific types can be removed (pruned away). This function also handles specific peculiarities, such as renaming/removing a channel which needs to rename/remove the corresponding `observation`. Args: prune_modifiers: A :obj:`list` of modifiers to prune. prune_modifier_types: A :obj:`list` of modifier types to prune. prune_samples: A :obj:`list` of samples to prune. prune_channels: A :obj:`list` of channels to prune. prune_measurements: A :obj:`list` of measurements to prune. rename_modifiers: A :obj:`dict` mapping old modifier name to new modifier name. rename_samples: A :obj:`dict` mapping old sample name to new sample name. rename_channels: A :obj:`dict` mapping old channel name to new channel name. rename_measurements: A :obj:`dict` mapping old measurement name to new measurement name. Returns: ~pyhf.workspace.Workspace: A new workspace object with the specified components removed or renamed Raises: ~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune or rename does not exist in the workspace. """ # avoid mutable defaults prune_modifiers = [] if prune_modifiers is None else prune_modifiers prune_modifier_types = ([] if prune_modifier_types is None else prune_modifier_types) prune_samples = [] if prune_samples is None else prune_samples prune_channels = [] if prune_channels is None else prune_channels prune_measurements = [] if prune_measurements is None else prune_measurements rename_modifiers = {} if rename_modifiers is None else rename_modifiers rename_samples = {} if rename_samples is None else rename_samples rename_channels = {} if rename_channels is None else rename_channels rename_measurements = {} if rename_measurements is None else rename_measurements for modifier_type in prune_modifier_types: if modifier_type not in dict(self.modifiers).values(): raise exceptions.InvalidWorkspaceOperation( f"{modifier_type} is not one of the modifier types in this workspace." ) for modifier_name in (*prune_modifiers, *rename_modifiers.keys()): if modifier_name not in dict(self.modifiers): raise exceptions.InvalidWorkspaceOperation( f"{modifier_name} is not one of the modifiers in this workspace." ) for sample_name in (*prune_samples, *rename_samples.keys()): if sample_name not in self.samples: raise exceptions.InvalidWorkspaceOperation( f"{sample_name} is not one of the samples in this workspace." ) for channel_name in (*prune_channels, *rename_channels.keys()): if channel_name not in self.channels: raise exceptions.InvalidWorkspaceOperation( f"{channel_name} is not one of the channels in this workspace." ) for measurement_name in (*prune_measurements, *rename_measurements.keys()): if measurement_name not in self.measurement_names: raise exceptions.InvalidWorkspaceOperation( f"{measurement_name} is not one of the measurements in this workspace." ) newspec = { 'channels': [{ 'name': rename_channels.get(channel['name'], channel['name']), 'samples': [{ 'name': rename_samples.get(sample['name'], sample['name']), 'data': sample['data'], 'modifiers': [ dict( modifier, name=rename_modifiers.get(modifier['name'], modifier['name']), ) for modifier in sample['modifiers'] if modifier['name'] not in prune_modifiers and modifier['type'] not in prune_modifier_types ], } for sample in channel['samples'] if sample['name'] not in prune_samples], } for channel in self['channels'] if channel['name'] not in prune_channels], 'measurements': [{ 'name': rename_measurements.get(measurement['name'], measurement['name']), 'config': { 'parameters': [ dict( parameter, name=rename_modifiers.get(parameter['name'], parameter['name']), ) for parameter in measurement['config']['parameters'] if parameter['name'] not in prune_modifiers ], 'poi': rename_modifiers.get(measurement['config']['poi'], measurement['config']['poi']), }, } for measurement in self['measurements'] if measurement['name'] not in prune_measurements], 'observations': [ dict( copy.deepcopy(observation), name=rename_channels.get(observation['name'], observation['name']), ) for observation in self['observations'] if observation['name'] not in prune_channels ], 'version': self['version'], } return Workspace(newspec)
def _join_measurements(join, left_measurements, right_measurements): """ Join two workspace measurement specifications. Raises: ~pyhf.exceptions.InvalidWorkspaceOperation: Measurement specifications are incompatible. Args: join (:obj:`str`): The join operation to apply. See ~pyhf.workspace.Workspace for valid join operations. left_measurements (:obj:`list`): The left measurement specification. right_measurements (:obj:`list`): The right measurement specification. Returns: :obj:`list`: A joined list of measurements. Each measurement follows the :obj:`defs.json#/definitions/measurement` `schema <https://scikit-hep.org/pyhf/likelihood.html#measurements>`__ """ joined_measurements = _join_items(join, left_measurements, right_measurements) if join == 'none': common_measurements = {meas['name'] for meas in left_measurements }.intersection(meas['name'] for meas in right_measurements) if common_measurements: raise exceptions.InvalidWorkspaceOperation( f"Workspaces cannot have any measurements in common with the same name: {common_measurements}. You can also try a different join operation: {Workspace.valid_joins}." ) elif join == 'outer': # need to store a mapping of measurement name to all measurement objects with that name _measurement_mapping = {} for measurement in joined_measurements: _measurement_mapping.setdefault(measurement['name'], []).append(measurement) # first check for incompatible POI # then merge parameter configs incompatible_poi = [ measurement_name for measurement_name, measurements in _measurement_mapping.items() if len({measurement['config']['poi'] for measurement in measurements}) > 1 ] if incompatible_poi: raise exceptions.InvalidWorkspaceOperation( f"Workspaces cannot have the same measurements with incompatible POI: {incompatible_poi}." ) joined_measurements = [] for measurement_name, measurements in _measurement_mapping.items(): if len(measurements) != 1: new_measurement = { 'name': measurement_name, 'config': { 'poi': measurements[0]['config']['poi'], 'parameters': _join_parameter_configs( measurement_name, *(measurement['config']['parameters'] for measurement in measurements), ), }, } else: new_measurement = measurements[0] joined_measurements.append(new_measurement) return joined_measurements