示例#1
0
文件: data.py 项目: FedeMPouzols/Savu
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """

    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.tomo_raw_obj = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set_meta_data('name', name)
        self.data_info.set_meta_data('data_patterns', {})
        self.data_info.set_meta_data('shape', None)
        self.data_info.set_meta_data('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _get_transport_data(self):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport = self.exp.meta_data.get_meta_data("transport")
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        return cu.import_class(transport_data)

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get_meta_data('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get_meta_data('data_patterns')

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set_meta_data('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get_meta_data('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get_meta_data("nDims")
        shape = self.data_info.get_meta_data('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get_meta_data('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set_meta_data(['data_patterns', dtype, args],
                                             kwargs[args])
            
            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._add_extra_dims_to_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get_meta_data("nDims"):
                    actualDims = self.data_info.get_meta_data('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified."
                               % (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set_meta_data('nDims', nDims)
        else:
            raise Exception("The data pattern '%s'does not exist. Please "
                            "choose from the following list: \n'%s'",
                            dtype, str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dir'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dir'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set_meta_data('nDims', len(args))
        axis_labels = []
        for arg in args:
            try:
                axis = arg.split('.')
                axis_labels.append({axis[0]: axis[1]})
            except:
                # data arrives here, but that may be an error
                pass
        self.data_info.set_meta_data('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get_meta_data('axis_labels')

    def find_axis_label_dimension(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get_meta_data('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self.non_negative_directions(ddirs, nDims)

    def non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dir']
        d2 = patterns[pname]['slice_dir']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'],
                                     list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get_meta_data('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = []
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list.append({'current': current_pattern,
                                  'next': next_pattern})
        self.exp.meta_data.set_meta_data('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def get_slice_directions(self):
        """ Get pattern slice_dir of pattern currently associated with the
        dataset (if any).

        :returns: the slicing dimensions.
        :rtype: tuple(int)
        """
        return self._get_plugin_data().get_slice_directions()
示例#2
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the 
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}}

    def meta_data_setup(self, process_file):
        self.meta_data.load_experiment_collection()
        self.meta_data.plugin_list = PluginList()
        self.meta_data.plugin_list.populate_plugin_list(process_file)

    def create_data_object(self, dtype, name, bases=[]):
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name)
            data_obj = self.index[dtype][name]
            bases.append(data_obj.get_transport_data(self.meta_data.get_meta_data("transport")))
            data_obj.add_base_classes(bases)
        return self.index[dtype][name]


    def set_nxs_filename(self):
        name = self.index["in_data"].keys()[0]
        filename = os.path.basename(self.index["in_data"][name].backing_file.filename)
        filename = os.path.splitext(filename)[0]
        filename = os.path.join(self.meta_data.get_meta_data("out_path"),
                                "%s_processed_%s.nxs" % (filename,
                                time.strftime("%Y%m%d%H%M%S")))
        self.meta_data.set_meta_data("nxs_filename", filename)

    def clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def clear_out_data_objects(self):
        self.index["out_data"] = {}

    def set_out_data_to_in(self):
        self.index["in_data"] = self.index["out_data"]
        self.index["out_data"] = {}

    def barrier(self):
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a barrier")
            MPI.COMM_WORLD.Barrier()
            logging.debug("Past the barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
示例#3
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = False
        self.end_pad = False
        self.split = None

    def get_total_frames(self):
        """ Get the total number of frames to process.

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information int the meta data dict.
        """

        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.__set_slice_directions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.__set_core_shape(tuple(shape))
        if self._get_frame_chunk() > 1:
            shape.insert(slice_idx, self._get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        """ Get the shape of the plugin data to be processed each time.
        """
        return self.shape

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_directions(self):
        """ Set the slice directions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        """ Get the slice directions (slice_dir) of the dataset.
        """
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self._get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0], )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        """ Get the core data directions

        :returns: value associated with pattern key ``core_dir``
        :rtype: tuple
        """
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.__set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_directions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_directions() + self.get_slice_directions()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _set_frame_chunk(self, nFrames):
        """ Set the number of frames to process at a time """
        self.meta_data.set_meta_data("nFrames", nFrames)

    def _get_frame_chunk(self):
        """ Get the number of frames to be processes at a time.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get_meta_data("nFrames")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self._set_frame_chunk(gcd(frame_chunk, chunk))
        return self.meta_data.get_meta_data("nFrames")

    def plugin_data_setup(self, pattern_name, chunk, fixed=False, split=None):
        """ Setup the PluginData object.

        :param str pattern_name: A pattern name
        :param int chunk: Number of frames to process at a time
        :keyword bool fixed: setting fixed=True will ensure the plugin \
            receives the same sized data array each time (padding if necessary)

        """
        self.__set_pattern(pattern_name)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')
        if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk):
            self._plugin.chunk = True
        self._set_frame_chunk(chunk)
        self.__set_shape()
        self.fixed_dims = fixed
        self.split = split
示例#4
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """

    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.selected_data = False
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin

    def get_total_frames(self):
        """ Get the total number of frames to process.

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information int the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.__set_slice_directions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.__set_core_shape(tuple(shape))
        if self._get_frame_chunk() > 1:
            shape.insert(slice_idx, self._get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        """ Get the shape of the plugin data to be processed each time.
        """
        return self.shape

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_directions(self):
        """ Set the slice directions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        """ Get the slice directions (slice_dir) of the dataset.
        """
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self._get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0],)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        """ Get the core data directions

        :returns: value associated with pattern key ``core_dir``
        :rtype: tuple
        """
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.__set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_directions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def _set_frame_chunk(self, nFrames):
        """ Set the number of frames to process at a time """
        self.meta_data.set_meta_data("nFrames", nFrames)

    def _get_frame_chunk(self):
        """ Get the number of frames to be processes at a time.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get_meta_data("nFrames")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self._set_frame_chunk(gcd(frame_chunk, chunk))
        return self.meta_data.get_meta_data("nFrames")

    def plugin_data_setup(self, pattern_name, chunk):
        """ Setup the PluginData object.

        :param str pattern_name: A pattern name
        :param int chunk: Number of frames to process at a time
        """
        self.__set_pattern(pattern_name)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')
        if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk):
            self._plugin.chunk = True
        self._set_frame_chunk(chunk)
        self.__set_shape()
示例#5
0
class Data(Pattern):
    """
    The Data class dynamically inherits from relevant data structure classes 
    at runtime and holds the data array.
    """

    def __init__(self, name):
        self.meta_data = MetaData()
        super(Data, self).__init__()
        self.name = name
        self.backing_file = None
        self.data = None
    

    def get_transport_data(self, transport): 
        transport_data = "savu.data.transport_data." + transport + "_transport_data"
        return self.import_class(transport_data)

    
    def import_class(self, class_name):
        name = class_name
        mod = __import__(name)
        components = name.split('.')
        for comp in components[1:]:
            mod = getattr(mod, comp)
        temp = name.split('.')[-1]
        module2class = ''.join(x.capitalize() for x in temp.split('_'))
        return getattr(mod, module2class.split('.')[-1])       
        
            
    def __deepcopy__(self, memo):
        return self


    def add_base(self, ExtraBase):
        cls = self.__class__
        self.__class__ = cls.__class__(cls.__name__, (cls, ExtraBase), {})
        ExtraBase().__init__()


    def add_base_classes(self, bases):
        for base in bases:
            self.add_base(base)


    def set_shape(self, shape):
        self.meta_data.set_meta_data('shape', shape)
        
   
    def get_shape(self):
        shape = self.meta_data.get_meta_data('shape')
        try:
            dirs = self.meta_data.get_meta_data("fixed_directions")
            shape = list(shape)
            for ddir in dirs:
                shape[ddir] = 1
            shape = tuple(shape)
        except KeyError:
            pass
        return shape


    def set_dist(self, dist):
        self.meta_data.set_meta_data('dist', dist)
        
    
    def get_dist(self):
        return self.meta_data.get_meta_data('dist')

        
    def set_data_params(self, pattern, chunk_size, **kwargs):
        self.set_current_pattern_name(pattern)
        self.set_nFrames(chunk_size)
示例#6
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}, "mapping": {}}
        self.nxs_file = None

    def get_meta_data(self):
        return self.meta_data

    def meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get_meta_data('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get_meta_data('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list.populate_plugin_list(process_file)

    def create_data_object(self, dtype, name, bases=[]):
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            bases.append(data_obj.get_transport_data())
            data_obj.add_base_classes(bases)
        return self.index[dtype][name]

    def set_nxs_filename(self):
        name = self.index["in_data"].keys()[0]
        filename = os.path.basename(self.index["in_data"][name].
                                    backing_file.filename)
        filename = os.path.splitext(filename)[0]
        filename = os.path.join(self.meta_data.get_meta_data("out_path"),
                                "%s_processed_%s.nxs" %
                                (filename, time.strftime("%Y%m%d%H%M%S")))
        self.meta_data.set_meta_data("nxs_filename", filename)
        if self.meta_data.get_meta_data("mpi") is True:
            self.nxs_file = h5py.File(filename, 'w', driver='mpio',
                                      comm=MPI.COMM_WORLD)
        else:
            self.nxs_file = h5py.File(filename, 'w')

    def remove_dataset(self, data_obj):
        data_obj.close_file()
        del self.index["out_data"][data_obj.data_info.get_meta_data('name')]

    def clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def clear_out_data_objects(self):
        self.index["out_data"] = {}

    def merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                if key in self.index['in_data'].keys():
                    data.meta_data.set_dictionary(
                        self.index['in_data'][key].meta_data.get_dictionary())
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def reorganise_datasets(self, out_data_objs, link_type):
        out_data_list = self.index["out_data"]
        self.close_unwanted_files(out_data_list)
        self.remove_unwanted_data(out_data_objs)

        self.barrier()
        self.copy_out_data_to_in_data(link_type)

        self.barrier()
        self.clear_out_data_objects()

    def remove_unwanted_data(self, out_data_objs):
        for out_objs in out_data_objs:
            if out_objs.remove is True:
                self.remove_dataset(out_objs)

    def close_unwanted_files(self, out_data_list):
        for out_objs in out_data_list:
            if out_objs in self.index["in_data"].keys():
                self.index["in_data"][out_objs].close_file()

    def copy_out_data_to_in_data(self, link_type):
        for key in self.index["out_data"]:
            output = self.index["out_data"][key]
            output.save_data(link_type)
            self.index["in_data"][key] = copy.deepcopy(output)

    def set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            data_names.append(key)
        return data_names

    def barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a barrier")
            comm_dict['comm'].Barrier()
            logging.debug("Past the barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
示例#7
0
文件: data.py 项目: FedeMPouzols/Savu
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """
    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.tomo_raw_obj = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set_meta_data('name', name)
        self.data_info.set_meta_data('data_patterns', {})
        self.data_info.set_meta_data('shape', None)
        self.data_info.set_meta_data('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _get_transport_data(self):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport = self.exp.meta_data.get_meta_data("transport")
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        return cu.import_class(transport_data)

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get_meta_data('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get_meta_data('data_patterns')

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set_meta_data('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get_meta_data('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get_meta_data("nDims")
        shape = self.data_info.get_meta_data('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get_meta_data('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set_meta_data(['data_patterns', dtype, args],
                                             kwargs[args])

            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._add_extra_dims_to_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get_meta_data("nDims"):
                    actualDims = self.data_info.get_meta_data('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified." %
                               (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set_meta_data('nDims', nDims)
        else:
            raise Exception(
                "The data pattern '%s'does not exist. Please "
                "choose from the following list: \n'%s'", dtype,
                str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dir'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dir'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set_meta_data('nDims', len(args))
        axis_labels = []
        for arg in args:
            try:
                axis = arg.split('.')
                axis_labels.append({axis[0]: axis[1]})
            except:
                # data arrives here, but that may be an error
                pass
        self.data_info.set_meta_data('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get_meta_data('axis_labels')

    def find_axis_label_dimension(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get_meta_data('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self.non_negative_directions(ddirs, nDims)

    def non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dir']
        d2 = patterns[pname]['slice_dir']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'],
                                     list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get_meta_data('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = []
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list.append({
                'current': current_pattern,
                'next': next_pattern
            })
        self.exp.meta_data.set_meta_data('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def get_slice_directions(self):
        """ Get pattern slice_dir of pattern currently associated with the
        dataset (if any).

        :returns: the slicing dimensions.
        :rtype: tuple(int)
        """
        return self._get_plugin_data().get_slice_directions()
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}, "mapping": {}}
        self.nxs_file = None

    def get_meta_data(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get_meta_data(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get_meta_data('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get_meta_data('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list._populate_plugin_list(process_file)

    def create_data_object(self, dtype, name):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        bases = []
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            bases.append(data_obj._get_transport_data())
            cu.add_base_classes(data_obj, bases)
        return self.index[dtype][name]

    def _set_nxs_filename(self):
        folder = self.meta_data.get_meta_data('out_path')
        fname = os.path.basename(folder.split('_')[-1]) + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set_meta_data("nxs_filename", filename)

        if self.meta_data.get_meta_data("mpi") is True:
            self.nxs_file = h5py.File(filename, 'w', driver='mpio',
                                      comm=MPI.COMM_WORLD)
        else:
            self.nxs_file = h5py.File(filename, 'w')

    def __remove_dataset(self, data_obj):
        data_obj._close_file()
        del self.index["out_data"][data_obj.data_info.get_meta_data('name')]

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                if key in self.index['in_data'].keys():
                    data.meta_data._set_dictionary(
                        self.index['in_data'][key].meta_data.get_dictionary())
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _reorganise_datasets(self, out_data_objs, link_type):
        out_data_list = self.index["out_data"]
        self.__close_unwanted_files(out_data_list)
        self.__remove_unwanted_data(out_data_objs)

        self._barrier()
        self.__copy_out_data_to_in_data(link_type)

        self._barrier()
        self.index['out_data'] = {}

    def __remove_unwanted_data(self, out_data_objs):
        for out_objs in out_data_objs:
            if out_objs.remove is True:
                self.__remove_dataset(out_objs)

    def __close_unwanted_files(self, out_data_list):
        for out_objs in out_data_list:
            if out_objs in self.index["in_data"].keys():
                self.index["in_data"][out_objs]._close_file()

    def __copy_out_data_to_in_data(self, link_type):
        for key in self.index["out_data"]:
            output = self.index["out_data"][key]
            output._save_data(link_type)
            self.index["in_data"][key] = copy.deepcopy(output)

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a _barrier %s", comm_dict)
            comm_dict['comm'].barrier()
            logging.debug("Past the _barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
示例#9
0
class PluginData(object):

    def __init__(self, data_obj):
        self.data_obj = data_obj
        self.data_obj.set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.selected_data = False
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []

    def get_total_frames(self):
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def set_pattern(self, name):
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.set_slice_directions()

    def get_pattern_name(self):
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def set_shape(self):
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.set_core_shape(tuple(shape))
        if self.get_frame_chunk() > 1:
            shape.insert(slice_idx, self.get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        return self.shape

    def set_core_shape(self, shape):
        self.core_shape = shape

    def get_core_shape(self):
        return self.core_shape

    def check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def set_slice_directions(self):
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self.get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0],)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one. The input param is a tuple stating the
        desired order of slicing directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.set_shape()

    def get_fixed_directions(self):
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def set_frame_chunk(self, nFrames):
        # number of frames to process at a time
        self.meta_data.set_meta_data("nFrames", nFrames)

    def get_frame_chunk(self):
        return self.meta_data.get_meta_data("nFrames")

    def get_index(self, indices):
        shape = self.get_shape()
        nDims = len(shape)
        name = self.get_current_pattern_name()
        ddirs = self.get_data_patterns()
        core_dir = ddirs[name]["core_dir"]
        slice_dir = ddirs[name]["slice_dir"]

        self.check_dimensions(indices, core_dir, slice_dir, nDims)

        index = [slice(None)]*nDims
        count = 0
        for tdir in slice_dir:
            index[tdir] = slice(indices[count], indices[count]+1, 1)
            count += 1

        return tuple(index)

    def plugin_data_setup(self, pattern_name, chunk):
        self.set_pattern(pattern_name)
        self.set_frame_chunk(chunk)
        self.set_shape()

    def set_temp_pad_dict(self, pad_dict):
        self.meta_data.set_meta_data('temp_pad_dict', pad_dict)

    def get_temp_pad_dict(self):
        if 'temp_pad_dict' in self.meta_data.get_dictionary().keys():
            return self.meta_data.get_dictionary()['temp_pad_dict']

    def delete_temp_pad_dict(self):
        del self.meta_data.get_dictionary()['temp_pad_dict']
示例#10
0
class Data(object):
    """
    The Data class dynamically inherits from relevant data structure classes
    at runtime and holds the data array.
    """

    def __init__(self, name, exp):
        self.meta_data = MetaData()
        self.pattern_list = self.set_available_pattern_list()
        self.data_info = MetaData()
        self.initialise_data_info(name)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.tomo_raw_obj = None
        self.data_mapping = None
        self.variable_length_flag = False
        self.dtype = None
        self.remove = False
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.mapping = None
        self.map_dim = []
        self.revert_shape = None

    def initialise_data_info(self, name):
        self.data_info.set_meta_data('name', name)
        self.data_info.set_meta_data('data_patterns', {})
        self.data_info.set_meta_data('shape', None)
        self.data_info.set_meta_data('nDims', None)

    def set_plugin_data(self, plugin_data_obj):
        self._plugin_data_obj = plugin_data_obj

    def clear_plugin_data(self):
        self._plugin_data_obj = None

    def get_plugin_data(self):
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def set_tomo_raw(self, tomo_raw_obj):
        self.tomo_raw_obj = tomo_raw_obj

    def clear_tomo_raw(self):
        self.tomo_raw_obj = None

    def get_tomo_raw(self):
        if self.tomo_raw_obj is not None:
            return self.tomo_raw_obj
        else:
            raise Exception("There is no TomoRaw object associated with "
                            "the Data object.")

    def get_transport_data(self):
        transport = self.exp.meta_data.get_meta_data("transport")
        "SETTING UP THE TRANSPORT DATA"
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        return import_class(transport_data)

    def __deepcopy__(self, memo):
        name = self.data_info.get_meta_data('name')
        new_obj = Data(name, self.exp)
        new_obj.add_base_classes(self.get_transport_data())
        new_obj.meta_data = self.meta_data
        new_obj.pattern_list = copy.deepcopy(self.pattern_list)
        new_obj.data_info = copy.deepcopy(self.data_info)
        new_obj.exp = self.exp
        new_obj._plugin_data_obj = self._plugin_data_obj
        new_obj.tomo_raw_obj = self.tomo_raw_obj
        new_obj.data_mapping = self.data_mapping
        new_obj.variable_length_flag = copy.deepcopy(self.variable_length_flag)
        new_obj.dtype = copy.deepcopy(self.dtype)
        new_obj.remove = copy.deepcopy(self.remove)
        new_obj.group_name = self.group_name
        new_obj.group = self.group
        new_obj.backing_file = self.backing_file
        new_obj.data = self.data
        new_obj.next_shape = copy.deepcopy(self.next_shape)
        new_obj.mapping = copy.deepcopy(self.mapping)
        new_obj.map_dim = copy.deepcopy(self.map_dim)
        new_obj.revert_shape = copy.deepcopy(self.map_dim)
        return new_obj

    def add_base(self, ExtraBase):
        cls = self.__class__
        self.__class__ = cls.__class__(cls.__name__, (cls, ExtraBase), {})
        ExtraBase().__init__()

    def add_base_classes(self, bases):
        bases = bases if isinstance(bases, list) else [bases]
        for base in bases:
            self.add_base(base)

    def external_link(self):
        return h5py.ExternalLink(self.backing_file.filename, self.group_name)

    def create_dataset(self, *args, **kwargs):
        """
        Set up required information when an output dataset has been created by
        a plugin
        """
        self.dtype = kwargs.get('dtype', np.float32)
        # remove from the plugin chain
        self.remove = kwargs.get('remove', False)
        if len(args) is 1:
            self.copy_dataset(args[0], removeDim=kwargs.get('removeDim', []))
            if args[0].tomo_raw_obj:
                self.set_tomo_raw(copy.deepcopy(args[0].get_tomo_raw()))
                self.get_tomo_raw().data_obj = self
        else:
            try:
                shape = kwargs['shape']
                self.create_axis_labels(kwargs['axis_labels'])
            except KeyError:
                raise Exception("Please state axis_labels and shape when "
                                "creating a new dataset")
            self.set_new_dataset_shape(shape)

            if 'patterns' in kwargs:
                self.copy_patterns(kwargs['patterns'])
        self.set_preview([])

    def set_new_dataset_shape(self, shape):
        if isinstance(shape, Data):
            self.find_and_set_shape(shape)
        elif type(shape) is dict:
            self.set_variable_flag()
            self.set_shape((shape[shape.keys()[0]] + ('var',)))
        else:
            pData = self.get_plugin_data()
            self.set_shape(shape + tuple(pData.extra_dims))
            if 'var' in shape:
                self.set_variable_flag()

    def copy_patterns(self, copy_data):
        if isinstance(copy_data, Data):
            patterns = copy_data.get_data_patterns()
        else:
            data = copy_data.keys()[0]
            pattern_list = copy_data[data]

            all_patterns = data.get_data_patterns()
            if len(pattern_list[0].split('.')) > 1:
                patterns = self.copy_patterns_removing_dimensions(
                    pattern_list, all_patterns, len(data.get_shape()))
            else:
                patterns = {}
                for pattern in pattern_list:
                    patterns[pattern] = all_patterns[pattern]
        self.set_data_patterns(patterns)

    def copy_patterns_removing_dimensions(self, pattern_list, all_patterns,
                                          nDims):
        copy_patterns = {}
        for new_pattern in pattern_list:
            name, all_dims = new_pattern.split('.')
            if name is '*':
                copy_patterns = all_patterns
            else:
                copy_patterns[name] = all_patterns[name]
            dims = tuple(map(int, all_dims.split(',')))
            dims = self.non_negative_directions(dims, nDims=nDims)

        patterns = {}
        for name, pattern_dict in copy_patterns.iteritems():
            empty_flag = False
            for ddir in pattern_dict:
                s_dims = self.non_negative_directions(
                    pattern_dict[ddir], nDims=nDims)
                new_dims = tuple([sd for sd in s_dims if sd not in dims])
                pattern_dict[ddir] = new_dims
                if not new_dims:
                    empty_flag = True
            if empty_flag is False:
                patterns[name] = pattern_dict
        return patterns

    def copy_dataset(self, copy_data, **kwargs):
        if copy_data.mapping:
            # copy label entries from meta data
            map_data = self.exp.index['mapping'][copy_data.get_name()]
            map_mData = map_data.meta_data
            map_axis_labels = map_data.data_info.get_meta_data('axis_labels')
            for axis_label in map_axis_labels:
                if axis_label.keys()[0] in map_mData.get_dictionary().keys():
                    map_label = map_mData.get_meta_data(axis_label.keys()[0])
                    copy_data.meta_data.set_meta_data(axis_label.keys()[0],
                                                      map_label)
            copy_data = map_data
        patterns = copy.deepcopy(copy_data.get_data_patterns())
        self.copy_labels(copy_data)
        self.find_and_set_shape(copy_data)
        self.set_data_patterns(patterns)

    def create_axis_labels(self, axis_labels):
        if isinstance(axis_labels, Data):
            self.copy_labels(axis_labels)
        elif isinstance(axis_labels, dict):
            data = axis_labels.keys()[0]
            self.copy_labels(data)          
            self.amend_axis_labels(axis_labels[data])
        else:
            self.set_axis_labels(*axis_labels)
            # if parameter tuning
            if self.get_plugin_data().multi_params_dict:
                self.add_extra_dims_labels()

    def copy_labels(self, copy_data):
        nDims = copy.copy(copy_data.data_info.get_meta_data('nDims'))
        axis_labels = \
            copy.copy(copy_data.data_info.get_meta_data('axis_labels'))
        self.data_info.set_meta_data('nDims', nDims)
        self.data_info.set_meta_data('axis_labels', axis_labels)
        # if parameter tuning
        if self.get_plugin_data().multi_params_dict:
            self.add_extra_dims_labels()

    def add_extra_dims_labels(self):
        params_dict = self.get_plugin_data().multi_params_dict
        # add multi_params axis labels from dictionary in pData
        nDims = self.data_info.get_meta_data('nDims')
        axis_labels = self.data_info.get_meta_data('axis_labels')
        axis_labels.extend([0]*len(params_dict))
        for key, value in params_dict.iteritems():
            title = value['label'].encode('ascii', 'ignore')
            name, unit = title.split('.')
            axis_labels[nDims + key] = {name: unit}
            # add parameter values to the meta_data
            self.meta_data.set_meta_data(name, np.array(value['values']))
        self.data_info.set_meta_data('nDims', nDims + len(self.extra_dims))
        self.data_info.set_meta_data('axis_labels', axis_labels)

    def amend_axis_labels(self, *args):
        axis_labels = self.data_info.get_meta_data('axis_labels')
        removed_dims = 0
        for arg in args[0]:
            label = arg.split('.')
            if len(label) is 1:
                del axis_labels[int(label[0]) + removed_dims]
                removed_dims += 1
                self.data_info.set_meta_data(
                    'nDims', self.data_info.get_meta_data('nDims') - 1)
            else:
                if int(label[0]) < 0:
                    axis_labels[int(label[0]) + removed_dims] = \
                        {label[1]: label[2]}
                else:
                    if int(label[0]) < self.data_info.get_meta_data('nDims'):
                        axis_labels[int(label[0])] = {label[1]: label[2]}
                    else:
                        axis_labels.insert(int(label[0]), {label[1]: label[2]})

    def set_data_patterns(self, patterns):
        self.add_extra_dims_to_patterns(patterns)
        self.data_info.set_meta_data('data_patterns', patterns)

    def add_extra_dims_to_patterns(self, patterns):
        all_dims = range(len(self.get_shape()))
        for p in patterns:
            pDims = patterns[p]['core_dir'] + patterns[p]['slice_dir']
            for dim in all_dims:
                if dim not in pDims:
                    patterns[p]['slice_dir'] += (dim,)

    def get_data_patterns(self):
        return self.data_info.get_meta_data('data_patterns')

    def set_shape(self, shape):
        self.data_info.set_meta_data('shape', shape)
        self.check_dims()

    def get_shape(self):
        shape = self.data_info.get_meta_data('shape')
        return shape

    def set_preview(self, preview_list, **kwargs):
        self.revert_shape = kwargs.get('revert', self.revert_shape)
        shape = self.get_shape()
        if preview_list:
            starts, stops, steps, chunks = \
                self.get_preview_indices(preview_list)
            shape_change = True
        else:
            starts, stops, steps, chunks = \
                [[0]*len(shape), shape, [1]*len(shape), [1]*len(shape)]
            shape_change = False
        self.set_starts_stops_steps(starts, stops, steps, chunks,
                                    shapeChange=shape_change)

    def unset_preview(self):
        self.set_preview([])
        self.set_shape(self.revert_shape)
        self.revert_shape = None

    def set_starts_stops_steps(self, starts, stops, steps, chunks,
                               shapeChange=True):
        self.data_info.set_meta_data('starts', starts)
        self.data_info.set_meta_data('stops', stops)
        self.data_info.set_meta_data('steps', steps)
        self.data_info.set_meta_data('chunks', chunks)
        if shapeChange or self.mapping:
            self.set_reduced_shape(starts, stops, steps, chunks)

    def get_preview_indices(self, preview_list):
        starts = len(preview_list)*[None]
        stops = len(preview_list)*[None]
        steps = len(preview_list)*[None]
        chunks = len(preview_list)*[None]
        for i in range(len(preview_list)):
            starts[i], stops[i], steps[i], chunks[i] = \
                self.convert_indices(preview_list[i].split(':'), i)
        return starts, stops, steps, chunks

    def convert_indices(self, idx, dim):
        shape = self.get_shape()
        mid = shape[dim]/2
        end = shape[dim]

        if self.mapping:
            map_shape = self.exp.index['mapping'][self.get_name()].get_shape()
            midmap = map_shape[dim]/2
            endmap = map_shape[dim]

        idx = [eval(equ) for equ in idx]
        idx = [idx[i] if idx[i] > -1 else shape[dim]+1+idx[i] for i in
               range(len(idx))]
        return idx

    def get_starts_stops_steps(self):
        starts = self.data_info.get_meta_data('starts')
        stops = self.data_info.get_meta_data('stops')
        steps = self.data_info.get_meta_data('steps')
        chunks = self.data_info.get_meta_data('chunks')
        return starts, stops, steps, chunks

    def set_reduced_shape(self, starts, stops, steps, chunks):
        orig_shape = self.get_shape()
        self.data_info.set_meta_data('orig_shape', orig_shape)
        new_shape = []
        for dim in range(len(starts)):
            new_shape.append(np.prod((self.get_slice_dir_matrix(dim).shape)))
        self.set_shape(tuple(new_shape))

        # reduce shape of mapping data if it exists
        if self.mapping:
            self.set_mapping_reduced_shape(orig_shape, new_shape,
                                           self.get_name())

    def set_mapping_reduced_shape(self, orig_shape, new_shape, name):
        map_obj = self.exp.index['mapping'][name]
        map_shape = np.array(map_obj.get_shape())
        diff = np.array(orig_shape) - map_shape[:len(orig_shape)]
        not_map_dim = np.where(diff == 0)[0]
        map_dim = np.where(diff != 0)[0]
        self.map_dim = map_dim
        map_obj.data_info.set_meta_data('full_map_dim_len', map_shape[map_dim])
        map_shape[not_map_dim] = np.array(new_shape)[not_map_dim]

        # assuming only one extra dimension added for now
        starts, stops, steps, chunks = self.get_starts_stops_steps()
        start = starts[map_dim] % map_shape[map_dim]
        stop = min(stops[map_dim], map_shape[map_dim])

        temp = len(np.arange(start, stop, steps[map_dim]))*chunks[map_dim]
        map_shape[len(orig_shape)] = np.ceil(new_shape[map_dim]/temp)
        map_shape[map_dim] = new_shape[map_dim]/map_shape[len(orig_shape)]
        map_obj.data_info.set_meta_data('map_dim_len', map_shape[map_dim])
        self.exp.index['mapping'][name].set_shape(tuple(map_shape))

    def find_and_set_shape(self, data):
        pData = self.get_plugin_data()
        new_shape = copy.copy(data.get_shape()) + tuple(pData.extra_dims)
        self.set_shape(new_shape)

    def set_variable_flag(self):
        self.variable_length_flag = True

    def get_variable_flag(self):
        return self.variable_length_flag

    def set_variable_array_length(self, var_size):
        var_size = var_size if isinstance(var_size, list) else [var_size]
        shape = list(self.get_shape())
        count = 0
        for i in range(len(shape)):
            if isinstance(shape[i], str):
                shape[i] = var_size[count]
                count += 1
        self.next_shape = tuple(shape)

    def check_dims(self):
        nDims = self.data_info.get_meta_data("nDims")
        shape = self.data_info.get_meta_data('shape')
        if nDims:
            if self.get_variable_flag() is False:
                if len(shape) != nDims:
                    error_msg = ("The number of axis labels, %d, does not "
                                 "coincide with the number of data "
                                 "dimensions %d." % (nDims, len(shape)))
                    raise Exception(error_msg)

    def set_name(self, name):
        self.data_info.set_meta_data('name', name)

    def get_name(self):
        return self.data_info.get_meta_data('name')

    def set_data_params(self, pattern, chunk_size, **kwargs):
        self.set_current_pattern_name(pattern)
        self.set_nFrames(chunk_size)

    def set_available_pattern_list(self):
        pattern_list = ["SINOGRAM",
                        "PROJECTION",
                        "VOLUME_YZ",
                        "VOLUME_XZ",
                        "VOLUME_XY",
                        "VOLUME_3D",
                        "SPECTRUM",
                        "DIFFRACTION",
                        "CHANNEL",
                        "SPECTRUM_STACK",
                        "PROJECTION_STACK",
                        "METADATA"]
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set_meta_data(['data_patterns', dtype, args],
                                             kwargs[args])
            self.convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self.add_extra_dims_to_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get_meta_data("nDims"):
                    actualDims = self.data_info.get_meta_data('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified."
                               % (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set_meta_data('nDims', nDims)
        else:
            raise Exception("The data pattern '%s'does not exist. Please "
                            "choose from the following list: \n'%s'",
                            dtype, str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        self.add_pattern("VOLUME_YZ", **self.get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.get_dirs_for_volume(x, y, z))

    def get_dirs_for_volume(self, dim1, dim2, sdir):
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dir'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dir'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        self.data_info.set_meta_data('nDims', len(args))
        axis_labels = []
        for arg in args:
            try:
                axis = arg.split('.')
                axis_labels.append({axis[0]: axis[1]})
            except:
                # data arrives here, but that may be an error
                pass
        self.data_info.set_meta_data('axis_labels', axis_labels)

    def find_axis_label_dimension(self, name, contains=False):
        axis_labels = self.data_info.get_meta_data('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def finalise_patterns(self):
        check = 0
        check += self.check_pattern('SINOGRAM')
        check += self.check_pattern('PROJECTION')
        if check is 2:
            self.set_main_axis('SINOGRAM')
            self.set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def check_pattern(self, pattern_name):
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def convert_pattern_directions(self, dtype):
        pattern = self.get_data_patterns()[dtype]
        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self.non_negative_directions(ddirs, nDims)

    def non_negative_directions(self, ddirs, nDims):
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def check_direction(self, tdir, dname):
        if not isinstance(tdir, int):
            raise TypeError('The direction should be an integer.')

        patterns = self.get_data_patterns()
        if not patterns:
            raise Exception("Please add available patterns before setting the"
                            " direction ", dname)

    def set_main_axis(self, pname):
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dir']
        d2 = patterns[pname]['slice_dir']
        tdir = set(d1).intersection(set(d2))
        self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'],
                                     list(tdir)[0])

    def trim_input_data(self, **kwargs):
        if self.tomo_raw_obj:
            self.get_tomo_raw().select_image_key(**kwargs)

    def trim_output_data(self, copy_obj, **kwargs):
        if self.tomo_raw_obj:
            self.get_tomo_raw().remove_image_key(copy_obj, **kwargs)
            self.set_preview([])

    def get_axis_label_keys(self):
        axis_labels = self.data_info.get_meta_data('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def get_current_and_next_patterns(self, datasets_lists):
        current_datasets = datasets_lists[0]
        patterns_list = []
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.find_next_pattern(datasets_lists[1:],
                                                  current_name)
            patterns_list.append({'current': current_pattern,
                                  'next': next_pattern})
        self.exp.meta_data.set_meta_data('current_and_next', patterns_list)

    def find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern
示例#11
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """
    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}, "mapping": {}}
        self.nxs_file = None

    def get_meta_data(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get_meta_data(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get_meta_data('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get_meta_data('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list._populate_plugin_list(process_file)

    def create_data_object(self, dtype, name):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        bases = []
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            bases.append(data_obj._get_transport_data())
            cu.add_base_classes(data_obj, bases)
        return self.index[dtype][name]

    def _set_nxs_filename(self):
        folder = self.meta_data.get_meta_data('out_path')
        fname = os.path.basename(folder.split('_')[-1]) + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set_meta_data("nxs_filename", filename)

        if self.meta_data.get_meta_data("mpi") is True:
            self.nxs_file = h5py.File(filename,
                                      'w',
                                      driver='mpio',
                                      comm=MPI.COMM_WORLD)
        else:
            self.nxs_file = h5py.File(filename, 'w')

    def __remove_dataset(self, data_obj):
        data_obj._close_file()
        del self.index["out_data"][data_obj.data_info.get_meta_data('name')]

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                if key in self.index['in_data'].keys():
                    data.meta_data._set_dictionary(
                        self.index['in_data'][key].meta_data.get_dictionary())
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _reorganise_datasets(self, out_data_objs, link_type):
        out_data_list = self.index["out_data"]
        self.__close_unwanted_files(out_data_list)
        self.__remove_unwanted_data(out_data_objs)

        self._barrier()
        self.__copy_out_data_to_in_data(link_type)

        self._barrier()
        self.index['out_data'] = {}

    def __remove_unwanted_data(self, out_data_objs):
        for out_objs in out_data_objs:
            if out_objs.remove is True:
                self.__remove_dataset(out_objs)

    def __close_unwanted_files(self, out_data_list):
        for out_objs in out_data_list:
            if out_objs in self.index["in_data"].keys():
                self.index["in_data"][out_objs]._close_file()

    def __copy_out_data_to_in_data(self, link_type):
        for key in self.index["out_data"]:
            output = self.index["out_data"][key]
            output._save_data(link_type)
            self.index["in_data"][key] = copy.deepcopy(output)

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a _barrier %s", comm_dict)
            comm_dict['comm'].barrier()
            logging.debug("Past the _barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())