class Data(DataCreate): """The Data class dynamically inherits from transport specific data class and holds the data array, along with associated information. """ def __init__(self, name, exp): super(Data, self).__init__(name) self.meta_data = MetaData() self.pattern_list = self.__get_available_pattern_list() self.data_info = MetaData() self.__initialise_data_info(name) self._preview = Preview(self) self.exp = exp self.group_name = None self.group = None self._plugin_data_obj = None self.tomo_raw_obj = None self.backing_file = None self.data = None self.next_shape = None self.orig_shape = None def __initialise_data_info(self, name): """ Initialise entries in the data_info meta data. """ self.data_info.set_meta_data('name', name) self.data_info.set_meta_data('data_patterns', {}) self.data_info.set_meta_data('shape', None) self.data_info.set_meta_data('nDims', None) def _set_plugin_data(self, plugin_data_obj): """ Encapsulate a PluginData object. """ self._plugin_data_obj = plugin_data_obj def _clear_plugin_data(self): """ Set encapsulated PluginData object to None. """ self._plugin_data_obj = None def _get_plugin_data(self): """ Get encapsulated PluginData object. """ if self._plugin_data_obj is not None: return self._plugin_data_obj else: raise Exception("There is no PluginData object associated with " "the Data object.") def get_preview(self): """ Get the Preview instance associated with the data object """ return self._preview def _get_transport_data(self): """ Import the data transport mechanism :returns: instance of data transport :rtype: transport_data """ transport = self.exp.meta_data.get_meta_data("transport") transport_data = "savu.data.transport_data." + transport + \ "_transport_data" return cu.import_class(transport_data) def __deepcopy__(self, memo): """ Copy the data object. """ name = self.data_info.get_meta_data('name') return dsu._deepcopy_data_object(self, Data(name, self.exp)) def get_data_patterns(self): """ Get data patterns associated with this data object. :returns: A dictionary of associated patterns. :rtype: dict """ return self.data_info.get_meta_data('data_patterns') def set_shape(self, shape): """ Set the dataset shape. """ self.data_info.set_meta_data('shape', shape) self.__check_dims() def set_original_shape(self, shape): self.orig_shape = shape self.set_shape(shape) def get_shape(self): """ Get the dataset shape :returns: data shape :rtype: tuple """ shape = self.data_info.get_meta_data('shape') return shape def __check_dims(self): """ Check the ``shape`` and ``nDims`` entries in the data_info meta_data dictionary are equal. """ nDims = self.data_info.get_meta_data("nDims") shape = self.data_info.get_meta_data('shape') if nDims: if len(shape) != nDims: error_msg = ("The number of axis labels, %d, does not " "coincide with the number of data " "dimensions %d." % (nDims, len(shape))) raise Exception(error_msg) def get_name(self): """ Get data name. :returns: the name associated with the dataset :rtype: str """ return self.data_info.get_meta_data('name') def __get_available_pattern_list(self): """ Get a list of ALL pattern names that are currently allowed in the framework. """ pattern_list = dsu.get_available_pattern_types() return pattern_list def add_pattern(self, dtype, **kwargs): """ Add a pattern. :params str dtype: The *type* of pattern to add, which can be anything from the :const:`savu.data.data_structures.utils.pattern_list` :const:`pattern_list` :data:`savu.data.data_structures.utils.pattern_list` :data:`pattern_list`: :keyword tuple core_dir: Dimension indices of core dimensions :keyword tuple slice_dir: Dimension indices of slice dimensions """ if dtype in self.pattern_list: nDims = 0 for args in kwargs: nDims += len(kwargs[args]) self.data_info.set_meta_data(['data_patterns', dtype, args], kwargs[args]) self.__convert_pattern_directions(dtype) if self.get_shape(): diff = len(self.get_shape()) - nDims if diff: pattern = {dtype: self.get_data_patterns()[dtype]} self._add_extra_dims_to_patterns(pattern) nDims += diff try: if nDims != self.data_info.get_meta_data("nDims"): actualDims = self.data_info.get_meta_data('nDims') err_msg = ("The pattern %s has an incorrect number of " "dimensions: %d required but %d specified." % (dtype, actualDims, nDims)) raise Exception(err_msg) except KeyError: self.data_info.set_meta_data('nDims', nDims) else: raise Exception("The data pattern '%s'does not exist. Please " "choose from the following list: \n'%s'", dtype, str(self.pattern_list)) def add_volume_patterns(self, x, y, z): """ Adds 3D volume patterns :params int x: dimension to be associated with x-axis :params int y: dimension to be associated with y-axis :params int z: dimension to be associated with z-axis """ self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x)) self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y)) self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z)) def __get_dirs_for_volume(self, dim1, dim2, sdir): """ Calculate core_dir and slice_dir for a 3D volume pattern. """ all_dims = range(len(self.get_shape())) vol_dict = {} vol_dict['core_dir'] = (dim1, dim2) slice_dir = [sdir] # *** need to add this for other patterns for ddir in all_dims: if ddir not in [dim1, dim2, sdir]: slice_dir.append(ddir) vol_dict['slice_dir'] = tuple(slice_dir) return vol_dict def set_axis_labels(self, *args): """ Set the axis labels associated with each data dimension. :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\ a data_obj.meta_data entry, it will be output to the final .nxs file. """ self.data_info.set_meta_data('nDims', len(args)) axis_labels = [] for arg in args: try: axis = arg.split('.') axis_labels.append({axis[0]: axis[1]}) except: # data arrives here, but that may be an error pass self.data_info.set_meta_data('axis_labels', axis_labels) def get_axis_labels(self): """ Get axis labels. :returns: Axis labels :rtype: list(dict) """ return self.data_info.get_meta_data('axis_labels') def find_axis_label_dimension(self, name, contains=False): """ Get the dimension of the data associated with a particular axis_label. :param str name: The name of the axis_label :keyword bool contains: Set this flag to true if the name is only part of the axis_label name :returns: The associated axis number :rtype: int """ axis_labels = self.data_info.get_meta_data('axis_labels') for i in range(len(axis_labels)): if contains is True: for names in axis_labels[i].keys(): if name in names: return i else: if name in axis_labels[i].keys(): return i raise Exception("Cannot find the specifed axis label.") def _finalise_patterns(self): """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON patterns. """ check = 0 check += self.__check_pattern('SINOGRAM') check += self.__check_pattern('PROJECTION') if check is 2 and len(self.get_shape()) > 2: self.__set_main_axis('SINOGRAM') self.__set_main_axis('PROJECTION') elif check is 1: pass def __check_pattern(self, pattern_name): """ Check if a pattern exists. """ patterns = self.get_data_patterns() try: patterns[pattern_name] except KeyError: return 0 return 1 def __convert_pattern_directions(self, dtype): """ Replace negative indices in pattern kwargs. """ pattern = self.get_data_patterns()[dtype] if 'main_dir' in pattern.keys(): del pattern['main_dir'] nDims = sum([len(i) for i in pattern.values()]) for p in pattern: ddirs = pattern[p] pattern[p] = self.non_negative_directions(ddirs, nDims) def non_negative_directions(self, ddirs, nDims): """ Replace negative indexing values with positive counterparts. :params tuple(int) ddirs: data dimension indices :params int nDims: The number of data dimensions :returns: non-negative data dimension indices :rtype: tuple(int) """ index = [i for i in range(len(ddirs)) if ddirs[i] < 0] list_ddirs = list(ddirs) for i in index: list_ddirs[i] = nDims + ddirs[i] return tuple(list_ddirs) def __set_main_axis(self, pname): """ Set the ``main_dir`` pattern kwarg to the fastest changing dimension """ patterns = self.get_data_patterns() n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM' d1 = patterns[n1]['core_dir'] d2 = patterns[pname]['slice_dir'] tdir = set(d1).intersection(set(d2)) # this is required when a single sinogram exists in the mm case, and a # dimension is added via parameter tuning. if not tdir: tdir = [d2[0]] self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'], list(tdir)[0]) def get_axis_label_keys(self): """ Get axis_label names :returns: A list containing associated axis names for each dimension :rtype: list(str) """ axis_labels = self.data_info.get_meta_data('axis_labels') axis_label_keys = [] for labels in axis_labels: for key in labels.keys(): axis_label_keys.append(key) return axis_label_keys def _get_current_and_next_patterns(self, datasets_lists): """ Get the current and next patterns associated with a dataset throughout the processing chain. """ current_datasets = datasets_lists[0] patterns_list = [] for current_data in current_datasets['out_datasets']: current_name = current_data['name'] current_pattern = current_data['pattern'] next_pattern = self.__find_next_pattern(datasets_lists[1:], current_name) patterns_list.append({'current': current_pattern, 'next': next_pattern}) self.exp.meta_data.set_meta_data('current_and_next', patterns_list) def __find_next_pattern(self, datasets_lists, current_name): next_pattern = [] for next_data_list in datasets_lists: for next_data in next_data_list['in_datasets']: if next_data['name'] == current_name: next_pattern = next_data['pattern'] return next_pattern return next_pattern def get_slice_directions(self): """ Get pattern slice_dir of pattern currently associated with the dataset (if any). :returns: the slicing dimensions. :rtype: tuple(int) """ return self._get_plugin_data().get_slice_directions()
class Experiment(object): """ One instance of this class is created at the beginning of the processing chain and remains until the end. It holds the current data object and a dictionary containing all metadata. """ def __init__(self, options): self.meta_data = MetaData(options) self.meta_data_setup(options["process_file"]) self.index = {"in_data": {}, "out_data": {}} def meta_data_setup(self, process_file): self.meta_data.load_experiment_collection() self.meta_data.plugin_list = PluginList() self.meta_data.plugin_list.populate_plugin_list(process_file) def create_data_object(self, dtype, name, bases=[]): try: self.index[dtype][name] except KeyError: self.index[dtype][name] = Data(name) data_obj = self.index[dtype][name] bases.append(data_obj.get_transport_data(self.meta_data.get_meta_data("transport"))) data_obj.add_base_classes(bases) return self.index[dtype][name] def set_nxs_filename(self): name = self.index["in_data"].keys()[0] filename = os.path.basename(self.index["in_data"][name].backing_file.filename) filename = os.path.splitext(filename)[0] filename = os.path.join(self.meta_data.get_meta_data("out_path"), "%s_processed_%s.nxs" % (filename, time.strftime("%Y%m%d%H%M%S"))) self.meta_data.set_meta_data("nxs_filename", filename) def clear_data_objects(self): self.index["out_data"] = {} self.index["in_data"] = {} def clear_out_data_objects(self): self.index["out_data"] = {} def set_out_data_to_in(self): self.index["in_data"] = self.index["out_data"] self.index["out_data"] = {} def barrier(self): if self.meta_data.get_meta_data('mpi') is True: logging.debug("About to hit a barrier") MPI.COMM_WORLD.Barrier() logging.debug("Past the barrier") def log(self, log_tag, log_level=logging.DEBUG): """ Log the contents of the experiment at the specified level """ logging.log(log_level, "Experimental Parameters for %s", log_tag) for key, value in self.index["in_data"].iteritems(): logging.log(log_level, "in data (%s) shape = %s", key, value.get_shape()) for key, value in self.index["in_data"].iteritems(): logging.log(log_level, "out data (%s) shape = %s", key, value.get_shape())
class PluginData(object): """ The PluginData class contains plugin specific information about a Data object for the duration of a plugin. An instance of the class is encapsulated inside the Data object during the plugin run """ def __init__(self, data_obj, plugin=None): self.data_obj = data_obj self.data_obj._set_plugin_data(self) self.meta_data = MetaData() self.padding = None self.pad_dict = None # this flag determines which data is passed. If false then just the # data, if true then all data including dark and flat fields. self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] self._plugin = plugin self.fixed_dims = False self.end_pad = False self.split = None def get_total_frames(self): """ Get the total number of frames to process. :returns: Number of frames :rtype: int """ temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dir"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def __set_pattern(self, name): """ Set the pattern related information int the meta data dict. """ pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set_meta_data("name", name) self.meta_data.set_meta_data("core_dir", pattern['core_dir']) self.__set_slice_directions() def get_pattern_name(self): """ Get the pattern name. :returns: the pattern name :rtype: str """ name = self.meta_data.get_meta_data("name") if name is not None: return name else: raise Exception("The pattern name has not been set.") def get_pattern(self): """ Get the current pattern. :returns: dict of the pattern name against the pattern. :rtype: dict """ pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def __set_shape(self): """ Set the shape of the plugin data processing chunk. """ core_dir = self.get_core_directions() slice_dir = self.get_slice_directions() dirs = list(set(core_dir + (slice_dir[0], ))) slice_idx = dirs.index(slice_dir[0]) shape = [] for core in set(core_dir): shape.append(self.data_obj.get_shape()[core]) self.__set_core_shape(tuple(shape)) if self._get_frame_chunk() > 1: shape.insert(slice_idx, self._get_frame_chunk()) self.shape = tuple(shape) def get_shape(self): """ Get the shape of the plugin data to be processed each time. """ return self.shape def __set_core_shape(self, shape): """ Set the core shape to hold only the shape of the core dimensions """ self.core_shape = shape def get_core_shape(self): """ Get the shape of the core dimensions only. :returns: shape of core dimensions :rtype: tuple """ return self.core_shape def __check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir) + len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def __set_slice_directions(self): """ Set the slice directions in the pluginData meta data dictionary. """ slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dir'] self.meta_data.set_meta_data('slice_dir', slice_dirs) def get_slice_directions(self): """ Get the slice directions (slice_dir) of the dataset. """ return self.meta_data.get_meta_data('slice_dir') def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.get_core_directions() slice_dir = self.get_slice_directions()[0] return list(set(core_dirs + (slice_dir, ))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = \ self.data_obj.find_axis_label_dimension(label, contains=contains) plugin_dims = self.get_core_directions() if self._get_frame_chunk() > 1: plugin_dims += (self.get_slice_directions()[0], ) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice directions. The fastest changing slice direction will always be the first one stated in the pattern key ``slice_dir``. The input param is a tuple stating the desired order of slicing directions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_directions(self): """ Get the core data directions :returns: value associated with pattern key ``core_dir`` :rtype: tuple """ core_dir = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['core_dir'] return core_dir def set_fixed_directions(self, dims, values): """ Fix a data direction to the index in values list. :param list(int) dims: Directions to fix :param list(int) value: Index of fixed directions """ slice_dirs = self.get_slice_directions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set_meta_data("fixed_directions", dims) self.meta_data.set_meta_data("fixed_directions_values", values) self.__set_slice_directions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) self.__set_shape() def _get_fixed_directions(self): """ Get the fixed data directions and their indices :returns: Fixed directions and their associated values :rtype: list(list(int), list(int)) """ fixed = [] values = [] if 'fixed_directions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get_meta_data("fixed_directions") values = self.meta_data.get_meta_data("fixed_directions_values") return [fixed, values] def _get_data_slice_list(self, plist): """ Convert a plugin data slice list to a slice list for the whole dataset, i.e. add in any missing dimensions. """ nDims = len(self.get_shape()) all_dims = self.get_core_directions() + self.get_slice_directions() extra_dims = all_dims[nDims:] dlist = list(plist) for i in extra_dims: dlist.insert(i, slice(None)) return tuple(dlist) def _set_frame_chunk(self, nFrames): """ Set the number of frames to process at a time """ self.meta_data.set_meta_data("nFrames", nFrames) def _get_frame_chunk(self): """ Get the number of frames to be processes at a time. If the number of frames is not divisible by the previewing ``chunk`` value then amend the number of frames to gcd(frames, chunk) :returns: Number of frames to process :rtype: int """ if self._plugin and self._plugin.chunk > 1: frame_chunk = self.meta_data.get_meta_data("nFrames") chunk = self.data_obj.get_preview().get_starts_stops_steps( key='chunks')[self.get_slice_directions()[0]] self._set_frame_chunk(gcd(frame_chunk, chunk)) return self.meta_data.get_meta_data("nFrames") def plugin_data_setup(self, pattern_name, chunk, fixed=False, split=None): """ Setup the PluginData object. :param str pattern_name: A pattern name :param int chunk: Number of frames to process at a time :keyword bool fixed: setting fixed=True will ensure the plugin \ receives the same sized data array each time (padding if necessary) """ self.__set_pattern(pattern_name) chunks = \ self.data_obj.get_preview().get_starts_stops_steps(key='chunks') if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk): self._plugin.chunk = True self._set_frame_chunk(chunk) self.__set_shape() self.fixed_dims = fixed self.split = split
class PluginData(object): """ The PluginData class contains plugin specific information about a Data object for the duration of a plugin. An instance of the class is encapsulated inside the Data object during the plugin run """ def __init__(self, data_obj, plugin=None): self.data_obj = data_obj self.data_obj._set_plugin_data(self) self.meta_data = MetaData() self.padding = None # this flag determines which data is passed. If false then just the # data, if true then all data including dark and flat fields. self.selected_data = False self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] self._plugin = plugin def get_total_frames(self): """ Get the total number of frames to process. :returns: Number of frames :rtype: int """ temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dir"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def __set_pattern(self, name): """ Set the pattern related information int the meta data dict. """ pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set_meta_data("name", name) self.meta_data.set_meta_data("core_dir", pattern['core_dir']) self.__set_slice_directions() def get_pattern_name(self): """ Get the pattern name. :returns: the pattern name :rtype: str """ name = self.meta_data.get_meta_data("name") if name is not None: return name else: raise Exception("The pattern name has not been set.") def get_pattern(self): """ Get the current pattern. :returns: dict of the pattern name against the pattern. :rtype: dict """ pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def __set_shape(self): """ Set the shape of the plugin data processing chunk. """ core_dir = self.get_core_directions() slice_dir = self.get_slice_directions() dirs = list(set(core_dir + (slice_dir[0],))) slice_idx = dirs.index(slice_dir[0]) shape = [] for core in set(core_dir): shape.append(self.data_obj.get_shape()[core]) self.__set_core_shape(tuple(shape)) if self._get_frame_chunk() > 1: shape.insert(slice_idx, self._get_frame_chunk()) self.shape = tuple(shape) def get_shape(self): """ Get the shape of the plugin data to be processed each time. """ return self.shape def __set_core_shape(self, shape): """ Set the core shape to hold only the shape of the core dimensions """ self.core_shape = shape def get_core_shape(self): """ Get the shape of the core dimensions only. :returns: shape of core dimensions :rtype: tuple """ return self.core_shape def __check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir)+len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def __set_slice_directions(self): """ Set the slice directions in the pluginData meta data dictionary. """ slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dir'] self.meta_data.set_meta_data('slice_dir', slice_dirs) def get_slice_directions(self): """ Get the slice directions (slice_dir) of the dataset. """ return self.meta_data.get_meta_data('slice_dir') def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.get_core_directions() slice_dir = self.get_slice_directions()[0] return list(set(core_dirs + (slice_dir,))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = \ self.data_obj.find_axis_label_dimension(label, contains=contains) plugin_dims = self.get_core_directions() if self._get_frame_chunk() > 1: plugin_dims += (self.get_slice_directions()[0],) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice directions. The fastest changing slice direction will always be the first one stated in the pattern key ``slice_dir``. The input param is a tuple stating the desired order of slicing directions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_directions(self): """ Get the core data directions :returns: value associated with pattern key ``core_dir`` :rtype: tuple """ core_dir = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['core_dir'] return core_dir def set_fixed_directions(self, dims, values): """ Fix a data direction to the index in values list. :param list(int) dims: Directions to fix :param list(int) value: Index of fixed directions """ slice_dirs = self.get_slice_directions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set_meta_data("fixed_directions", dims) self.meta_data.set_meta_data("fixed_directions_values", values) self.__set_slice_directions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) self.__set_shape() def _get_fixed_directions(self): """ Get the fixed data directions and their indices :returns: Fixed directions and their associated values :rtype: list(list(int), list(int)) """ fixed = [] values = [] if 'fixed_directions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get_meta_data("fixed_directions") values = self.meta_data.get_meta_data("fixed_directions_values") return [fixed, values] def _set_frame_chunk(self, nFrames): """ Set the number of frames to process at a time """ self.meta_data.set_meta_data("nFrames", nFrames) def _get_frame_chunk(self): """ Get the number of frames to be processes at a time. If the number of frames is not divisible by the previewing ``chunk`` value then amend the number of frames to gcd(frames, chunk) :returns: Number of frames to process :rtype: int """ if self._plugin and self._plugin.chunk > 1: frame_chunk = self.meta_data.get_meta_data("nFrames") chunk = self.data_obj.get_preview().get_starts_stops_steps( key='chunks')[self.get_slice_directions()[0]] self._set_frame_chunk(gcd(frame_chunk, chunk)) return self.meta_data.get_meta_data("nFrames") def plugin_data_setup(self, pattern_name, chunk): """ Setup the PluginData object. :param str pattern_name: A pattern name :param int chunk: Number of frames to process at a time """ self.__set_pattern(pattern_name) chunks = \ self.data_obj.get_preview().get_starts_stops_steps(key='chunks') if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk): self._plugin.chunk = True self._set_frame_chunk(chunk) self.__set_shape()
class Data(Pattern): """ The Data class dynamically inherits from relevant data structure classes at runtime and holds the data array. """ def __init__(self, name): self.meta_data = MetaData() super(Data, self).__init__() self.name = name self.backing_file = None self.data = None def get_transport_data(self, transport): transport_data = "savu.data.transport_data." + transport + "_transport_data" return self.import_class(transport_data) def import_class(self, class_name): name = class_name mod = __import__(name) components = name.split('.') for comp in components[1:]: mod = getattr(mod, comp) temp = name.split('.')[-1] module2class = ''.join(x.capitalize() for x in temp.split('_')) return getattr(mod, module2class.split('.')[-1]) def __deepcopy__(self, memo): return self def add_base(self, ExtraBase): cls = self.__class__ self.__class__ = cls.__class__(cls.__name__, (cls, ExtraBase), {}) ExtraBase().__init__() def add_base_classes(self, bases): for base in bases: self.add_base(base) def set_shape(self, shape): self.meta_data.set_meta_data('shape', shape) def get_shape(self): shape = self.meta_data.get_meta_data('shape') try: dirs = self.meta_data.get_meta_data("fixed_directions") shape = list(shape) for ddir in dirs: shape[ddir] = 1 shape = tuple(shape) except KeyError: pass return shape def set_dist(self, dist): self.meta_data.set_meta_data('dist', dist) def get_dist(self): return self.meta_data.get_meta_data('dist') def set_data_params(self, pattern, chunk_size, **kwargs): self.set_current_pattern_name(pattern) self.set_nFrames(chunk_size)
class Experiment(object): """ One instance of this class is created at the beginning of the processing chain and remains until the end. It holds the current data object and a dictionary containing all metadata. """ def __init__(self, options): self.meta_data = MetaData(options) self.meta_data_setup(options["process_file"]) self.index = {"in_data": {}, "out_data": {}, "mapping": {}} self.nxs_file = None def get_meta_data(self): return self.meta_data def meta_data_setup(self, process_file): self.meta_data.plugin_list = PluginList() try: rtype = self.meta_data.get_meta_data('run_type') if rtype is 'test': self.meta_data.plugin_list.plugin_list = \ self.meta_data.get_meta_data('plugin_list') else: raise Exception('the run_type is unknown in Experiment class') except KeyError: self.meta_data.plugin_list.populate_plugin_list(process_file) def create_data_object(self, dtype, name, bases=[]): try: self.index[dtype][name] except KeyError: self.index[dtype][name] = Data(name, self) data_obj = self.index[dtype][name] bases.append(data_obj.get_transport_data()) data_obj.add_base_classes(bases) return self.index[dtype][name] def set_nxs_filename(self): name = self.index["in_data"].keys()[0] filename = os.path.basename(self.index["in_data"][name]. backing_file.filename) filename = os.path.splitext(filename)[0] filename = os.path.join(self.meta_data.get_meta_data("out_path"), "%s_processed_%s.nxs" % (filename, time.strftime("%Y%m%d%H%M%S"))) self.meta_data.set_meta_data("nxs_filename", filename) if self.meta_data.get_meta_data("mpi") is True: self.nxs_file = h5py.File(filename, 'w', driver='mpio', comm=MPI.COMM_WORLD) else: self.nxs_file = h5py.File(filename, 'w') def remove_dataset(self, data_obj): data_obj.close_file() del self.index["out_data"][data_obj.data_info.get_meta_data('name')] def clear_data_objects(self): self.index["out_data"] = {} self.index["in_data"] = {} def clear_out_data_objects(self): self.index["out_data"] = {} def merge_out_data_to_in(self): for key, data in self.index["out_data"].iteritems(): if data.remove is False: if key in self.index['in_data'].keys(): data.meta_data.set_dictionary( self.index['in_data'][key].meta_data.get_dictionary()) self.index['in_data'][key] = data self.index["out_data"] = {} def reorganise_datasets(self, out_data_objs, link_type): out_data_list = self.index["out_data"] self.close_unwanted_files(out_data_list) self.remove_unwanted_data(out_data_objs) self.barrier() self.copy_out_data_to_in_data(link_type) self.barrier() self.clear_out_data_objects() def remove_unwanted_data(self, out_data_objs): for out_objs in out_data_objs: if out_objs.remove is True: self.remove_dataset(out_objs) def close_unwanted_files(self, out_data_list): for out_objs in out_data_list: if out_objs in self.index["in_data"].keys(): self.index["in_data"][out_objs].close_file() def copy_out_data_to_in_data(self, link_type): for key in self.index["out_data"]: output = self.index["out_data"][key] output.save_data(link_type) self.index["in_data"][key] = copy.deepcopy(output) def set_all_datasets(self, name): data_names = [] for key in self.index["in_data"].keys(): data_names.append(key) return data_names def barrier(self, communicator=MPI.COMM_WORLD): comm_dict = {'comm': communicator} if self.meta_data.get_meta_data('mpi') is True: logging.debug("About to hit a barrier") comm_dict['comm'].Barrier() logging.debug("Past the barrier") def log(self, log_tag, log_level=logging.DEBUG): """ Log the contents of the experiment at the specified level """ logging.log(log_level, "Experimental Parameters for %s", log_tag) for key, value in self.index["in_data"].iteritems(): logging.log(log_level, "in data (%s) shape = %s", key, value.get_shape()) for key, value in self.index["in_data"].iteritems(): logging.log(log_level, "out data (%s) shape = %s", key, value.get_shape())
class Data(DataCreate): """The Data class dynamically inherits from transport specific data class and holds the data array, along with associated information. """ def __init__(self, name, exp): super(Data, self).__init__(name) self.meta_data = MetaData() self.pattern_list = self.__get_available_pattern_list() self.data_info = MetaData() self.__initialise_data_info(name) self._preview = Preview(self) self.exp = exp self.group_name = None self.group = None self._plugin_data_obj = None self.tomo_raw_obj = None self.backing_file = None self.data = None self.next_shape = None self.orig_shape = None def __initialise_data_info(self, name): """ Initialise entries in the data_info meta data. """ self.data_info.set_meta_data('name', name) self.data_info.set_meta_data('data_patterns', {}) self.data_info.set_meta_data('shape', None) self.data_info.set_meta_data('nDims', None) def _set_plugin_data(self, plugin_data_obj): """ Encapsulate a PluginData object. """ self._plugin_data_obj = plugin_data_obj def _clear_plugin_data(self): """ Set encapsulated PluginData object to None. """ self._plugin_data_obj = None def _get_plugin_data(self): """ Get encapsulated PluginData object. """ if self._plugin_data_obj is not None: return self._plugin_data_obj else: raise Exception("There is no PluginData object associated with " "the Data object.") def get_preview(self): """ Get the Preview instance associated with the data object """ return self._preview def _get_transport_data(self): """ Import the data transport mechanism :returns: instance of data transport :rtype: transport_data """ transport = self.exp.meta_data.get_meta_data("transport") transport_data = "savu.data.transport_data." + transport + \ "_transport_data" return cu.import_class(transport_data) def __deepcopy__(self, memo): """ Copy the data object. """ name = self.data_info.get_meta_data('name') return dsu._deepcopy_data_object(self, Data(name, self.exp)) def get_data_patterns(self): """ Get data patterns associated with this data object. :returns: A dictionary of associated patterns. :rtype: dict """ return self.data_info.get_meta_data('data_patterns') def set_shape(self, shape): """ Set the dataset shape. """ self.data_info.set_meta_data('shape', shape) self.__check_dims() def set_original_shape(self, shape): self.orig_shape = shape self.set_shape(shape) def get_shape(self): """ Get the dataset shape :returns: data shape :rtype: tuple """ shape = self.data_info.get_meta_data('shape') return shape def __check_dims(self): """ Check the ``shape`` and ``nDims`` entries in the data_info meta_data dictionary are equal. """ nDims = self.data_info.get_meta_data("nDims") shape = self.data_info.get_meta_data('shape') if nDims: if len(shape) != nDims: error_msg = ("The number of axis labels, %d, does not " "coincide with the number of data " "dimensions %d." % (nDims, len(shape))) raise Exception(error_msg) def get_name(self): """ Get data name. :returns: the name associated with the dataset :rtype: str """ return self.data_info.get_meta_data('name') def __get_available_pattern_list(self): """ Get a list of ALL pattern names that are currently allowed in the framework. """ pattern_list = dsu.get_available_pattern_types() return pattern_list def add_pattern(self, dtype, **kwargs): """ Add a pattern. :params str dtype: The *type* of pattern to add, which can be anything from the :const:`savu.data.data_structures.utils.pattern_list` :const:`pattern_list` :data:`savu.data.data_structures.utils.pattern_list` :data:`pattern_list`: :keyword tuple core_dir: Dimension indices of core dimensions :keyword tuple slice_dir: Dimension indices of slice dimensions """ if dtype in self.pattern_list: nDims = 0 for args in kwargs: nDims += len(kwargs[args]) self.data_info.set_meta_data(['data_patterns', dtype, args], kwargs[args]) self.__convert_pattern_directions(dtype) if self.get_shape(): diff = len(self.get_shape()) - nDims if diff: pattern = {dtype: self.get_data_patterns()[dtype]} self._add_extra_dims_to_patterns(pattern) nDims += diff try: if nDims != self.data_info.get_meta_data("nDims"): actualDims = self.data_info.get_meta_data('nDims') err_msg = ("The pattern %s has an incorrect number of " "dimensions: %d required but %d specified." % (dtype, actualDims, nDims)) raise Exception(err_msg) except KeyError: self.data_info.set_meta_data('nDims', nDims) else: raise Exception( "The data pattern '%s'does not exist. Please " "choose from the following list: \n'%s'", dtype, str(self.pattern_list)) def add_volume_patterns(self, x, y, z): """ Adds 3D volume patterns :params int x: dimension to be associated with x-axis :params int y: dimension to be associated with y-axis :params int z: dimension to be associated with z-axis """ self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x)) self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y)) self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z)) def __get_dirs_for_volume(self, dim1, dim2, sdir): """ Calculate core_dir and slice_dir for a 3D volume pattern. """ all_dims = range(len(self.get_shape())) vol_dict = {} vol_dict['core_dir'] = (dim1, dim2) slice_dir = [sdir] # *** need to add this for other patterns for ddir in all_dims: if ddir not in [dim1, dim2, sdir]: slice_dir.append(ddir) vol_dict['slice_dir'] = tuple(slice_dir) return vol_dict def set_axis_labels(self, *args): """ Set the axis labels associated with each data dimension. :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\ a data_obj.meta_data entry, it will be output to the final .nxs file. """ self.data_info.set_meta_data('nDims', len(args)) axis_labels = [] for arg in args: try: axis = arg.split('.') axis_labels.append({axis[0]: axis[1]}) except: # data arrives here, but that may be an error pass self.data_info.set_meta_data('axis_labels', axis_labels) def get_axis_labels(self): """ Get axis labels. :returns: Axis labels :rtype: list(dict) """ return self.data_info.get_meta_data('axis_labels') def find_axis_label_dimension(self, name, contains=False): """ Get the dimension of the data associated with a particular axis_label. :param str name: The name of the axis_label :keyword bool contains: Set this flag to true if the name is only part of the axis_label name :returns: The associated axis number :rtype: int """ axis_labels = self.data_info.get_meta_data('axis_labels') for i in range(len(axis_labels)): if contains is True: for names in axis_labels[i].keys(): if name in names: return i else: if name in axis_labels[i].keys(): return i raise Exception("Cannot find the specifed axis label.") def _finalise_patterns(self): """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON patterns. """ check = 0 check += self.__check_pattern('SINOGRAM') check += self.__check_pattern('PROJECTION') if check is 2 and len(self.get_shape()) > 2: self.__set_main_axis('SINOGRAM') self.__set_main_axis('PROJECTION') elif check is 1: pass def __check_pattern(self, pattern_name): """ Check if a pattern exists. """ patterns = self.get_data_patterns() try: patterns[pattern_name] except KeyError: return 0 return 1 def __convert_pattern_directions(self, dtype): """ Replace negative indices in pattern kwargs. """ pattern = self.get_data_patterns()[dtype] if 'main_dir' in pattern.keys(): del pattern['main_dir'] nDims = sum([len(i) for i in pattern.values()]) for p in pattern: ddirs = pattern[p] pattern[p] = self.non_negative_directions(ddirs, nDims) def non_negative_directions(self, ddirs, nDims): """ Replace negative indexing values with positive counterparts. :params tuple(int) ddirs: data dimension indices :params int nDims: The number of data dimensions :returns: non-negative data dimension indices :rtype: tuple(int) """ index = [i for i in range(len(ddirs)) if ddirs[i] < 0] list_ddirs = list(ddirs) for i in index: list_ddirs[i] = nDims + ddirs[i] return tuple(list_ddirs) def __set_main_axis(self, pname): """ Set the ``main_dir`` pattern kwarg to the fastest changing dimension """ patterns = self.get_data_patterns() n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM' d1 = patterns[n1]['core_dir'] d2 = patterns[pname]['slice_dir'] tdir = set(d1).intersection(set(d2)) # this is required when a single sinogram exists in the mm case, and a # dimension is added via parameter tuning. if not tdir: tdir = [d2[0]] self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'], list(tdir)[0]) def get_axis_label_keys(self): """ Get axis_label names :returns: A list containing associated axis names for each dimension :rtype: list(str) """ axis_labels = self.data_info.get_meta_data('axis_labels') axis_label_keys = [] for labels in axis_labels: for key in labels.keys(): axis_label_keys.append(key) return axis_label_keys def _get_current_and_next_patterns(self, datasets_lists): """ Get the current and next patterns associated with a dataset throughout the processing chain. """ current_datasets = datasets_lists[0] patterns_list = [] for current_data in current_datasets['out_datasets']: current_name = current_data['name'] current_pattern = current_data['pattern'] next_pattern = self.__find_next_pattern(datasets_lists[1:], current_name) patterns_list.append({ 'current': current_pattern, 'next': next_pattern }) self.exp.meta_data.set_meta_data('current_and_next', patterns_list) def __find_next_pattern(self, datasets_lists, current_name): next_pattern = [] for next_data_list in datasets_lists: for next_data in next_data_list['in_datasets']: if next_data['name'] == current_name: next_pattern = next_data['pattern'] return next_pattern return next_pattern def get_slice_directions(self): """ Get pattern slice_dir of pattern currently associated with the dataset (if any). :returns: the slicing dimensions. :rtype: tuple(int) """ return self._get_plugin_data().get_slice_directions()
class Experiment(object): """ One instance of this class is created at the beginning of the processing chain and remains until the end. It holds the current data object and a dictionary containing all metadata. """ def __init__(self, options): self.meta_data = MetaData(options) self.__meta_data_setup(options["process_file"]) self.index = {"in_data": {}, "out_data": {}, "mapping": {}} self.nxs_file = None def get_meta_data(self, entry): """ Get the meta data dictionary. """ return self.meta_data.get_meta_data(entry) def __meta_data_setup(self, process_file): self.meta_data.plugin_list = PluginList() try: rtype = self.meta_data.get_meta_data('run_type') if rtype is 'test': self.meta_data.plugin_list.plugin_list = \ self.meta_data.get_meta_data('plugin_list') else: raise Exception('the run_type is unknown in Experiment class') except KeyError: self.meta_data.plugin_list._populate_plugin_list(process_file) def create_data_object(self, dtype, name): """ Create a data object. Plugin developers should apply this method in loaders only. :params str dtype: either "in_data" or "out_data". """ bases = [] try: self.index[dtype][name] except KeyError: self.index[dtype][name] = Data(name, self) data_obj = self.index[dtype][name] bases.append(data_obj._get_transport_data()) cu.add_base_classes(data_obj, bases) return self.index[dtype][name] def _set_nxs_filename(self): folder = self.meta_data.get_meta_data('out_path') fname = os.path.basename(folder.split('_')[-1]) + '_processed.nxs' filename = os.path.join(folder, fname) self.meta_data.set_meta_data("nxs_filename", filename) if self.meta_data.get_meta_data("mpi") is True: self.nxs_file = h5py.File(filename, 'w', driver='mpio', comm=MPI.COMM_WORLD) else: self.nxs_file = h5py.File(filename, 'w') def __remove_dataset(self, data_obj): data_obj._close_file() del self.index["out_data"][data_obj.data_info.get_meta_data('name')] def _clear_data_objects(self): self.index["out_data"] = {} self.index["in_data"] = {} def _merge_out_data_to_in(self): for key, data in self.index["out_data"].iteritems(): if data.remove is False: if key in self.index['in_data'].keys(): data.meta_data._set_dictionary( self.index['in_data'][key].meta_data.get_dictionary()) self.index['in_data'][key] = data self.index["out_data"] = {} def _reorganise_datasets(self, out_data_objs, link_type): out_data_list = self.index["out_data"] self.__close_unwanted_files(out_data_list) self.__remove_unwanted_data(out_data_objs) self._barrier() self.__copy_out_data_to_in_data(link_type) self._barrier() self.index['out_data'] = {} def __remove_unwanted_data(self, out_data_objs): for out_objs in out_data_objs: if out_objs.remove is True: self.__remove_dataset(out_objs) def __close_unwanted_files(self, out_data_list): for out_objs in out_data_list: if out_objs in self.index["in_data"].keys(): self.index["in_data"][out_objs]._close_file() def __copy_out_data_to_in_data(self, link_type): for key in self.index["out_data"]: output = self.index["out_data"][key] output._save_data(link_type) self.index["in_data"][key] = copy.deepcopy(output) def _set_all_datasets(self, name): data_names = [] for key in self.index["in_data"].keys(): data_names.append(key) return data_names def _barrier(self, communicator=MPI.COMM_WORLD): comm_dict = {'comm': communicator} if self.meta_data.get_meta_data('mpi') is True: logging.debug("About to hit a _barrier %s", comm_dict) comm_dict['comm'].barrier() logging.debug("Past the _barrier") def log(self, log_tag, log_level=logging.DEBUG): """ Log the contents of the experiment at the specified level """ logging.log(log_level, "Experimental Parameters for %s", log_tag) for key, value in self.index["in_data"].iteritems(): logging.log(log_level, "in data (%s) shape = %s", key, value.get_shape()) for key, value in self.index["in_data"].iteritems(): logging.log(log_level, "out data (%s) shape = %s", key, value.get_shape())
class PluginData(object): def __init__(self, data_obj): self.data_obj = data_obj self.data_obj.set_plugin_data(self) self.meta_data = MetaData() self.padding = None # this flag determines which data is passed. If false then just the # data, if true then all data including dark and flat fields. self.selected_data = False self.shape = None self.core_shape = None self.multi_params = {} self.extra_dims = [] def get_total_frames(self): temp = 1 slice_dir = \ self.data_obj.get_data_patterns()[ self.get_pattern_name()]["slice_dir"] for tslice in slice_dir: temp *= self.data_obj.get_shape()[tslice] return temp def set_pattern(self, name): pattern = self.data_obj.get_data_patterns()[name] self.meta_data.set_meta_data("name", name) self.meta_data.set_meta_data("core_dir", pattern['core_dir']) self.set_slice_directions() def get_pattern_name(self): name = self.meta_data.get_meta_data("name") if name is not None: return name else: raise Exception("The pattern name has not been set.") def get_pattern(self): pattern_name = self.get_pattern_name() return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]} def set_shape(self): core_dir = self.get_core_directions() slice_dir = self.get_slice_directions() dirs = list(set(core_dir + (slice_dir[0],))) slice_idx = dirs.index(slice_dir[0]) shape = [] for core in set(core_dir): shape.append(self.data_obj.get_shape()[core]) self.set_core_shape(tuple(shape)) if self.get_frame_chunk() > 1: shape.insert(slice_idx, self.get_frame_chunk()) self.shape = tuple(shape) def get_shape(self): return self.shape def set_core_shape(self, shape): self.core_shape = shape def get_core_shape(self): return self.core_shape def check_dimensions(self, indices, core_dir, slice_dir, nDims): if len(indices) is not len(slice_dir): sys.exit("Incorrect number of indices specified when accessing " "data.") if (len(core_dir)+len(slice_dir)) is not nDims: sys.exit("Incorrect number of data dimensions specified.") def set_slice_directions(self): slice_dirs = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['slice_dir'] self.meta_data.set_meta_data('slice_dir', slice_dirs) def get_slice_directions(self): return self.meta_data.get_meta_data('slice_dir') def get_slice_dimension(self): """ Return the position of the slice dimension in relation to the data handed to the plugin. """ core_dirs = self.get_core_directions() slice_dir = self.get_slice_directions()[0] return list(set(core_dirs + (slice_dir,))).index(slice_dir) def get_data_dimension_by_axis_label(self, label, contains=False): """ Return the dimension of the data in the plugin that has the specified axis label. """ label_dim = \ self.data_obj.find_axis_label_dimension(label, contains=contains) plugin_dims = self.get_core_directions() if self.get_frame_chunk() > 1: plugin_dims += (self.get_slice_directions()[0],) return list(set(plugin_dims)).index(label_dim) def set_slicing_order(self, order): """ Reorder the slice directions. The fastest changing slice direction will always be the first one. The input param is a tuple stating the desired order of slicing directions relative to the current order. """ slice_dirs = self.get_slice_directions() if len(slice_dirs) < len(order): raise Exception("Incorrect number of dimensions specifed.") ordered = [slice_dirs[o] for o in order] remaining = [s for s in slice_dirs if s not in ordered] new_slice_dirs = tuple(ordered + remaining) self.get_current_pattern()['slice_dir'] = new_slice_dirs def get_core_directions(self): core_dir = self.data_obj.get_data_patterns()[ self.get_pattern_name()]['core_dir'] return core_dir def set_fixed_directions(self, dims, values): slice_dirs = self.get_slice_directions() if set(dims).difference(set(slice_dirs)): raise Exception("You are trying to fix a direction that is not" " a slicing direction") self.meta_data.set_meta_data("fixed_directions", dims) self.meta_data.set_meta_data("fixed_directions_values", values) self.set_slice_directions() shape = list(self.data_obj.get_shape()) for dim in dims: shape[dim] = 1 self.data_obj.set_shape(tuple(shape)) self.set_shape() def get_fixed_directions(self): fixed = [] values = [] if 'fixed_directions' in self.meta_data.get_dictionary(): fixed = self.meta_data.get_meta_data("fixed_directions") values = self.meta_data.get_meta_data("fixed_directions_values") return [fixed, values] def set_frame_chunk(self, nFrames): # number of frames to process at a time self.meta_data.set_meta_data("nFrames", nFrames) def get_frame_chunk(self): return self.meta_data.get_meta_data("nFrames") def get_index(self, indices): shape = self.get_shape() nDims = len(shape) name = self.get_current_pattern_name() ddirs = self.get_data_patterns() core_dir = ddirs[name]["core_dir"] slice_dir = ddirs[name]["slice_dir"] self.check_dimensions(indices, core_dir, slice_dir, nDims) index = [slice(None)]*nDims count = 0 for tdir in slice_dir: index[tdir] = slice(indices[count], indices[count]+1, 1) count += 1 return tuple(index) def plugin_data_setup(self, pattern_name, chunk): self.set_pattern(pattern_name) self.set_frame_chunk(chunk) self.set_shape() def set_temp_pad_dict(self, pad_dict): self.meta_data.set_meta_data('temp_pad_dict', pad_dict) def get_temp_pad_dict(self): if 'temp_pad_dict' in self.meta_data.get_dictionary().keys(): return self.meta_data.get_dictionary()['temp_pad_dict'] def delete_temp_pad_dict(self): del self.meta_data.get_dictionary()['temp_pad_dict']
class Data(object): """ The Data class dynamically inherits from relevant data structure classes at runtime and holds the data array. """ def __init__(self, name, exp): self.meta_data = MetaData() self.pattern_list = self.set_available_pattern_list() self.data_info = MetaData() self.initialise_data_info(name) self.exp = exp self.group_name = None self.group = None self._plugin_data_obj = None self.tomo_raw_obj = None self.data_mapping = None self.variable_length_flag = False self.dtype = None self.remove = False self.backing_file = None self.data = None self.next_shape = None self.mapping = None self.map_dim = [] self.revert_shape = None def initialise_data_info(self, name): self.data_info.set_meta_data('name', name) self.data_info.set_meta_data('data_patterns', {}) self.data_info.set_meta_data('shape', None) self.data_info.set_meta_data('nDims', None) def set_plugin_data(self, plugin_data_obj): self._plugin_data_obj = plugin_data_obj def clear_plugin_data(self): self._plugin_data_obj = None def get_plugin_data(self): if self._plugin_data_obj is not None: return self._plugin_data_obj else: raise Exception("There is no PluginData object associated with " "the Data object.") def set_tomo_raw(self, tomo_raw_obj): self.tomo_raw_obj = tomo_raw_obj def clear_tomo_raw(self): self.tomo_raw_obj = None def get_tomo_raw(self): if self.tomo_raw_obj is not None: return self.tomo_raw_obj else: raise Exception("There is no TomoRaw object associated with " "the Data object.") def get_transport_data(self): transport = self.exp.meta_data.get_meta_data("transport") "SETTING UP THE TRANSPORT DATA" transport_data = "savu.data.transport_data." + transport + \ "_transport_data" return import_class(transport_data) def __deepcopy__(self, memo): name = self.data_info.get_meta_data('name') new_obj = Data(name, self.exp) new_obj.add_base_classes(self.get_transport_data()) new_obj.meta_data = self.meta_data new_obj.pattern_list = copy.deepcopy(self.pattern_list) new_obj.data_info = copy.deepcopy(self.data_info) new_obj.exp = self.exp new_obj._plugin_data_obj = self._plugin_data_obj new_obj.tomo_raw_obj = self.tomo_raw_obj new_obj.data_mapping = self.data_mapping new_obj.variable_length_flag = copy.deepcopy(self.variable_length_flag) new_obj.dtype = copy.deepcopy(self.dtype) new_obj.remove = copy.deepcopy(self.remove) new_obj.group_name = self.group_name new_obj.group = self.group new_obj.backing_file = self.backing_file new_obj.data = self.data new_obj.next_shape = copy.deepcopy(self.next_shape) new_obj.mapping = copy.deepcopy(self.mapping) new_obj.map_dim = copy.deepcopy(self.map_dim) new_obj.revert_shape = copy.deepcopy(self.map_dim) return new_obj def add_base(self, ExtraBase): cls = self.__class__ self.__class__ = cls.__class__(cls.__name__, (cls, ExtraBase), {}) ExtraBase().__init__() def add_base_classes(self, bases): bases = bases if isinstance(bases, list) else [bases] for base in bases: self.add_base(base) def external_link(self): return h5py.ExternalLink(self.backing_file.filename, self.group_name) def create_dataset(self, *args, **kwargs): """ Set up required information when an output dataset has been created by a plugin """ self.dtype = kwargs.get('dtype', np.float32) # remove from the plugin chain self.remove = kwargs.get('remove', False) if len(args) is 1: self.copy_dataset(args[0], removeDim=kwargs.get('removeDim', [])) if args[0].tomo_raw_obj: self.set_tomo_raw(copy.deepcopy(args[0].get_tomo_raw())) self.get_tomo_raw().data_obj = self else: try: shape = kwargs['shape'] self.create_axis_labels(kwargs['axis_labels']) except KeyError: raise Exception("Please state axis_labels and shape when " "creating a new dataset") self.set_new_dataset_shape(shape) if 'patterns' in kwargs: self.copy_patterns(kwargs['patterns']) self.set_preview([]) def set_new_dataset_shape(self, shape): if isinstance(shape, Data): self.find_and_set_shape(shape) elif type(shape) is dict: self.set_variable_flag() self.set_shape((shape[shape.keys()[0]] + ('var',))) else: pData = self.get_plugin_data() self.set_shape(shape + tuple(pData.extra_dims)) if 'var' in shape: self.set_variable_flag() def copy_patterns(self, copy_data): if isinstance(copy_data, Data): patterns = copy_data.get_data_patterns() else: data = copy_data.keys()[0] pattern_list = copy_data[data] all_patterns = data.get_data_patterns() if len(pattern_list[0].split('.')) > 1: patterns = self.copy_patterns_removing_dimensions( pattern_list, all_patterns, len(data.get_shape())) else: patterns = {} for pattern in pattern_list: patterns[pattern] = all_patterns[pattern] self.set_data_patterns(patterns) def copy_patterns_removing_dimensions(self, pattern_list, all_patterns, nDims): copy_patterns = {} for new_pattern in pattern_list: name, all_dims = new_pattern.split('.') if name is '*': copy_patterns = all_patterns else: copy_patterns[name] = all_patterns[name] dims = tuple(map(int, all_dims.split(','))) dims = self.non_negative_directions(dims, nDims=nDims) patterns = {} for name, pattern_dict in copy_patterns.iteritems(): empty_flag = False for ddir in pattern_dict: s_dims = self.non_negative_directions( pattern_dict[ddir], nDims=nDims) new_dims = tuple([sd for sd in s_dims if sd not in dims]) pattern_dict[ddir] = new_dims if not new_dims: empty_flag = True if empty_flag is False: patterns[name] = pattern_dict return patterns def copy_dataset(self, copy_data, **kwargs): if copy_data.mapping: # copy label entries from meta data map_data = self.exp.index['mapping'][copy_data.get_name()] map_mData = map_data.meta_data map_axis_labels = map_data.data_info.get_meta_data('axis_labels') for axis_label in map_axis_labels: if axis_label.keys()[0] in map_mData.get_dictionary().keys(): map_label = map_mData.get_meta_data(axis_label.keys()[0]) copy_data.meta_data.set_meta_data(axis_label.keys()[0], map_label) copy_data = map_data patterns = copy.deepcopy(copy_data.get_data_patterns()) self.copy_labels(copy_data) self.find_and_set_shape(copy_data) self.set_data_patterns(patterns) def create_axis_labels(self, axis_labels): if isinstance(axis_labels, Data): self.copy_labels(axis_labels) elif isinstance(axis_labels, dict): data = axis_labels.keys()[0] self.copy_labels(data) self.amend_axis_labels(axis_labels[data]) else: self.set_axis_labels(*axis_labels) # if parameter tuning if self.get_plugin_data().multi_params_dict: self.add_extra_dims_labels() def copy_labels(self, copy_data): nDims = copy.copy(copy_data.data_info.get_meta_data('nDims')) axis_labels = \ copy.copy(copy_data.data_info.get_meta_data('axis_labels')) self.data_info.set_meta_data('nDims', nDims) self.data_info.set_meta_data('axis_labels', axis_labels) # if parameter tuning if self.get_plugin_data().multi_params_dict: self.add_extra_dims_labels() def add_extra_dims_labels(self): params_dict = self.get_plugin_data().multi_params_dict # add multi_params axis labels from dictionary in pData nDims = self.data_info.get_meta_data('nDims') axis_labels = self.data_info.get_meta_data('axis_labels') axis_labels.extend([0]*len(params_dict)) for key, value in params_dict.iteritems(): title = value['label'].encode('ascii', 'ignore') name, unit = title.split('.') axis_labels[nDims + key] = {name: unit} # add parameter values to the meta_data self.meta_data.set_meta_data(name, np.array(value['values'])) self.data_info.set_meta_data('nDims', nDims + len(self.extra_dims)) self.data_info.set_meta_data('axis_labels', axis_labels) def amend_axis_labels(self, *args): axis_labels = self.data_info.get_meta_data('axis_labels') removed_dims = 0 for arg in args[0]: label = arg.split('.') if len(label) is 1: del axis_labels[int(label[0]) + removed_dims] removed_dims += 1 self.data_info.set_meta_data( 'nDims', self.data_info.get_meta_data('nDims') - 1) else: if int(label[0]) < 0: axis_labels[int(label[0]) + removed_dims] = \ {label[1]: label[2]} else: if int(label[0]) < self.data_info.get_meta_data('nDims'): axis_labels[int(label[0])] = {label[1]: label[2]} else: axis_labels.insert(int(label[0]), {label[1]: label[2]}) def set_data_patterns(self, patterns): self.add_extra_dims_to_patterns(patterns) self.data_info.set_meta_data('data_patterns', patterns) def add_extra_dims_to_patterns(self, patterns): all_dims = range(len(self.get_shape())) for p in patterns: pDims = patterns[p]['core_dir'] + patterns[p]['slice_dir'] for dim in all_dims: if dim not in pDims: patterns[p]['slice_dir'] += (dim,) def get_data_patterns(self): return self.data_info.get_meta_data('data_patterns') def set_shape(self, shape): self.data_info.set_meta_data('shape', shape) self.check_dims() def get_shape(self): shape = self.data_info.get_meta_data('shape') return shape def set_preview(self, preview_list, **kwargs): self.revert_shape = kwargs.get('revert', self.revert_shape) shape = self.get_shape() if preview_list: starts, stops, steps, chunks = \ self.get_preview_indices(preview_list) shape_change = True else: starts, stops, steps, chunks = \ [[0]*len(shape), shape, [1]*len(shape), [1]*len(shape)] shape_change = False self.set_starts_stops_steps(starts, stops, steps, chunks, shapeChange=shape_change) def unset_preview(self): self.set_preview([]) self.set_shape(self.revert_shape) self.revert_shape = None def set_starts_stops_steps(self, starts, stops, steps, chunks, shapeChange=True): self.data_info.set_meta_data('starts', starts) self.data_info.set_meta_data('stops', stops) self.data_info.set_meta_data('steps', steps) self.data_info.set_meta_data('chunks', chunks) if shapeChange or self.mapping: self.set_reduced_shape(starts, stops, steps, chunks) def get_preview_indices(self, preview_list): starts = len(preview_list)*[None] stops = len(preview_list)*[None] steps = len(preview_list)*[None] chunks = len(preview_list)*[None] for i in range(len(preview_list)): starts[i], stops[i], steps[i], chunks[i] = \ self.convert_indices(preview_list[i].split(':'), i) return starts, stops, steps, chunks def convert_indices(self, idx, dim): shape = self.get_shape() mid = shape[dim]/2 end = shape[dim] if self.mapping: map_shape = self.exp.index['mapping'][self.get_name()].get_shape() midmap = map_shape[dim]/2 endmap = map_shape[dim] idx = [eval(equ) for equ in idx] idx = [idx[i] if idx[i] > -1 else shape[dim]+1+idx[i] for i in range(len(idx))] return idx def get_starts_stops_steps(self): starts = self.data_info.get_meta_data('starts') stops = self.data_info.get_meta_data('stops') steps = self.data_info.get_meta_data('steps') chunks = self.data_info.get_meta_data('chunks') return starts, stops, steps, chunks def set_reduced_shape(self, starts, stops, steps, chunks): orig_shape = self.get_shape() self.data_info.set_meta_data('orig_shape', orig_shape) new_shape = [] for dim in range(len(starts)): new_shape.append(np.prod((self.get_slice_dir_matrix(dim).shape))) self.set_shape(tuple(new_shape)) # reduce shape of mapping data if it exists if self.mapping: self.set_mapping_reduced_shape(orig_shape, new_shape, self.get_name()) def set_mapping_reduced_shape(self, orig_shape, new_shape, name): map_obj = self.exp.index['mapping'][name] map_shape = np.array(map_obj.get_shape()) diff = np.array(orig_shape) - map_shape[:len(orig_shape)] not_map_dim = np.where(diff == 0)[0] map_dim = np.where(diff != 0)[0] self.map_dim = map_dim map_obj.data_info.set_meta_data('full_map_dim_len', map_shape[map_dim]) map_shape[not_map_dim] = np.array(new_shape)[not_map_dim] # assuming only one extra dimension added for now starts, stops, steps, chunks = self.get_starts_stops_steps() start = starts[map_dim] % map_shape[map_dim] stop = min(stops[map_dim], map_shape[map_dim]) temp = len(np.arange(start, stop, steps[map_dim]))*chunks[map_dim] map_shape[len(orig_shape)] = np.ceil(new_shape[map_dim]/temp) map_shape[map_dim] = new_shape[map_dim]/map_shape[len(orig_shape)] map_obj.data_info.set_meta_data('map_dim_len', map_shape[map_dim]) self.exp.index['mapping'][name].set_shape(tuple(map_shape)) def find_and_set_shape(self, data): pData = self.get_plugin_data() new_shape = copy.copy(data.get_shape()) + tuple(pData.extra_dims) self.set_shape(new_shape) def set_variable_flag(self): self.variable_length_flag = True def get_variable_flag(self): return self.variable_length_flag def set_variable_array_length(self, var_size): var_size = var_size if isinstance(var_size, list) else [var_size] shape = list(self.get_shape()) count = 0 for i in range(len(shape)): if isinstance(shape[i], str): shape[i] = var_size[count] count += 1 self.next_shape = tuple(shape) def check_dims(self): nDims = self.data_info.get_meta_data("nDims") shape = self.data_info.get_meta_data('shape') if nDims: if self.get_variable_flag() is False: if len(shape) != nDims: error_msg = ("The number of axis labels, %d, does not " "coincide with the number of data " "dimensions %d." % (nDims, len(shape))) raise Exception(error_msg) def set_name(self, name): self.data_info.set_meta_data('name', name) def get_name(self): return self.data_info.get_meta_data('name') def set_data_params(self, pattern, chunk_size, **kwargs): self.set_current_pattern_name(pattern) self.set_nFrames(chunk_size) def set_available_pattern_list(self): pattern_list = ["SINOGRAM", "PROJECTION", "VOLUME_YZ", "VOLUME_XZ", "VOLUME_XY", "VOLUME_3D", "SPECTRUM", "DIFFRACTION", "CHANNEL", "SPECTRUM_STACK", "PROJECTION_STACK", "METADATA"] return pattern_list def add_pattern(self, dtype, **kwargs): if dtype in self.pattern_list: nDims = 0 for args in kwargs: nDims += len(kwargs[args]) self.data_info.set_meta_data(['data_patterns', dtype, args], kwargs[args]) self.convert_pattern_directions(dtype) if self.get_shape(): diff = len(self.get_shape()) - nDims if diff: pattern = {dtype: self.get_data_patterns()[dtype]} self.add_extra_dims_to_patterns(pattern) nDims += diff try: if nDims != self.data_info.get_meta_data("nDims"): actualDims = self.data_info.get_meta_data('nDims') err_msg = ("The pattern %s has an incorrect number of " "dimensions: %d required but %d specified." % (dtype, actualDims, nDims)) raise Exception(err_msg) except KeyError: self.data_info.set_meta_data('nDims', nDims) else: raise Exception("The data pattern '%s'does not exist. Please " "choose from the following list: \n'%s'", dtype, str(self.pattern_list)) def add_volume_patterns(self, x, y, z): self.add_pattern("VOLUME_YZ", **self.get_dirs_for_volume(y, z, x)) self.add_pattern("VOLUME_XZ", **self.get_dirs_for_volume(x, z, y)) self.add_pattern("VOLUME_XY", **self.get_dirs_for_volume(x, y, z)) def get_dirs_for_volume(self, dim1, dim2, sdir): all_dims = range(len(self.get_shape())) vol_dict = {} vol_dict['core_dir'] = (dim1, dim2) slice_dir = [sdir] # *** need to add this for other patterns for ddir in all_dims: if ddir not in [dim1, dim2, sdir]: slice_dir.append(ddir) vol_dict['slice_dir'] = tuple(slice_dir) return vol_dict def set_axis_labels(self, *args): self.data_info.set_meta_data('nDims', len(args)) axis_labels = [] for arg in args: try: axis = arg.split('.') axis_labels.append({axis[0]: axis[1]}) except: # data arrives here, but that may be an error pass self.data_info.set_meta_data('axis_labels', axis_labels) def find_axis_label_dimension(self, name, contains=False): axis_labels = self.data_info.get_meta_data('axis_labels') for i in range(len(axis_labels)): if contains is True: for names in axis_labels[i].keys(): if name in names: return i else: if name in axis_labels[i].keys(): return i raise Exception("Cannot find the specifed axis label.") def finalise_patterns(self): check = 0 check += self.check_pattern('SINOGRAM') check += self.check_pattern('PROJECTION') if check is 2: self.set_main_axis('SINOGRAM') self.set_main_axis('PROJECTION') elif check is 1: pass def check_pattern(self, pattern_name): patterns = self.get_data_patterns() try: patterns[pattern_name] except KeyError: return 0 return 1 def convert_pattern_directions(self, dtype): pattern = self.get_data_patterns()[dtype] nDims = sum([len(i) for i in pattern.values()]) for p in pattern: ddirs = pattern[p] pattern[p] = self.non_negative_directions(ddirs, nDims) def non_negative_directions(self, ddirs, nDims): index = [i for i in range(len(ddirs)) if ddirs[i] < 0] list_ddirs = list(ddirs) for i in index: list_ddirs[i] = nDims + ddirs[i] return tuple(list_ddirs) def check_direction(self, tdir, dname): if not isinstance(tdir, int): raise TypeError('The direction should be an integer.') patterns = self.get_data_patterns() if not patterns: raise Exception("Please add available patterns before setting the" " direction ", dname) def set_main_axis(self, pname): patterns = self.get_data_patterns() n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM' d1 = patterns[n1]['core_dir'] d2 = patterns[pname]['slice_dir'] tdir = set(d1).intersection(set(d2)) self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'], list(tdir)[0]) def trim_input_data(self, **kwargs): if self.tomo_raw_obj: self.get_tomo_raw().select_image_key(**kwargs) def trim_output_data(self, copy_obj, **kwargs): if self.tomo_raw_obj: self.get_tomo_raw().remove_image_key(copy_obj, **kwargs) self.set_preview([]) def get_axis_label_keys(self): axis_labels = self.data_info.get_meta_data('axis_labels') axis_label_keys = [] for labels in axis_labels: for key in labels.keys(): axis_label_keys.append(key) return axis_label_keys def get_current_and_next_patterns(self, datasets_lists): current_datasets = datasets_lists[0] patterns_list = [] for current_data in current_datasets['out_datasets']: current_name = current_data['name'] current_pattern = current_data['pattern'] next_pattern = self.find_next_pattern(datasets_lists[1:], current_name) patterns_list.append({'current': current_pattern, 'next': next_pattern}) self.exp.meta_data.set_meta_data('current_and_next', patterns_list) def find_next_pattern(self, datasets_lists, current_name): next_pattern = [] for next_data_list in datasets_lists: for next_data in next_data_list['in_datasets']: if next_data['name'] == current_name: next_pattern = next_data['pattern'] return next_pattern return next_pattern