Exemplo n.º 1
0
 def __init__(self, options):
     self.meta_data = MetaData(options)
     self.__meta_data_setup(options["process_file"])
     self.experiment_collection = {}
     self.index = {"in_data": {}, "out_data": {}}
     self.initial_datasets = None
     self.plugin = None
Exemplo n.º 2
0
 def __init__(self, options):
     self.meta_data = MetaData(options)
     self.__set_system_params()
     self.checkpoint = Checkpointing(self)
     self.__meta_data_setup(options["process_file"])
     self.collection = {}
     self.index = {"in_data": {}, "out_data": {}}
     self.initial_datasets = None
     self.plugin = None
     self._transport = None
     self._barrier_count = 0
Exemplo n.º 3
0
 def __init__(self, data_obj, plugin=None):
     self.data_obj = data_obj
     self.data_obj._set_plugin_data(self)
     self.meta_data = MetaData()
     self.padding = None
     # this flag determines which data is passed. If false then just the
     # data, if true then all data including dark and flat fields.
     self.selected_data = False
     self.shape = None
     self.core_shape = None
     self.multi_params = {}
     self.extra_dims = []
     self._plugin = plugin
Exemplo n.º 4
0
 def __init__(self, exp, name='Checkpointing'):
     self._exp = exp
     self._h5 = Hdf5Utils(self._exp)
     self._filename = '_checkpoint.h5'
     self._file = None
     self._start_values = (0, 0, 0)
     self._completed_plugins = 0
     self._level = None
     self._proc_idx = 0
     self._trans_idx = 0
     self._comm = None
     self._timer = None
     self._set_timer()
     self.meta_data = MetaData()
Exemplo n.º 5
0
 def __init__(self, name, exp):
     super(Data, self).__init__(name)
     self.meta_data = MetaData()
     self.pattern_list = self.__get_available_pattern_list()
     self.data_info = MetaData()
     self.__initialise_data_info(name)
     self._preview = Preview(self)
     self.exp = exp
     self.group_name = None
     self.group = None
     self._plugin_data_obj = None
     self.tomo_raw_obj = None
     self.backing_file = None
     self.data = None
     self.next_shape = None
     self.orig_shape = None
Exemplo n.º 6
0
 def __init__(self, data_obj, plugin=None):
     self.data_obj = data_obj
     self._preview = None
     self.data_obj._set_plugin_data(self)
     self.meta_data = MetaData()
     self.padding = None
     self.pad_dict = None
     self.shape = None
     self.shape_transfer = None
     self.core_shape = None
     self.multi_params = {}
     self.extra_dims = []
     self._plugin = plugin
     self.fixed_dims = True
     self.split = None
     self.boundary_padding = None
     self.no_squeeze = False
     self.pre_tuning_shape = None
Exemplo n.º 7
0
 def __init__(self, name, exp):
     super(Data, self).__init__(name)
     self.meta_data = MetaData()
     self.pattern_list = self.__get_available_pattern_list()
     self.data_info = MetaData()
     self.__initialise_data_info(name)
     self._preview = Preview(self)
     self.exp = exp
     self.group_name = None
     self.group = None
     self._plugin_data_obj = None
     self.raw = None
     self.backing_file = None
     self.data = None
     self.next_shape = None
     self.orig_shape = None
     self.previous_pattern = None
     self.transport_data = None
Exemplo n.º 8
0
 def __init__(self, name, exp):
     self.meta_data = MetaData()
     self.pattern_list = self.set_available_pattern_list()
     self.data_info = MetaData()
     self.initialise_data_info(name)
     self.exp = exp
     self.group_name = None
     self.group = None
     self._plugin_data_obj = None
     self.tomo_raw_obj = None
     self.data_mapping = None
     self.variable_length_flag = False
     self.dtype = None
     self.remove = False
     self.backing_file = None
     self.data = None
     self.next_shape = None
     self.mapping = None
     self.map_dim = []
     self.revert_shape = None
 def __init__(self, options):
     self.meta_data = MetaData(options)
     self.__set_system_params()
     self.checkpoint = Checkpointing(self)
     self.__meta_data_setup(options["process_file"])
     self.experiment_collection = {}
     self.index = {"in_data": {}, "out_data": {}}
     self.initial_datasets = None
     self.plugin = None
     self._transport = None
     self._barrier_count = 0
Exemplo n.º 10
0
 def __init__(self, data_obj):
     self.data_obj = data_obj
     self.data_obj.set_plugin_data(self)
     self.meta_data = MetaData()
     self.padding = None
     # this flag determines which data is passed. If false then just the
     # data, if true then all data including dark and flat fields.
     self.selected_data = False
     self.shape = None
     self.core_shape = None
     self.multi_params = {}
     self.extra_dims = []
Exemplo n.º 11
0
 def __init__(self, exp, name='Checkpointing'):
     self._exp = exp
     self._h5 = Hdf5Utils(self._exp)
     self._filename = '_checkpoint.h5'
     self._file = None
     self._start_values = (0, 0, 0)
     self._completed_plugins = 0
     self._level = None
     self._proc_idx = 0
     self._trans_idx = 0
     self._comm = None
     self._timer = None
     self._set_timer()
     self.meta_data = MetaData()
Exemplo n.º 12
0
 def __init__(self, data_obj, plugin=None):
     self.data_obj = data_obj
     self._preview = None
     self.data_obj._set_plugin_data(self)
     self.meta_data = MetaData()
     self.padding = None
     self.pad_dict = None
     self.shape = None
     self.core_shape = None
     self.multi_params = {}
     self.extra_dims = []
     self._plugin = plugin
     self.fixed_dims = True
     self.split = None
     self.boundary_padding = None
     self.no_squeeze = False
     self.pre_tuning_shape = None
     self._frame_limit = None
Exemplo n.º 13
0
 def __init__(self, name, exp):
     super(Data, self).__init__(name)
     self.meta_data = MetaData()
     self.pattern_list = self.__get_available_pattern_list()
     self.data_info = MetaData()
     self.__initialise_data_info(name)
     self._preview = Preview(self)
     self.exp = exp
     self.group_name = None
     self.group = None
     self._plugin_data_obj = None
     self.tomo_raw_obj = None
     self.data_mapping = None
     self.backing_file = None
     self.data = None
     self.next_shape = None
     self.mapping = None
     self.map_dim = []
Exemplo n.º 14
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = False
        self.end_pad = False
        self.split = None

    def get_total_frames(self):
        """ Get the total number of frames to process.

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information int the meta data dict.
        """

        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.__set_slice_directions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.__set_core_shape(tuple(shape))
        if self._get_frame_chunk() > 1:
            shape.insert(slice_idx, self._get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        """ Get the shape of the plugin data to be processed each time.
        """
        return self.shape

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_directions(self):
        """ Set the slice directions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        """ Get the slice directions (slice_dir) of the dataset.
        """
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self._get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0], )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        """ Get the core data directions

        :returns: value associated with pattern key ``core_dir``
        :rtype: tuple
        """
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.__set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_directions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_directions() + self.get_slice_directions()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _set_frame_chunk(self, nFrames):
        """ Set the number of frames to process at a time """
        self.meta_data.set_meta_data("nFrames", nFrames)

    def _get_frame_chunk(self):
        """ Get the number of frames to be processes at a time.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get_meta_data("nFrames")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self._set_frame_chunk(gcd(frame_chunk, chunk))
        return self.meta_data.get_meta_data("nFrames")

    def plugin_data_setup(self, pattern_name, chunk, fixed=False, split=None):
        """ Setup the PluginData object.

        :param str pattern_name: A pattern name
        :param int chunk: Number of frames to process at a time
        :keyword bool fixed: setting fixed=True will ensure the plugin \
            receives the same sized data array each time (padding if necessary)

        """
        self.__set_pattern(pattern_name)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')
        if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk):
            self._plugin.chunk = True
        self._set_frame_chunk(chunk)
        self.__set_shape()
        self.fixed_dims = fixed
        self.split = split
Exemplo n.º 15
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None
        self._increase_rank = 0

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name, first_sdim=None):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions(first_sdim=first_sdim)

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1] * (len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None] * len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1
        return tuple(shape)

    def __get_slice_size(self, mft):
        """ Calculate the number of frames transfer in each dimension given
            mft. """
        dshape = list(self.data_obj.get_shape())

        if 'fixed_dimensions' in list(self.meta_data.get_dictionary().keys()):
            fixed_dims = self.meta_data.get('fixed_dimensions')
            for d in fixed_dims:
                dshape[d] = 1

        dshape = [dshape[i] for i in self.meta_data.get('slice_dims')]
        size_list = [1] * len(dshape)
        i = 0

        while (mft > 1 and i < len(size_list)):
            size_list[i] = min(dshape[i], mft)
            mft //= np.prod(size_list) if np.prod(size_list) > 1 else 1
            i += 1

        self.meta_data.set('size_list', size_list)
        return size_list

    def set_bytes_per_frame(self):
        """ Return the size of a single frame in bytes. """
        nBytes = self.data_obj.get_itemsize()
        dims = list(self.get_pattern().values())[0]['core_dims']
        frame_shape = [self.data_obj.get_shape()[d] for d in dims]
        b_per_f = np.prod(frame_shape) * nBytes
        return frame_shape, b_per_f

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.meta_data.get('transfer_shape')

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self, first_sdim=None):
        """ Set the slice dimensions in the pluginData meta data dictionary.\
        Reorder pattern slice_dims to ensure first_sdim is at the front.
        """
        pattern = self.data_obj.get_data_patterns()[self.get_pattern_name()]
        slice_dims = pattern['slice_dims']

        if first_sdim:
            slice_dims = list(slice_dims)
            first_sdim = \
                self.data_obj.get_data_dimension_by_axis_label(first_sdim)
            slice_dims.insert(0, slice_dims.pop(slice_dims.index(first_sdim)))
            pattern['slice_dims'] = tuple(slice_dims)

        self.meta_data.set('slice_dims', tuple(slice_dims))

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
            label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(), )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):  # should this function be deleted?
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0], )
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        #self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _set_rank_inc(self, n):
        """ Increase the rank of the array passed to the plugin by n.
        
        :param int n: Rank increment.
        """
        self._increase_rank = n

    def _get_rank_inc(self):
        """ Return the increased rank value
        
        :returns: Rank increment
        :rtype: int
        """
        return self._increase_rank

    def _set_meta_data(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]

        if 'fix_total_frames' in list(self.meta_data.get_dictionary().keys()):
            frames = self.meta_data.get('fix_total_frames')
        else:
            frames = np.prod([shape[d] for d in sdir])

        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        # Fixing f_per_p to be just the first slice dimension for now due to
        # slow performance from HDF5 when not slicing multiple dimensions
        # concurrently
        #f_per_p = np.ceil(frames/n_procs)
        f_per_p = np.ceil(shape[sdir[0]] / n_procs)
        self.meta_data.set('shape', shape)
        self.meta_data.set('sdir', sdir)
        self.meta_data.set('total_frames', frames)
        self.meta_data.set('mpi_procs', n_procs)
        self.meta_data.set('frames_per_process', f_per_p)
        frame_shape, b_per_f = self.set_bytes_per_frame()
        self.meta_data.set('bytes_per_frame', b_per_f)
        self.meta_data.set('bytes_per_process', b_per_f * f_per_p)
        self.meta_data.set('frame_shape', frame_shape)

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        warn_threshold = 0.85
        nprocs = self.meta_data.get('mpi_procs')
        nframes = self.meta_data.get('total_frames')
        temp = (((nframes / mft) / float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            shape = self.meta_data.get('shape')
            sdir = self.meta_data.get('sdir')
            logging.warning(
                'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                'sdir %s, nprocs %s', shape, nframes, sdir, nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in list(self.pad_dict.keys()):
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self,
                          pattern,
                          nFrames,
                          split=None,
                          slice_axis=None,
                          getall=None):
        """ Setup the PluginData object.

        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from
            'single', 'multiple', 'fixed_multiple' or an integer (an integer
            should only ever be passed in exceptional circumstances)
        :keyword str slice_axis: An axis label associated with the fastest
            changing (first) slice dimension.
        :keyword list[pattern, axis_label] getall: A list of two values.  If
        the requested pattern doesn't exist then use all of "axis_label"
        dimension of "pattern" as this is equivalent to one slice of the
        original pattern.
        """

        if pattern not in self.data_obj.get_data_patterns() and getall:
            pattern, nFrames = self.__set_getall_pattern(getall, nFrames)

        # slice_axis is first slice dimension
        self.__set_pattern(pattern, first_sdim=slice_axis)
        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames
        self.max_frames = nFrames
        self.split = split

    def __set_getall_pattern(self, getall, nFrames):
        """ Set framework changes required to get all of a pattern of lower
        rank.
        """
        pattern, slice_axis = getall
        dim = self.data_obj.get_data_dimension_by_axis_label(slice_axis)
        # ensure data remains the same shape when 'getall' dim has length 1
        self._set_no_squeeze()
        if nFrames == 'multiple' or (isinstance(nFrames, int) and nFrames > 1):
            self._set_rank_inc(1)
        nFrames = self.data_obj.get_shape()[dim]
        return pattern, nFrames

    def plugin_data_transfer_setup(self, copy=None, calc=None):
        """ Set up the plugin data transfer frame parameters.
        If copy=pData (another PluginData instance) then copy """
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if not copy and not calc:
            mft, mft_shape, mfp = self._calculate_max_frames()
        elif calc:
            max_mft = calc.meta_data.get('max_frames_transfer')
            max_mfp = calc.meta_data.get('max_frames_process')
            max_nProc = int(np.ceil(max_mft / float(max_mfp)))
            nProc = max_nProc
            mfp = 1 if self.max_frames == 'single' else self.max_frames
            mft = nProc * mfp
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
        elif copy:
            mft = copy._get_max_frames_transfer()
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
            mfp = copy._get_max_frames_process()

        self.__set_max_frames(mft, mft_shape, mfp)

        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()

    def _calculate_max_frames(self):
        nFrames = self.max_frames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, size_list = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('size_list', size_list)
        mfp = td._calc_max_frames_process(nFrames)
        if mft:
            mft_shape = self._set_shape_transfer(list(size_list))
        return mft, mft_shape, mfp

    def __set_max_frames(self, mft, mft_shape, mfp):
        self.meta_data.set('max_frames_transfer', mft)
        self.meta_data.set('transfer_shape', mft_shape)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)
        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and self.max_frames == 'multiple':
            self._set_no_squeeze()

    def _get_plugin_data_size_params(self):
        nBytes = self.data_obj.get_itemsize()
        frame_shape = self.meta_data.get('frame_shape')
        total_frames = self.meta_data.get('total_frames')
        tbytes = nBytes * np.prod(frame_shape) * total_frames

        params = {
            'nBytes': nBytes,
            'frame_shape': frame_shape,
            'total_frames': total_frames,
            'transfer_bytes': tbytes
        }
        return params

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not np.issubdtype(type(nFrames),
                             np.int64) and nFrames not in options:
            e_str = (
                "The value of nFrames is not recognised.  Please choose " +
                "from 'single' and 'multiple' (or an integer in exceptional " +
                "circumstances).")
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit

    def get_current_frame_idx(self):
        """ Returns the index of the frames currently being processed.
        """
        global_index = self._plugin.get_global_frame_index()
        count = self._plugin.get_process_frames_counter()
        mfp = self.meta_data.get('max_frames_process')
        start = global_index[count] * mfp
        index = np.arange(start, start + mfp)
        nFrames = self.get_total_frames()
        index[index >= nFrames] = nFrames - 1
        return index
Exemplo n.º 16
0
 def __init__(self, options):
     self.meta_data = MetaData(options)
     self.__meta_data_setup(options["process_file"])
     self.index = {"in_data": {}, "out_data": {}, "mapping": {}}
     self.nxs_file = None
Exemplo n.º 17
0
class Data(Pattern):
    """
    The Data class dynamically inherits from relevant data structure classes 
    at runtime and holds the data array.
    """

    def __init__(self, name):
        self.meta_data = MetaData()
        super(Data, self).__init__()
        self.name = name
        self.backing_file = None
        self.data = None
    

    def get_transport_data(self, transport): 
        transport_data = "savu.data.transport_data." + transport + "_transport_data"
        return self.import_class(transport_data)

    
    def import_class(self, class_name):
        name = class_name
        mod = __import__(name)
        components = name.split('.')
        for comp in components[1:]:
            mod = getattr(mod, comp)
        temp = name.split('.')[-1]
        module2class = ''.join(x.capitalize() for x in temp.split('_'))
        return getattr(mod, module2class.split('.')[-1])       
        
            
    def __deepcopy__(self, memo):
        return self


    def add_base(self, ExtraBase):
        cls = self.__class__
        self.__class__ = cls.__class__(cls.__name__, (cls, ExtraBase), {})
        ExtraBase().__init__()


    def add_base_classes(self, bases):
        for base in bases:
            self.add_base(base)


    def set_shape(self, shape):
        self.meta_data.set_meta_data('shape', shape)
        
   
    def get_shape(self):
        shape = self.meta_data.get_meta_data('shape')
        try:
            dirs = self.meta_data.get_meta_data("fixed_directions")
            shape = list(shape)
            for ddir in dirs:
                shape[ddir] = 1
            shape = tuple(shape)
        except KeyError:
            pass
        return shape


    def set_dist(self, dist):
        self.meta_data.set_meta_data('dist', dist)
        
    
    def get_dist(self):
        return self.meta_data.get_meta_data('dist')

        
    def set_data_params(self, pattern, chunk_size, **kwargs):
        self.set_current_pattern_name(pattern)
        self.set_nFrames(chunk_size)
Exemplo n.º 18
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """
    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__meta_data_setup(options["process_file"])
        self.experiment_collection = {}
        self.index = {"in_data": {}, "out_data": {}}
        self.initial_datasets = None
        self.plugin = None

    def get(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list._populate_plugin_list(process_file)

    def create_data_object(self, dtype, name):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            data_obj._set_transport_data(self.meta_data.get('transport'))
        return self.index[dtype][name]

    def _experiment_setup(self):
        """ Setup an experiment collection.
        """
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list
        plist = plugin_list.plugin_list

        # load the loader plugins
        self._set_loaders()

        # load the saver plugin and save the plugin list
        self.experiment_collection = {'plugin_dict': [], 'datasets': []}

        self._barrier()
        if self.meta_data.get('process') == \
                len(self.meta_data.get('processes'))-1:
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
        self._barrier()

        n_plugins = plugin_list._get_n_processing_plugins()
        count = 0
        # first run through of the plugin setup methods
        for plugin_dict in plist[n_loaders:n_loaders + n_plugins]:
            data = self.__plugin_setup(plugin_dict, count)
            self.experiment_collection['datasets'].append(data)
            self.experiment_collection['plugin_dict'].append(plugin_dict)
            self._merge_out_data_to_in()
            count += 1
        self._reset_datasets()

    def _set_loaders(self):
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list.plugin_list
        for i in range(n_loaders):
            pu.plugin_loader(self, plugin_list[i])
        self.initial_datasets = copy.deepcopy(self.index['in_data'])

    def _reset_datasets(self):
        self.index['in_data'] = self.initial_datasets

    def __plugin_setup(self, plugin_dict, count):
        """ Determine plugin specific information.
        """
        plugin_id = plugin_dict["id"]
        logging.debug("Loading plugin %s", plugin_id)
        # Run main_setup method
        plugin = pu.plugin_loader(self, plugin_dict)
        plugin._revert_preview(plugin.get_in_datasets())
        # Populate the metadata
        plugin._clean_up()
        data = self.index['out_data'].copy()
        return data

    def _get_experiment_collection(self):
        return self.experiment_collection

    def _set_experiment_for_current_plugin(self, count):
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
        exp_coll = self._get_experiment_collection()
        self.index['out_data'] = exp_coll['datasets'][count]
        if datasets_list:
            self._get_current_and_next_patterns(datasets_list)
        self.meta_data.set('nPlugin', count)

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = []
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list.append({
                'current': current_pattern,
                'next': next_pattern
            })
        self.meta_data.set('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def _set_nxs_filename(self):
        folder = self.meta_data.get('out_path')
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set('nxs_filename', filename)

        if self.meta_data.get('process') == 1:
            if self.meta_data.get('bllog'):
                log_folder_name = self.meta_data.get('bllog')
                log_folder = open(log_folder_name, 'a')
                log_folder.write(os.path.abspath(filename) + '\n')
                log_folder.close()

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _finalise_experiment_for_current_plugin(self):
        finalise = {}
        # populate nexus file with out_dataset information and determine which
        # datasets to remove from the framework.
        finalise['remove'] = []
        finalise['keep'] = []

        for key, data in self.index['out_data'].iteritems():
            if data.remove is True:
                finalise['remove'].append(data)
            else:
                finalise['keep'].append(data)

        # find in datasets to replace
        finalise['replace'] = []
        for out_name in self.index['out_data'].keys():
            if out_name in self.index['in_data'].keys():
                finalise['replace'].append(self.index['in_data'][out_name])

        return finalise

    def _reorganise_datasets(self, finalise):
        # unreplicate replicated in_datasets
        self.__unreplicate_data()

        # delete all datasets for removal
        for data in finalise['remove']:
            del self.index["out_data"][data.data_info.get('name')]

        # Add remaining output datasets to input datasets
        for name, data in self.index['out_data'].iteritems():
            data.get_preview().set_preview([])
            self.index["in_data"][name] = copy.deepcopy(data)
        self.index['out_data'] = {}

    def __unreplicate_data(self):
        in_data_list = self.index['in_data']
        from savu.data.data_structures.data_types.replicate import Replicate
        for in_data in in_data_list.values():
            if isinstance(in_data.data, Replicate):
                in_data.data = in_data.data.reset()

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            if 'itr_clone' not in key:
                data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get('mpi') is True:
            logging.debug("About to hit a _barrier %s", comm_dict)
            comm_dict['comm'].barrier()
            logging.debug("Past the _barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__set_system_params()
        self.checkpoint = Checkpointing(self)
        self.__meta_data_setup(options["process_file"])
        self.experiment_collection = {}
        self.index = {"in_data": {}, "out_data": {}}
        self.initial_datasets = None
        self.plugin = None
        self._transport = None
        self._barrier_count = 0

    def get(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()
        try:
            rtype = self.meta_data.get('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            template = self.meta_data.get('template')
            self.meta_data.plugin_list._populate_plugin_list(process_file,
                                                             template=template)

    def create_data_object(self, dtype, name, override=True):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        if name not in self.index[dtype].keys() or override:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            data_obj._set_transport_data(self.meta_data.get('transport'))
        return self.index[dtype][name]

    def _experiment_setup(self, transport):
        """ Setup an experiment collection.
        """
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list
        plist = plugin_list.plugin_list
        self.__set_transport(transport)
        # load the loader plugins
        self._set_loaders()
        # load the saver plugin and save the plugin list
        self.experiment_collection = {'plugin_dict': [],
                                      'datasets': []}
        self._barrier()
        self._check_checkpoint()
        self._barrier()
        checkpoint = self.meta_data.get('checkpoint')
        if self.meta_data.get('process') == \
                len(self.meta_data.get('processes'))-1 and not checkpoint:
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
            # links the input data to the nexus file
            self._add_input_data_to_nxs_file(transport)
        # Barrier 13
        self._barrier()

        n_plugins = plugin_list._get_n_processing_plugins()
        count = 0
        # first run through of the plugin setup methods
        for plugin_dict in plist[n_loaders:n_loaders+n_plugins]:
            data = self.__plugin_setup(plugin_dict, count)
            self.experiment_collection['datasets'].append(data)
            self.experiment_collection['plugin_dict'].append(plugin_dict)
            self._merge_out_data_to_in()
            count += 1
        self._reset_datasets()

    def __set_transport(self, transport):
        self._transport = transport

    def _get_transport(self):
        return self._transport

    def __set_system_params(self):
        sys_file = self.meta_data.get('system_params')
        import sys
        if sys_file is None:
            # look in conda environment to see which version is being used
            savu_path = sys.modules['savu'].__path__[0]
            sys_files = os.path.join(
                    os.path.dirname(savu_path), 'system_files')
            subdirs = os.listdir(sys_files)
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
            fname = 'system_parameters.yml'
            sys_file = os.path.join(sys_files, sys_folder, fname)
        logging.info('Using the system parameters file: %s', sys_file)
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))

    def _check_checkpoint(self):
        # if checkpointing has been set but the nxs file doesn't contain an
        # entry then remove checkpointing (as the previous run didn't get far
        # enough to require it).
        if self.meta_data.get('checkpoint'):
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
                if 'entry' not in f:
                    self.meta_data.set('checkpoint', None)

    def _set_loaders(self):
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list.plugin_list
        for i in range(n_loaders):
            pu.plugin_loader(self, plugin_list[i])
        self.initial_datasets = copy.deepcopy(self.index['in_data'])

    def _add_input_data_to_nxs_file(self, transport):
        # save the loaded data to file
        h5 = Hdf5Utils(self)
        for name, data in self.index['in_data'].iteritems():
            self.meta_data.set(['link_type', name], 'input_data')
            self.meta_data.set(['group_name', name], name)
            self.meta_data.set(['filename', name], data.backing_file)
            transport._populate_nexus_file(data)
            h5._link_datafile_to_nexus_file(data)

    def _reset_datasets(self):
        self.index['in_data'] = self.initial_datasets

    def __plugin_setup(self, plugin_dict, count):
        """ Determine plugin specific information.
        """
        plugin_id = plugin_dict["id"]
        logging.debug("Loading plugin %s", plugin_id)
        # Run main_setup method
        plugin = pu.plugin_loader(self, plugin_dict)
        plugin._revert_preview(plugin.get_in_datasets())
        # Populate the metadata
        plugin._clean_up()
        data = self.index['out_data'].copy()
        return data

    def _get_experiment_collection(self):
        return self.experiment_collection

    def _set_experiment_for_current_plugin(self, count):
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
        exp_coll = self._get_experiment_collection()
        self.index['out_data'] = exp_coll['datasets'][count]
        if datasets_list:
            self._get_current_and_next_patterns(datasets_list)
        self.meta_data.set('nPlugin', count)

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = {}
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list[current_name] = \
                {'current': current_pattern, 'next': next_pattern}
        self.meta_data.set('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def _set_nxs_filename(self):
        folder = self.meta_data.get('out_path')
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set('nxs_filename', filename)

        if self.meta_data.get('process') == 1:
            if self.meta_data.get('bllog'):
                log_folder_name = self.meta_data.get('bllog')
                log_folder = open(log_folder_name, 'a')
                log_folder.write(os.path.abspath(filename) + '\n')
                log_folder.close()

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _finalise_experiment_for_current_plugin(self):
        finalise = {}
        # populate nexus file with out_dataset information and determine which
        # datasets to remove from the framework.
        finalise['remove'] = []
        finalise['keep'] = []

        for key, data in self.index['out_data'].iteritems():
            if data.remove is True:
                finalise['remove'].append(data)
            else:
                finalise['keep'].append(data)

        # find in datasets to replace
        finalise['replace'] = []
        for out_name in self.index['out_data'].keys():
            if out_name in self.index['in_data'].keys():
                finalise['replace'].append(self.index['in_data'][out_name])

        return finalise

    def _reorganise_datasets(self, finalise):
        # unreplicate replicated in_datasets
        self.__unreplicate_data()

        # delete all datasets for removal
        for data in finalise['remove']:
            del self.index["out_data"][data.data_info.get('name')]

        # Add remaining output datasets to input datasets
        for name, data in self.index['out_data'].iteritems():
            data.get_preview().set_preview([])
            self.index["in_data"][name] = copy.deepcopy(data)
        self.index['out_data'] = {}

    def __unreplicate_data(self):
        in_data_list = self.index['in_data']
        from savu.data.data_structures.data_types.replicate import Replicate
        for in_data in in_data_list.values():
            if isinstance(in_data.data, Replicate):
                in_data.data = in_data.data.reset()

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            if 'itr_clone' not in key:
                data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
        comm_dict = {'comm': communicator}
        if self.meta_data.get('mpi') is True:
            logging.debug("Barrier %d: %d processes expected: %s",
                          self._barrier_count, communicator.size, msg)
            comm_dict['comm'].barrier()
        self._barrier_count += 1

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
Exemplo n.º 20
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.shape_transfer = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1] * (len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None] * len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1
        self.shape_transfer = tuple(shape)

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.shape_transfer

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self):
        """ Set the slice dimensions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dims']
        self.meta_data.set('slice_dims', slice_dirs)

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
            label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(), )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0], )
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _get_max_frames_parameters(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]

        frames = np.prod([shape[d] for d in sdir])
        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        f_per_p = np.ceil(frames / n_procs)
        params_dict = {
            'shape': shape,
            'sdir': sdir,
            'total_frames': frames,
            'mpi_procs': n_procs,
            'frames_per_process': f_per_p
        }
        return params_dict

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        self.params = self._get_max_frames_parameters()
        warn_threshold = 0.85
        nprocs = self.params['mpi_procs']
        nframes = self.params['total_frames']
        temp = (((nframes / mft) / float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            logging.warn(
                'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                'sdir %s, nprocs %s', self.params['shape'], nframes,
                self.params['sdir'], nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in self.pad_dict.keys():
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self, pattern, nFrames, split=None):
        """ Setup the PluginData object.

        # add more information into here via a decorator!
        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from\
            'single', 'multiple', 'fixed_multiple' or an integer (an integer \
            should only ever be passed in exceptional circumstances)
        """
        self.__set_pattern(pattern)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames

        self.__set_max_frames(nFrames)
        mft = self.meta_data.get('max_frames_transfer')
        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()
        self.split = split

    def __set_max_frames(self, nFrames):
        self.max_frames = nFrames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, mft_shape = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('max_frames_transfer', mft)
        if mft:
            self._set_shape_transfer(mft_shape)
        mfp = td._calc_max_frames_process(nFrames)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)

        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and nFrames == 'multiple':
            self._set_no_squeeze()

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not isinstance(nFrames, int) and nFrames not in options:
            e_str = "The value of nFrames is not recognised.  Please choose "
            "from 'single' and 'multiple' (or an integer in exceptional "
            "circumstances)."
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit
Exemplo n.º 21
0
class Checkpointing(object):
    """ Contains all checkpointing associated methods.
    """

    def __init__(self, exp, name='Checkpointing'):
        self._exp = exp
        self._h5 = Hdf5Utils(self._exp)
        self._filename = '_checkpoint.h5'
        self._file = None
        self._start_values = (0, 0, 0)
        self._completed_plugins = 0
        self._level = None
        self._proc_idx = 0
        self._trans_idx = 0
        self._comm = None
        self._timer = None
        self._set_timer()
        self.meta_data = MetaData()

    def _initialise(self, comm):
        """ Create a new checkpoint file """
        with self._h5._open_backing_h5(self._file, 'a', mpi=False) as f:
            self._create_dataset(f, 'transfer_idx', 0)
            self._create_dataset(f, 'process_idx', 0)
            self._create_dataset(
                    f, 'completed_plugins', self._completed_plugins)
        msg = "%s initialise." % self.__class__.__name__
        self._exp._barrier(communicator=comm, msg=msg)

    def _create_dataset(self, f, name, val):
        if name in f.keys():
            f[name][...] = val
        else:
            f.create_dataset(name, data=val, dtype=np.int16)

    def __set_checkpoint_info(self):
        mData = self._exp.meta_data.get
        proc = 'process%d' % mData('process')
        self._folder = os.path.join(mData('out_path'), 'checkpoint')
        self._file = os.path.join(self._folder, proc + self._filename)

        if self._exp.meta_data.get('process') == 0:
            if not os.path.exists(self._folder):
                os.makedirs(os.path.join(self._folder))
        self._exp._barrier(msg='Creating checkpoint folder.')

    def _set_checkpoint_info_from_file(self, level):
        self._level = level
        self.__set_checkpoint_info()
        self.__does_file_exist(self._file, level)

        with self._h5._open_backing_h5(self._file, 'r', mpi=False) as f:
            self._completed_plugins = \
                f['completed_plugins'][...] if 'completed_plugins' in f else 0
            self._proc_idx = f['process_idx'][...] if 'process_idx' in f and \
                level == 'subplugin' else 0
            self._trans_idx = f['transfer_idx'][...] if 'transfer_idx' in f \
                and level == 'subplugin' else 0
            # for testing
            self.__set_start_values(
                    self._completed_plugins, self._trans_idx, self._proc_idx)
            self.__set_dataset_metadata(f, 'in_data')
            self.__set_dataset_metadata(f, 'out_data')

        self.__load_data()
        msg = "%s _set_checkpoint_info_from_file" % self.__class__.__name__
        self._exp._barrier(msg=msg)

    def __does_file_exist(self, thefile, level):
        if not os.path.exists(thefile):
            if level == 'plugin':
                proc0 = os.path.join(self._folder, 'process0' + self._filename)
                self.__does_file_exist(proc0, None)
                copyfile(proc0, self._file)
                return
            raise Exception("No checkpoint file found.")

    def __set_dataset_metadata(self, f, dtype):
        self.meta_data.set(dtype, {})
        if dtype not in f.keys():
            return
        entry = f[dtype]
        for name, gp in entry.iteritems():
            data_entry = gp.require_group('meta_data')
            for key, value in data_entry.iteritems():
                self.meta_data.set([dtype, name, key], value[key][...])

    def _get_dataset_metadata(self, dtype, name):
        return self._data_meta_data(dtype)

    def set_completed_plugins(self, n):
        self._completed_plugins = n

    def __load_data(self):
        self._exp.meta_data.set('checkpoint_loader', True)
        temp = self._exp.meta_data.get('data_file')
        nxsfile = self._exp.meta_data.get('nxs_filename')
        self._exp.meta_data.set('data_file', nxsfile)
        pid = 'savu.plugins.loaders.savu_nexus_loader'
        pu.plugin_loader(self._exp, {'id': pid, 'data': {}})
        self._exp.meta_data.delete('checkpoint_loader')
        self._exp.meta_data.set('data_file', temp)

    def output_plugin_checkpoint(self):
        self._completed_plugins += 1
        self.__write_plugin_checkpoint()
        self._reset_indices()

    def get_checkpoint_plugin(self):
        checkpoint_flag = self._exp.meta_data.get('checkpoint')
        if not checkpoint_flag:
            self.__set_checkpoint_info()
            self._initialise(MPI.COMM_WORLD)
        else:
            self._set_checkpoint_info_from_file(checkpoint_flag)
        return self._completed_plugins

    def is_time_to_checkpoint(self, transport, ti, pi):
        interval = self._exp.meta_data.get(
                ['system_params', 'checkpoint_interval'])
        end = time.time()
        if (end - self._get_timer()) > interval:
            self.__write_subplugin_checkpoint(ti, pi)
            self._set_timer()
            transport._transport_checkpoint()
            return transport._transport_kill_signal()
        return False

    def _get_checkpoint_params(self):
        return self._level, self._completed_plugins

    def __write_subplugin_checkpoint(self, ti, pi):
        with self._h5._open_backing_h5(self._file, 'a', mpi=False) as f:
            f['transfer_idx'][...] = ti
            f['process_idx'][...] = pi

    def __write_plugin_checkpoint(self):
        with self._h5._open_backing_h5(self._file, 'a', mpi=False) as f:
            f['completed_plugins'][...] = self._completed_plugins
            f['transfer_idx'][...] = 0
            f['process_idx'][...] = 0

    def _reset_indices(self):
        self._trans_idx = 0
        self._proc_idx = 0

    def get_trans_idx(self):
        return self._trans_idx

    def get_proc_idx(self):
        return self._proc_idx

    def get_level(self):
        return self._level

    def _set_timer(self):
        self._timer = time.time()

    def _get_timer(self):
        return self._timer

    def __set_start_values(self, v1, v2, v3):
        self._start_values = (copy.copy(v1), copy.copy(v2), copy.copy(v3))

    def get_start_values(self):
        return self._start_values
Exemplo n.º 22
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the 
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}}

    def meta_data_setup(self, process_file):
        self.meta_data.load_experiment_collection()
        self.meta_data.plugin_list = PluginList()
        self.meta_data.plugin_list.populate_plugin_list(process_file)

    def create_data_object(self, dtype, name, bases=[]):
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name)
            data_obj = self.index[dtype][name]
            bases.append(data_obj.get_transport_data(self.meta_data.get_meta_data("transport")))
            data_obj.add_base_classes(bases)
        return self.index[dtype][name]


    def set_nxs_filename(self):
        name = self.index["in_data"].keys()[0]
        filename = os.path.basename(self.index["in_data"][name].backing_file.filename)
        filename = os.path.splitext(filename)[0]
        filename = os.path.join(self.meta_data.get_meta_data("out_path"),
                                "%s_processed_%s.nxs" % (filename,
                                time.strftime("%Y%m%d%H%M%S")))
        self.meta_data.set_meta_data("nxs_filename", filename)

    def clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def clear_out_data_objects(self):
        self.index["out_data"] = {}

    def set_out_data_to_in(self):
        self.index["in_data"] = self.index["out_data"]
        self.index["out_data"] = {}

    def barrier(self):
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a barrier")
            MPI.COMM_WORLD.Barrier()
            logging.debug("Past the barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
Exemplo n.º 23
0
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """

    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.tomo_raw_obj = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set_meta_data('name', name)
        self.data_info.set_meta_data('data_patterns', {})
        self.data_info.set_meta_data('shape', None)
        self.data_info.set_meta_data('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _get_transport_data(self):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport = self.exp.meta_data.get_meta_data("transport")
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        return cu.import_class(transport_data)

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get_meta_data('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get_meta_data('data_patterns')

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set_meta_data('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get_meta_data('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get_meta_data("nDims")
        shape = self.data_info.get_meta_data('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get_meta_data('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set_meta_data(['data_patterns', dtype, args],
                                             kwargs[args])
            
            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._add_extra_dims_to_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get_meta_data("nDims"):
                    actualDims = self.data_info.get_meta_data('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified."
                               % (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set_meta_data('nDims', nDims)
        else:
            raise Exception("The data pattern '%s'does not exist. Please "
                            "choose from the following list: \n'%s'",
                            dtype, str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dir'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dir'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set_meta_data('nDims', len(args))
        axis_labels = []
        for arg in args:
            try:
                axis = arg.split('.')
                axis_labels.append({axis[0]: axis[1]})
            except:
                # data arrives here, but that may be an error
                pass
        self.data_info.set_meta_data('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get_meta_data('axis_labels')

    def find_axis_label_dimension(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get_meta_data('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self.non_negative_directions(ddirs, nDims)

    def non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dir']
        d2 = patterns[pname]['slice_dir']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'],
                                     list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get_meta_data('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = []
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list.append({'current': current_pattern,
                                  'next': next_pattern})
        self.exp.meta_data.set_meta_data('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def get_slice_directions(self):
        """ Get pattern slice_dir of pattern currently associated with the
        dataset (if any).

        :returns: the slicing dimensions.
        :rtype: tuple(int)
        """
        return self._get_plugin_data().get_slice_directions()
Exemplo n.º 24
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """

    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1]*(len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None]*len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1            
        return tuple(shape)

    def __get_slice_size(self, mft):
        """ Calculate the number of frames transfer in each dimension given
            mft. """
        dshape = list(self.data_obj.get_shape())

        if 'fixed_dimensions' in self.meta_data.get_dictionary().keys():
            fixed_dims = self.meta_data.get('fixed_dimensions')
            for d in fixed_dims:
                dshape[d] = 1

        dshape = [dshape[i] for i in self.meta_data.get('slice_dims')]
        size_list = [1]*len(dshape)
        i = 0
        
        while(mft > 1):
            size_list[i] = min(dshape[i], mft)
            mft -= np.prod(size_list) if np.prod(size_list) > 1 else 0
            i += 1

        self.meta_data.set('size_list', size_list)
        return size_list

    def set_bytes_per_frame(self):
        """ Return the size of a single frame in bytes. """
        nBytes = self.data_obj.get_itemsize()
        dims = self.get_pattern().values()[0]['core_dims']
        frame_shape = [self.data_obj.get_shape()[d] for d in dims]
        b_per_f = np.prod(frame_shape)*nBytes
        return frame_shape, b_per_f

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.meta_data.get('transfer_shape')

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self):
        """ Set the slice dimensions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dims']
        self.meta_data.set('slice_dims', slice_dirs)

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
                label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(),)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0],)
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        #self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _set_meta_data(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]
        
        if 'fix_total_frames' in self.meta_data.get_dictionary().keys():
            frames = self.meta_data.get('fix_total_frames')
        else:
            frames = np.prod([shape[d] for d in sdir])

        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        f_per_p = np.ceil(frames/n_procs)
        self.meta_data.set('shape', shape)
        self.meta_data.set('sdir', sdir)
        self.meta_data.set('total_frames', frames)
        self.meta_data.set('mpi_procs', n_procs)
        self.meta_data.set('frames_per_process', f_per_p)
        frame_shape, b_per_f = self.set_bytes_per_frame()
        self.meta_data.set('bytes_per_frame', b_per_f)
        self.meta_data.set('bytes_per_process', b_per_f*f_per_p)
        self.meta_data.set('frame_shape', frame_shape)

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        warn_threshold = 0.85
        nprocs = self.meta_data.get('mpi_procs')
        nframes = self.meta_data.get('total_frames')
        temp = (((nframes/mft)/float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            shape = self.meta_data.get('shape')
            sdir = self.meta_data.get('sdir')
            logging.warn('UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                         'sdir %s, nprocs %s', shape, nframes, sdir, nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in self.pad_dict.keys():
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self, pattern, nFrames, split=None):
        """ Setup the PluginData object.

        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from\
            'single', 'multiple', 'fixed_multiple' or an integer (an integer \
            should only ever be passed in exceptional circumstances)
        """
        self.__set_pattern(pattern)
        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames
        self.max_frames = nFrames
        self.split = split

    def plugin_data_transfer_setup(self, copy=None, calc=None):
        """ Set up the plugin data transfer frame parameters.
        If copy=pData (another PluginData instance) then copy """
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if not copy and not calc:
            mft, mft_shape, mfp = self._calculate_max_frames()
        elif calc:
            max_mft = calc.meta_data.get('max_frames_transfer')             
            max_mfp = calc.meta_data.get('max_frames_process')
            max_nProc = int(np.ceil(max_mft/float(max_mfp)))
            nProc = max_nProc
            mfp = 1 if self.max_frames == 'single' else self.max_frames
            mft = nProc*mfp
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
        elif copy:
            mft = copy._get_max_frames_transfer()
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
            mfp = copy._get_max_frames_process()

        self.__set_max_frames(mft, mft_shape, mfp)

        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()

    def _calculate_max_frames(self):
        nFrames = self.max_frames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, size_list = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('size_list', size_list)
        mfp = td._calc_max_frames_process(nFrames)
        if mft:
            mft_shape = self._set_shape_transfer(list(size_list))
        return mft, mft_shape, mfp

    def __set_max_frames(self, mft, mft_shape, mfp):
        self.meta_data.set('max_frames_transfer', mft)
        self.meta_data.set('transfer_shape', mft_shape)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)
        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and self.max_frames == 'multiple':
            self._set_no_squeeze()

    def _get_plugin_data_size_params(self):
        nBytes = self.data_obj.get_itemsize()
        frame_shape = self.meta_data.get('frame_shape')
        total_frames = self.meta_data.get('total_frames')
        tbytes = nBytes*np.prod(frame_shape)*total_frames
        
        params = {'nBytes': nBytes, 'frame_shape': frame_shape,
                  'total_frames': total_frames, 'transfer_bytes': tbytes}
        return params

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not isinstance(nFrames, int) and nFrames not in options:
            e_str = "The value of nFrames is not recognised.  Please choose "
            "from 'single' and 'multiple' (or an integer in exceptional "
            "circumstances)."
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit

    def get_current_frame_idx(self):
        """ Returns the index of the frames currently being processed.
        """
        global_index = self._plugin.get_global_frame_index()
        count = self._plugin.get_process_frames_counter()
        mfp = self.meta_data.get('max_frames_process')
        start = global_index[count]*mfp
        index = np.arange(start, start + mfp)
        nFrames = self.get_total_frames()
        index[index >= nFrames] = nFrames - 1
        return index
Exemplo n.º 25
0
    def _output_template(self, fname, process_fname):
        plist = self.plist.plugin_list
        index = [i for i in range(len(plist)) if plist[i]['active']]

        local_dict = MetaData(ordered=True)
        global_dict = MetaData(ordered=True)
        local_dict.set(['process_list'], os.path.abspath(process_fname))

        for i in index:
            params = self.__get_template_params(plist[i]['data'], [])
            name = plist[i]['name']
            for p in params:
                ptype, isyaml, key, value = p
                if isyaml:
                    data_name = isyaml if ptype == 'local' else 'all'
                    local_dict.set([i+1, name, data_name, key], value)
                elif ptype == 'local':
                    local_dict.set([i+1, name, key], value)
                else:
                    global_dict.set(['all', name, key], value)

        with open(fname, 'w') as stream:
            local_dict.get_dictionary().update(global_dict.get_dictionary())
            yu.dump_yaml(local_dict.get_dictionary(), stream)
Exemplo n.º 26
0
 def __init__(self, name):
     self.meta_data = MetaData()
     super(Data, self).__init__()
     self.name = name
     self.backing_file = None
     self.data = None
Exemplo n.º 27
0
def get_data_dict(paths):
    all_data_dict = {}
    yu = YamlConverter()
    for p in paths:
        yu.parameters['yaml_file'] = p
        all_data_dict.update(yu.setup(template=True))
    return all_data_dict


def get_string_result(key, data_dict, mData, res=[]):
    if '*' in key:
        for data in list(data_dict.keys()):
            res = get_string_result(
                    key.replace('*', data), data_dict, mData, res)
    else:
        res.append(mData.get(key.split('.')))
    return res


if __name__ == '__main__':
    arg = arg_parser()
    # find all Hdf5TemplateLoaders
    plugin = 'Hdf5TemplateLoader'
    param = 'yaml_file'
    paths = check_yaml_path(get_parameter_value(arg.nxsfile, plugin, param))
    data_dict = get_data_dict(paths)
    res = get_string_result(arg.key, data_dict, MetaData(data_dict))
    print(','.join(res))

    
Exemplo n.º 28
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """

    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.selected_data = False
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin

    def get_total_frames(self):
        """ Get the total number of frames to process.

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information int the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.__set_slice_directions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.__set_core_shape(tuple(shape))
        if self._get_frame_chunk() > 1:
            shape.insert(slice_idx, self._get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        """ Get the shape of the plugin data to be processed each time.
        """
        return self.shape

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_directions(self):
        """ Set the slice directions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        """ Get the slice directions (slice_dir) of the dataset.
        """
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self._get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0],)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        """ Get the core data directions

        :returns: value associated with pattern key ``core_dir``
        :rtype: tuple
        """
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.__set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_directions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def _set_frame_chunk(self, nFrames):
        """ Set the number of frames to process at a time """
        self.meta_data.set_meta_data("nFrames", nFrames)

    def _get_frame_chunk(self):
        """ Get the number of frames to be processes at a time.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get_meta_data("nFrames")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self._set_frame_chunk(gcd(frame_chunk, chunk))
        return self.meta_data.get_meta_data("nFrames")

    def plugin_data_setup(self, pattern_name, chunk):
        """ Setup the PluginData object.

        :param str pattern_name: A pattern name
        :param int chunk: Number of frames to process at a time
        """
        self.__set_pattern(pattern_name)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')
        if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk):
            self._plugin.chunk = True
        self._set_frame_chunk(chunk)
        self.__set_shape()
Exemplo n.º 29
0
class Checkpointing(object):
    """ Contains all checkpointing associated methods.
    """
    def __init__(self, exp, name='Checkpointing'):
        self._exp = exp
        self._h5 = Hdf5Utils(self._exp)
        self._filename = '_checkpoint.h5'
        self._file = None
        self._start_values = (0, 0, 0)
        self._completed_plugins = 0
        self._level = None
        self._proc_idx = 0
        self._trans_idx = 0
        self._comm = None
        self._timer = None
        self._set_timer()
        self.meta_data = MetaData()

    def _initialise(self, comm):
        """ Create a new checkpoint file """
        with self._h5._open_backing_h5(self._file, 'a', mpi=False) as f:
            self._create_dataset(f, 'transfer_idx', 0)
            self._create_dataset(f, 'process_idx', 0)
            self._create_dataset(f, 'completed_plugins',
                                 self._completed_plugins)
        msg = "%s initialise." % self.__class__.__name__
        self._exp._barrier(communicator=comm, msg=msg)

    def _create_dataset(self, f, name, val):
        if name in list(f.keys()):
            f[name][...] = val
        else:
            f.create_dataset(name, data=val, dtype=np.int16)

    def __set_checkpoint_info(self):
        mData = self._exp.meta_data.get
        proc = 'process%d' % mData('process')
        self._folder = os.path.join(mData('out_path'), 'checkpoint')
        self._file = os.path.join(self._folder, proc + self._filename)

        if self._exp.meta_data.get('process') == 0:
            if not os.path.exists(self._folder):
                os.makedirs(os.path.join(self._folder))
        self._exp._barrier(msg='Creating checkpoint folder.')

    def _set_checkpoint_info_from_file(self, level):
        self._level = level
        self.__set_checkpoint_info()
        self.__does_file_exist(self._file, level)

        with self._h5._open_backing_h5(self._file, 'r', mpi=False) as f:
            self._completed_plugins = \
                f['completed_plugins'][...] if 'completed_plugins' in f else 0
            self._proc_idx = f['process_idx'][...] if 'process_idx' in f and \
                level == 'subplugin' else 0
            self._trans_idx = f['transfer_idx'][...] if 'transfer_idx' in f \
                and level == 'subplugin' else 0
            # for testing
            self.__set_start_values(self._completed_plugins, self._trans_idx,
                                    self._proc_idx)
            self.__set_dataset_metadata(f, 'in_data')
            self.__set_dataset_metadata(f, 'out_data')

        self.__load_data()
        msg = "%s _set_checkpoint_info_from_file" % self.__class__.__name__
        self._exp._barrier(msg=msg)

    def __does_file_exist(self, thefile, level):
        if not os.path.exists(thefile):
            if level == 'plugin':
                proc0 = os.path.join(self._folder, 'process0' + self._filename)
                self.__does_file_exist(proc0, None)
                copyfile(proc0, self._file)
                return
            raise Exception("No checkpoint file found.")

    def __set_dataset_metadata(self, f, dtype):
        self.meta_data.set(dtype, {})
        if dtype not in list(f.keys()):
            return
        entry = f[dtype]
        for name, gp in entry.items():
            data_entry = gp.require_group('meta_data')
            for key, value in data_entry.items():
                self.meta_data.set([dtype, name, key], value[key][...])

    def _get_dataset_metadata(self, dtype, name):
        return self._data_meta_data(dtype)

    def set_completed_plugins(self, n):
        self._completed_plugins = n

    def __load_data(self):
        self._exp.meta_data.set('checkpoint_loader', True)
        temp = self._exp.meta_data.get('data_file')
        nxsfile = self._exp.meta_data.get('nxs_filename')
        self._exp.meta_data.set('data_file', nxsfile)
        pid = 'savu.plugins.loaders.savu_nexus_loader'
        pu.plugin_loader(self._exp, {'id': pid, 'data': {}})
        self._exp.meta_data.delete('checkpoint_loader')
        self._exp.meta_data.set('data_file', temp)

    def output_plugin_checkpoint(self):
        self._completed_plugins += 1
        self.__write_plugin_checkpoint()
        self._reset_indices()

    def get_checkpoint_plugin(self):
        checkpoint_flag = self._exp.meta_data.get('checkpoint')
        if not checkpoint_flag:
            self.__set_checkpoint_info()
            self._initialise(MPI.COMM_WORLD)
        else:
            self._set_checkpoint_info_from_file(checkpoint_flag)
        return self._completed_plugins

    def is_time_to_checkpoint(self, transport, ti, pi):
        interval = self._exp.meta_data.get(
            ['system_params', 'checkpoint_interval'])
        end = time.time()
        if (end - self._get_timer()) > interval:
            self.__write_subplugin_checkpoint(ti, pi)
            self._set_timer()
            transport._transport_checkpoint()
            return transport._transport_kill_signal()
        return False

    def _get_checkpoint_params(self):
        return self._level, self._completed_plugins

    def __write_subplugin_checkpoint(self, ti, pi):
        with self._h5._open_backing_h5(self._file, 'a', mpi=False) as f:
            f['transfer_idx'][...] = ti
            f['process_idx'][...] = pi

    def __write_plugin_checkpoint(self):
        with self._h5._open_backing_h5(self._file, 'a', mpi=False) as f:
            f['completed_plugins'][...] = self._completed_plugins
            f['transfer_idx'][...] = 0
            f['process_idx'][...] = 0

    def _reset_indices(self):
        self._trans_idx = 0
        self._proc_idx = 0

    def get_trans_idx(self):
        return self._trans_idx

    def get_proc_idx(self):
        return self._proc_idx

    def get_level(self):
        return self._level

    def _set_timer(self):
        self._timer = time.time()

    def _get_timer(self):
        return self._timer

    def __set_start_values(self, v1, v2, v3):
        self._start_values = (copy.copy(v1), copy.copy(v2), copy.copy(v3))

    def get_start_values(self):
        return self._start_values
Exemplo n.º 30
0
 def __init__(self, options):
     self.index={"in_data": {}, "out_data": {}, "mapping": {}}
     self.meta_data = MetaData(get_options())
     self.nxs_file = None
Exemplo n.º 31
0
class PluginData(object):

    def __init__(self, data_obj):
        self.data_obj = data_obj
        self.data_obj.set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.selected_data = False
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []

    def get_total_frames(self):
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def set_pattern(self, name):
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.set_slice_directions()

    def get_pattern_name(self):
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def set_shape(self):
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.set_core_shape(tuple(shape))
        if self.get_frame_chunk() > 1:
            shape.insert(slice_idx, self.get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        return self.shape

    def set_core_shape(self, shape):
        self.core_shape = shape

    def get_core_shape(self):
        return self.core_shape

    def check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def set_slice_directions(self):
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self.get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0],)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one. The input param is a tuple stating the
        desired order of slicing directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.set_shape()

    def get_fixed_directions(self):
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def set_frame_chunk(self, nFrames):
        # number of frames to process at a time
        self.meta_data.set_meta_data("nFrames", nFrames)

    def get_frame_chunk(self):
        return self.meta_data.get_meta_data("nFrames")

    def get_index(self, indices):
        shape = self.get_shape()
        nDims = len(shape)
        name = self.get_current_pattern_name()
        ddirs = self.get_data_patterns()
        core_dir = ddirs[name]["core_dir"]
        slice_dir = ddirs[name]["slice_dir"]

        self.check_dimensions(indices, core_dir, slice_dir, nDims)

        index = [slice(None)]*nDims
        count = 0
        for tdir in slice_dir:
            index[tdir] = slice(indices[count], indices[count]+1, 1)
            count += 1

        return tuple(index)

    def plugin_data_setup(self, pattern_name, chunk):
        self.set_pattern(pattern_name)
        self.set_frame_chunk(chunk)
        self.set_shape()

    def set_temp_pad_dict(self, pad_dict):
        self.meta_data.set_meta_data('temp_pad_dict', pad_dict)

    def get_temp_pad_dict(self):
        if 'temp_pad_dict' in self.meta_data.get_dictionary().keys():
            return self.meta_data.get_dictionary()['temp_pad_dict']

    def delete_temp_pad_dict(self):
        del self.meta_data.get_dictionary()['temp_pad_dict']
Exemplo n.º 32
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """
    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}, "mapping": {}}
        self.nxs_file = None

    def get_meta_data(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get_meta_data(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get_meta_data('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get_meta_data('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list._populate_plugin_list(process_file)

    def create_data_object(self, dtype, name):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        bases = []
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            bases.append(data_obj._get_transport_data())
            cu.add_base_classes(data_obj, bases)
        return self.index[dtype][name]

    def _set_nxs_filename(self):
        folder = self.meta_data.get_meta_data('out_path')
        fname = os.path.basename(folder.split('_')[-1]) + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set_meta_data("nxs_filename", filename)

        if self.meta_data.get_meta_data("mpi") is True:
            self.nxs_file = h5py.File(filename,
                                      'w',
                                      driver='mpio',
                                      comm=MPI.COMM_WORLD)
        else:
            self.nxs_file = h5py.File(filename, 'w')

    def __remove_dataset(self, data_obj):
        data_obj._close_file()
        del self.index["out_data"][data_obj.data_info.get_meta_data('name')]

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                if key in self.index['in_data'].keys():
                    data.meta_data._set_dictionary(
                        self.index['in_data'][key].meta_data.get_dictionary())
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _reorganise_datasets(self, out_data_objs, link_type):
        out_data_list = self.index["out_data"]
        self.__close_unwanted_files(out_data_list)
        self.__remove_unwanted_data(out_data_objs)

        self._barrier()
        self.__copy_out_data_to_in_data(link_type)

        self._barrier()
        self.index['out_data'] = {}

    def __remove_unwanted_data(self, out_data_objs):
        for out_objs in out_data_objs:
            if out_objs.remove is True:
                self.__remove_dataset(out_objs)

    def __close_unwanted_files(self, out_data_list):
        for out_objs in out_data_list:
            if out_objs in self.index["in_data"].keys():
                self.index["in_data"][out_objs]._close_file()

    def __copy_out_data_to_in_data(self, link_type):
        for key in self.index["out_data"]:
            output = self.index["out_data"][key]
            output._save_data(link_type)
            self.index["in_data"][key] = copy.deepcopy(output)

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a _barrier %s", comm_dict)
            comm_dict['comm'].barrier()
            logging.debug("Past the _barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
Exemplo n.º 33
0
class Data(object):
    """
    The Data class dynamically inherits from relevant data structure classes
    at runtime and holds the data array.
    """

    def __init__(self, name, exp):
        self.meta_data = MetaData()
        self.pattern_list = self.set_available_pattern_list()
        self.data_info = MetaData()
        self.initialise_data_info(name)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.tomo_raw_obj = None
        self.data_mapping = None
        self.variable_length_flag = False
        self.dtype = None
        self.remove = False
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.mapping = None
        self.map_dim = []
        self.revert_shape = None

    def initialise_data_info(self, name):
        self.data_info.set_meta_data('name', name)
        self.data_info.set_meta_data('data_patterns', {})
        self.data_info.set_meta_data('shape', None)
        self.data_info.set_meta_data('nDims', None)

    def set_plugin_data(self, plugin_data_obj):
        self._plugin_data_obj = plugin_data_obj

    def clear_plugin_data(self):
        self._plugin_data_obj = None

    def get_plugin_data(self):
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def set_tomo_raw(self, tomo_raw_obj):
        self.tomo_raw_obj = tomo_raw_obj

    def clear_tomo_raw(self):
        self.tomo_raw_obj = None

    def get_tomo_raw(self):
        if self.tomo_raw_obj is not None:
            return self.tomo_raw_obj
        else:
            raise Exception("There is no TomoRaw object associated with "
                            "the Data object.")

    def get_transport_data(self):
        transport = self.exp.meta_data.get_meta_data("transport")
        "SETTING UP THE TRANSPORT DATA"
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        return import_class(transport_data)

    def __deepcopy__(self, memo):
        name = self.data_info.get_meta_data('name')
        new_obj = Data(name, self.exp)
        new_obj.add_base_classes(self.get_transport_data())
        new_obj.meta_data = self.meta_data
        new_obj.pattern_list = copy.deepcopy(self.pattern_list)
        new_obj.data_info = copy.deepcopy(self.data_info)
        new_obj.exp = self.exp
        new_obj._plugin_data_obj = self._plugin_data_obj
        new_obj.tomo_raw_obj = self.tomo_raw_obj
        new_obj.data_mapping = self.data_mapping
        new_obj.variable_length_flag = copy.deepcopy(self.variable_length_flag)
        new_obj.dtype = copy.deepcopy(self.dtype)
        new_obj.remove = copy.deepcopy(self.remove)
        new_obj.group_name = self.group_name
        new_obj.group = self.group
        new_obj.backing_file = self.backing_file
        new_obj.data = self.data
        new_obj.next_shape = copy.deepcopy(self.next_shape)
        new_obj.mapping = copy.deepcopy(self.mapping)
        new_obj.map_dim = copy.deepcopy(self.map_dim)
        new_obj.revert_shape = copy.deepcopy(self.map_dim)
        return new_obj

    def add_base(self, ExtraBase):
        cls = self.__class__
        self.__class__ = cls.__class__(cls.__name__, (cls, ExtraBase), {})
        ExtraBase().__init__()

    def add_base_classes(self, bases):
        bases = bases if isinstance(bases, list) else [bases]
        for base in bases:
            self.add_base(base)

    def external_link(self):
        return h5py.ExternalLink(self.backing_file.filename, self.group_name)

    def create_dataset(self, *args, **kwargs):
        """
        Set up required information when an output dataset has been created by
        a plugin
        """
        self.dtype = kwargs.get('dtype', np.float32)
        # remove from the plugin chain
        self.remove = kwargs.get('remove', False)
        if len(args) is 1:
            self.copy_dataset(args[0], removeDim=kwargs.get('removeDim', []))
            if args[0].tomo_raw_obj:
                self.set_tomo_raw(copy.deepcopy(args[0].get_tomo_raw()))
                self.get_tomo_raw().data_obj = self
        else:
            try:
                shape = kwargs['shape']
                self.create_axis_labels(kwargs['axis_labels'])
            except KeyError:
                raise Exception("Please state axis_labels and shape when "
                                "creating a new dataset")
            self.set_new_dataset_shape(shape)

            if 'patterns' in kwargs:
                self.copy_patterns(kwargs['patterns'])
        self.set_preview([])

    def set_new_dataset_shape(self, shape):
        if isinstance(shape, Data):
            self.find_and_set_shape(shape)
        elif type(shape) is dict:
            self.set_variable_flag()
            self.set_shape((shape[shape.keys()[0]] + ('var',)))
        else:
            pData = self.get_plugin_data()
            self.set_shape(shape + tuple(pData.extra_dims))
            if 'var' in shape:
                self.set_variable_flag()

    def copy_patterns(self, copy_data):
        if isinstance(copy_data, Data):
            patterns = copy_data.get_data_patterns()
        else:
            data = copy_data.keys()[0]
            pattern_list = copy_data[data]

            all_patterns = data.get_data_patterns()
            if len(pattern_list[0].split('.')) > 1:
                patterns = self.copy_patterns_removing_dimensions(
                    pattern_list, all_patterns, len(data.get_shape()))
            else:
                patterns = {}
                for pattern in pattern_list:
                    patterns[pattern] = all_patterns[pattern]
        self.set_data_patterns(patterns)

    def copy_patterns_removing_dimensions(self, pattern_list, all_patterns,
                                          nDims):
        copy_patterns = {}
        for new_pattern in pattern_list:
            name, all_dims = new_pattern.split('.')
            if name is '*':
                copy_patterns = all_patterns
            else:
                copy_patterns[name] = all_patterns[name]
            dims = tuple(map(int, all_dims.split(',')))
            dims = self.non_negative_directions(dims, nDims=nDims)

        patterns = {}
        for name, pattern_dict in copy_patterns.iteritems():
            empty_flag = False
            for ddir in pattern_dict:
                s_dims = self.non_negative_directions(
                    pattern_dict[ddir], nDims=nDims)
                new_dims = tuple([sd for sd in s_dims if sd not in dims])
                pattern_dict[ddir] = new_dims
                if not new_dims:
                    empty_flag = True
            if empty_flag is False:
                patterns[name] = pattern_dict
        return patterns

    def copy_dataset(self, copy_data, **kwargs):
        if copy_data.mapping:
            # copy label entries from meta data
            map_data = self.exp.index['mapping'][copy_data.get_name()]
            map_mData = map_data.meta_data
            map_axis_labels = map_data.data_info.get_meta_data('axis_labels')
            for axis_label in map_axis_labels:
                if axis_label.keys()[0] in map_mData.get_dictionary().keys():
                    map_label = map_mData.get_meta_data(axis_label.keys()[0])
                    copy_data.meta_data.set_meta_data(axis_label.keys()[0],
                                                      map_label)
            copy_data = map_data
        patterns = copy.deepcopy(copy_data.get_data_patterns())
        self.copy_labels(copy_data)
        self.find_and_set_shape(copy_data)
        self.set_data_patterns(patterns)

    def create_axis_labels(self, axis_labels):
        if isinstance(axis_labels, Data):
            self.copy_labels(axis_labels)
        elif isinstance(axis_labels, dict):
            data = axis_labels.keys()[0]
            self.copy_labels(data)          
            self.amend_axis_labels(axis_labels[data])
        else:
            self.set_axis_labels(*axis_labels)
            # if parameter tuning
            if self.get_plugin_data().multi_params_dict:
                self.add_extra_dims_labels()

    def copy_labels(self, copy_data):
        nDims = copy.copy(copy_data.data_info.get_meta_data('nDims'))
        axis_labels = \
            copy.copy(copy_data.data_info.get_meta_data('axis_labels'))
        self.data_info.set_meta_data('nDims', nDims)
        self.data_info.set_meta_data('axis_labels', axis_labels)
        # if parameter tuning
        if self.get_plugin_data().multi_params_dict:
            self.add_extra_dims_labels()

    def add_extra_dims_labels(self):
        params_dict = self.get_plugin_data().multi_params_dict
        # add multi_params axis labels from dictionary in pData
        nDims = self.data_info.get_meta_data('nDims')
        axis_labels = self.data_info.get_meta_data('axis_labels')
        axis_labels.extend([0]*len(params_dict))
        for key, value in params_dict.iteritems():
            title = value['label'].encode('ascii', 'ignore')
            name, unit = title.split('.')
            axis_labels[nDims + key] = {name: unit}
            # add parameter values to the meta_data
            self.meta_data.set_meta_data(name, np.array(value['values']))
        self.data_info.set_meta_data('nDims', nDims + len(self.extra_dims))
        self.data_info.set_meta_data('axis_labels', axis_labels)

    def amend_axis_labels(self, *args):
        axis_labels = self.data_info.get_meta_data('axis_labels')
        removed_dims = 0
        for arg in args[0]:
            label = arg.split('.')
            if len(label) is 1:
                del axis_labels[int(label[0]) + removed_dims]
                removed_dims += 1
                self.data_info.set_meta_data(
                    'nDims', self.data_info.get_meta_data('nDims') - 1)
            else:
                if int(label[0]) < 0:
                    axis_labels[int(label[0]) + removed_dims] = \
                        {label[1]: label[2]}
                else:
                    if int(label[0]) < self.data_info.get_meta_data('nDims'):
                        axis_labels[int(label[0])] = {label[1]: label[2]}
                    else:
                        axis_labels.insert(int(label[0]), {label[1]: label[2]})

    def set_data_patterns(self, patterns):
        self.add_extra_dims_to_patterns(patterns)
        self.data_info.set_meta_data('data_patterns', patterns)

    def add_extra_dims_to_patterns(self, patterns):
        all_dims = range(len(self.get_shape()))
        for p in patterns:
            pDims = patterns[p]['core_dir'] + patterns[p]['slice_dir']
            for dim in all_dims:
                if dim not in pDims:
                    patterns[p]['slice_dir'] += (dim,)

    def get_data_patterns(self):
        return self.data_info.get_meta_data('data_patterns')

    def set_shape(self, shape):
        self.data_info.set_meta_data('shape', shape)
        self.check_dims()

    def get_shape(self):
        shape = self.data_info.get_meta_data('shape')
        return shape

    def set_preview(self, preview_list, **kwargs):
        self.revert_shape = kwargs.get('revert', self.revert_shape)
        shape = self.get_shape()
        if preview_list:
            starts, stops, steps, chunks = \
                self.get_preview_indices(preview_list)
            shape_change = True
        else:
            starts, stops, steps, chunks = \
                [[0]*len(shape), shape, [1]*len(shape), [1]*len(shape)]
            shape_change = False
        self.set_starts_stops_steps(starts, stops, steps, chunks,
                                    shapeChange=shape_change)

    def unset_preview(self):
        self.set_preview([])
        self.set_shape(self.revert_shape)
        self.revert_shape = None

    def set_starts_stops_steps(self, starts, stops, steps, chunks,
                               shapeChange=True):
        self.data_info.set_meta_data('starts', starts)
        self.data_info.set_meta_data('stops', stops)
        self.data_info.set_meta_data('steps', steps)
        self.data_info.set_meta_data('chunks', chunks)
        if shapeChange or self.mapping:
            self.set_reduced_shape(starts, stops, steps, chunks)

    def get_preview_indices(self, preview_list):
        starts = len(preview_list)*[None]
        stops = len(preview_list)*[None]
        steps = len(preview_list)*[None]
        chunks = len(preview_list)*[None]
        for i in range(len(preview_list)):
            starts[i], stops[i], steps[i], chunks[i] = \
                self.convert_indices(preview_list[i].split(':'), i)
        return starts, stops, steps, chunks

    def convert_indices(self, idx, dim):
        shape = self.get_shape()
        mid = shape[dim]/2
        end = shape[dim]

        if self.mapping:
            map_shape = self.exp.index['mapping'][self.get_name()].get_shape()
            midmap = map_shape[dim]/2
            endmap = map_shape[dim]

        idx = [eval(equ) for equ in idx]
        idx = [idx[i] if idx[i] > -1 else shape[dim]+1+idx[i] for i in
               range(len(idx))]
        return idx

    def get_starts_stops_steps(self):
        starts = self.data_info.get_meta_data('starts')
        stops = self.data_info.get_meta_data('stops')
        steps = self.data_info.get_meta_data('steps')
        chunks = self.data_info.get_meta_data('chunks')
        return starts, stops, steps, chunks

    def set_reduced_shape(self, starts, stops, steps, chunks):
        orig_shape = self.get_shape()
        self.data_info.set_meta_data('orig_shape', orig_shape)
        new_shape = []
        for dim in range(len(starts)):
            new_shape.append(np.prod((self.get_slice_dir_matrix(dim).shape)))
        self.set_shape(tuple(new_shape))

        # reduce shape of mapping data if it exists
        if self.mapping:
            self.set_mapping_reduced_shape(orig_shape, new_shape,
                                           self.get_name())

    def set_mapping_reduced_shape(self, orig_shape, new_shape, name):
        map_obj = self.exp.index['mapping'][name]
        map_shape = np.array(map_obj.get_shape())
        diff = np.array(orig_shape) - map_shape[:len(orig_shape)]
        not_map_dim = np.where(diff == 0)[0]
        map_dim = np.where(diff != 0)[0]
        self.map_dim = map_dim
        map_obj.data_info.set_meta_data('full_map_dim_len', map_shape[map_dim])
        map_shape[not_map_dim] = np.array(new_shape)[not_map_dim]

        # assuming only one extra dimension added for now
        starts, stops, steps, chunks = self.get_starts_stops_steps()
        start = starts[map_dim] % map_shape[map_dim]
        stop = min(stops[map_dim], map_shape[map_dim])

        temp = len(np.arange(start, stop, steps[map_dim]))*chunks[map_dim]
        map_shape[len(orig_shape)] = np.ceil(new_shape[map_dim]/temp)
        map_shape[map_dim] = new_shape[map_dim]/map_shape[len(orig_shape)]
        map_obj.data_info.set_meta_data('map_dim_len', map_shape[map_dim])
        self.exp.index['mapping'][name].set_shape(tuple(map_shape))

    def find_and_set_shape(self, data):
        pData = self.get_plugin_data()
        new_shape = copy.copy(data.get_shape()) + tuple(pData.extra_dims)
        self.set_shape(new_shape)

    def set_variable_flag(self):
        self.variable_length_flag = True

    def get_variable_flag(self):
        return self.variable_length_flag

    def set_variable_array_length(self, var_size):
        var_size = var_size if isinstance(var_size, list) else [var_size]
        shape = list(self.get_shape())
        count = 0
        for i in range(len(shape)):
            if isinstance(shape[i], str):
                shape[i] = var_size[count]
                count += 1
        self.next_shape = tuple(shape)

    def check_dims(self):
        nDims = self.data_info.get_meta_data("nDims")
        shape = self.data_info.get_meta_data('shape')
        if nDims:
            if self.get_variable_flag() is False:
                if len(shape) != nDims:
                    error_msg = ("The number of axis labels, %d, does not "
                                 "coincide with the number of data "
                                 "dimensions %d." % (nDims, len(shape)))
                    raise Exception(error_msg)

    def set_name(self, name):
        self.data_info.set_meta_data('name', name)

    def get_name(self):
        return self.data_info.get_meta_data('name')

    def set_data_params(self, pattern, chunk_size, **kwargs):
        self.set_current_pattern_name(pattern)
        self.set_nFrames(chunk_size)

    def set_available_pattern_list(self):
        pattern_list = ["SINOGRAM",
                        "PROJECTION",
                        "VOLUME_YZ",
                        "VOLUME_XZ",
                        "VOLUME_XY",
                        "VOLUME_3D",
                        "SPECTRUM",
                        "DIFFRACTION",
                        "CHANNEL",
                        "SPECTRUM_STACK",
                        "PROJECTION_STACK",
                        "METADATA"]
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set_meta_data(['data_patterns', dtype, args],
                                             kwargs[args])
            self.convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self.add_extra_dims_to_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get_meta_data("nDims"):
                    actualDims = self.data_info.get_meta_data('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified."
                               % (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set_meta_data('nDims', nDims)
        else:
            raise Exception("The data pattern '%s'does not exist. Please "
                            "choose from the following list: \n'%s'",
                            dtype, str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        self.add_pattern("VOLUME_YZ", **self.get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.get_dirs_for_volume(x, y, z))

    def get_dirs_for_volume(self, dim1, dim2, sdir):
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dir'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dir'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        self.data_info.set_meta_data('nDims', len(args))
        axis_labels = []
        for arg in args:
            try:
                axis = arg.split('.')
                axis_labels.append({axis[0]: axis[1]})
            except:
                # data arrives here, but that may be an error
                pass
        self.data_info.set_meta_data('axis_labels', axis_labels)

    def find_axis_label_dimension(self, name, contains=False):
        axis_labels = self.data_info.get_meta_data('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def finalise_patterns(self):
        check = 0
        check += self.check_pattern('SINOGRAM')
        check += self.check_pattern('PROJECTION')
        if check is 2:
            self.set_main_axis('SINOGRAM')
            self.set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def check_pattern(self, pattern_name):
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def convert_pattern_directions(self, dtype):
        pattern = self.get_data_patterns()[dtype]
        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self.non_negative_directions(ddirs, nDims)

    def non_negative_directions(self, ddirs, nDims):
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def check_direction(self, tdir, dname):
        if not isinstance(tdir, int):
            raise TypeError('The direction should be an integer.')

        patterns = self.get_data_patterns()
        if not patterns:
            raise Exception("Please add available patterns before setting the"
                            " direction ", dname)

    def set_main_axis(self, pname):
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dir']
        d2 = patterns[pname]['slice_dir']
        tdir = set(d1).intersection(set(d2))
        self.data_info.set_meta_data(['data_patterns', pname, 'main_dir'],
                                     list(tdir)[0])

    def trim_input_data(self, **kwargs):
        if self.tomo_raw_obj:
            self.get_tomo_raw().select_image_key(**kwargs)

    def trim_output_data(self, copy_obj, **kwargs):
        if self.tomo_raw_obj:
            self.get_tomo_raw().remove_image_key(copy_obj, **kwargs)
            self.set_preview([])

    def get_axis_label_keys(self):
        axis_labels = self.data_info.get_meta_data('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def get_current_and_next_patterns(self, datasets_lists):
        current_datasets = datasets_lists[0]
        patterns_list = []
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.find_next_pattern(datasets_lists[1:],
                                                  current_name)
            patterns_list.append({'current': current_pattern,
                                  'next': next_pattern})
        self.exp.meta_data.set_meta_data('current_and_next', patterns_list)

    def find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern
Exemplo n.º 34
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """
    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__set_system_params()
        self.checkpoint = Checkpointing(self)
        self.__meta_data_setup(options["process_file"])
        self.collection = {}
        self.index = {"in_data": {}, "out_data": {}}
        self.initial_datasets = None
        self.plugin = None
        self._transport = None
        self._barrier_count = 0
        self._dataset_names_complete = False

    def get(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()
        try:
            rtype = self.meta_data.get('run_type')
            if rtype == 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            template = self.meta_data.get('template')
            self.meta_data.plugin_list._populate_plugin_list(process_file,
                                                             template=template)
        self.meta_data.set("nPlugin", 0)  # initialise

    def create_data_object(self, dtype, name, override=True):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        if name not in list(self.index[dtype].keys()) or override:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            data_obj._set_transport_data(self.meta_data.get('transport'))
        return self.index[dtype][name]

    def _setup(self, transport):
        self._set_nxs_file()
        self._set_transport(transport)
        self.collection = {'plugin_dict': [], 'datasets': []}

        self._barrier()
        self._check_checkpoint()
        self._barrier()

    def _finalise_setup(self, plugin_list):
        checkpoint = self.meta_data.get('checkpoint')
        self._set_dataset_names_complete()
        # save the plugin list - one process, first time only
        if self.meta_data.get('process') == \
                len(self.meta_data.get('processes'))-1 and not checkpoint:
            # links the input data to the nexus file
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
            self._add_input_data_to_nxs_file(self._get_transport())
        self._set_dataset_names_complete()

    def _set_initial_datasets(self):
        self.initial_datasets = copy.deepcopy(self.index['in_data'])

    def _update(self, plugin_dict):
        data = self.index['out_data'].copy()
        # clear output metadata after first setup
        for d in list(data.values()):
            d.meta_data._set_dictionary({})
        self.collection['datasets'].append(data)
        self.collection['plugin_dict'].append(plugin_dict)

    def _set_transport(self, transport):
        self._transport = transport

    def _get_transport(self):
        return self._transport

    def __set_system_params(self):
        sys_file = self.meta_data.get('system_params')
        import sys
        if sys_file is None:
            # look in conda environment to see which version is being used
            savu_path = sys.modules['savu'].__path__[0]
            sys_files = os.path.join(os.path.dirname(savu_path),
                                     'system_files')
            subdirs = os.listdir(sys_files)
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
            fname = 'system_parameters.yml'
            sys_file = os.path.join(sys_files, sys_folder, fname)
        logging.info('Using the system parameters file: %s', sys_file)
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))

    def _check_checkpoint(self):
        # if checkpointing has been set but the nxs file doesn't contain an
        # entry then remove checkpointing (as the previous run didn't get far
        # enough to require it).
        if self.meta_data.get('checkpoint'):
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
                if 'entry' not in f:
                    self.meta_data.set('checkpoint', None)

    def _add_input_data_to_nxs_file(self, transport):
        # save the loaded data to file
        h5 = Hdf5Utils(self)
        for name, data in self.index['in_data'].items():
            self.meta_data.set(['link_type', name], 'input_data')
            self.meta_data.set(['group_name', name], name)
            self.meta_data.set(['filename', name], data.backing_file)
            transport._populate_nexus_file(data)
            h5._link_datafile_to_nexus_file(data)

    def _set_dataset_names_complete(self):
        """ Missing in/out_datasets fields have been populated
        """
        self._dataset_names_complete = True

    def _get_dataset_names_complete(self):
        return self._dataset_names_complete

    def _reset_datasets(self):
        self.index['in_data'] = self.initial_datasets

    def _get_collection(self):
        return self.collection

    def _set_experiment_for_current_plugin(self, count):
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
        exp_coll = self._get_collection()
        self.index['out_data'] = exp_coll['datasets'][count]
        if datasets_list:
            self._get_current_and_next_patterns(datasets_list)
        self.meta_data.set('nPlugin', count)

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = {}
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list[current_name] = \
                {'current': current_pattern, 'next': next_pattern}
        self.meta_data.set('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def _set_nxs_file(self):
        folder = self.meta_data.get('out_path')
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set('nxs_filename', filename)

        if self.meta_data.get('process') == 1:
            if self.meta_data.get('bllog'):
                log_folder_name = self.meta_data.get('bllog')
                log_folder = open(log_folder_name, 'a')
                log_folder.write(os.path.abspath(filename) + '\n')
                log_folder.close()

        self._create_nxs_entry()

    def _create_nxs_entry(self):  # what if the file already exists?!
        logging.debug("Testing nexus file")
        import h5py
        if self.meta_data.get('process') == len(
                self.meta_data.get('processes')) - 1 and not self.checkpoint:
            with h5py.File(self.meta_data.get('nxs_filename'),
                           'w') as nxs_file:
                entry_group = nxs_file.create_group('entry')
                entry_group.attrs['NX_class'] = 'NXentry'

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].items():
            if data.remove is False:
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _finalise_experiment_for_current_plugin(self):
        finalise = {'remove': [], 'keep': []}
        # populate nexus file with out_dataset information and determine which
        # datasets to remove from the framework.

        for key, data in self.index['out_data'].items():
            if data.remove is True:
                finalise['remove'].append(data)
            else:
                finalise['keep'].append(data)

        # find in datasets to replace
        finalise['replace'] = []
        for out_name in list(self.index['out_data'].keys()):
            if out_name in list(self.index['in_data'].keys()):
                finalise['replace'].append(self.index['in_data'][out_name])

        return finalise

    def _reorganise_datasets(self, finalise):
        # unreplicate replicated in_datasets
        self.__unreplicate_data()

        # delete all datasets for removal
        for data in finalise['remove']:
            del self.index["out_data"][data.data_info.get('name')]

        # Add remaining output datasets to input datasets
        for name, data in self.index['out_data'].items():
            data.get_preview().set_preview([])
            self.index["in_data"][name] = copy.deepcopy(data)
        self.index['out_data'] = {}

    def __unreplicate_data(self):
        in_data_list = self.index['in_data']
        from savu.data.data_structures.data_types.replicate import Replicate
        for in_data in list(in_data_list.values()):
            if isinstance(in_data.data, Replicate):
                in_data.data = in_data.data._reset()

    def _set_all_datasets(self, name):
        data_names = []
        for key in list(self.index["in_data"].keys()):
            if 'itr_clone' not in key:
                data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
        comm_dict = {'comm': communicator}
        if self.meta_data.get('mpi') is True:
            logging.debug("Barrier %d: %d processes expected: %s",
                          self._barrier_count, communicator.size, msg)
            comm_dict['comm'].barrier()
        self._barrier_count += 1

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].items():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].items():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
Exemplo n.º 35
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}, "mapping": {}}
        self.nxs_file = None

    def get_meta_data(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get_meta_data(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get_meta_data('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get_meta_data('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list._populate_plugin_list(process_file)

    def create_data_object(self, dtype, name):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        bases = []
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            bases.append(data_obj._get_transport_data())
            cu.add_base_classes(data_obj, bases)
        return self.index[dtype][name]

    def _set_nxs_filename(self):
        folder = self.meta_data.get_meta_data('out_path')
        fname = os.path.basename(folder.split('_')[-1]) + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set_meta_data("nxs_filename", filename)

        if self.meta_data.get_meta_data("mpi") is True:
            self.nxs_file = h5py.File(filename, 'w', driver='mpio',
                                      comm=MPI.COMM_WORLD)
        else:
            self.nxs_file = h5py.File(filename, 'w')

    def __remove_dataset(self, data_obj):
        data_obj._close_file()
        del self.index["out_data"][data_obj.data_info.get_meta_data('name')]

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                if key in self.index['in_data'].keys():
                    data.meta_data._set_dictionary(
                        self.index['in_data'][key].meta_data.get_dictionary())
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _reorganise_datasets(self, out_data_objs, link_type):
        out_data_list = self.index["out_data"]
        self.__close_unwanted_files(out_data_list)
        self.__remove_unwanted_data(out_data_objs)

        self._barrier()
        self.__copy_out_data_to_in_data(link_type)

        self._barrier()
        self.index['out_data'] = {}

    def __remove_unwanted_data(self, out_data_objs):
        for out_objs in out_data_objs:
            if out_objs.remove is True:
                self.__remove_dataset(out_objs)

    def __close_unwanted_files(self, out_data_list):
        for out_objs in out_data_list:
            if out_objs in self.index["in_data"].keys():
                self.index["in_data"][out_objs]._close_file()

    def __copy_out_data_to_in_data(self, link_type):
        for key in self.index["out_data"]:
            output = self.index["out_data"][key]
            output._save_data(link_type)
            self.index["in_data"][key] = copy.deepcopy(output)

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a _barrier %s", comm_dict)
            comm_dict['comm'].barrier()
            logging.debug("Past the _barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
Exemplo n.º 36
0
    def _output_template(self, fname, process_fname):
        plist = self.plist.plugin_list
        index = [i for i in range(len(plist)) if plist[i]['active']]

        local_dict = MetaData(ordered=True)
        global_dict = MetaData(ordered=True)
        local_dict.set(['process_list'], os.path.abspath(process_fname))

        for i in index:
            params = self.__get_template_params(plist[i]['data'], [])
            name = plist[i]['name']
            for p in params:
                ptype, isyaml, key, value = p
                if isyaml:
                    data_name = isyaml if ptype == 'local' else 'all'
                    local_dict.set([i + 1, name, data_name, key], value)
                elif ptype == 'local':
                    local_dict.set([i + 1, name, key], value)
                else:
                    global_dict.set(['all', name, key], value)

        with open(fname, 'w') as stream:
            local_dict.get_dictionary().update(global_dict.get_dictionary())
            yu.dump_yaml(local_dict.get_dictionary(), stream)
Exemplo n.º 37
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.meta_data_setup(options["process_file"])
        self.index = {"in_data": {}, "out_data": {}, "mapping": {}}
        self.nxs_file = None

    def get_meta_data(self):
        return self.meta_data

    def meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get_meta_data('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get_meta_data('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list.populate_plugin_list(process_file)

    def create_data_object(self, dtype, name, bases=[]):
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            bases.append(data_obj.get_transport_data())
            data_obj.add_base_classes(bases)
        return self.index[dtype][name]

    def set_nxs_filename(self):
        name = self.index["in_data"].keys()[0]
        filename = os.path.basename(self.index["in_data"][name].
                                    backing_file.filename)
        filename = os.path.splitext(filename)[0]
        filename = os.path.join(self.meta_data.get_meta_data("out_path"),
                                "%s_processed_%s.nxs" %
                                (filename, time.strftime("%Y%m%d%H%M%S")))
        self.meta_data.set_meta_data("nxs_filename", filename)
        if self.meta_data.get_meta_data("mpi") is True:
            self.nxs_file = h5py.File(filename, 'w', driver='mpio',
                                      comm=MPI.COMM_WORLD)
        else:
            self.nxs_file = h5py.File(filename, 'w')

    def remove_dataset(self, data_obj):
        data_obj.close_file()
        del self.index["out_data"][data_obj.data_info.get_meta_data('name')]

    def clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def clear_out_data_objects(self):
        self.index["out_data"] = {}

    def merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                if key in self.index['in_data'].keys():
                    data.meta_data.set_dictionary(
                        self.index['in_data'][key].meta_data.get_dictionary())
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def reorganise_datasets(self, out_data_objs, link_type):
        out_data_list = self.index["out_data"]
        self.close_unwanted_files(out_data_list)
        self.remove_unwanted_data(out_data_objs)

        self.barrier()
        self.copy_out_data_to_in_data(link_type)

        self.barrier()
        self.clear_out_data_objects()

    def remove_unwanted_data(self, out_data_objs):
        for out_objs in out_data_objs:
            if out_objs.remove is True:
                self.remove_dataset(out_objs)

    def close_unwanted_files(self, out_data_list):
        for out_objs in out_data_list:
            if out_objs in self.index["in_data"].keys():
                self.index["in_data"][out_objs].close_file()

    def copy_out_data_to_in_data(self, link_type):
        for key in self.index["out_data"]:
            output = self.index["out_data"][key]
            output.save_data(link_type)
            self.index["in_data"][key] = copy.deepcopy(output)

    def set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            data_names.append(key)
        return data_names

    def barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get_meta_data('mpi') is True:
            logging.debug("About to hit a barrier")
            comm_dict['comm'].Barrier()
            logging.debug("Past the barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
Exemplo n.º 38
0
 def __init__(self, options):
     self.meta_data = MetaData(options)
     self.meta_data_setup(options["process_file"])
     self.index = {"in_data": {}, "out_data": {}}
Exemplo n.º 39
0
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """
    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.raw = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None
        self.previous_pattern = None
        self.transport_data = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set('name', name)
        self.data_info.set('data_patterns', {})
        self.data_info.set('shape', None)
        self.data_info.set('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _set_transport_data(self, transport):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        transport_data = cu.import_class(transport_data)
        self.transport_data = transport_data(self)

    def _get_transport_data(self):
        return self.transport_data

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get('data_patterns')

    def _set_previous_pattern(self, pattern):
        self.previous_pattern = pattern

    def get_previous_pattern(self):
        return self.previous_pattern

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        """ Set the original data shape before previewing
        """
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get("nDims")
        shape = self.data_info.get('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def _set_name(self, name):
        self.data_info.set('name', name)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set(['data_patterns', dtype, args],
                                   kwargs[args])

            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._set_data_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get("nDims"):
                    actualDims = self.data_info.get('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified." %
                               (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set('nDims', nDims)
        else:
            raise Exception(
                "The data pattern '%s'does not exist. Please "
                "choose from the following list: \n'%s'", dtype,
                str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dims'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dims'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set('nDims', len(args))
        axis_labels = []
        for arg in args:
            if isinstance(arg, dict):
                axis_labels.append(arg)
            else:
                try:
                    axis = arg.split('.')
                    axis_labels.append({axis[0]: axis[1]})
                except:
                    # data arrives here, but that may be an error
                    pass
        self.data_info.set('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get('axis_labels')

    def get_data_dimension_by_axis_label(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self._non_negative_directions(ddirs, nDims)

    def _non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dims']
        d2 = patterns[pname]['slice_dims']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def amend_axis_label_values(self, slice_list):
        """ Amend all axis label values based on the slice_list parameter.\
        This is required if the data is reduced.
        """
        axis_labels = self.get_axis_labels()
        for i in range(len(slice_list)):
            label = axis_labels[i].keys()[0]
            if label in self.meta_data.get_dictionary().keys():
                values = self.meta_data.get(label)
                preview_sl = [slice(None)] * len(values.shape)
                preview_sl[0] = slice_list[i]
                self.meta_data.set(label, values[preview_sl])

    def get_core_dimensions(self):
        """ Get the core data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``core_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['core_dims']

    def get_slice_dimensions(self):
        """ Get the slice data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``slice_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['slice_dims']

    def _add_raw_data_obj(self, data_obj):
        from savu.data.data_structures.data_types.data_plus_darks_and_flats\
            import ImageKey, NoImageKey

        if not isinstance(data_obj.raw, (ImageKey, NoImageKey)):
            raise Exception('Raw data type not recognised.')

        proj_dim = self.get_data_dimension_by_axis_label('rotation_angle')
        data_obj.data = NoImageKey(data_obj, None, proj_dim)
        data_obj.data._set_dark_and_flat()

        if isinstance(data_obj.raw, ImageKey):
            data_obj.raw._convert_to_noimagekey(data_obj.data)
            data_obj.data._set_fake_key(data_obj.raw.get_image_key())