コード例 #1
0
    def _output_template(self, fname):
        plist = self.plist.plugin_list
        index = [i for i in range(len(plist)) if plist[i]['active']]

        local_dict = MetaData(ordered=True)
        global_dict = MetaData(ordered=True)

        for i in index:
            params = self.__get_template_params(plist[i]['data'], [])
            name = plist[i]['name']
            for p in params:
                ptype, isyaml, key, value = p
                if isyaml:
                    data_name = isyaml if ptype == 'local' else 'all'
                    local_dict.set([i+1, name, data_name, key], value)
                elif ptype == 'local':
                    local_dict.set([i+1, name, key], value)
                else:
                    global_dict.set(['all', name, key], value)

        with open(fname, 'w') as stream:
            local_dict.get_dictionary().update(global_dict.get_dictionary())
            yu.dump_yaml(local_dict.get_dictionary(), stream)
コード例 #2
0
    def _output_template(self, fname, process_fname):
        plist = self.plist.plugin_list
        index = [i for i in range(len(plist)) if plist[i]['active']]

        local_dict = MetaData(ordered=True)
        global_dict = MetaData(ordered=True)
        local_dict.set(['process_list'], os.path.abspath(process_fname))

        for i in index:
            params = self.__get_template_params(plist[i]['data'], [])
            name = plist[i]['name']
            for p in params:
                ptype, isyaml, key, value = p
                if isyaml:
                    data_name = isyaml if ptype == 'local' else 'all'
                    local_dict.set([i+1, name, data_name, key], value)
                elif ptype == 'local':
                    local_dict.set([i+1, name, key], value)
                else:
                    global_dict.set(['all', name, key], value)

        with open(fname, 'w') as stream:
            local_dict.get_dictionary().update(global_dict.get_dictionary())
            yu.dump_yaml(local_dict.get_dictionary(), stream)
コード例 #3
0
ファイル: plugin_data.py プロジェクト: DimitriosBellos/Savu
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.shape_transfer = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1] * (len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None] * len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1
        self.shape_transfer = tuple(shape)

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.shape_transfer

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self):
        """ Set the slice dimensions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dims']
        self.meta_data.set('slice_dims', slice_dirs)

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
            label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(), )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0], )
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _get_max_frames_parameters(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]

        frames = np.prod([shape[d] for d in sdir])
        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        f_per_p = np.ceil(frames / n_procs)
        params_dict = {
            'shape': shape,
            'sdir': sdir,
            'total_frames': frames,
            'mpi_procs': n_procs,
            'frames_per_process': f_per_p
        }
        return params_dict

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        self.params = self._get_max_frames_parameters()
        warn_threshold = 0.85
        nprocs = self.params['mpi_procs']
        nframes = self.params['total_frames']
        temp = (((nframes / mft) / float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            logging.warn(
                'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                'sdir %s, nprocs %s', self.params['shape'], nframes,
                self.params['sdir'], nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in self.pad_dict.keys():
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self, pattern, nFrames, split=None):
        """ Setup the PluginData object.

        # add more information into here via a decorator!
        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from\
            'single', 'multiple', 'fixed_multiple' or an integer (an integer \
            should only ever be passed in exceptional circumstances)
        """
        self.__set_pattern(pattern)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames

        self.__set_max_frames(nFrames)
        mft = self.meta_data.get('max_frames_transfer')
        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()
        self.split = split

    def __set_max_frames(self, nFrames):
        self.max_frames = nFrames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, mft_shape = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('max_frames_transfer', mft)
        if mft:
            self._set_shape_transfer(mft_shape)
        mfp = td._calc_max_frames_process(nFrames)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)

        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and nFrames == 'multiple':
            self._set_no_squeeze()

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not isinstance(nFrames, int) and nFrames not in options:
            e_str = "The value of nFrames is not recognised.  Please choose "
            "from 'single' and 'multiple' (or an integer in exceptional "
            "circumstances)."
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit
コード例 #4
0
ファイル: plugin_data.py プロジェクト: rcatwood/Savu
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """

    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.selected_data = False
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin

    def get_total_frames(self):
        """ Get the total number of frames to process.

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information int the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.__set_slice_directions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.__set_core_shape(tuple(shape))
        if self._get_frame_chunk() > 1:
            shape.insert(slice_idx, self._get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        """ Get the shape of the plugin data to be processed each time.
        """
        return self.shape

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_directions(self):
        """ Set the slice directions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        """ Get the slice directions (slice_dir) of the dataset.
        """
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self._get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0],)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        """ Get the core data directions

        :returns: value associated with pattern key ``core_dir``
        :rtype: tuple
        """
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.__set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_directions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def _set_frame_chunk(self, nFrames):
        """ Set the number of frames to process at a time """
        self.meta_data.set_meta_data("nFrames", nFrames)

    def _get_frame_chunk(self):
        """ Get the number of frames to be processes at a time.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get_meta_data("nFrames")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self._set_frame_chunk(gcd(frame_chunk, chunk))
        return self.meta_data.get_meta_data("nFrames")

    def plugin_data_setup(self, pattern_name, chunk):
        """ Setup the PluginData object.

        :param str pattern_name: A pattern name
        :param int chunk: Number of frames to process at a time
        """
        self.__set_pattern(pattern_name)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')
        if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk):
            self._plugin.chunk = True
        self._set_frame_chunk(chunk)
        self.__set_shape()
コード例 #5
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = False
        self.end_pad = False
        self.split = None

    def get_total_frames(self):
        """ Get the total number of frames to process.

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information int the meta data dict.
        """

        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.__set_slice_directions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.__set_core_shape(tuple(shape))
        if self._get_frame_chunk() > 1:
            shape.insert(slice_idx, self._get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        """ Get the shape of the plugin data to be processed each time.
        """
        return self.shape

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_directions(self):
        """ Set the slice directions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        """ Get the slice directions (slice_dir) of the dataset.
        """
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self._get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0], )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        """ Get the core data directions

        :returns: value associated with pattern key ``core_dir``
        :rtype: tuple
        """
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.__set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_directions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_directions() + self.get_slice_directions()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _set_frame_chunk(self, nFrames):
        """ Set the number of frames to process at a time """
        self.meta_data.set_meta_data("nFrames", nFrames)

    def _get_frame_chunk(self):
        """ Get the number of frames to be processes at a time.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get_meta_data("nFrames")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self._set_frame_chunk(gcd(frame_chunk, chunk))
        return self.meta_data.get_meta_data("nFrames")

    def plugin_data_setup(self, pattern_name, chunk, fixed=False, split=None):
        """ Setup the PluginData object.

        :param str pattern_name: A pattern name
        :param int chunk: Number of frames to process at a time
        :keyword bool fixed: setting fixed=True will ensure the plugin \
            receives the same sized data array each time (padding if necessary)

        """
        self.__set_pattern(pattern_name)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')
        if self._plugin and (chunks[self.get_slice_directions()[0]] % chunk):
            self._plugin.chunk = True
        self._set_frame_chunk(chunk)
        self.__set_shape()
        self.fixed_dims = fixed
        self.split = split
コード例 #6
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None
        self._increase_rank = 0

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name, first_sdim=None):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions(first_sdim=first_sdim)

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1] * (len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None] * len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1
        return tuple(shape)

    def __get_slice_size(self, mft):
        """ Calculate the number of frames transfer in each dimension given
            mft. """
        dshape = list(self.data_obj.get_shape())

        if 'fixed_dimensions' in list(self.meta_data.get_dictionary().keys()):
            fixed_dims = self.meta_data.get('fixed_dimensions')
            for d in fixed_dims:
                dshape[d] = 1

        dshape = [dshape[i] for i in self.meta_data.get('slice_dims')]
        size_list = [1] * len(dshape)
        i = 0

        while (mft > 1 and i < len(size_list)):
            size_list[i] = min(dshape[i], mft)
            mft //= np.prod(size_list) if np.prod(size_list) > 1 else 1
            i += 1

        self.meta_data.set('size_list', size_list)
        return size_list

    def set_bytes_per_frame(self):
        """ Return the size of a single frame in bytes. """
        nBytes = self.data_obj.get_itemsize()
        dims = list(self.get_pattern().values())[0]['core_dims']
        frame_shape = [self.data_obj.get_shape()[d] for d in dims]
        b_per_f = np.prod(frame_shape) * nBytes
        return frame_shape, b_per_f

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.meta_data.get('transfer_shape')

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self, first_sdim=None):
        """ Set the slice dimensions in the pluginData meta data dictionary.\
        Reorder pattern slice_dims to ensure first_sdim is at the front.
        """
        pattern = self.data_obj.get_data_patterns()[self.get_pattern_name()]
        slice_dims = pattern['slice_dims']

        if first_sdim:
            slice_dims = list(slice_dims)
            first_sdim = \
                self.data_obj.get_data_dimension_by_axis_label(first_sdim)
            slice_dims.insert(0, slice_dims.pop(slice_dims.index(first_sdim)))
            pattern['slice_dims'] = tuple(slice_dims)

        self.meta_data.set('slice_dims', tuple(slice_dims))

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
            label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(), )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):  # should this function be deleted?
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0], )
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        #self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _set_rank_inc(self, n):
        """ Increase the rank of the array passed to the plugin by n.
        
        :param int n: Rank increment.
        """
        self._increase_rank = n

    def _get_rank_inc(self):
        """ Return the increased rank value
        
        :returns: Rank increment
        :rtype: int
        """
        return self._increase_rank

    def _set_meta_data(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]

        if 'fix_total_frames' in list(self.meta_data.get_dictionary().keys()):
            frames = self.meta_data.get('fix_total_frames')
        else:
            frames = np.prod([shape[d] for d in sdir])

        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        # Fixing f_per_p to be just the first slice dimension for now due to
        # slow performance from HDF5 when not slicing multiple dimensions
        # concurrently
        #f_per_p = np.ceil(frames/n_procs)
        f_per_p = np.ceil(shape[sdir[0]] / n_procs)
        self.meta_data.set('shape', shape)
        self.meta_data.set('sdir', sdir)
        self.meta_data.set('total_frames', frames)
        self.meta_data.set('mpi_procs', n_procs)
        self.meta_data.set('frames_per_process', f_per_p)
        frame_shape, b_per_f = self.set_bytes_per_frame()
        self.meta_data.set('bytes_per_frame', b_per_f)
        self.meta_data.set('bytes_per_process', b_per_f * f_per_p)
        self.meta_data.set('frame_shape', frame_shape)

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        warn_threshold = 0.85
        nprocs = self.meta_data.get('mpi_procs')
        nframes = self.meta_data.get('total_frames')
        temp = (((nframes / mft) / float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            shape = self.meta_data.get('shape')
            sdir = self.meta_data.get('sdir')
            logging.warning(
                'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                'sdir %s, nprocs %s', shape, nframes, sdir, nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in list(self.pad_dict.keys()):
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self,
                          pattern,
                          nFrames,
                          split=None,
                          slice_axis=None,
                          getall=None):
        """ Setup the PluginData object.

        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from
            'single', 'multiple', 'fixed_multiple' or an integer (an integer
            should only ever be passed in exceptional circumstances)
        :keyword str slice_axis: An axis label associated with the fastest
            changing (first) slice dimension.
        :keyword list[pattern, axis_label] getall: A list of two values.  If
        the requested pattern doesn't exist then use all of "axis_label"
        dimension of "pattern" as this is equivalent to one slice of the
        original pattern.
        """

        if pattern not in self.data_obj.get_data_patterns() and getall:
            pattern, nFrames = self.__set_getall_pattern(getall, nFrames)

        # slice_axis is first slice dimension
        self.__set_pattern(pattern, first_sdim=slice_axis)
        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames
        self.max_frames = nFrames
        self.split = split

    def __set_getall_pattern(self, getall, nFrames):
        """ Set framework changes required to get all of a pattern of lower
        rank.
        """
        pattern, slice_axis = getall
        dim = self.data_obj.get_data_dimension_by_axis_label(slice_axis)
        # ensure data remains the same shape when 'getall' dim has length 1
        self._set_no_squeeze()
        if nFrames == 'multiple' or (isinstance(nFrames, int) and nFrames > 1):
            self._set_rank_inc(1)
        nFrames = self.data_obj.get_shape()[dim]
        return pattern, nFrames

    def plugin_data_transfer_setup(self, copy=None, calc=None):
        """ Set up the plugin data transfer frame parameters.
        If copy=pData (another PluginData instance) then copy """
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if not copy and not calc:
            mft, mft_shape, mfp = self._calculate_max_frames()
        elif calc:
            max_mft = calc.meta_data.get('max_frames_transfer')
            max_mfp = calc.meta_data.get('max_frames_process')
            max_nProc = int(np.ceil(max_mft / float(max_mfp)))
            nProc = max_nProc
            mfp = 1 if self.max_frames == 'single' else self.max_frames
            mft = nProc * mfp
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
        elif copy:
            mft = copy._get_max_frames_transfer()
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
            mfp = copy._get_max_frames_process()

        self.__set_max_frames(mft, mft_shape, mfp)

        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()

    def _calculate_max_frames(self):
        nFrames = self.max_frames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, size_list = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('size_list', size_list)
        mfp = td._calc_max_frames_process(nFrames)
        if mft:
            mft_shape = self._set_shape_transfer(list(size_list))
        return mft, mft_shape, mfp

    def __set_max_frames(self, mft, mft_shape, mfp):
        self.meta_data.set('max_frames_transfer', mft)
        self.meta_data.set('transfer_shape', mft_shape)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)
        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and self.max_frames == 'multiple':
            self._set_no_squeeze()

    def _get_plugin_data_size_params(self):
        nBytes = self.data_obj.get_itemsize()
        frame_shape = self.meta_data.get('frame_shape')
        total_frames = self.meta_data.get('total_frames')
        tbytes = nBytes * np.prod(frame_shape) * total_frames

        params = {
            'nBytes': nBytes,
            'frame_shape': frame_shape,
            'total_frames': total_frames,
            'transfer_bytes': tbytes
        }
        return params

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not np.issubdtype(type(nFrames),
                             np.int64) and nFrames not in options:
            e_str = (
                "The value of nFrames is not recognised.  Please choose " +
                "from 'single' and 'multiple' (or an integer in exceptional " +
                "circumstances).")
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit

    def get_current_frame_idx(self):
        """ Returns the index of the frames currently being processed.
        """
        global_index = self._plugin.get_global_frame_index()
        count = self._plugin.get_process_frames_counter()
        mfp = self.meta_data.get('max_frames_process')
        start = global_index[count] * mfp
        index = np.arange(start, start + mfp)
        nFrames = self.get_total_frames()
        index[index >= nFrames] = nFrames - 1
        return index
コード例 #7
0
ファイル: data.py プロジェクト: DiamondLightSource/Savu
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """

    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.raw = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None
        self.previous_pattern = None
        self.transport_data = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set('name', name)
        self.data_info.set('data_patterns', {})
        self.data_info.set('shape', None)
        self.data_info.set('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _set_transport_data(self, transport):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        transport_data = cu.import_class(transport_data)
        self.transport_data = transport_data(self)
        self.data_info.set('transport', transport)

    def _get_transport_data(self):
        return self.transport_data

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get('data_patterns')

    def _set_previous_pattern(self, pattern):
        self.previous_pattern = pattern

    def get_previous_pattern(self):
        return self.previous_pattern

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        """ Set the original data shape before previewing
        """
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get("nDims")
        shape = self.data_info.get('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def _set_name(self, name):
        self.data_info.set('name', name)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set(['data_patterns', dtype, args],
                                   kwargs[args])

            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._set_data_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get("nDims"):
                    actualDims = self.data_info.get('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified."
                               % (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set('nDims', nDims)
        else:
            raise Exception("The data pattern '%s'does not exist. Please "
                            "choose from the following list: \n'%s'",
                            dtype, str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dims'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dims'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set('nDims', len(args))
        axis_labels = []
        for arg in args:
            if isinstance(arg, dict):
                axis_labels.append(arg)
            else:
                try:
                    axis = arg.split('.')
                    axis_labels.append({axis[0]: axis[1]})
                except:
                    # data arrives here, but that may be an error
                    pass
        self.data_info.set('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get('axis_labels')

    def get_data_dimension_by_axis_label(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self._non_negative_directions(ddirs, nDims)

    def _non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dims']
        d2 = patterns[pname]['slice_dims']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def amend_axis_label_values(self, slice_list):
        """ Amend all axis label values based on the slice_list parameter.\
        This is required if the data is reduced.
        """
        axis_labels = self.get_axis_labels()
        for i in range(len(slice_list)):
            label = axis_labels[i].keys()[0]
            if label in self.meta_data.get_dictionary().keys():
                values = self.meta_data.get(label)
                preview_sl = [slice(None)]*len(values.shape)
                preview_sl[0] = slice_list[i]
                self.meta_data.set(label, values[preview_sl])

    def get_core_dimensions(self):
        """ Get the core data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``core_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['core_dims']

    def get_slice_dimensions(self):
        """ Get the slice data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``slice_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['slice_dims']

    def get_itemsize(self):
        """ Returns bytes per entry """
        dtype = self.get_dtype()
        if not dtype:
            self.set_dtype(None)
            dtype = self.get_dtype()
        return self.get_dtype().itemsize
コード例 #8
0
ファイル: data.py プロジェクト: mattronix/Savu
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """
    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.raw = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None
        self.previous_pattern = None
        self.transport_data = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set('name', name)
        self.data_info.set('data_patterns', {})
        self.data_info.set('shape', None)
        self.data_info.set('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _set_transport_data(self, transport):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        transport_data = cu.import_class(transport_data)
        self.transport_data = transport_data(self)

    def _get_transport_data(self):
        return self.transport_data

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get('data_patterns')

    def _set_previous_pattern(self, pattern):
        self.previous_pattern = pattern

    def get_previous_pattern(self):
        return self.previous_pattern

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        """ Set the original data shape before previewing
        """
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get("nDims")
        shape = self.data_info.get('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def _set_name(self, name):
        self.data_info.set('name', name)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set(['data_patterns', dtype, args],
                                   kwargs[args])

            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._set_data_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get("nDims"):
                    actualDims = self.data_info.get('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified." %
                               (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set('nDims', nDims)
        else:
            raise Exception(
                "The data pattern '%s'does not exist. Please "
                "choose from the following list: \n'%s'", dtype,
                str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dims'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dims'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set('nDims', len(args))
        axis_labels = []
        for arg in args:
            if isinstance(arg, dict):
                axis_labels.append(arg)
            else:
                try:
                    axis = arg.split('.')
                    axis_labels.append({axis[0]: axis[1]})
                except:
                    # data arrives here, but that may be an error
                    pass
        self.data_info.set('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get('axis_labels')

    def get_data_dimension_by_axis_label(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self._non_negative_directions(ddirs, nDims)

    def _non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dims']
        d2 = patterns[pname]['slice_dims']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def amend_axis_label_values(self, slice_list):
        """ Amend all axis label values based on the slice_list parameter.\
        This is required if the data is reduced.
        """
        axis_labels = self.get_axis_labels()
        for i in range(len(slice_list)):
            label = axis_labels[i].keys()[0]
            if label in self.meta_data.get_dictionary().keys():
                values = self.meta_data.get(label)
                preview_sl = [slice(None)] * len(values.shape)
                preview_sl[0] = slice_list[i]
                self.meta_data.set(label, values[preview_sl])

    def get_core_dimensions(self):
        """ Get the core data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``core_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['core_dims']

    def get_slice_dimensions(self):
        """ Get the slice data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``slice_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['slice_dims']

    def _add_raw_data_obj(self, data_obj):
        from savu.data.data_structures.data_types.data_plus_darks_and_flats\
            import ImageKey, NoImageKey

        if not isinstance(data_obj.raw, (ImageKey, NoImageKey)):
            raise Exception('Raw data type not recognised.')

        proj_dim = self.get_data_dimension_by_axis_label('rotation_angle')
        data_obj.data = NoImageKey(data_obj, None, proj_dim)
        data_obj.data._set_dark_and_flat()

        if isinstance(data_obj.raw, ImageKey):
            data_obj.raw._convert_to_noimagekey(data_obj.data)
            data_obj.data._set_fake_key(data_obj.raw.get_image_key())
コード例 #9
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """

    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1]*(len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None]*len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1            
        return tuple(shape)

    def __get_slice_size(self, mft):
        """ Calculate the number of frames transfer in each dimension given
            mft. """
        dshape = list(self.data_obj.get_shape())

        if 'fixed_dimensions' in self.meta_data.get_dictionary().keys():
            fixed_dims = self.meta_data.get('fixed_dimensions')
            for d in fixed_dims:
                dshape[d] = 1

        dshape = [dshape[i] for i in self.meta_data.get('slice_dims')]
        size_list = [1]*len(dshape)
        i = 0
        
        while(mft > 1):
            size_list[i] = min(dshape[i], mft)
            mft -= np.prod(size_list) if np.prod(size_list) > 1 else 0
            i += 1

        self.meta_data.set('size_list', size_list)
        return size_list

    def set_bytes_per_frame(self):
        """ Return the size of a single frame in bytes. """
        nBytes = self.data_obj.get_itemsize()
        dims = self.get_pattern().values()[0]['core_dims']
        frame_shape = [self.data_obj.get_shape()[d] for d in dims]
        b_per_f = np.prod(frame_shape)*nBytes
        return frame_shape, b_per_f

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.meta_data.get('transfer_shape')

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self):
        """ Set the slice dimensions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dims']
        self.meta_data.set('slice_dims', slice_dirs)

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
                label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(),)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0],)
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        #self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _set_meta_data(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]
        
        if 'fix_total_frames' in self.meta_data.get_dictionary().keys():
            frames = self.meta_data.get('fix_total_frames')
        else:
            frames = np.prod([shape[d] for d in sdir])

        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        f_per_p = np.ceil(frames/n_procs)
        self.meta_data.set('shape', shape)
        self.meta_data.set('sdir', sdir)
        self.meta_data.set('total_frames', frames)
        self.meta_data.set('mpi_procs', n_procs)
        self.meta_data.set('frames_per_process', f_per_p)
        frame_shape, b_per_f = self.set_bytes_per_frame()
        self.meta_data.set('bytes_per_frame', b_per_f)
        self.meta_data.set('bytes_per_process', b_per_f*f_per_p)
        self.meta_data.set('frame_shape', frame_shape)

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        warn_threshold = 0.85
        nprocs = self.meta_data.get('mpi_procs')
        nframes = self.meta_data.get('total_frames')
        temp = (((nframes/mft)/float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            shape = self.meta_data.get('shape')
            sdir = self.meta_data.get('sdir')
            logging.warn('UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                         'sdir %s, nprocs %s', shape, nframes, sdir, nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in self.pad_dict.keys():
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self, pattern, nFrames, split=None):
        """ Setup the PluginData object.

        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from\
            'single', 'multiple', 'fixed_multiple' or an integer (an integer \
            should only ever be passed in exceptional circumstances)
        """
        self.__set_pattern(pattern)
        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames
        self.max_frames = nFrames
        self.split = split

    def plugin_data_transfer_setup(self, copy=None, calc=None):
        """ Set up the plugin data transfer frame parameters.
        If copy=pData (another PluginData instance) then copy """
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if not copy and not calc:
            mft, mft_shape, mfp = self._calculate_max_frames()
        elif calc:
            max_mft = calc.meta_data.get('max_frames_transfer')             
            max_mfp = calc.meta_data.get('max_frames_process')
            max_nProc = int(np.ceil(max_mft/float(max_mfp)))
            nProc = max_nProc
            mfp = 1 if self.max_frames == 'single' else self.max_frames
            mft = nProc*mfp
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
        elif copy:
            mft = copy._get_max_frames_transfer()
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
            mfp = copy._get_max_frames_process()

        self.__set_max_frames(mft, mft_shape, mfp)

        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()

    def _calculate_max_frames(self):
        nFrames = self.max_frames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, size_list = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('size_list', size_list)
        mfp = td._calc_max_frames_process(nFrames)
        if mft:
            mft_shape = self._set_shape_transfer(list(size_list))
        return mft, mft_shape, mfp

    def __set_max_frames(self, mft, mft_shape, mfp):
        self.meta_data.set('max_frames_transfer', mft)
        self.meta_data.set('transfer_shape', mft_shape)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)
        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and self.max_frames == 'multiple':
            self._set_no_squeeze()

    def _get_plugin_data_size_params(self):
        nBytes = self.data_obj.get_itemsize()
        frame_shape = self.meta_data.get('frame_shape')
        total_frames = self.meta_data.get('total_frames')
        tbytes = nBytes*np.prod(frame_shape)*total_frames
        
        params = {'nBytes': nBytes, 'frame_shape': frame_shape,
                  'total_frames': total_frames, 'transfer_bytes': tbytes}
        return params

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not isinstance(nFrames, int) and nFrames not in options:
            e_str = "The value of nFrames is not recognised.  Please choose "
            "from 'single' and 'multiple' (or an integer in exceptional "
            "circumstances)."
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit

    def get_current_frame_idx(self):
        """ Returns the index of the frames currently being processed.
        """
        global_index = self._plugin.get_global_frame_index()
        count = self._plugin.get_process_frames_counter()
        mfp = self.meta_data.get('max_frames_process')
        start = global_index[count]*mfp
        index = np.arange(start, start + mfp)
        nFrames = self.get_total_frames()
        index[index >= nFrames] = nFrames - 1
        return index
コード例 #10
0
ファイル: data_structures.py プロジェクト: r-atwood/Savu
class PluginData(object):

    def __init__(self, data_obj):
        self.data_obj = data_obj
        self.data_obj.set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        # this flag determines which data is passed. If false then just the
        # data, if true then all data including dark and flat fields.
        self.selected_data = False
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []

    def get_total_frames(self):
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dir"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def set_pattern(self, name):
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set_meta_data("name", name)
        self.meta_data.set_meta_data("core_dir", pattern['core_dir'])
        self.set_slice_directions()

    def get_pattern_name(self):
        name = self.meta_data.get_meta_data("name")
        if name is not None:
            return name
        else:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def set_shape(self):
        core_dir = self.get_core_directions()
        slice_dir = self.get_slice_directions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        shape = []
        for core in set(core_dir):
            shape.append(self.data_obj.get_shape()[core])
        self.set_core_shape(tuple(shape))
        if self.get_frame_chunk() > 1:
            shape.insert(slice_idx, self.get_frame_chunk())
        self.shape = tuple(shape)

    def get_shape(self):
        return self.shape

    def set_core_shape(self, shape):
        self.core_shape = shape

    def get_core_shape(self):
        return self.core_shape

    def check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def set_slice_directions(self):
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dir']
        self.meta_data.set_meta_data('slice_dir', slice_dirs)

    def get_slice_directions(self):
        return self.meta_data.get_meta_data('slice_dir')

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.get_core_directions()
        slice_dir = self.get_slice_directions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = \
            self.data_obj.find_axis_label_dimension(label, contains=contains)
        plugin_dims = self.get_core_directions()
        if self.get_frame_chunk() > 1:
            plugin_dims += (self.get_slice_directions()[0],)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice directions.  The fastest changing slice direction
        will always be the first one. The input param is a tuple stating the
        desired order of slicing directions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_directions(self):
        core_dir = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['core_dir']
        return core_dir

    def set_fixed_directions(self, dims, values):
        slice_dirs = self.get_slice_directions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set_meta_data("fixed_directions", dims)
        self.meta_data.set_meta_data("fixed_directions_values", values)
        self.set_slice_directions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.set_shape()

    def get_fixed_directions(self):
        fixed = []
        values = []
        if 'fixed_directions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get_meta_data("fixed_directions")
            values = self.meta_data.get_meta_data("fixed_directions_values")
        return [fixed, values]

    def set_frame_chunk(self, nFrames):
        # number of frames to process at a time
        self.meta_data.set_meta_data("nFrames", nFrames)

    def get_frame_chunk(self):
        return self.meta_data.get_meta_data("nFrames")

    def get_index(self, indices):
        shape = self.get_shape()
        nDims = len(shape)
        name = self.get_current_pattern_name()
        ddirs = self.get_data_patterns()
        core_dir = ddirs[name]["core_dir"]
        slice_dir = ddirs[name]["slice_dir"]

        self.check_dimensions(indices, core_dir, slice_dir, nDims)

        index = [slice(None)]*nDims
        count = 0
        for tdir in slice_dir:
            index[tdir] = slice(indices[count], indices[count]+1, 1)
            count += 1

        return tuple(index)

    def plugin_data_setup(self, pattern_name, chunk):
        self.set_pattern(pattern_name)
        self.set_frame_chunk(chunk)
        self.set_shape()

    def set_temp_pad_dict(self, pad_dict):
        self.meta_data.set_meta_data('temp_pad_dict', pad_dict)

    def get_temp_pad_dict(self):
        if 'temp_pad_dict' in self.meta_data.get_dictionary().keys():
            return self.meta_data.get_dictionary()['temp_pad_dict']

    def delete_temp_pad_dict(self):
        del self.meta_data.get_dictionary()['temp_pad_dict']