def _output_template(self, fname): plist = self.plist.plugin_list index = [i for i in range(len(plist)) if plist[i]['active']] local_dict = MetaData(ordered=True) global_dict = MetaData(ordered=True) for i in index: params = self.__get_template_params(plist[i]['data'], []) name = plist[i]['name'] for p in params: ptype, isyaml, key, value = p if isyaml: data_name = isyaml if ptype == 'local' else 'all' local_dict.set([i+1, name, data_name, key], value) elif ptype == 'local': local_dict.set([i+1, name, key], value) else: global_dict.set(['all', name, key], value) with open(fname, 'w') as stream: local_dict.get_dictionary().update(global_dict.get_dictionary()) yu.dump_yaml(local_dict.get_dictionary(), stream)
def _output_template(self, fname, process_fname): plist = self.plist.plugin_list index = [i for i in range(len(plist)) if plist[i]['active']] local_dict = MetaData(ordered=True) global_dict = MetaData(ordered=True) local_dict.set(['process_list'], os.path.abspath(process_fname)) for i in index: params = self.__get_template_params(plist[i]['data'], []) name = plist[i]['name'] for p in params: ptype, isyaml, key, value = p if isyaml: data_name = isyaml if ptype == 'local' else 'all' local_dict.set([i+1, name, data_name, key], value) elif ptype == 'local': local_dict.set([i+1, name, key], value) else: global_dict.set(['all', name, key], value) with open(fname, 'w') as stream: local_dict.get_dictionary().update(global_dict.get_dictionary()) yu.dump_yaml(local_dict.get_dictionary(), stream)
class PluginData(object): """ The PluginData class contains plugin specific information about a Data object for the duration of a plugin. An instance of the class is encapsulated inside the Data object during the plugin run """ def __init__(self, data_obj, plugin=None): self.data_obj = data_obj self._preview = None self.data_obj._set_plugin_data(self) self.meta_data = MetaData() self.padding = None self.pad_dict = None self.shape = None self.shape_transfer = None self.core_shape = None self.multi_params = {} self.extra_dims = [] self._plugin = plugin self.fixed_dims = True self.split = None self.boundary_padding = None self.no_squeeze = False self.pre_tuning_shape = None self._frame_limit = None def _get_preview(self): return self._preview def get_total_frames(self): """ Get the total number of frames to process (all MPI processes). :returns: Number of frames :rtype: int """ temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dims"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def __set_pattern(self, name): """ Set the pattern related information in the meta data dict. """ pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set("name", name) self.meta_data.set("core_dims", pattern['core_dims']) self.__set_slice_dimensions() def get_pattern_name(self): """ Get the pattern name. :returns: the pattern name :rtype: str """ try: name = self.meta_data.get("name") return name except KeyError: raise Exception("The pattern name has not been set.") def get_pattern(self): """ Get the current pattern. :returns: dict of the pattern name against the pattern. :rtype: dict """ pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def __set_shape(self): """ Set the shape of the plugin data processing chunk. """ core_dir = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions() dirs = list(set(core_dir + (slice_dir[0], ))) slice_idx = dirs.index(slice_dir[0]) dshape = self.data_obj.get_shape() shape = [] for core in set(core_dir): shape.append(dshape[core]) self.__set_core_shape(tuple(shape)) mfp = self._get_max_frames_process() if mfp > 1 or self._get_no_squeeze(): shape.insert(slice_idx, mfp) self.shape = tuple(shape) def _set_shape_transfer(self, slice_size): dshape = self.data_obj.get_shape() shape_before_tuning = self._get_shape_before_tuning() add = [1] * (len(dshape) - len(shape_before_tuning)) slice_size = slice_size + add core_dir = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions() shape = [None] * len(dshape) for dim in core_dir: shape[dim] = dshape[dim] i = 0 for dim in slice_dir: shape[dim] = slice_size[i] i += 1 self.shape_transfer = tuple(shape) def get_shape(self): """ Get the shape of the data (without padding) that is passed to the plugin process_frames method. """ return self.shape def _set_padded_shape(self): pass def get_padded_shape(self): """ Get the shape of the data (with padding) that is passed to the plugin process_frames method. """ return self.shape def get_shape_transfer(self): """ Get the shape of the plugin data to be transferred each time. """ return self.shape_transfer def __set_core_shape(self, shape): """ Set the core shape to hold only the shape of the core dimensions """ self.core_shape = shape def get_core_shape(self): """ Get the shape of the core dimensions only. :returns: shape of core dimensions :rtype: tuple """ return self.core_shape def _set_shape_before_tuning(self, shape): """ Set the shape of the full dataset used during each run of the \ plugin (i.e. ignore extra dimensions due to parameter tuning). """ self.pre_tuning_shape = shape def _get_shape_before_tuning(self): """ Return the shape of the full dataset used during each run of the \ plugin (i.e. ignore extra dimensions due to parameter tuning). """ return self.pre_tuning_shape if self.pre_tuning_shape else\ self.data_obj.get_shape() def __check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir) + len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def __set_slice_dimensions(self): """ Set the slice dimensions in the pluginData meta data dictionary. """ slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dims'] self.meta_data.set('slice_dims', slice_dirs) def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions()[0] return list(set(core_dirs + (slice_dir, ))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = self.data_obj.get_data_dimension_by_axis_label( label, contains=contains) plugin_dims = self.data_obj.get_core_dimensions() if self._get_max_frames_process() > 1 or self.max_frames == 'multiple': plugin_dims += (self.get_slice_dimension(), ) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice dimensions. The fastest changing slice dimension will always be the first one stated in the pattern key ``slice_dir``. The input param is a tuple stating the desired order of slicing dimensions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_dimensions(self): """ Return the position of the core dimensions in relation to the data handed to the plugin. """ core_dims = self.data_obj.get_core_dimensions() first_slice_dim = (self.data_obj.get_slice_dimensions()[0], ) plugin_dims = np.sort(core_dims + first_slice_dim) return np.searchsorted(plugin_dims, np.sort(core_dims)) def set_fixed_dimensions(self, dims, values): """ Fix a data direction to the index in values list. :param list(int) dims: Directions to fix :param list(int) value: Index of fixed directions """ slice_dirs = self.data_obj.get_slice_dimensions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set("fixed_dimensions", dims) self.meta_data.set("fixed_dimensions_values", values) self.__set_slice_dimensions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) self.__set_shape() def _get_fixed_dimensions(self): """ Get the fixed data directions and their indices :returns: Fixed directions and their associated values :rtype: list(list(int), list(int)) """ fixed = [] values = [] if 'fixed_dimensions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get("fixed_dimensions") values = self.meta_data.get("fixed_dimensions_values") return [fixed, values] def _get_data_slice_list(self, plist): """ Convert a plugin data slice list to a slice list for the whole dataset, i.e. add in any missing dimensions. """ nDims = len(self.get_shape()) all_dims = self.get_core_dimensions() + self.get_slice_dimension() extra_dims = all_dims[nDims:] dlist = list(plist) for i in extra_dims: dlist.insert(i, slice(None)) return tuple(dlist) def _get_max_frames_process(self): """ Get the number of frames to process for each run of process_frames. If the number of frames is not divisible by the previewing ``chunk`` value then amend the number of frames to gcd(frames, chunk) :returns: Number of frames to process :rtype: int """ if self._plugin and self._plugin.chunk > 1: frame_chunk = self.meta_data.get("max_frames_process") chunk = self.data_obj.get_preview().get_starts_stops_steps( key='chunks')[self.get_slice_directions()[0]] self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk)) return self.meta_data.get("max_frames_process") def _get_max_frames_transfer(self): """ Get the number of frames to transfer for each run of process_frames. """ return self.meta_data.get('max_frames_transfer') def _set_no_squeeze(self): self.no_squeeze = True def _get_no_squeeze(self): return self.no_squeeze def _get_max_frames_parameters(self): fixed, _ = self._get_fixed_dimensions() sdir = \ [s for s in self.data_obj.get_slice_dimensions() if s not in fixed] shape = self.data_obj.get_shape() shape_before_tuning = self._get_shape_before_tuning() diff = len(shape) - len(shape_before_tuning) if diff: shape = shape_before_tuning sdir = sdir[:-diff] frames = np.prod([shape[d] for d in sdir]) base_names = [p.__name__ for p in self._plugin.__class__.__bases__] processes = self.data_obj.exp.meta_data.get('processes') if 'GpuPlugin' in base_names: n_procs = len([n for n in processes if 'GPU' in n]) else: n_procs = len(processes) f_per_p = np.ceil(frames / n_procs) params_dict = { 'shape': shape, 'sdir': sdir, 'total_frames': frames, 'mpi_procs': n_procs, 'frames_per_process': f_per_p } return params_dict def __log_max_frames(self, mft, mfp, check=True): logging.debug("Setting max frames transfer for plugin %s to %d" % (self._plugin, mft)) logging.debug("Setting max frames process for plugin %s to %d" % (self._plugin, mfp)) self.meta_data.set('max_frames_process', mfp) if check: self.__check_distribution(mft) # (((total_frames/mft)/mpi_procs) % 1) def __check_distribution(self, mft): self.params = self._get_max_frames_parameters() warn_threshold = 0.85 nprocs = self.params['mpi_procs'] nframes = self.params['total_frames'] temp = (((nframes / mft) / float(nprocs)) % 1) if temp != 0.0 and temp < warn_threshold: logging.warn( 'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' + 'sdir %s, nprocs %s', self.params['shape'], nframes, self.params['sdir'], nprocs) def _set_padding_dict(self): if self.padding and not isinstance(self.padding, Padding): self.pad_dict = copy.deepcopy(self.padding) self.padding = Padding(self) for key in self.pad_dict.keys(): getattr(self.padding, key)(self.pad_dict[key]) def plugin_data_setup(self, pattern, nFrames, split=None): """ Setup the PluginData object. # add more information into here via a decorator! :param str pattern: A pattern name :param int nFrames: How many frames to process at a time. Choose from\ 'single', 'multiple', 'fixed_multiple' or an integer (an integer \ should only ever be passed in exceptional circumstances) """ self.__set_pattern(pattern) chunks = \ self.data_obj.get_preview().get_starts_stops_steps(key='chunks') if isinstance(nFrames, list): nFrames, self._frame_limit = nFrames self.__set_max_frames(nFrames) mft = self.meta_data.get('max_frames_transfer') if self._plugin and mft \ and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft): self._plugin.chunk = True self.__set_shape() self.split = split def __set_max_frames(self, nFrames): self.max_frames = nFrames self.__perform_checks(nFrames) td = self.data_obj._get_transport_data() mft, mft_shape = td._calc_max_frames_transfer(nFrames) self.meta_data.set('max_frames_transfer', mft) if mft: self._set_shape_transfer(mft_shape) mfp = td._calc_max_frames_process(nFrames) self.meta_data.set('max_frames_process', mfp) self.__log_max_frames(mft, mfp) # Retain the shape if the first slice dimension has length 1 if mfp == 1 and nFrames == 'multiple': self._set_no_squeeze() def __perform_checks(self, nFrames): options = ['single', 'multiple'] if not isinstance(nFrames, int) and nFrames not in options: e_str = "The value of nFrames is not recognised. Please choose " "from 'single' and 'multiple' (or an integer in exceptional " "circumstances)." raise Exception(e_str) def get_frame_limit(self): return self._frame_limit
class PluginData(object): """ The PluginData class contains plugin specific information about a Data object for the duration of a plugin. An instance of the class is encapsulated inside the Data object during the plugin run """ def __init__(self, data_obj, plugin=None): self.data_obj = data_obj self.data_obj._set_plugin_data(self) self.meta_data = MetaData() self.padding = None # this flag determines which data is passed. If false then just the # data, if true then all data including dark and flat fields. self.selected_data = False self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] self._plugin = plugin def get_total_frames(self): """ Get the total number of frames to process. :returns: Number of frames :rtype: int """ temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dir"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def __set_pattern(self, name): """ Set the pattern related information int the meta data dict. """ pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set_meta_data("name", name) self.meta_data.set_meta_data("core_dir", pattern['core_dir']) self.__set_slice_directions() def get_pattern_name(self): """ Get the pattern name. :returns: the pattern name :rtype: str """ name = self.meta_data.get_meta_data("name") if name is not None: return name else: raise Exception("The pattern name has not been set.") def get_pattern(self): """ Get the current pattern. :returns: dict of the pattern name against the pattern. :rtype: dict """ pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def __set_shape(self): """ Set the shape of the plugin data processing chunk. """ core_dir = self.get_core_directions() slice_dir = self.get_slice_directions() dirs = list(set(core_dir + (slice_dir[0],))) slice_idx = dirs.index(slice_dir[0]) shape = [] for core in set(core_dir): shape.append(self.data_obj.get_shape()[core]) self.__set_core_shape(tuple(shape)) if self._get_frame_chunk() > 1: shape.insert(slice_idx, self._get_frame_chunk()) self.shape = tuple(shape) def get_shape(self): """ Get the shape of the plugin data to be processed each time. """ return self.shape def __set_core_shape(self, shape): """ Set the core shape to hold only the shape of the core dimensions """ self.core_shape = shape def get_core_shape(self): """ Get the shape of the core dimensions only. :returns: shape of core dimensions :rtype: tuple """ return self.core_shape def __check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir)+len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def __set_slice_directions(self): """ Set the slice directions in the pluginData meta data dictionary. """ slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dir'] self.meta_data.set_meta_data('slice_dir', slice_dirs) def get_slice_directions(self): """ Get the slice directions (slice_dir) of the dataset. """ return self.meta_data.get_meta_data('slice_dir') def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.get_core_directions() slice_dir = self.get_slice_directions()[0] return list(set(core_dirs + (slice_dir,))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = \ self.data_obj.find_axis_label_dimension(label, contains=contains) plugin_dims = self.get_core_directions() if self._get_frame_chunk() > 1: plugin_dims += (self.get_slice_directions()[0],) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice directions. The fastest changing slice direction will always be the first one stated in the pattern key ``slice_dir``. The input param is a tuple stating the desired order of slicing directions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_directions(self): """ Get the core data directions :returns: value associated with pattern key ``core_dir`` :rtype: tuple """ core_dir = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['core_dir'] return core_dir def set_fixed_directions(self, dims, values): """ Fix a data direction to the index in values list. :param list(int) dims: Directions to fix :param list(int) value: Index of fixed directions """ slice_dirs = self.get_slice_directions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set_meta_data("fixed_directions", dims) self.meta_data.set_meta_data("fixed_directions_values", values) self.__set_slice_directions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) self.__set_shape() def _get_fixed_directions(self): """ Get the fixed data directions and their indices :returns: Fixed directions and their associated values :rtype: list(list(int), list(int)) """ fixed = [] values = [] if 'fixed_directions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get_meta_data("fixed_directions") values = self.meta_data.get_meta_data("fixed_directions_values") return [fixed, values] def _set_frame_chunk(self, nFrames): """ Set the number of frames to process at a time """ self.meta_data.set_meta_data("nFrames", nFrames) def _get_frame_chunk(self): """ Get the number of frames to be processes at a time. If the number of frames is not divisible by the previewing ``chunk`` value then amend the number of frames to gcd(frames, chunk) :returns: Number of frames to process :rtype: int """ if self._plugin and self._plugin.chunk > 1: frame_chunk = self.meta_data.get_meta_data("nFrames") chunk = self.data_obj.get_preview().get_starts_stops_steps( key='chunks')[self.get_slice_directions()[0]] self._set_frame_chunk(gcd(frame_chunk, chunk)) return self.meta_data.get_meta_data("nFrames") def plugin_data_setup(self, pattern_name, chunk): """ Setup the PluginData object. :param str pattern_name: A pattern name :param int chunk: Number of frames to process at a time """ self.__set_pattern(pattern_name) chunks = \ self.data_obj.get_preview().get_starts_stops_steps(key='chunks') if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk): self._plugin.chunk = True self._set_frame_chunk(chunk) self.__set_shape()
class PluginData(object): """ The PluginData class contains plugin specific information about a Data object for the duration of a plugin. An instance of the class is encapsulated inside the Data object during the plugin run """ def __init__(self, data_obj, plugin=None): self.data_obj = data_obj self.data_obj._set_plugin_data(self) self.meta_data = MetaData() self.padding = None self.pad_dict = None # this flag determines which data is passed. If false then just the # data, if true then all data including dark and flat fields. self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] self._plugin = plugin self.fixed_dims = False self.end_pad = False self.split = None def get_total_frames(self): """ Get the total number of frames to process. :returns: Number of frames :rtype: int """ temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dir"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def __set_pattern(self, name): """ Set the pattern related information int the meta data dict. """ pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set_meta_data("name", name) self.meta_data.set_meta_data("core_dir", pattern['core_dir']) self.__set_slice_directions() def get_pattern_name(self): """ Get the pattern name. :returns: the pattern name :rtype: str """ name = self.meta_data.get_meta_data("name") if name is not None: return name else: raise Exception("The pattern name has not been set.") def get_pattern(self): """ Get the current pattern. :returns: dict of the pattern name against the pattern. :rtype: dict """ pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def __set_shape(self): """ Set the shape of the plugin data processing chunk. """ core_dir = self.get_core_directions() slice_dir = self.get_slice_directions() dirs = list(set(core_dir + (slice_dir[0], ))) slice_idx = dirs.index(slice_dir[0]) shape = [] for core in set(core_dir): shape.append(self.data_obj.get_shape()[core]) self.__set_core_shape(tuple(shape)) if self._get_frame_chunk() > 1: shape.insert(slice_idx, self._get_frame_chunk()) self.shape = tuple(shape) def get_shape(self): """ Get the shape of the plugin data to be processed each time. """ return self.shape def __set_core_shape(self, shape): """ Set the core shape to hold only the shape of the core dimensions """ self.core_shape = shape def get_core_shape(self): """ Get the shape of the core dimensions only. :returns: shape of core dimensions :rtype: tuple """ return self.core_shape def __check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir) + len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def __set_slice_directions(self): """ Set the slice directions in the pluginData meta data dictionary. """ slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dir'] self.meta_data.set_meta_data('slice_dir', slice_dirs) def get_slice_directions(self): """ Get the slice directions (slice_dir) of the dataset. """ return self.meta_data.get_meta_data('slice_dir') def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.get_core_directions() slice_dir = self.get_slice_directions()[0] return list(set(core_dirs + (slice_dir, ))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = \ self.data_obj.find_axis_label_dimension(label, contains=contains) plugin_dims = self.get_core_directions() if self._get_frame_chunk() > 1: plugin_dims += (self.get_slice_directions()[0], ) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice directions. The fastest changing slice direction will always be the first one stated in the pattern key ``slice_dir``. The input param is a tuple stating the desired order of slicing directions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_directions(self): """ Get the core data directions :returns: value associated with pattern key ``core_dir`` :rtype: tuple """ core_dir = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['core_dir'] return core_dir def set_fixed_directions(self, dims, values): """ Fix a data direction to the index in values list. :param list(int) dims: Directions to fix :param list(int) value: Index of fixed directions """ slice_dirs = self.get_slice_directions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set_meta_data("fixed_directions", dims) self.meta_data.set_meta_data("fixed_directions_values", values) self.__set_slice_directions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) self.__set_shape() def _get_fixed_directions(self): """ Get the fixed data directions and their indices :returns: Fixed directions and their associated values :rtype: list(list(int), list(int)) """ fixed = [] values = [] if 'fixed_directions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get_meta_data("fixed_directions") values = self.meta_data.get_meta_data("fixed_directions_values") return [fixed, values] def _get_data_slice_list(self, plist): """ Convert a plugin data slice list to a slice list for the whole dataset, i.e. add in any missing dimensions. """ nDims = len(self.get_shape()) all_dims = self.get_core_directions() + self.get_slice_directions() extra_dims = all_dims[nDims:] dlist = list(plist) for i in extra_dims: dlist.insert(i, slice(None)) return tuple(dlist) def _set_frame_chunk(self, nFrames): """ Set the number of frames to process at a time """ self.meta_data.set_meta_data("nFrames", nFrames) def _get_frame_chunk(self): """ Get the number of frames to be processes at a time. If the number of frames is not divisible by the previewing ``chunk`` value then amend the number of frames to gcd(frames, chunk) :returns: Number of frames to process :rtype: int """ if self._plugin and self._plugin.chunk > 1: frame_chunk = self.meta_data.get_meta_data("nFrames") chunk = self.data_obj.get_preview().get_starts_stops_steps( key='chunks')[self.get_slice_directions()[0]] self._set_frame_chunk(gcd(frame_chunk, chunk)) return self.meta_data.get_meta_data("nFrames") def plugin_data_setup(self, pattern_name, chunk, fixed=False, split=None): """ Setup the PluginData object. :param str pattern_name: A pattern name :param int chunk: Number of frames to process at a time :keyword bool fixed: setting fixed=True will ensure the plugin \ receives the same sized data array each time (padding if necessary) """ self.__set_pattern(pattern_name) chunks = \ self.data_obj.get_preview().get_starts_stops_steps(key='chunks') if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk): self._plugin.chunk = True self._set_frame_chunk(chunk) self.__set_shape() self.fixed_dims = fixed self.split = split
class PluginData(object): """ The PluginData class contains plugin specific information about a Data object for the duration of a plugin. An instance of the class is encapsulated inside the Data object during the plugin run """ def __init__(self, data_obj, plugin=None): self.data_obj = data_obj self._preview = None self.data_obj._set_plugin_data(self) self.meta_data = MetaData() self.padding = None self.pad_dict = None self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] self._plugin = plugin self.fixed_dims = True self.split = None self.boundary_padding = None self.no_squeeze = False self.pre_tuning_shape = None self._frame_limit = None self._increase_rank = 0 def _get_preview(self): return self._preview def get_total_frames(self): """ Get the total number of frames to process (all MPI processes). :returns: Number of frames :rtype: int """ temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dims"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def __set_pattern(self, name, first_sdim=None): """ Set the pattern related information in the meta data dict. """ pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set("name", name) self.meta_data.set("core_dims", pattern['core_dims']) self.__set_slice_dimensions(first_sdim=first_sdim) def get_pattern_name(self): """ Get the pattern name. :returns: the pattern name :rtype: str """ try: name = self.meta_data.get("name") return name except KeyError: raise Exception("The pattern name has not been set.") def get_pattern(self): """ Get the current pattern. :returns: dict of the pattern name against the pattern. :rtype: dict """ pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def __set_shape(self): """ Set the shape of the plugin data processing chunk. """ core_dir = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions() dirs = list(set(core_dir + (slice_dir[0], ))) slice_idx = dirs.index(slice_dir[0]) dshape = self.data_obj.get_shape() shape = [] for core in set(core_dir): shape.append(dshape[core]) self.__set_core_shape(tuple(shape)) mfp = self._get_max_frames_process() if mfp > 1 or self._get_no_squeeze(): shape.insert(slice_idx, mfp) self.shape = tuple(shape) def _set_shape_transfer(self, slice_size): dshape = self.data_obj.get_shape() shape_before_tuning = self._get_shape_before_tuning() add = [1] * (len(dshape) - len(shape_before_tuning)) slice_size = slice_size + add core_dir = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions() shape = [None] * len(dshape) for dim in core_dir: shape[dim] = dshape[dim] i = 0 for dim in slice_dir: shape[dim] = slice_size[i] i += 1 return tuple(shape) def __get_slice_size(self, mft): """ Calculate the number of frames transfer in each dimension given mft. """ dshape = list(self.data_obj.get_shape()) if 'fixed_dimensions' in list(self.meta_data.get_dictionary().keys()): fixed_dims = self.meta_data.get('fixed_dimensions') for d in fixed_dims: dshape[d] = 1 dshape = [dshape[i] for i in self.meta_data.get('slice_dims')] size_list = [1] * len(dshape) i = 0 while (mft > 1 and i < len(size_list)): size_list[i] = min(dshape[i], mft) mft //= np.prod(size_list) if np.prod(size_list) > 1 else 1 i += 1 self.meta_data.set('size_list', size_list) return size_list def set_bytes_per_frame(self): """ Return the size of a single frame in bytes. """ nBytes = self.data_obj.get_itemsize() dims = list(self.get_pattern().values())[0]['core_dims'] frame_shape = [self.data_obj.get_shape()[d] for d in dims] b_per_f = np.prod(frame_shape) * nBytes return frame_shape, b_per_f def get_shape(self): """ Get the shape of the data (without padding) that is passed to the plugin process_frames method. """ return self.shape def _set_padded_shape(self): pass def get_padded_shape(self): """ Get the shape of the data (with padding) that is passed to the plugin process_frames method. """ return self.shape def get_shape_transfer(self): """ Get the shape of the plugin data to be transferred each time. """ return self.meta_data.get('transfer_shape') def __set_core_shape(self, shape): """ Set the core shape to hold only the shape of the core dimensions """ self.core_shape = shape def get_core_shape(self): """ Get the shape of the core dimensions only. :returns: shape of core dimensions :rtype: tuple """ return self.core_shape def _set_shape_before_tuning(self, shape): """ Set the shape of the full dataset used during each run of the \ plugin (i.e. ignore extra dimensions due to parameter tuning). """ self.pre_tuning_shape = shape def _get_shape_before_tuning(self): """ Return the shape of the full dataset used during each run of the \ plugin (i.e. ignore extra dimensions due to parameter tuning). """ return self.pre_tuning_shape if self.pre_tuning_shape else\ self.data_obj.get_shape() def __check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir) + len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def __set_slice_dimensions(self, first_sdim=None): """ Set the slice dimensions in the pluginData meta data dictionary.\ Reorder pattern slice_dims to ensure first_sdim is at the front. """ pattern = self.data_obj.get_data_patterns()[self.get_pattern_name()] slice_dims = pattern['slice_dims'] if first_sdim: slice_dims = list(slice_dims) first_sdim = \ self.data_obj.get_data_dimension_by_axis_label(first_sdim) slice_dims.insert(0, slice_dims.pop(slice_dims.index(first_sdim))) pattern['slice_dims'] = tuple(slice_dims) self.meta_data.set('slice_dims', tuple(slice_dims)) def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions()[0] return list(set(core_dirs + (slice_dir, ))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = self.data_obj.get_data_dimension_by_axis_label( label, contains=contains) plugin_dims = self.data_obj.get_core_dimensions() if self._get_max_frames_process() > 1 or self.max_frames == 'multiple': plugin_dims += (self.get_slice_dimension(), ) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): # should this function be deleted? """ Reorder the slice dimensions. The fastest changing slice dimension will always be the first one stated in the pattern key ``slice_dir``. The input param is a tuple stating the desired order of slicing dimensions relative to the current order. """ slice_dirs = self.data_obj.get_slice_dimensions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_pattern()['slice_dir'] = new_slice_dirs def get_core_dimensions(self): """ Return the position of the core dimensions in relation to the data handed to the plugin. """ core_dims = self.data_obj.get_core_dimensions() first_slice_dim = (self.data_obj.get_slice_dimensions()[0], ) plugin_dims = np.sort(core_dims + first_slice_dim) return np.searchsorted(plugin_dims, np.sort(core_dims)) def set_fixed_dimensions(self, dims, values): """ Fix a data direction to the index in values list. :param list(int) dims: Directions to fix :param list(int) value: Index of fixed directions """ slice_dirs = self.data_obj.get_slice_dimensions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set("fixed_dimensions", dims) self.meta_data.set("fixed_dimensions_values", values) self.__set_slice_dimensions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) #self.__set_shape() def _get_fixed_dimensions(self): """ Get the fixed data directions and their indices :returns: Fixed directions and their associated values :rtype: list(list(int), list(int)) """ fixed = [] values = [] if 'fixed_dimensions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get("fixed_dimensions") values = self.meta_data.get("fixed_dimensions_values") return [fixed, values] def _get_data_slice_list(self, plist): """ Convert a plugin data slice list to a slice list for the whole dataset, i.e. add in any missing dimensions. """ nDims = len(self.get_shape()) all_dims = self.get_core_dimensions() + self.get_slice_dimension() extra_dims = all_dims[nDims:] dlist = list(plist) for i in extra_dims: dlist.insert(i, slice(None)) return tuple(dlist) def _get_max_frames_process(self): """ Get the number of frames to process for each run of process_frames. If the number of frames is not divisible by the previewing ``chunk`` value then amend the number of frames to gcd(frames, chunk) :returns: Number of frames to process :rtype: int """ if self._plugin and self._plugin.chunk > 1: frame_chunk = self.meta_data.get("max_frames_process") chunk = self.data_obj.get_preview().get_starts_stops_steps( key='chunks')[self.get_slice_directions()[0]] self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk)) return self.meta_data.get("max_frames_process") def _get_max_frames_transfer(self): """ Get the number of frames to transfer for each run of process_frames. """ return self.meta_data.get('max_frames_transfer') def _set_no_squeeze(self): self.no_squeeze = True def _get_no_squeeze(self): return self.no_squeeze def _set_rank_inc(self, n): """ Increase the rank of the array passed to the plugin by n. :param int n: Rank increment. """ self._increase_rank = n def _get_rank_inc(self): """ Return the increased rank value :returns: Rank increment :rtype: int """ return self._increase_rank def _set_meta_data(self): fixed, _ = self._get_fixed_dimensions() sdir = \ [s for s in self.data_obj.get_slice_dimensions() if s not in fixed] shape = self.data_obj.get_shape() shape_before_tuning = self._get_shape_before_tuning() diff = len(shape) - len(shape_before_tuning) if diff: shape = shape_before_tuning sdir = sdir[:-diff] if 'fix_total_frames' in list(self.meta_data.get_dictionary().keys()): frames = self.meta_data.get('fix_total_frames') else: frames = np.prod([shape[d] for d in sdir]) base_names = [p.__name__ for p in self._plugin.__class__.__bases__] processes = self.data_obj.exp.meta_data.get('processes') if 'GpuPlugin' in base_names: n_procs = len([n for n in processes if 'GPU' in n]) else: n_procs = len(processes) # Fixing f_per_p to be just the first slice dimension for now due to # slow performance from HDF5 when not slicing multiple dimensions # concurrently #f_per_p = np.ceil(frames/n_procs) f_per_p = np.ceil(shape[sdir[0]] / n_procs) self.meta_data.set('shape', shape) self.meta_data.set('sdir', sdir) self.meta_data.set('total_frames', frames) self.meta_data.set('mpi_procs', n_procs) self.meta_data.set('frames_per_process', f_per_p) frame_shape, b_per_f = self.set_bytes_per_frame() self.meta_data.set('bytes_per_frame', b_per_f) self.meta_data.set('bytes_per_process', b_per_f * f_per_p) self.meta_data.set('frame_shape', frame_shape) def __log_max_frames(self, mft, mfp, check=True): logging.debug("Setting max frames transfer for plugin %s to %d" % (self._plugin, mft)) logging.debug("Setting max frames process for plugin %s to %d" % (self._plugin, mfp)) self.meta_data.set('max_frames_process', mfp) if check: self.__check_distribution(mft) # (((total_frames/mft)/mpi_procs) % 1) def __check_distribution(self, mft): warn_threshold = 0.85 nprocs = self.meta_data.get('mpi_procs') nframes = self.meta_data.get('total_frames') temp = (((nframes / mft) / float(nprocs)) % 1) if temp != 0.0 and temp < warn_threshold: shape = self.meta_data.get('shape') sdir = self.meta_data.get('sdir') logging.warning( 'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' + 'sdir %s, nprocs %s', shape, nframes, sdir, nprocs) def _set_padding_dict(self): if self.padding and not isinstance(self.padding, Padding): self.pad_dict = copy.deepcopy(self.padding) self.padding = Padding(self) for key in list(self.pad_dict.keys()): getattr(self.padding, key)(self.pad_dict[key]) def plugin_data_setup(self, pattern, nFrames, split=None, slice_axis=None, getall=None): """ Setup the PluginData object. :param str pattern: A pattern name :param int nFrames: How many frames to process at a time. Choose from 'single', 'multiple', 'fixed_multiple' or an integer (an integer should only ever be passed in exceptional circumstances) :keyword str slice_axis: An axis label associated with the fastest changing (first) slice dimension. :keyword list[pattern, axis_label] getall: A list of two values. If the requested pattern doesn't exist then use all of "axis_label" dimension of "pattern" as this is equivalent to one slice of the original pattern. """ if pattern not in self.data_obj.get_data_patterns() and getall: pattern, nFrames = self.__set_getall_pattern(getall, nFrames) # slice_axis is first slice dimension self.__set_pattern(pattern, first_sdim=slice_axis) if isinstance(nFrames, list): nFrames, self._frame_limit = nFrames self.max_frames = nFrames self.split = split def __set_getall_pattern(self, getall, nFrames): """ Set framework changes required to get all of a pattern of lower rank. """ pattern, slice_axis = getall dim = self.data_obj.get_data_dimension_by_axis_label(slice_axis) # ensure data remains the same shape when 'getall' dim has length 1 self._set_no_squeeze() if nFrames == 'multiple' or (isinstance(nFrames, int) and nFrames > 1): self._set_rank_inc(1) nFrames = self.data_obj.get_shape()[dim] return pattern, nFrames def plugin_data_transfer_setup(self, copy=None, calc=None): """ Set up the plugin data transfer frame parameters. If copy=pData (another PluginData instance) then copy """ chunks = \ self.data_obj.get_preview().get_starts_stops_steps(key='chunks') if not copy and not calc: mft, mft_shape, mfp = self._calculate_max_frames() elif calc: max_mft = calc.meta_data.get('max_frames_transfer') max_mfp = calc.meta_data.get('max_frames_process') max_nProc = int(np.ceil(max_mft / float(max_mfp))) nProc = max_nProc mfp = 1 if self.max_frames == 'single' else self.max_frames mft = nProc * mfp mft_shape = self._set_shape_transfer(self.__get_slice_size(mft)) elif copy: mft = copy._get_max_frames_transfer() mft_shape = self._set_shape_transfer(self.__get_slice_size(mft)) mfp = copy._get_max_frames_process() self.__set_max_frames(mft, mft_shape, mfp) if self._plugin and mft \ and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft): self._plugin.chunk = True self.__set_shape() def _calculate_max_frames(self): nFrames = self.max_frames self.__perform_checks(nFrames) td = self.data_obj._get_transport_data() mft, size_list = td._calc_max_frames_transfer(nFrames) self.meta_data.set('size_list', size_list) mfp = td._calc_max_frames_process(nFrames) if mft: mft_shape = self._set_shape_transfer(list(size_list)) return mft, mft_shape, mfp def __set_max_frames(self, mft, mft_shape, mfp): self.meta_data.set('max_frames_transfer', mft) self.meta_data.set('transfer_shape', mft_shape) self.meta_data.set('max_frames_process', mfp) self.__log_max_frames(mft, mfp) # Retain the shape if the first slice dimension has length 1 if mfp == 1 and self.max_frames == 'multiple': self._set_no_squeeze() def _get_plugin_data_size_params(self): nBytes = self.data_obj.get_itemsize() frame_shape = self.meta_data.get('frame_shape') total_frames = self.meta_data.get('total_frames') tbytes = nBytes * np.prod(frame_shape) * total_frames params = { 'nBytes': nBytes, 'frame_shape': frame_shape, 'total_frames': total_frames, 'transfer_bytes': tbytes } return params def __perform_checks(self, nFrames): options = ['single', 'multiple'] if not np.issubdtype(type(nFrames), np.int64) and nFrames not in options: e_str = ( "The value of nFrames is not recognised. Please choose " + "from 'single' and 'multiple' (or an integer in exceptional " + "circumstances).") raise Exception(e_str) def get_frame_limit(self): return self._frame_limit def get_current_frame_idx(self): """ Returns the index of the frames currently being processed. """ global_index = self._plugin.get_global_frame_index() count = self._plugin.get_process_frames_counter() mfp = self.meta_data.get('max_frames_process') start = global_index[count] * mfp index = np.arange(start, start + mfp) nFrames = self.get_total_frames() index[index >= nFrames] = nFrames - 1 return index
class Data(DataCreate): """The Data class dynamically inherits from transport specific data class and holds the data array, along with associated information. """ def __init__(self, name, exp): super(Data, self).__init__(name) self.meta_data = MetaData() self.pattern_list = self.__get_available_pattern_list() self.data_info = MetaData() self.__initialise_data_info(name) self._preview = Preview(self) self.exp = exp self.group_name = None self.group = None self._plugin_data_obj = None self.raw = None self.backing_file = None self.data = None self.next_shape = None self.orig_shape = None self.previous_pattern = None self.transport_data = None def __initialise_data_info(self, name): """ Initialise entries in the data_info meta data. """ self.data_info.set('name', name) self.data_info.set('data_patterns', {}) self.data_info.set('shape', None) self.data_info.set('nDims', None) def _set_plugin_data(self, plugin_data_obj): """ Encapsulate a PluginData object. """ self._plugin_data_obj = plugin_data_obj def _clear_plugin_data(self): """ Set encapsulated PluginData object to None. """ self._plugin_data_obj = None def _get_plugin_data(self): """ Get encapsulated PluginData object. """ if self._plugin_data_obj is not None: return self._plugin_data_obj else: raise Exception("There is no PluginData object associated with " "the Data object.") def get_preview(self): """ Get the Preview instance associated with the data object """ return self._preview def _set_transport_data(self, transport): """ Import the data transport mechanism :returns: instance of data transport :rtype: transport_data """ transport_data = "savu.data.transport_data." + transport + \ "_transport_data" transport_data = cu.import_class(transport_data) self.transport_data = transport_data(self) self.data_info.set('transport', transport) def _get_transport_data(self): return self.transport_data def __deepcopy__(self, memo): """ Copy the data object. """ name = self.data_info.get('name') return dsu._deepcopy_data_object(self, Data(name, self.exp)) def get_data_patterns(self): """ Get data patterns associated with this data object. :returns: A dictionary of associated patterns. :rtype: dict """ return self.data_info.get('data_patterns') def _set_previous_pattern(self, pattern): self.previous_pattern = pattern def get_previous_pattern(self): return self.previous_pattern def set_shape(self, shape): """ Set the dataset shape. """ self.data_info.set('shape', shape) self.__check_dims() def set_original_shape(self, shape): """ Set the original data shape before previewing """ self.orig_shape = shape self.set_shape(shape) def get_shape(self): """ Get the dataset shape :returns: data shape :rtype: tuple """ shape = self.data_info.get('shape') return shape def __check_dims(self): """ Check the ``shape`` and ``nDims`` entries in the data_info meta_data dictionary are equal. """ nDims = self.data_info.get("nDims") shape = self.data_info.get('shape') if nDims: if len(shape) != nDims: error_msg = ("The number of axis labels, %d, does not " "coincide with the number of data " "dimensions %d." % (nDims, len(shape))) raise Exception(error_msg) def _set_name(self, name): self.data_info.set('name', name) def get_name(self): """ Get data name. :returns: the name associated with the dataset :rtype: str """ return self.data_info.get('name') def __get_available_pattern_list(self): """ Get a list of ALL pattern names that are currently allowed in the framework. """ pattern_list = dsu.get_available_pattern_types() return pattern_list def add_pattern(self, dtype, **kwargs): """ Add a pattern. :params str dtype: The *type* of pattern to add, which can be anything from the :const:`savu.data.data_structures.utils.pattern_list` :const:`pattern_list` :data:`savu.data.data_structures.utils.pattern_list` :data:`pattern_list`: :keyword tuple core_dir: Dimension indices of core dimensions :keyword tuple slice_dir: Dimension indices of slice dimensions """ if dtype in self.pattern_list: nDims = 0 for args in kwargs: nDims += len(kwargs[args]) self.data_info.set(['data_patterns', dtype, args], kwargs[args]) self.__convert_pattern_directions(dtype) if self.get_shape(): diff = len(self.get_shape()) - nDims if diff: pattern = {dtype: self.get_data_patterns()[dtype]} self._set_data_patterns(pattern) nDims += diff try: if nDims != self.data_info.get("nDims"): actualDims = self.data_info.get('nDims') err_msg = ("The pattern %s has an incorrect number of " "dimensions: %d required but %d specified." % (dtype, actualDims, nDims)) raise Exception(err_msg) except KeyError: self.data_info.set('nDims', nDims) else: raise Exception("The data pattern '%s'does not exist. Please " "choose from the following list: \n'%s'", dtype, str(self.pattern_list)) def add_volume_patterns(self, x, y, z): """ Adds 3D volume patterns :params int x: dimension to be associated with x-axis :params int y: dimension to be associated with y-axis :params int z: dimension to be associated with z-axis """ self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x)) self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y)) self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z)) def __get_dirs_for_volume(self, dim1, dim2, sdir): """ Calculate core_dir and slice_dir for a 3D volume pattern. """ all_dims = range(len(self.get_shape())) vol_dict = {} vol_dict['core_dims'] = (dim1, dim2) slice_dir = [sdir] # *** need to add this for other patterns for ddir in all_dims: if ddir not in [dim1, dim2, sdir]: slice_dir.append(ddir) vol_dict['slice_dims'] = tuple(slice_dir) return vol_dict def set_axis_labels(self, *args): """ Set the axis labels associated with each data dimension. :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\ a data_obj.meta_data entry, it will be output to the final .nxs file. """ self.data_info.set('nDims', len(args)) axis_labels = [] for arg in args: if isinstance(arg, dict): axis_labels.append(arg) else: try: axis = arg.split('.') axis_labels.append({axis[0]: axis[1]}) except: # data arrives here, but that may be an error pass self.data_info.set('axis_labels', axis_labels) def get_axis_labels(self): """ Get axis labels. :returns: Axis labels :rtype: list(dict) """ return self.data_info.get('axis_labels') def get_data_dimension_by_axis_label(self, name, contains=False): """ Get the dimension of the data associated with a particular axis_label. :param str name: The name of the axis_label :keyword bool contains: Set this flag to true if the name is only part of the axis_label name :returns: The associated axis number :rtype: int """ axis_labels = self.data_info.get('axis_labels') for i in range(len(axis_labels)): if contains is True: for names in axis_labels[i].keys(): if name in names: return i else: if name in axis_labels[i].keys(): return i raise Exception("Cannot find the specifed axis label.") def _finalise_patterns(self): """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON patterns. """ check = 0 check += self.__check_pattern('SINOGRAM') check += self.__check_pattern('PROJECTION') if check is 2 and len(self.get_shape()) > 2: self.__set_main_axis('SINOGRAM') self.__set_main_axis('PROJECTION') elif check is 1: pass def __check_pattern(self, pattern_name): """ Check if a pattern exists. """ patterns = self.get_data_patterns() try: patterns[pattern_name] except KeyError: return 0 return 1 def __convert_pattern_directions(self, dtype): """ Replace negative indices in pattern kwargs. """ pattern = self.get_data_patterns()[dtype] if 'main_dir' in pattern.keys(): del pattern['main_dir'] nDims = sum([len(i) for i in pattern.values()]) for p in pattern: ddirs = pattern[p] pattern[p] = self._non_negative_directions(ddirs, nDims) def _non_negative_directions(self, ddirs, nDims): """ Replace negative indexing values with positive counterparts. :params tuple(int) ddirs: data dimension indices :params int nDims: The number of data dimensions :returns: non-negative data dimension indices :rtype: tuple(int) """ index = [i for i in range(len(ddirs)) if ddirs[i] < 0] list_ddirs = list(ddirs) for i in index: list_ddirs[i] = nDims + ddirs[i] return tuple(list_ddirs) def __set_main_axis(self, pname): """ Set the ``main_dir`` pattern kwarg to the fastest changing dimension """ patterns = self.get_data_patterns() n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM' d1 = patterns[n1]['core_dims'] d2 = patterns[pname]['slice_dims'] tdir = set(d1).intersection(set(d2)) # this is required when a single sinogram exists in the mm case, and a # dimension is added via parameter tuning. if not tdir: tdir = [d2[0]] self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0]) def get_axis_label_keys(self): """ Get axis_label names :returns: A list containing associated axis names for each dimension :rtype: list(str) """ axis_labels = self.data_info.get('axis_labels') axis_label_keys = [] for labels in axis_labels: for key in labels.keys(): axis_label_keys.append(key) return axis_label_keys def amend_axis_label_values(self, slice_list): """ Amend all axis label values based on the slice_list parameter.\ This is required if the data is reduced. """ axis_labels = self.get_axis_labels() for i in range(len(slice_list)): label = axis_labels[i].keys()[0] if label in self.meta_data.get_dictionary().keys(): values = self.meta_data.get(label) preview_sl = [slice(None)]*len(values.shape) preview_sl[0] = slice_list[i] self.meta_data.set(label, values[preview_sl]) def get_core_dimensions(self): """ Get the core data dimensions associated with the current pattern. :returns: value associated with pattern key ``core_dims`` :rtype: tuple """ return self._get_plugin_data().get_pattern().values()[0]['core_dims'] def get_slice_dimensions(self): """ Get the slice data dimensions associated with the current pattern. :returns: value associated with pattern key ``slice_dims`` :rtype: tuple """ return self._get_plugin_data().get_pattern().values()[0]['slice_dims'] def get_itemsize(self): """ Returns bytes per entry """ dtype = self.get_dtype() if not dtype: self.set_dtype(None) dtype = self.get_dtype() return self.get_dtype().itemsize
class Data(DataCreate): """The Data class dynamically inherits from transport specific data class and holds the data array, along with associated information. """ def __init__(self, name, exp): super(Data, self).__init__(name) self.meta_data = MetaData() self.pattern_list = self.__get_available_pattern_list() self.data_info = MetaData() self.__initialise_data_info(name) self._preview = Preview(self) self.exp = exp self.group_name = None self.group = None self._plugin_data_obj = None self.raw = None self.backing_file = None self.data = None self.next_shape = None self.orig_shape = None self.previous_pattern = None self.transport_data = None def __initialise_data_info(self, name): """ Initialise entries in the data_info meta data. """ self.data_info.set('name', name) self.data_info.set('data_patterns', {}) self.data_info.set('shape', None) self.data_info.set('nDims', None) def _set_plugin_data(self, plugin_data_obj): """ Encapsulate a PluginData object. """ self._plugin_data_obj = plugin_data_obj def _clear_plugin_data(self): """ Set encapsulated PluginData object to None. """ self._plugin_data_obj = None def _get_plugin_data(self): """ Get encapsulated PluginData object. """ if self._plugin_data_obj is not None: return self._plugin_data_obj else: raise Exception("There is no PluginData object associated with " "the Data object.") def get_preview(self): """ Get the Preview instance associated with the data object """ return self._preview def _set_transport_data(self, transport): """ Import the data transport mechanism :returns: instance of data transport :rtype: transport_data """ transport_data = "savu.data.transport_data." + transport + \ "_transport_data" transport_data = cu.import_class(transport_data) self.transport_data = transport_data(self) def _get_transport_data(self): return self.transport_data def __deepcopy__(self, memo): """ Copy the data object. """ name = self.data_info.get('name') return dsu._deepcopy_data_object(self, Data(name, self.exp)) def get_data_patterns(self): """ Get data patterns associated with this data object. :returns: A dictionary of associated patterns. :rtype: dict """ return self.data_info.get('data_patterns') def _set_previous_pattern(self, pattern): self.previous_pattern = pattern def get_previous_pattern(self): return self.previous_pattern def set_shape(self, shape): """ Set the dataset shape. """ self.data_info.set('shape', shape) self.__check_dims() def set_original_shape(self, shape): """ Set the original data shape before previewing """ self.orig_shape = shape self.set_shape(shape) def get_shape(self): """ Get the dataset shape :returns: data shape :rtype: tuple """ shape = self.data_info.get('shape') return shape def __check_dims(self): """ Check the ``shape`` and ``nDims`` entries in the data_info meta_data dictionary are equal. """ nDims = self.data_info.get("nDims") shape = self.data_info.get('shape') if nDims: if len(shape) != nDims: error_msg = ("The number of axis labels, %d, does not " "coincide with the number of data " "dimensions %d." % (nDims, len(shape))) raise Exception(error_msg) def _set_name(self, name): self.data_info.set('name', name) def get_name(self): """ Get data name. :returns: the name associated with the dataset :rtype: str """ return self.data_info.get('name') def __get_available_pattern_list(self): """ Get a list of ALL pattern names that are currently allowed in the framework. """ pattern_list = dsu.get_available_pattern_types() return pattern_list def add_pattern(self, dtype, **kwargs): """ Add a pattern. :params str dtype: The *type* of pattern to add, which can be anything from the :const:`savu.data.data_structures.utils.pattern_list` :const:`pattern_list` :data:`savu.data.data_structures.utils.pattern_list` :data:`pattern_list`: :keyword tuple core_dir: Dimension indices of core dimensions :keyword tuple slice_dir: Dimension indices of slice dimensions """ if dtype in self.pattern_list: nDims = 0 for args in kwargs: nDims += len(kwargs[args]) self.data_info.set(['data_patterns', dtype, args], kwargs[args]) self.__convert_pattern_directions(dtype) if self.get_shape(): diff = len(self.get_shape()) - nDims if diff: pattern = {dtype: self.get_data_patterns()[dtype]} self._set_data_patterns(pattern) nDims += diff try: if nDims != self.data_info.get("nDims"): actualDims = self.data_info.get('nDims') err_msg = ("The pattern %s has an incorrect number of " "dimensions: %d required but %d specified." % (dtype, actualDims, nDims)) raise Exception(err_msg) except KeyError: self.data_info.set('nDims', nDims) else: raise Exception( "The data pattern '%s'does not exist. Please " "choose from the following list: \n'%s'", dtype, str(self.pattern_list)) def add_volume_patterns(self, x, y, z): """ Adds 3D volume patterns :params int x: dimension to be associated with x-axis :params int y: dimension to be associated with y-axis :params int z: dimension to be associated with z-axis """ self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x)) self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y)) self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z)) def __get_dirs_for_volume(self, dim1, dim2, sdir): """ Calculate core_dir and slice_dir for a 3D volume pattern. """ all_dims = range(len(self.get_shape())) vol_dict = {} vol_dict['core_dims'] = (dim1, dim2) slice_dir = [sdir] # *** need to add this for other patterns for ddir in all_dims: if ddir not in [dim1, dim2, sdir]: slice_dir.append(ddir) vol_dict['slice_dims'] = tuple(slice_dir) return vol_dict def set_axis_labels(self, *args): """ Set the axis labels associated with each data dimension. :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\ a data_obj.meta_data entry, it will be output to the final .nxs file. """ self.data_info.set('nDims', len(args)) axis_labels = [] for arg in args: if isinstance(arg, dict): axis_labels.append(arg) else: try: axis = arg.split('.') axis_labels.append({axis[0]: axis[1]}) except: # data arrives here, but that may be an error pass self.data_info.set('axis_labels', axis_labels) def get_axis_labels(self): """ Get axis labels. :returns: Axis labels :rtype: list(dict) """ return self.data_info.get('axis_labels') def get_data_dimension_by_axis_label(self, name, contains=False): """ Get the dimension of the data associated with a particular axis_label. :param str name: The name of the axis_label :keyword bool contains: Set this flag to true if the name is only part of the axis_label name :returns: The associated axis number :rtype: int """ axis_labels = self.data_info.get('axis_labels') for i in range(len(axis_labels)): if contains is True: for names in axis_labels[i].keys(): if name in names: return i else: if name in axis_labels[i].keys(): return i raise Exception("Cannot find the specifed axis label.") def _finalise_patterns(self): """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON patterns. """ check = 0 check += self.__check_pattern('SINOGRAM') check += self.__check_pattern('PROJECTION') if check is 2 and len(self.get_shape()) > 2: self.__set_main_axis('SINOGRAM') self.__set_main_axis('PROJECTION') elif check is 1: pass def __check_pattern(self, pattern_name): """ Check if a pattern exists. """ patterns = self.get_data_patterns() try: patterns[pattern_name] except KeyError: return 0 return 1 def __convert_pattern_directions(self, dtype): """ Replace negative indices in pattern kwargs. """ pattern = self.get_data_patterns()[dtype] if 'main_dir' in pattern.keys(): del pattern['main_dir'] nDims = sum([len(i) for i in pattern.values()]) for p in pattern: ddirs = pattern[p] pattern[p] = self._non_negative_directions(ddirs, nDims) def _non_negative_directions(self, ddirs, nDims): """ Replace negative indexing values with positive counterparts. :params tuple(int) ddirs: data dimension indices :params int nDims: The number of data dimensions :returns: non-negative data dimension indices :rtype: tuple(int) """ index = [i for i in range(len(ddirs)) if ddirs[i] < 0] list_ddirs = list(ddirs) for i in index: list_ddirs[i] = nDims + ddirs[i] return tuple(list_ddirs) def __set_main_axis(self, pname): """ Set the ``main_dir`` pattern kwarg to the fastest changing dimension """ patterns = self.get_data_patterns() n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM' d1 = patterns[n1]['core_dims'] d2 = patterns[pname]['slice_dims'] tdir = set(d1).intersection(set(d2)) # this is required when a single sinogram exists in the mm case, and a # dimension is added via parameter tuning. if not tdir: tdir = [d2[0]] self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0]) def get_axis_label_keys(self): """ Get axis_label names :returns: A list containing associated axis names for each dimension :rtype: list(str) """ axis_labels = self.data_info.get('axis_labels') axis_label_keys = [] for labels in axis_labels: for key in labels.keys(): axis_label_keys.append(key) return axis_label_keys def amend_axis_label_values(self, slice_list): """ Amend all axis label values based on the slice_list parameter.\ This is required if the data is reduced. """ axis_labels = self.get_axis_labels() for i in range(len(slice_list)): label = axis_labels[i].keys()[0] if label in self.meta_data.get_dictionary().keys(): values = self.meta_data.get(label) preview_sl = [slice(None)] * len(values.shape) preview_sl[0] = slice_list[i] self.meta_data.set(label, values[preview_sl]) def get_core_dimensions(self): """ Get the core data dimensions associated with the current pattern. :returns: value associated with pattern key ``core_dims`` :rtype: tuple """ return self._get_plugin_data().get_pattern().values()[0]['core_dims'] def get_slice_dimensions(self): """ Get the slice data dimensions associated with the current pattern. :returns: value associated with pattern key ``slice_dims`` :rtype: tuple """ return self._get_plugin_data().get_pattern().values()[0]['slice_dims'] def _add_raw_data_obj(self, data_obj): from savu.data.data_structures.data_types.data_plus_darks_and_flats\ import ImageKey, NoImageKey if not isinstance(data_obj.raw, (ImageKey, NoImageKey)): raise Exception('Raw data type not recognised.') proj_dim = self.get_data_dimension_by_axis_label('rotation_angle') data_obj.data = NoImageKey(data_obj, None, proj_dim) data_obj.data._set_dark_and_flat() if isinstance(data_obj.raw, ImageKey): data_obj.raw._convert_to_noimagekey(data_obj.data) data_obj.data._set_fake_key(data_obj.raw.get_image_key())
class PluginData(object): """ The PluginData class contains plugin specific information about a Data object for the duration of a plugin. An instance of the class is encapsulated inside the Data object during the plugin run """ def __init__(self, data_obj, plugin=None): self.data_obj = data_obj self._preview = None self.data_obj._set_plugin_data(self) self.meta_data = MetaData() self.padding = None self.pad_dict = None self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] self._plugin = plugin self.fixed_dims = True self.split = None self.boundary_padding = None self.no_squeeze = False self.pre_tuning_shape = None self._frame_limit = None def _get_preview(self): return self._preview def get_total_frames(self): """ Get the total number of frames to process (all MPI processes). :returns: Number of frames :rtype: int """ temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dims"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def __set_pattern(self, name): """ Set the pattern related information in the meta data dict. """ pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set("name", name) self.meta_data.set("core_dims", pattern['core_dims']) self.__set_slice_dimensions() def get_pattern_name(self): """ Get the pattern name. :returns: the pattern name :rtype: str """ try: name = self.meta_data.get("name") return name except KeyError: raise Exception("The pattern name has not been set.") def get_pattern(self): """ Get the current pattern. :returns: dict of the pattern name against the pattern. :rtype: dict """ pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def __set_shape(self): """ Set the shape of the plugin data processing chunk. """ core_dir = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions() dirs = list(set(core_dir + (slice_dir[0],))) slice_idx = dirs.index(slice_dir[0]) dshape = self.data_obj.get_shape() shape = [] for core in set(core_dir): shape.append(dshape[core]) self.__set_core_shape(tuple(shape)) mfp = self._get_max_frames_process() if mfp > 1 or self._get_no_squeeze(): shape.insert(slice_idx, mfp) self.shape = tuple(shape) def _set_shape_transfer(self, slice_size): dshape = self.data_obj.get_shape() shape_before_tuning = self._get_shape_before_tuning() add = [1]*(len(dshape) - len(shape_before_tuning)) slice_size = slice_size + add core_dir = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions() shape = [None]*len(dshape) for dim in core_dir: shape[dim] = dshape[dim] i = 0 for dim in slice_dir: shape[dim] = slice_size[i] i += 1 return tuple(shape) def __get_slice_size(self, mft): """ Calculate the number of frames transfer in each dimension given mft. """ dshape = list(self.data_obj.get_shape()) if 'fixed_dimensions' in self.meta_data.get_dictionary().keys(): fixed_dims = self.meta_data.get('fixed_dimensions') for d in fixed_dims: dshape[d] = 1 dshape = [dshape[i] for i in self.meta_data.get('slice_dims')] size_list = [1]*len(dshape) i = 0 while(mft > 1): size_list[i] = min(dshape[i], mft) mft -= np.prod(size_list) if np.prod(size_list) > 1 else 0 i += 1 self.meta_data.set('size_list', size_list) return size_list def set_bytes_per_frame(self): """ Return the size of a single frame in bytes. """ nBytes = self.data_obj.get_itemsize() dims = self.get_pattern().values()[0]['core_dims'] frame_shape = [self.data_obj.get_shape()[d] for d in dims] b_per_f = np.prod(frame_shape)*nBytes return frame_shape, b_per_f def get_shape(self): """ Get the shape of the data (without padding) that is passed to the plugin process_frames method. """ return self.shape def _set_padded_shape(self): pass def get_padded_shape(self): """ Get the shape of the data (with padding) that is passed to the plugin process_frames method. """ return self.shape def get_shape_transfer(self): """ Get the shape of the plugin data to be transferred each time. """ return self.meta_data.get('transfer_shape') def __set_core_shape(self, shape): """ Set the core shape to hold only the shape of the core dimensions """ self.core_shape = shape def get_core_shape(self): """ Get the shape of the core dimensions only. :returns: shape of core dimensions :rtype: tuple """ return self.core_shape def _set_shape_before_tuning(self, shape): """ Set the shape of the full dataset used during each run of the \ plugin (i.e. ignore extra dimensions due to parameter tuning). """ self.pre_tuning_shape = shape def _get_shape_before_tuning(self): """ Return the shape of the full dataset used during each run of the \ plugin (i.e. ignore extra dimensions due to parameter tuning). """ return self.pre_tuning_shape if self.pre_tuning_shape else\ self.data_obj.get_shape() def __check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir)+len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def __set_slice_dimensions(self): """ Set the slice dimensions in the pluginData meta data dictionary. """ slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dims'] self.meta_data.set('slice_dims', slice_dirs) def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.data_obj.get_core_dimensions() slice_dir = self.data_obj.get_slice_dimensions()[0] return list(set(core_dirs + (slice_dir,))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = self.data_obj.get_data_dimension_by_axis_label( label, contains=contains) plugin_dims = self.data_obj.get_core_dimensions() if self._get_max_frames_process() > 1 or self.max_frames == 'multiple': plugin_dims += (self.get_slice_dimension(),) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice dimensions. The fastest changing slice dimension will always be the first one stated in the pattern key ``slice_dir``. The input param is a tuple stating the desired order of slicing dimensions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_dimensions(self): """ Return the position of the core dimensions in relation to the data handed to the plugin. """ core_dims = self.data_obj.get_core_dimensions() first_slice_dim = (self.data_obj.get_slice_dimensions()[0],) plugin_dims = np.sort(core_dims + first_slice_dim) return np.searchsorted(plugin_dims, np.sort(core_dims)) def set_fixed_dimensions(self, dims, values): """ Fix a data direction to the index in values list. :param list(int) dims: Directions to fix :param list(int) value: Index of fixed directions """ slice_dirs = self.data_obj.get_slice_dimensions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set("fixed_dimensions", dims) self.meta_data.set("fixed_dimensions_values", values) self.__set_slice_dimensions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) #self.__set_shape() def _get_fixed_dimensions(self): """ Get the fixed data directions and their indices :returns: Fixed directions and their associated values :rtype: list(list(int), list(int)) """ fixed = [] values = [] if 'fixed_dimensions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get("fixed_dimensions") values = self.meta_data.get("fixed_dimensions_values") return [fixed, values] def _get_data_slice_list(self, plist): """ Convert a plugin data slice list to a slice list for the whole dataset, i.e. add in any missing dimensions. """ nDims = len(self.get_shape()) all_dims = self.get_core_dimensions() + self.get_slice_dimension() extra_dims = all_dims[nDims:] dlist = list(plist) for i in extra_dims: dlist.insert(i, slice(None)) return tuple(dlist) def _get_max_frames_process(self): """ Get the number of frames to process for each run of process_frames. If the number of frames is not divisible by the previewing ``chunk`` value then amend the number of frames to gcd(frames, chunk) :returns: Number of frames to process :rtype: int """ if self._plugin and self._plugin.chunk > 1: frame_chunk = self.meta_data.get("max_frames_process") chunk = self.data_obj.get_preview().get_starts_stops_steps( key='chunks')[self.get_slice_directions()[0]] self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk)) return self.meta_data.get("max_frames_process") def _get_max_frames_transfer(self): """ Get the number of frames to transfer for each run of process_frames. """ return self.meta_data.get('max_frames_transfer') def _set_no_squeeze(self): self.no_squeeze = True def _get_no_squeeze(self): return self.no_squeeze def _set_meta_data(self): fixed, _ = self._get_fixed_dimensions() sdir = \ [s for s in self.data_obj.get_slice_dimensions() if s not in fixed] shape = self.data_obj.get_shape() shape_before_tuning = self._get_shape_before_tuning() diff = len(shape) - len(shape_before_tuning) if diff: shape = shape_before_tuning sdir = sdir[:-diff] if 'fix_total_frames' in self.meta_data.get_dictionary().keys(): frames = self.meta_data.get('fix_total_frames') else: frames = np.prod([shape[d] for d in sdir]) base_names = [p.__name__ for p in self._plugin.__class__.__bases__] processes = self.data_obj.exp.meta_data.get('processes') if 'GpuPlugin' in base_names: n_procs = len([n for n in processes if 'GPU' in n]) else: n_procs = len(processes) f_per_p = np.ceil(frames/n_procs) self.meta_data.set('shape', shape) self.meta_data.set('sdir', sdir) self.meta_data.set('total_frames', frames) self.meta_data.set('mpi_procs', n_procs) self.meta_data.set('frames_per_process', f_per_p) frame_shape, b_per_f = self.set_bytes_per_frame() self.meta_data.set('bytes_per_frame', b_per_f) self.meta_data.set('bytes_per_process', b_per_f*f_per_p) self.meta_data.set('frame_shape', frame_shape) def __log_max_frames(self, mft, mfp, check=True): logging.debug("Setting max frames transfer for plugin %s to %d" % (self._plugin, mft)) logging.debug("Setting max frames process for plugin %s to %d" % (self._plugin, mfp)) self.meta_data.set('max_frames_process', mfp) if check: self.__check_distribution(mft) # (((total_frames/mft)/mpi_procs) % 1) def __check_distribution(self, mft): warn_threshold = 0.85 nprocs = self.meta_data.get('mpi_procs') nframes = self.meta_data.get('total_frames') temp = (((nframes/mft)/float(nprocs)) % 1) if temp != 0.0 and temp < warn_threshold: shape = self.meta_data.get('shape') sdir = self.meta_data.get('sdir') logging.warn('UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' + 'sdir %s, nprocs %s', shape, nframes, sdir, nprocs) def _set_padding_dict(self): if self.padding and not isinstance(self.padding, Padding): self.pad_dict = copy.deepcopy(self.padding) self.padding = Padding(self) for key in self.pad_dict.keys(): getattr(self.padding, key)(self.pad_dict[key]) def plugin_data_setup(self, pattern, nFrames, split=None): """ Setup the PluginData object. :param str pattern: A pattern name :param int nFrames: How many frames to process at a time. Choose from\ 'single', 'multiple', 'fixed_multiple' or an integer (an integer \ should only ever be passed in exceptional circumstances) """ self.__set_pattern(pattern) if isinstance(nFrames, list): nFrames, self._frame_limit = nFrames self.max_frames = nFrames self.split = split def plugin_data_transfer_setup(self, copy=None, calc=None): """ Set up the plugin data transfer frame parameters. If copy=pData (another PluginData instance) then copy """ chunks = \ self.data_obj.get_preview().get_starts_stops_steps(key='chunks') if not copy and not calc: mft, mft_shape, mfp = self._calculate_max_frames() elif calc: max_mft = calc.meta_data.get('max_frames_transfer') max_mfp = calc.meta_data.get('max_frames_process') max_nProc = int(np.ceil(max_mft/float(max_mfp))) nProc = max_nProc mfp = 1 if self.max_frames == 'single' else self.max_frames mft = nProc*mfp mft_shape = self._set_shape_transfer(self.__get_slice_size(mft)) elif copy: mft = copy._get_max_frames_transfer() mft_shape = self._set_shape_transfer(self.__get_slice_size(mft)) mfp = copy._get_max_frames_process() self.__set_max_frames(mft, mft_shape, mfp) if self._plugin and mft \ and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft): self._plugin.chunk = True self.__set_shape() def _calculate_max_frames(self): nFrames = self.max_frames self.__perform_checks(nFrames) td = self.data_obj._get_transport_data() mft, size_list = td._calc_max_frames_transfer(nFrames) self.meta_data.set('size_list', size_list) mfp = td._calc_max_frames_process(nFrames) if mft: mft_shape = self._set_shape_transfer(list(size_list)) return mft, mft_shape, mfp def __set_max_frames(self, mft, mft_shape, mfp): self.meta_data.set('max_frames_transfer', mft) self.meta_data.set('transfer_shape', mft_shape) self.meta_data.set('max_frames_process', mfp) self.__log_max_frames(mft, mfp) # Retain the shape if the first slice dimension has length 1 if mfp == 1 and self.max_frames == 'multiple': self._set_no_squeeze() def _get_plugin_data_size_params(self): nBytes = self.data_obj.get_itemsize() frame_shape = self.meta_data.get('frame_shape') total_frames = self.meta_data.get('total_frames') tbytes = nBytes*np.prod(frame_shape)*total_frames params = {'nBytes': nBytes, 'frame_shape': frame_shape, 'total_frames': total_frames, 'transfer_bytes': tbytes} return params def __perform_checks(self, nFrames): options = ['single', 'multiple'] if not isinstance(nFrames, int) and nFrames not in options: e_str = "The value of nFrames is not recognised. Please choose " "from 'single' and 'multiple' (or an integer in exceptional " "circumstances)." raise Exception(e_str) def get_frame_limit(self): return self._frame_limit def get_current_frame_idx(self): """ Returns the index of the frames currently being processed. """ global_index = self._plugin.get_global_frame_index() count = self._plugin.get_process_frames_counter() mfp = self.meta_data.get('max_frames_process') start = global_index[count]*mfp index = np.arange(start, start + mfp) nFrames = self.get_total_frames() index[index >= nFrames] = nFrames - 1 return index
class PluginData(object): def __init__(self, data_obj): self.data_obj = data_obj self.data_obj.set_plugin_data(self) self.meta_data = MetaData() self.padding = None # this flag determines which data is passed. If false then just the # data, if true then all data including dark and flat fields. self.selected_data = False self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] def get_total_frames(self): temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dir"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def set_pattern(self, name): pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set_meta_data("name", name) self.meta_data.set_meta_data("core_dir", pattern['core_dir']) self.set_slice_directions() def get_pattern_name(self): name = self.meta_data.get_meta_data("name") if name is not None: return name else: raise Exception("The pattern name has not been set.") def get_pattern(self): pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def set_shape(self): core_dir = self.get_core_directions() slice_dir = self.get_slice_directions() dirs = list(set(core_dir + (slice_dir[0],))) slice_idx = dirs.index(slice_dir[0]) shape = [] for core in set(core_dir): shape.append(self.data_obj.get_shape()[core]) self.set_core_shape(tuple(shape)) if self.get_frame_chunk() > 1: shape.insert(slice_idx, self.get_frame_chunk()) self.shape = tuple(shape) def get_shape(self): return self.shape def set_core_shape(self, shape): self.core_shape = shape def get_core_shape(self): return self.core_shape def check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir)+len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def set_slice_directions(self): slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dir'] self.meta_data.set_meta_data('slice_dir', slice_dirs) def get_slice_directions(self): return self.meta_data.get_meta_data('slice_dir') def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.get_core_directions() slice_dir = self.get_slice_directions()[0] return list(set(core_dirs + (slice_dir,))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = \ self.data_obj.find_axis_label_dimension(label, contains=contains) plugin_dims = self.get_core_directions() if self.get_frame_chunk() > 1: plugin_dims += (self.get_slice_directions()[0],) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice directions. The fastest changing slice direction will always be the first one. The input param is a tuple stating the desired order of slicing directions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_directions(self): core_dir = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['core_dir'] return core_dir def set_fixed_directions(self, dims, values): slice_dirs = self.get_slice_directions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set_meta_data("fixed_directions", dims) self.meta_data.set_meta_data("fixed_directions_values", values) self.set_slice_directions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) self.set_shape() def get_fixed_directions(self): fixed = [] values = [] if 'fixed_directions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get_meta_data("fixed_directions") values = self.meta_data.get_meta_data("fixed_directions_values") return [fixed, values] def set_frame_chunk(self, nFrames): # number of frames to process at a time self.meta_data.set_meta_data("nFrames", nFrames) def get_frame_chunk(self): return self.meta_data.get_meta_data("nFrames") def get_index(self, indices): shape = self.get_shape() nDims = len(shape) name = self.get_current_pattern_name() ddirs = self.get_data_patterns() core_dir = ddirs[name]["core_dir"] slice_dir = ddirs[name]["slice_dir"] self.check_dimensions(indices, core_dir, slice_dir, nDims) index = [slice(None)]*nDims count = 0 for tdir in slice_dir: index[tdir] = slice(indices[count], indices[count]+1, 1) count += 1 return tuple(index) def plugin_data_setup(self, pattern_name, chunk): self.set_pattern(pattern_name) self.set_frame_chunk(chunk) self.set_shape() def set_temp_pad_dict(self, pad_dict): self.meta_data.set_meta_data('temp_pad_dict', pad_dict) def get_temp_pad_dict(self): if 'temp_pad_dict' in self.meta_data.get_dictionary().keys(): return self.meta_data.get_dictionary()['temp_pad_dict'] def delete_temp_pad_dict(self): del self.meta_data.get_dictionary()['temp_pad_dict']