示例#1
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """
    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__meta_data_setup(options["process_file"])
        self.experiment_collection = {}
        self.index = {"in_data": {}, "out_data": {}}
        self.initial_datasets = None
        self.plugin = None

    def get(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()

        try:
            rtype = self.meta_data.get('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            self.meta_data.plugin_list._populate_plugin_list(process_file)

    def create_data_object(self, dtype, name):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        try:
            self.index[dtype][name]
        except KeyError:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            data_obj._set_transport_data(self.meta_data.get('transport'))
        return self.index[dtype][name]

    def _experiment_setup(self):
        """ Setup an experiment collection.
        """
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list
        plist = plugin_list.plugin_list

        # load the loader plugins
        self._set_loaders()

        # load the saver plugin and save the plugin list
        self.experiment_collection = {'plugin_dict': [], 'datasets': []}

        self._barrier()
        if self.meta_data.get('process') == \
                len(self.meta_data.get('processes'))-1:
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
        self._barrier()

        n_plugins = plugin_list._get_n_processing_plugins()
        count = 0
        # first run through of the plugin setup methods
        for plugin_dict in plist[n_loaders:n_loaders + n_plugins]:
            data = self.__plugin_setup(plugin_dict, count)
            self.experiment_collection['datasets'].append(data)
            self.experiment_collection['plugin_dict'].append(plugin_dict)
            self._merge_out_data_to_in()
            count += 1
        self._reset_datasets()

    def _set_loaders(self):
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list.plugin_list
        for i in range(n_loaders):
            pu.plugin_loader(self, plugin_list[i])
        self.initial_datasets = copy.deepcopy(self.index['in_data'])

    def _reset_datasets(self):
        self.index['in_data'] = self.initial_datasets

    def __plugin_setup(self, plugin_dict, count):
        """ Determine plugin specific information.
        """
        plugin_id = plugin_dict["id"]
        logging.debug("Loading plugin %s", plugin_id)
        # Run main_setup method
        plugin = pu.plugin_loader(self, plugin_dict)
        plugin._revert_preview(plugin.get_in_datasets())
        # Populate the metadata
        plugin._clean_up()
        data = self.index['out_data'].copy()
        return data

    def _get_experiment_collection(self):
        return self.experiment_collection

    def _set_experiment_for_current_plugin(self, count):
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
        exp_coll = self._get_experiment_collection()
        self.index['out_data'] = exp_coll['datasets'][count]
        if datasets_list:
            self._get_current_and_next_patterns(datasets_list)
        self.meta_data.set('nPlugin', count)

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = []
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list.append({
                'current': current_pattern,
                'next': next_pattern
            })
        self.meta_data.set('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def _set_nxs_filename(self):
        folder = self.meta_data.get('out_path')
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set('nxs_filename', filename)

        if self.meta_data.get('process') == 1:
            if self.meta_data.get('bllog'):
                log_folder_name = self.meta_data.get('bllog')
                log_folder = open(log_folder_name, 'a')
                log_folder.write(os.path.abspath(filename) + '\n')
                log_folder.close()

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _finalise_experiment_for_current_plugin(self):
        finalise = {}
        # populate nexus file with out_dataset information and determine which
        # datasets to remove from the framework.
        finalise['remove'] = []
        finalise['keep'] = []

        for key, data in self.index['out_data'].iteritems():
            if data.remove is True:
                finalise['remove'].append(data)
            else:
                finalise['keep'].append(data)

        # find in datasets to replace
        finalise['replace'] = []
        for out_name in self.index['out_data'].keys():
            if out_name in self.index['in_data'].keys():
                finalise['replace'].append(self.index['in_data'][out_name])

        return finalise

    def _reorganise_datasets(self, finalise):
        # unreplicate replicated in_datasets
        self.__unreplicate_data()

        # delete all datasets for removal
        for data in finalise['remove']:
            del self.index["out_data"][data.data_info.get('name')]

        # Add remaining output datasets to input datasets
        for name, data in self.index['out_data'].iteritems():
            data.get_preview().set_preview([])
            self.index["in_data"][name] = copy.deepcopy(data)
        self.index['out_data'] = {}

    def __unreplicate_data(self):
        in_data_list = self.index['in_data']
        from savu.data.data_structures.data_types.replicate import Replicate
        for in_data in in_data_list.values():
            if isinstance(in_data.data, Replicate):
                in_data.data = in_data.data.reset()

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            if 'itr_clone' not in key:
                data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD):
        comm_dict = {'comm': communicator}
        if self.meta_data.get('mpi') is True:
            logging.debug("About to hit a _barrier %s", comm_dict)
            comm_dict['comm'].barrier()
            logging.debug("Past the _barrier")

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
示例#2
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.shape_transfer = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1] * (len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None] * len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1
        self.shape_transfer = tuple(shape)

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.shape_transfer

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self):
        """ Set the slice dimensions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dims']
        self.meta_data.set('slice_dims', slice_dirs)

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
            label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(), )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0], )
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _get_max_frames_parameters(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]

        frames = np.prod([shape[d] for d in sdir])
        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        f_per_p = np.ceil(frames / n_procs)
        params_dict = {
            'shape': shape,
            'sdir': sdir,
            'total_frames': frames,
            'mpi_procs': n_procs,
            'frames_per_process': f_per_p
        }
        return params_dict

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        self.params = self._get_max_frames_parameters()
        warn_threshold = 0.85
        nprocs = self.params['mpi_procs']
        nframes = self.params['total_frames']
        temp = (((nframes / mft) / float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            logging.warn(
                'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                'sdir %s, nprocs %s', self.params['shape'], nframes,
                self.params['sdir'], nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in self.pad_dict.keys():
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self, pattern, nFrames, split=None):
        """ Setup the PluginData object.

        # add more information into here via a decorator!
        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from\
            'single', 'multiple', 'fixed_multiple' or an integer (an integer \
            should only ever be passed in exceptional circumstances)
        """
        self.__set_pattern(pattern)
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames

        self.__set_max_frames(nFrames)
        mft = self.meta_data.get('max_frames_transfer')
        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()
        self.split = split

    def __set_max_frames(self, nFrames):
        self.max_frames = nFrames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, mft_shape = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('max_frames_transfer', mft)
        if mft:
            self._set_shape_transfer(mft_shape)
        mfp = td._calc_max_frames_process(nFrames)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)

        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and nFrames == 'multiple':
            self._set_no_squeeze()

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not isinstance(nFrames, int) and nFrames not in options:
            e_str = "The value of nFrames is not recognised.  Please choose "
            "from 'single' and 'multiple' (or an integer in exceptional "
            "circumstances)."
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """

    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__set_system_params()
        self.checkpoint = Checkpointing(self)
        self.__meta_data_setup(options["process_file"])
        self.experiment_collection = {}
        self.index = {"in_data": {}, "out_data": {}}
        self.initial_datasets = None
        self.plugin = None
        self._transport = None
        self._barrier_count = 0

    def get(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()
        try:
            rtype = self.meta_data.get('run_type')
            if rtype is 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            template = self.meta_data.get('template')
            self.meta_data.plugin_list._populate_plugin_list(process_file,
                                                             template=template)

    def create_data_object(self, dtype, name, override=True):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        if name not in self.index[dtype].keys() or override:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            data_obj._set_transport_data(self.meta_data.get('transport'))
        return self.index[dtype][name]

    def _experiment_setup(self, transport):
        """ Setup an experiment collection.
        """
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list
        plist = plugin_list.plugin_list
        self.__set_transport(transport)
        # load the loader plugins
        self._set_loaders()
        # load the saver plugin and save the plugin list
        self.experiment_collection = {'plugin_dict': [],
                                      'datasets': []}
        self._barrier()
        self._check_checkpoint()
        self._barrier()
        checkpoint = self.meta_data.get('checkpoint')
        if self.meta_data.get('process') == \
                len(self.meta_data.get('processes'))-1 and not checkpoint:
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
            # links the input data to the nexus file
            self._add_input_data_to_nxs_file(transport)
        # Barrier 13
        self._barrier()

        n_plugins = plugin_list._get_n_processing_plugins()
        count = 0
        # first run through of the plugin setup methods
        for plugin_dict in plist[n_loaders:n_loaders+n_plugins]:
            data = self.__plugin_setup(plugin_dict, count)
            self.experiment_collection['datasets'].append(data)
            self.experiment_collection['plugin_dict'].append(plugin_dict)
            self._merge_out_data_to_in()
            count += 1
        self._reset_datasets()

    def __set_transport(self, transport):
        self._transport = transport

    def _get_transport(self):
        return self._transport

    def __set_system_params(self):
        sys_file = self.meta_data.get('system_params')
        import sys
        if sys_file is None:
            # look in conda environment to see which version is being used
            savu_path = sys.modules['savu'].__path__[0]
            sys_files = os.path.join(
                    os.path.dirname(savu_path), 'system_files')
            subdirs = os.listdir(sys_files)
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
            fname = 'system_parameters.yml'
            sys_file = os.path.join(sys_files, sys_folder, fname)
        logging.info('Using the system parameters file: %s', sys_file)
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))

    def _check_checkpoint(self):
        # if checkpointing has been set but the nxs file doesn't contain an
        # entry then remove checkpointing (as the previous run didn't get far
        # enough to require it).
        if self.meta_data.get('checkpoint'):
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
                if 'entry' not in f:
                    self.meta_data.set('checkpoint', None)

    def _set_loaders(self):
        n_loaders = self.meta_data.plugin_list._get_n_loaders()
        plugin_list = self.meta_data.plugin_list.plugin_list
        for i in range(n_loaders):
            pu.plugin_loader(self, plugin_list[i])
        self.initial_datasets = copy.deepcopy(self.index['in_data'])

    def _add_input_data_to_nxs_file(self, transport):
        # save the loaded data to file
        h5 = Hdf5Utils(self)
        for name, data in self.index['in_data'].iteritems():
            self.meta_data.set(['link_type', name], 'input_data')
            self.meta_data.set(['group_name', name], name)
            self.meta_data.set(['filename', name], data.backing_file)
            transport._populate_nexus_file(data)
            h5._link_datafile_to_nexus_file(data)

    def _reset_datasets(self):
        self.index['in_data'] = self.initial_datasets

    def __plugin_setup(self, plugin_dict, count):
        """ Determine plugin specific information.
        """
        plugin_id = plugin_dict["id"]
        logging.debug("Loading plugin %s", plugin_id)
        # Run main_setup method
        plugin = pu.plugin_loader(self, plugin_dict)
        plugin._revert_preview(plugin.get_in_datasets())
        # Populate the metadata
        plugin._clean_up()
        data = self.index['out_data'].copy()
        return data

    def _get_experiment_collection(self):
        return self.experiment_collection

    def _set_experiment_for_current_plugin(self, count):
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
        exp_coll = self._get_experiment_collection()
        self.index['out_data'] = exp_coll['datasets'][count]
        if datasets_list:
            self._get_current_and_next_patterns(datasets_list)
        self.meta_data.set('nPlugin', count)

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = {}
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list[current_name] = \
                {'current': current_pattern, 'next': next_pattern}
        self.meta_data.set('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def _set_nxs_filename(self):
        folder = self.meta_data.get('out_path')
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set('nxs_filename', filename)

        if self.meta_data.get('process') == 1:
            if self.meta_data.get('bllog'):
                log_folder_name = self.meta_data.get('bllog')
                log_folder = open(log_folder_name, 'a')
                log_folder.write(os.path.abspath(filename) + '\n')
                log_folder.close()

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].iteritems():
            if data.remove is False:
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _finalise_experiment_for_current_plugin(self):
        finalise = {}
        # populate nexus file with out_dataset information and determine which
        # datasets to remove from the framework.
        finalise['remove'] = []
        finalise['keep'] = []

        for key, data in self.index['out_data'].iteritems():
            if data.remove is True:
                finalise['remove'].append(data)
            else:
                finalise['keep'].append(data)

        # find in datasets to replace
        finalise['replace'] = []
        for out_name in self.index['out_data'].keys():
            if out_name in self.index['in_data'].keys():
                finalise['replace'].append(self.index['in_data'][out_name])

        return finalise

    def _reorganise_datasets(self, finalise):
        # unreplicate replicated in_datasets
        self.__unreplicate_data()

        # delete all datasets for removal
        for data in finalise['remove']:
            del self.index["out_data"][data.data_info.get('name')]

        # Add remaining output datasets to input datasets
        for name, data in self.index['out_data'].iteritems():
            data.get_preview().set_preview([])
            self.index["in_data"][name] = copy.deepcopy(data)
        self.index['out_data'] = {}

    def __unreplicate_data(self):
        in_data_list = self.index['in_data']
        from savu.data.data_structures.data_types.replicate import Replicate
        for in_data in in_data_list.values():
            if isinstance(in_data.data, Replicate):
                in_data.data = in_data.data.reset()

    def _set_all_datasets(self, name):
        data_names = []
        for key in self.index["in_data"].keys():
            if 'itr_clone' not in key:
                data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
        comm_dict = {'comm': communicator}
        if self.meta_data.get('mpi') is True:
            logging.debug("Barrier %d: %d processes expected: %s",
                          self._barrier_count, communicator.size, msg)
            comm_dict['comm'].barrier()
        self._barrier_count += 1

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].iteritems():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
示例#4
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """
    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None
        self._increase_rank = 0

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name, first_sdim=None):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions(first_sdim=first_sdim)

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0], )))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1] * (len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None] * len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1
        return tuple(shape)

    def __get_slice_size(self, mft):
        """ Calculate the number of frames transfer in each dimension given
            mft. """
        dshape = list(self.data_obj.get_shape())

        if 'fixed_dimensions' in list(self.meta_data.get_dictionary().keys()):
            fixed_dims = self.meta_data.get('fixed_dimensions')
            for d in fixed_dims:
                dshape[d] = 1

        dshape = [dshape[i] for i in self.meta_data.get('slice_dims')]
        size_list = [1] * len(dshape)
        i = 0

        while (mft > 1 and i < len(size_list)):
            size_list[i] = min(dshape[i], mft)
            mft //= np.prod(size_list) if np.prod(size_list) > 1 else 1
            i += 1

        self.meta_data.set('size_list', size_list)
        return size_list

    def set_bytes_per_frame(self):
        """ Return the size of a single frame in bytes. """
        nBytes = self.data_obj.get_itemsize()
        dims = list(self.get_pattern().values())[0]['core_dims']
        frame_shape = [self.data_obj.get_shape()[d] for d in dims]
        b_per_f = np.prod(frame_shape) * nBytes
        return frame_shape, b_per_f

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.meta_data.get('transfer_shape')

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir) + len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self, first_sdim=None):
        """ Set the slice dimensions in the pluginData meta data dictionary.\
        Reorder pattern slice_dims to ensure first_sdim is at the front.
        """
        pattern = self.data_obj.get_data_patterns()[self.get_pattern_name()]
        slice_dims = pattern['slice_dims']

        if first_sdim:
            slice_dims = list(slice_dims)
            first_sdim = \
                self.data_obj.get_data_dimension_by_axis_label(first_sdim)
            slice_dims.insert(0, slice_dims.pop(slice_dims.index(first_sdim)))
            pattern['slice_dims'] = tuple(slice_dims)

        self.meta_data.set('slice_dims', tuple(slice_dims))

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir, ))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
            label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(), )
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):  # should this function be deleted?
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0], )
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        #self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _set_rank_inc(self, n):
        """ Increase the rank of the array passed to the plugin by n.
        
        :param int n: Rank increment.
        """
        self._increase_rank = n

    def _get_rank_inc(self):
        """ Return the increased rank value
        
        :returns: Rank increment
        :rtype: int
        """
        return self._increase_rank

    def _set_meta_data(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]

        if 'fix_total_frames' in list(self.meta_data.get_dictionary().keys()):
            frames = self.meta_data.get('fix_total_frames')
        else:
            frames = np.prod([shape[d] for d in sdir])

        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        # Fixing f_per_p to be just the first slice dimension for now due to
        # slow performance from HDF5 when not slicing multiple dimensions
        # concurrently
        #f_per_p = np.ceil(frames/n_procs)
        f_per_p = np.ceil(shape[sdir[0]] / n_procs)
        self.meta_data.set('shape', shape)
        self.meta_data.set('sdir', sdir)
        self.meta_data.set('total_frames', frames)
        self.meta_data.set('mpi_procs', n_procs)
        self.meta_data.set('frames_per_process', f_per_p)
        frame_shape, b_per_f = self.set_bytes_per_frame()
        self.meta_data.set('bytes_per_frame', b_per_f)
        self.meta_data.set('bytes_per_process', b_per_f * f_per_p)
        self.meta_data.set('frame_shape', frame_shape)

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        warn_threshold = 0.85
        nprocs = self.meta_data.get('mpi_procs')
        nframes = self.meta_data.get('total_frames')
        temp = (((nframes / mft) / float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            shape = self.meta_data.get('shape')
            sdir = self.meta_data.get('sdir')
            logging.warning(
                'UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                'sdir %s, nprocs %s', shape, nframes, sdir, nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in list(self.pad_dict.keys()):
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self,
                          pattern,
                          nFrames,
                          split=None,
                          slice_axis=None,
                          getall=None):
        """ Setup the PluginData object.

        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from
            'single', 'multiple', 'fixed_multiple' or an integer (an integer
            should only ever be passed in exceptional circumstances)
        :keyword str slice_axis: An axis label associated with the fastest
            changing (first) slice dimension.
        :keyword list[pattern, axis_label] getall: A list of two values.  If
        the requested pattern doesn't exist then use all of "axis_label"
        dimension of "pattern" as this is equivalent to one slice of the
        original pattern.
        """

        if pattern not in self.data_obj.get_data_patterns() and getall:
            pattern, nFrames = self.__set_getall_pattern(getall, nFrames)

        # slice_axis is first slice dimension
        self.__set_pattern(pattern, first_sdim=slice_axis)
        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames
        self.max_frames = nFrames
        self.split = split

    def __set_getall_pattern(self, getall, nFrames):
        """ Set framework changes required to get all of a pattern of lower
        rank.
        """
        pattern, slice_axis = getall
        dim = self.data_obj.get_data_dimension_by_axis_label(slice_axis)
        # ensure data remains the same shape when 'getall' dim has length 1
        self._set_no_squeeze()
        if nFrames == 'multiple' or (isinstance(nFrames, int) and nFrames > 1):
            self._set_rank_inc(1)
        nFrames = self.data_obj.get_shape()[dim]
        return pattern, nFrames

    def plugin_data_transfer_setup(self, copy=None, calc=None):
        """ Set up the plugin data transfer frame parameters.
        If copy=pData (another PluginData instance) then copy """
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if not copy and not calc:
            mft, mft_shape, mfp = self._calculate_max_frames()
        elif calc:
            max_mft = calc.meta_data.get('max_frames_transfer')
            max_mfp = calc.meta_data.get('max_frames_process')
            max_nProc = int(np.ceil(max_mft / float(max_mfp)))
            nProc = max_nProc
            mfp = 1 if self.max_frames == 'single' else self.max_frames
            mft = nProc * mfp
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
        elif copy:
            mft = copy._get_max_frames_transfer()
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
            mfp = copy._get_max_frames_process()

        self.__set_max_frames(mft, mft_shape, mfp)

        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()

    def _calculate_max_frames(self):
        nFrames = self.max_frames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, size_list = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('size_list', size_list)
        mfp = td._calc_max_frames_process(nFrames)
        if mft:
            mft_shape = self._set_shape_transfer(list(size_list))
        return mft, mft_shape, mfp

    def __set_max_frames(self, mft, mft_shape, mfp):
        self.meta_data.set('max_frames_transfer', mft)
        self.meta_data.set('transfer_shape', mft_shape)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)
        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and self.max_frames == 'multiple':
            self._set_no_squeeze()

    def _get_plugin_data_size_params(self):
        nBytes = self.data_obj.get_itemsize()
        frame_shape = self.meta_data.get('frame_shape')
        total_frames = self.meta_data.get('total_frames')
        tbytes = nBytes * np.prod(frame_shape) * total_frames

        params = {
            'nBytes': nBytes,
            'frame_shape': frame_shape,
            'total_frames': total_frames,
            'transfer_bytes': tbytes
        }
        return params

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not np.issubdtype(type(nFrames),
                             np.int64) and nFrames not in options:
            e_str = (
                "The value of nFrames is not recognised.  Please choose " +
                "from 'single' and 'multiple' (or an integer in exceptional " +
                "circumstances).")
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit

    def get_current_frame_idx(self):
        """ Returns the index of the frames currently being processed.
        """
        global_index = self._plugin.get_global_frame_index()
        count = self._plugin.get_process_frames_counter()
        mfp = self.meta_data.get('max_frames_process')
        start = global_index[count] * mfp
        index = np.arange(start, start + mfp)
        nFrames = self.get_total_frames()
        index[index >= nFrames] = nFrames - 1
        return index
示例#5
0
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """

    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.raw = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None
        self.previous_pattern = None
        self.transport_data = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set('name', name)
        self.data_info.set('data_patterns', {})
        self.data_info.set('shape', None)
        self.data_info.set('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _set_transport_data(self, transport):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        transport_data = cu.import_class(transport_data)
        self.transport_data = transport_data(self)
        self.data_info.set('transport', transport)

    def _get_transport_data(self):
        return self.transport_data

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get('data_patterns')

    def _set_previous_pattern(self, pattern):
        self.previous_pattern = pattern

    def get_previous_pattern(self):
        return self.previous_pattern

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        """ Set the original data shape before previewing
        """
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get("nDims")
        shape = self.data_info.get('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def _set_name(self, name):
        self.data_info.set('name', name)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set(['data_patterns', dtype, args],
                                   kwargs[args])

            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._set_data_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get("nDims"):
                    actualDims = self.data_info.get('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified."
                               % (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set('nDims', nDims)
        else:
            raise Exception("The data pattern '%s'does not exist. Please "
                            "choose from the following list: \n'%s'",
                            dtype, str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dims'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dims'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set('nDims', len(args))
        axis_labels = []
        for arg in args:
            if isinstance(arg, dict):
                axis_labels.append(arg)
            else:
                try:
                    axis = arg.split('.')
                    axis_labels.append({axis[0]: axis[1]})
                except:
                    # data arrives here, but that may be an error
                    pass
        self.data_info.set('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get('axis_labels')

    def get_data_dimension_by_axis_label(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self._non_negative_directions(ddirs, nDims)

    def _non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dims']
        d2 = patterns[pname]['slice_dims']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def amend_axis_label_values(self, slice_list):
        """ Amend all axis label values based on the slice_list parameter.\
        This is required if the data is reduced.
        """
        axis_labels = self.get_axis_labels()
        for i in range(len(slice_list)):
            label = axis_labels[i].keys()[0]
            if label in self.meta_data.get_dictionary().keys():
                values = self.meta_data.get(label)
                preview_sl = [slice(None)]*len(values.shape)
                preview_sl[0] = slice_list[i]
                self.meta_data.set(label, values[preview_sl])

    def get_core_dimensions(self):
        """ Get the core data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``core_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['core_dims']

    def get_slice_dimensions(self):
        """ Get the slice data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``slice_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['slice_dims']

    def get_itemsize(self):
        """ Returns bytes per entry """
        dtype = self.get_dtype()
        if not dtype:
            self.set_dtype(None)
            dtype = self.get_dtype()
        return self.get_dtype().itemsize
示例#6
0
class Experiment(object):
    """
    One instance of this class is created at the beginning of the
    processing chain and remains until the end.  It holds the current data
    object and a dictionary containing all metadata.
    """
    def __init__(self, options):
        self.meta_data = MetaData(options)
        self.__set_system_params()
        self.checkpoint = Checkpointing(self)
        self.__meta_data_setup(options["process_file"])
        self.collection = {}
        self.index = {"in_data": {}, "out_data": {}}
        self.initial_datasets = None
        self.plugin = None
        self._transport = None
        self._barrier_count = 0
        self._dataset_names_complete = False

    def get(self, entry):
        """ Get the meta data dictionary. """
        return self.meta_data.get(entry)

    def __meta_data_setup(self, process_file):
        self.meta_data.plugin_list = PluginList()
        try:
            rtype = self.meta_data.get('run_type')
            if rtype == 'test':
                self.meta_data.plugin_list.plugin_list = \
                    self.meta_data.get('plugin_list')
            else:
                raise Exception('the run_type is unknown in Experiment class')
        except KeyError:
            template = self.meta_data.get('template')
            self.meta_data.plugin_list._populate_plugin_list(process_file,
                                                             template=template)
        self.meta_data.set("nPlugin", 0)  # initialise

    def create_data_object(self, dtype, name, override=True):
        """ Create a data object.

        Plugin developers should apply this method in loaders only.

        :params str dtype: either "in_data" or "out_data".
        """
        if name not in list(self.index[dtype].keys()) or override:
            self.index[dtype][name] = Data(name, self)
            data_obj = self.index[dtype][name]
            data_obj._set_transport_data(self.meta_data.get('transport'))
        return self.index[dtype][name]

    def _setup(self, transport):
        self._set_nxs_file()
        self._set_transport(transport)
        self.collection = {'plugin_dict': [], 'datasets': []}

        self._barrier()
        self._check_checkpoint()
        self._barrier()

    def _finalise_setup(self, plugin_list):
        checkpoint = self.meta_data.get('checkpoint')
        self._set_dataset_names_complete()
        # save the plugin list - one process, first time only
        if self.meta_data.get('process') == \
                len(self.meta_data.get('processes'))-1 and not checkpoint:
            # links the input data to the nexus file
            plugin_list._save_plugin_list(self.meta_data.get('nxs_filename'))
            self._add_input_data_to_nxs_file(self._get_transport())
        self._set_dataset_names_complete()

    def _set_initial_datasets(self):
        self.initial_datasets = copy.deepcopy(self.index['in_data'])

    def _update(self, plugin_dict):
        data = self.index['out_data'].copy()
        # clear output metadata after first setup
        for d in list(data.values()):
            d.meta_data._set_dictionary({})
        self.collection['datasets'].append(data)
        self.collection['plugin_dict'].append(plugin_dict)

    def _set_transport(self, transport):
        self._transport = transport

    def _get_transport(self):
        return self._transport

    def __set_system_params(self):
        sys_file = self.meta_data.get('system_params')
        import sys
        if sys_file is None:
            # look in conda environment to see which version is being used
            savu_path = sys.modules['savu'].__path__[0]
            sys_files = os.path.join(os.path.dirname(savu_path),
                                     'system_files')
            subdirs = os.listdir(sys_files)
            sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0]
            fname = 'system_parameters.yml'
            sys_file = os.path.join(sys_files, sys_folder, fname)
        logging.info('Using the system parameters file: %s', sys_file)
        self.meta_data.set('system_params', yaml.read_yaml(sys_file))

    def _check_checkpoint(self):
        # if checkpointing has been set but the nxs file doesn't contain an
        # entry then remove checkpointing (as the previous run didn't get far
        # enough to require it).
        if self.meta_data.get('checkpoint'):
            with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f:
                if 'entry' not in f:
                    self.meta_data.set('checkpoint', None)

    def _add_input_data_to_nxs_file(self, transport):
        # save the loaded data to file
        h5 = Hdf5Utils(self)
        for name, data in self.index['in_data'].items():
            self.meta_data.set(['link_type', name], 'input_data')
            self.meta_data.set(['group_name', name], name)
            self.meta_data.set(['filename', name], data.backing_file)
            transport._populate_nexus_file(data)
            h5._link_datafile_to_nexus_file(data)

    def _set_dataset_names_complete(self):
        """ Missing in/out_datasets fields have been populated
        """
        self._dataset_names_complete = True

    def _get_dataset_names_complete(self):
        return self._dataset_names_complete

    def _reset_datasets(self):
        self.index['in_data'] = self.initial_datasets

    def _get_collection(self):
        return self.collection

    def _set_experiment_for_current_plugin(self, count):
        datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:]
        exp_coll = self._get_collection()
        self.index['out_data'] = exp_coll['datasets'][count]
        if datasets_list:
            self._get_current_and_next_patterns(datasets_list)
        self.meta_data.set('nPlugin', count)

    def _get_current_and_next_patterns(self, datasets_lists):
        """ Get the current and next patterns associated with a dataset
        throughout the processing chain.
        """
        current_datasets = datasets_lists[0]
        patterns_list = {}
        for current_data in current_datasets['out_datasets']:
            current_name = current_data['name']
            current_pattern = current_data['pattern']
            next_pattern = self.__find_next_pattern(datasets_lists[1:],
                                                    current_name)
            patterns_list[current_name] = \
                {'current': current_pattern, 'next': next_pattern}
        self.meta_data.set('current_and_next', patterns_list)

    def __find_next_pattern(self, datasets_lists, current_name):
        next_pattern = []
        for next_data_list in datasets_lists:
            for next_data in next_data_list['in_datasets']:
                if next_data['name'] == current_name:
                    next_pattern = next_data['pattern']
                    return next_pattern
        return next_pattern

    def _set_nxs_file(self):
        folder = self.meta_data.get('out_path')
        fname = self.meta_data.get('datafile_name') + '_processed.nxs'
        filename = os.path.join(folder, fname)
        self.meta_data.set('nxs_filename', filename)

        if self.meta_data.get('process') == 1:
            if self.meta_data.get('bllog'):
                log_folder_name = self.meta_data.get('bllog')
                log_folder = open(log_folder_name, 'a')
                log_folder.write(os.path.abspath(filename) + '\n')
                log_folder.close()

        self._create_nxs_entry()

    def _create_nxs_entry(self):  # what if the file already exists?!
        logging.debug("Testing nexus file")
        import h5py
        if self.meta_data.get('process') == len(
                self.meta_data.get('processes')) - 1 and not self.checkpoint:
            with h5py.File(self.meta_data.get('nxs_filename'),
                           'w') as nxs_file:
                entry_group = nxs_file.create_group('entry')
                entry_group.attrs['NX_class'] = 'NXentry'

    def _clear_data_objects(self):
        self.index["out_data"] = {}
        self.index["in_data"] = {}

    def _merge_out_data_to_in(self):
        for key, data in self.index["out_data"].items():
            if data.remove is False:
                self.index['in_data'][key] = data
        self.index["out_data"] = {}

    def _finalise_experiment_for_current_plugin(self):
        finalise = {'remove': [], 'keep': []}
        # populate nexus file with out_dataset information and determine which
        # datasets to remove from the framework.

        for key, data in self.index['out_data'].items():
            if data.remove is True:
                finalise['remove'].append(data)
            else:
                finalise['keep'].append(data)

        # find in datasets to replace
        finalise['replace'] = []
        for out_name in list(self.index['out_data'].keys()):
            if out_name in list(self.index['in_data'].keys()):
                finalise['replace'].append(self.index['in_data'][out_name])

        return finalise

    def _reorganise_datasets(self, finalise):
        # unreplicate replicated in_datasets
        self.__unreplicate_data()

        # delete all datasets for removal
        for data in finalise['remove']:
            del self.index["out_data"][data.data_info.get('name')]

        # Add remaining output datasets to input datasets
        for name, data in self.index['out_data'].items():
            data.get_preview().set_preview([])
            self.index["in_data"][name] = copy.deepcopy(data)
        self.index['out_data'] = {}

    def __unreplicate_data(self):
        in_data_list = self.index['in_data']
        from savu.data.data_structures.data_types.replicate import Replicate
        for in_data in list(in_data_list.values()):
            if isinstance(in_data.data, Replicate):
                in_data.data = in_data.data._reset()

    def _set_all_datasets(self, name):
        data_names = []
        for key in list(self.index["in_data"].keys()):
            if 'itr_clone' not in key:
                data_names.append(key)
        return data_names

    def _barrier(self, communicator=MPI.COMM_WORLD, msg=''):
        comm_dict = {'comm': communicator}
        if self.meta_data.get('mpi') is True:
            logging.debug("Barrier %d: %d processes expected: %s",
                          self._barrier_count, communicator.size, msg)
            comm_dict['comm'].barrier()
        self._barrier_count += 1

    def log(self, log_tag, log_level=logging.DEBUG):
        """
        Log the contents of the experiment at the specified level
        """
        logging.log(log_level, "Experimental Parameters for %s", log_tag)
        for key, value in self.index["in_data"].items():
            logging.log(log_level, "in data (%s) shape = %s", key,
                        value.get_shape())
        for key, value in self.index["in_data"].items():
            logging.log(log_level, "out data (%s) shape = %s", key,
                        value.get_shape())
示例#7
0
文件: data.py 项目: mattronix/Savu
class Data(DataCreate):
    """The Data class dynamically inherits from transport specific data class
    and holds the data array, along with associated information.
    """
    def __init__(self, name, exp):
        super(Data, self).__init__(name)
        self.meta_data = MetaData()
        self.pattern_list = self.__get_available_pattern_list()
        self.data_info = MetaData()
        self.__initialise_data_info(name)
        self._preview = Preview(self)
        self.exp = exp
        self.group_name = None
        self.group = None
        self._plugin_data_obj = None
        self.raw = None
        self.backing_file = None
        self.data = None
        self.next_shape = None
        self.orig_shape = None
        self.previous_pattern = None
        self.transport_data = None

    def __initialise_data_info(self, name):
        """ Initialise entries in the data_info meta data.
        """
        self.data_info.set('name', name)
        self.data_info.set('data_patterns', {})
        self.data_info.set('shape', None)
        self.data_info.set('nDims', None)

    def _set_plugin_data(self, plugin_data_obj):
        """ Encapsulate a PluginData object.
        """
        self._plugin_data_obj = plugin_data_obj

    def _clear_plugin_data(self):
        """ Set encapsulated PluginData object to None.
        """
        self._plugin_data_obj = None

    def _get_plugin_data(self):
        """ Get encapsulated PluginData object.
        """
        if self._plugin_data_obj is not None:
            return self._plugin_data_obj
        else:
            raise Exception("There is no PluginData object associated with "
                            "the Data object.")

    def get_preview(self):
        """ Get the Preview instance associated with the data object
        """
        return self._preview

    def _set_transport_data(self, transport):
        """ Import the data transport mechanism

        :returns: instance of data transport
        :rtype: transport_data
        """
        transport_data = "savu.data.transport_data." + transport + \
                         "_transport_data"
        transport_data = cu.import_class(transport_data)
        self.transport_data = transport_data(self)

    def _get_transport_data(self):
        return self.transport_data

    def __deepcopy__(self, memo):
        """ Copy the data object.
        """
        name = self.data_info.get('name')
        return dsu._deepcopy_data_object(self, Data(name, self.exp))

    def get_data_patterns(self):
        """ Get data patterns associated with this data object.

        :returns: A dictionary of associated patterns.
        :rtype: dict
        """
        return self.data_info.get('data_patterns')

    def _set_previous_pattern(self, pattern):
        self.previous_pattern = pattern

    def get_previous_pattern(self):
        return self.previous_pattern

    def set_shape(self, shape):
        """ Set the dataset shape.
        """
        self.data_info.set('shape', shape)
        self.__check_dims()

    def set_original_shape(self, shape):
        """ Set the original data shape before previewing
        """
        self.orig_shape = shape
        self.set_shape(shape)

    def get_shape(self):
        """ Get the dataset shape

        :returns: data shape
        :rtype: tuple
        """
        shape = self.data_info.get('shape')
        return shape

    def __check_dims(self):
        """ Check the ``shape`` and ``nDims`` entries in the data_info
        meta_data dictionary are equal.
        """
        nDims = self.data_info.get("nDims")
        shape = self.data_info.get('shape')
        if nDims:
            if len(shape) != nDims:
                error_msg = ("The number of axis labels, %d, does not "
                             "coincide with the number of data "
                             "dimensions %d." % (nDims, len(shape)))
                raise Exception(error_msg)

    def _set_name(self, name):
        self.data_info.set('name', name)

    def get_name(self):
        """ Get data name.

        :returns: the name associated with the dataset
        :rtype: str
        """
        return self.data_info.get('name')

    def __get_available_pattern_list(self):
        """ Get a list of ALL pattern names that are currently allowed in the
        framework.
        """
        pattern_list = dsu.get_available_pattern_types()
        return pattern_list

    def add_pattern(self, dtype, **kwargs):
        """ Add a pattern.

        :params str dtype: The *type* of pattern to add, which can be anything
            from the :const:`savu.data.data_structures.utils.pattern_list`
            :const:`pattern_list`
            :data:`savu.data.data_structures.utils.pattern_list`
            :data:`pattern_list`:
        :keyword tuple core_dir: Dimension indices of core dimensions
        :keyword tuple slice_dir: Dimension indices of slice dimensions
        """
        if dtype in self.pattern_list:
            nDims = 0
            for args in kwargs:
                nDims += len(kwargs[args])
                self.data_info.set(['data_patterns', dtype, args],
                                   kwargs[args])

            self.__convert_pattern_directions(dtype)
            if self.get_shape():
                diff = len(self.get_shape()) - nDims
                if diff:
                    pattern = {dtype: self.get_data_patterns()[dtype]}
                    self._set_data_patterns(pattern)
                    nDims += diff
            try:
                if nDims != self.data_info.get("nDims"):
                    actualDims = self.data_info.get('nDims')
                    err_msg = ("The pattern %s has an incorrect number of "
                               "dimensions: %d required but %d specified." %
                               (dtype, actualDims, nDims))
                    raise Exception(err_msg)
            except KeyError:
                self.data_info.set('nDims', nDims)
        else:
            raise Exception(
                "The data pattern '%s'does not exist. Please "
                "choose from the following list: \n'%s'", dtype,
                str(self.pattern_list))

    def add_volume_patterns(self, x, y, z):
        """ Adds 3D volume patterns

        :params int x: dimension to be associated with x-axis
        :params int y: dimension to be associated with y-axis
        :params int z: dimension to be associated with z-axis
        """
        self.add_pattern("VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x))
        self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y))
        self.add_pattern("VOLUME_XY", **self.__get_dirs_for_volume(x, y, z))

    def __get_dirs_for_volume(self, dim1, dim2, sdir):
        """ Calculate core_dir and slice_dir for a 3D volume pattern.
        """
        all_dims = range(len(self.get_shape()))
        vol_dict = {}
        vol_dict['core_dims'] = (dim1, dim2)
        slice_dir = [sdir]
        # *** need to add this for other patterns
        for ddir in all_dims:
            if ddir not in [dim1, dim2, sdir]:
                slice_dir.append(ddir)
        vol_dict['slice_dims'] = tuple(slice_dir)
        return vol_dict

    def set_axis_labels(self, *args):
        """ Set the axis labels associated with each data dimension.

        :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\
        a data_obj.meta_data entry, it will be output to the final .nxs file.
        """
        self.data_info.set('nDims', len(args))
        axis_labels = []
        for arg in args:
            if isinstance(arg, dict):
                axis_labels.append(arg)
            else:
                try:
                    axis = arg.split('.')
                    axis_labels.append({axis[0]: axis[1]})
                except:
                    # data arrives here, but that may be an error
                    pass
        self.data_info.set('axis_labels', axis_labels)

    def get_axis_labels(self):
        """ Get axis labels.

        :returns: Axis labels
        :rtype: list(dict)
        """
        return self.data_info.get('axis_labels')

    def get_data_dimension_by_axis_label(self, name, contains=False):
        """ Get the dimension of the data associated with a particular
        axis_label.

        :param str name: The name of the axis_label
        :keyword bool contains: Set this flag to true if the name is only part
            of the axis_label name
        :returns: The associated axis number
        :rtype: int
        """
        axis_labels = self.data_info.get('axis_labels')
        for i in range(len(axis_labels)):
            if contains is True:
                for names in axis_labels[i].keys():
                    if name in names:
                        return i
            else:
                if name in axis_labels[i].keys():
                    return i
        raise Exception("Cannot find the specifed axis label.")

    def _finalise_patterns(self):
        """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON
        patterns.
        """
        check = 0
        check += self.__check_pattern('SINOGRAM')
        check += self.__check_pattern('PROJECTION')

        if check is 2 and len(self.get_shape()) > 2:
            self.__set_main_axis('SINOGRAM')
            self.__set_main_axis('PROJECTION')
        elif check is 1:
            pass

    def __check_pattern(self, pattern_name):
        """ Check if a pattern exists.
        """
        patterns = self.get_data_patterns()
        try:
            patterns[pattern_name]
        except KeyError:
            return 0
        return 1

    def __convert_pattern_directions(self, dtype):
        """ Replace negative indices in pattern kwargs.
        """
        pattern = self.get_data_patterns()[dtype]
        if 'main_dir' in pattern.keys():
            del pattern['main_dir']

        nDims = sum([len(i) for i in pattern.values()])
        for p in pattern:
            ddirs = pattern[p]
            pattern[p] = self._non_negative_directions(ddirs, nDims)

    def _non_negative_directions(self, ddirs, nDims):
        """ Replace negative indexing values with positive counterparts.

        :params tuple(int) ddirs: data dimension indices
        :params int nDims: The number of data dimensions
        :returns: non-negative data dimension indices
        :rtype: tuple(int)
        """
        index = [i for i in range(len(ddirs)) if ddirs[i] < 0]
        list_ddirs = list(ddirs)
        for i in index:
            list_ddirs[i] = nDims + ddirs[i]
        return tuple(list_ddirs)

    def __set_main_axis(self, pname):
        """ Set the ``main_dir`` pattern kwarg to the fastest changing
        dimension
        """
        patterns = self.get_data_patterns()
        n1 = 'PROJECTION' if pname is 'SINOGRAM' else 'SINOGRAM'
        d1 = patterns[n1]['core_dims']
        d2 = patterns[pname]['slice_dims']
        tdir = set(d1).intersection(set(d2))

        # this is required when a single sinogram exists in the mm case, and a
        # dimension is added via parameter tuning.
        if not tdir:
            tdir = [d2[0]]

        self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0])

    def get_axis_label_keys(self):
        """ Get axis_label names

        :returns: A list containing associated axis names for each dimension
        :rtype: list(str)
        """
        axis_labels = self.data_info.get('axis_labels')
        axis_label_keys = []
        for labels in axis_labels:
            for key in labels.keys():
                axis_label_keys.append(key)
        return axis_label_keys

    def amend_axis_label_values(self, slice_list):
        """ Amend all axis label values based on the slice_list parameter.\
        This is required if the data is reduced.
        """
        axis_labels = self.get_axis_labels()
        for i in range(len(slice_list)):
            label = axis_labels[i].keys()[0]
            if label in self.meta_data.get_dictionary().keys():
                values = self.meta_data.get(label)
                preview_sl = [slice(None)] * len(values.shape)
                preview_sl[0] = slice_list[i]
                self.meta_data.set(label, values[preview_sl])

    def get_core_dimensions(self):
        """ Get the core data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``core_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['core_dims']

    def get_slice_dimensions(self):
        """ Get the slice data dimensions associated with the current pattern.

        :returns: value associated with pattern key ``slice_dims``
        :rtype: tuple
        """
        return self._get_plugin_data().get_pattern().values()[0]['slice_dims']

    def _add_raw_data_obj(self, data_obj):
        from savu.data.data_structures.data_types.data_plus_darks_and_flats\
            import ImageKey, NoImageKey

        if not isinstance(data_obj.raw, (ImageKey, NoImageKey)):
            raise Exception('Raw data type not recognised.')

        proj_dim = self.get_data_dimension_by_axis_label('rotation_angle')
        data_obj.data = NoImageKey(data_obj, None, proj_dim)
        data_obj.data._set_dark_and_flat()

        if isinstance(data_obj.raw, ImageKey):
            data_obj.raw._convert_to_noimagekey(data_obj.data)
            data_obj.data._set_fake_key(data_obj.raw.get_image_key())
示例#8
0
class PluginData(object):
    """ The PluginData class contains plugin specific information about a Data
    object for the duration of a plugin.  An instance of the class is
    encapsulated inside the Data object during the plugin run
    """

    def __init__(self, data_obj, plugin=None):
        self.data_obj = data_obj
        self._preview = None
        self.data_obj._set_plugin_data(self)
        self.meta_data = MetaData()
        self.padding = None
        self.pad_dict = None
        self.shape = None
        self.core_shape = None
        self.multi_params = {}
        self.extra_dims = []
        self._plugin = plugin
        self.fixed_dims = True
        self.split = None
        self.boundary_padding = None
        self.no_squeeze = False
        self.pre_tuning_shape = None
        self._frame_limit = None

    def _get_preview(self):
        return self._preview

    def get_total_frames(self):
        """ Get the total number of frames to process (all MPI processes).

        :returns: Number of frames
        :rtype: int
        """
        temp = 1
        slice_dir = \
            self.data_obj.get_data_patterns()[
                self.get_pattern_name()]["slice_dims"]
        for tslice in slice_dir:
            temp *= self.data_obj.get_shape()[tslice]
        return temp

    def __set_pattern(self, name):
        """ Set the pattern related information in the meta data dict.
        """
        pattern = self.data_obj.get_data_patterns()[name]
        self.meta_data.set("name", name)
        self.meta_data.set("core_dims", pattern['core_dims'])
        self.__set_slice_dimensions()

    def get_pattern_name(self):
        """ Get the pattern name.

        :returns: the pattern name
        :rtype: str
        """
        try:
            name = self.meta_data.get("name")
            return name
        except KeyError:
            raise Exception("The pattern name has not been set.")

    def get_pattern(self):
        """ Get the current pattern.

        :returns: dict of the pattern name against the pattern.
        :rtype: dict
        """
        pattern_name = self.get_pattern_name()
        return {pattern_name: self.data_obj.get_data_patterns()[pattern_name]}

    def __set_shape(self):
        """ Set the shape of the plugin data processing chunk.
        """
        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        dirs = list(set(core_dir + (slice_dir[0],)))
        slice_idx = dirs.index(slice_dir[0])
        dshape = self.data_obj.get_shape()
        shape = []
        for core in set(core_dir):
            shape.append(dshape[core])
        self.__set_core_shape(tuple(shape))

        mfp = self._get_max_frames_process()
        if mfp > 1 or self._get_no_squeeze():
            shape.insert(slice_idx, mfp)
        self.shape = tuple(shape)

    def _set_shape_transfer(self, slice_size):
        dshape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()
        add = [1]*(len(dshape) - len(shape_before_tuning))
        slice_size = slice_size + add

        core_dir = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()
        shape = [None]*len(dshape)
        for dim in core_dir:
            shape[dim] = dshape[dim]
        i = 0
        for dim in slice_dir:
            shape[dim] = slice_size[i]
            i += 1            
        return tuple(shape)

    def __get_slice_size(self, mft):
        """ Calculate the number of frames transfer in each dimension given
            mft. """
        dshape = list(self.data_obj.get_shape())

        if 'fixed_dimensions' in self.meta_data.get_dictionary().keys():
            fixed_dims = self.meta_data.get('fixed_dimensions')
            for d in fixed_dims:
                dshape[d] = 1

        dshape = [dshape[i] for i in self.meta_data.get('slice_dims')]
        size_list = [1]*len(dshape)
        i = 0
        
        while(mft > 1):
            size_list[i] = min(dshape[i], mft)
            mft -= np.prod(size_list) if np.prod(size_list) > 1 else 0
            i += 1

        self.meta_data.set('size_list', size_list)
        return size_list

    def set_bytes_per_frame(self):
        """ Return the size of a single frame in bytes. """
        nBytes = self.data_obj.get_itemsize()
        dims = self.get_pattern().values()[0]['core_dims']
        frame_shape = [self.data_obj.get_shape()[d] for d in dims]
        b_per_f = np.prod(frame_shape)*nBytes
        return frame_shape, b_per_f

    def get_shape(self):
        """ Get the shape of the data (without padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def _set_padded_shape(self):
        pass

    def get_padded_shape(self):
        """ Get the shape of the data (with padding) that is passed to the
        plugin process_frames method.
        """
        return self.shape

    def get_shape_transfer(self):
        """ Get the shape of the plugin data to be transferred each time.
        """
        return self.meta_data.get('transfer_shape')

    def __set_core_shape(self, shape):
        """ Set the core shape to hold only the shape of the core dimensions
        """
        self.core_shape = shape

    def get_core_shape(self):
        """ Get the shape of the core dimensions only.

        :returns: shape of core dimensions
        :rtype: tuple
        """
        return self.core_shape

    def _set_shape_before_tuning(self, shape):
        """ Set the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        self.pre_tuning_shape = shape

    def _get_shape_before_tuning(self):
        """ Return the shape of the full dataset used during each run of the \
        plugin (i.e. ignore extra dimensions due to parameter tuning). """
        return self.pre_tuning_shape if self.pre_tuning_shape else\
            self.data_obj.get_shape()

    def __check_dimensions(self, indices, core_dir, slice_dir, nDims):
        if len(indices) is not len(slice_dir):
            sys.exit("Incorrect number of indices specified when accessing "
                     "data.")

        if (len(core_dir)+len(slice_dir)) is not nDims:
            sys.exit("Incorrect number of data dimensions specified.")

    def __set_slice_dimensions(self):
        """ Set the slice dimensions in the pluginData meta data dictionary.
        """
        slice_dirs = self.data_obj.get_data_patterns()[
            self.get_pattern_name()]['slice_dims']
        self.meta_data.set('slice_dims', slice_dirs)

    def get_slice_dimension(self):
        """
        Return the position of the slice dimension in relation to the data
        handed to the plugin.
        """
        core_dirs = self.data_obj.get_core_dimensions()
        slice_dir = self.data_obj.get_slice_dimensions()[0]
        return list(set(core_dirs + (slice_dir,))).index(slice_dir)

    def get_data_dimension_by_axis_label(self, label, contains=False):
        """
        Return the dimension of the data in the plugin that has the specified
        axis label.
        """
        label_dim = self.data_obj.get_data_dimension_by_axis_label(
                label, contains=contains)
        plugin_dims = self.data_obj.get_core_dimensions()
        if self._get_max_frames_process() > 1 or self.max_frames == 'multiple':
            plugin_dims += (self.get_slice_dimension(),)
        return list(set(plugin_dims)).index(label_dim)

    def set_slicing_order(self, order):
        """
        Reorder the slice dimensions.  The fastest changing slice dimension
        will always be the first one stated in the pattern key ``slice_dir``.
        The input param is a tuple stating the desired order of slicing
        dimensions relative to the current order.
        """
        slice_dirs = self.get_slice_directions()
        if len(slice_dirs) < len(order):
            raise Exception("Incorrect number of dimensions specifed.")
        ordered = [slice_dirs[o] for o in order]
        remaining = [s for s in slice_dirs if s not in ordered]
        new_slice_dirs = tuple(ordered + remaining)
        self.get_current_pattern()['slice_dir'] = new_slice_dirs

    def get_core_dimensions(self):
        """
        Return the position of the core dimensions in relation to the data
        handed to the plugin.
        """
        core_dims = self.data_obj.get_core_dimensions()
        first_slice_dim = (self.data_obj.get_slice_dimensions()[0],)
        plugin_dims = np.sort(core_dims + first_slice_dim)
        return np.searchsorted(plugin_dims, np.sort(core_dims))

    def set_fixed_dimensions(self, dims, values):
        """ Fix a data direction to the index in values list.

        :param list(int) dims: Directions to fix
        :param list(int) value: Index of fixed directions
        """
        slice_dirs = self.data_obj.get_slice_dimensions()
        if set(dims).difference(set(slice_dirs)):
            raise Exception("You are trying to fix a direction that is not"
                            " a slicing direction")
        self.meta_data.set("fixed_dimensions", dims)
        self.meta_data.set("fixed_dimensions_values", values)
        self.__set_slice_dimensions()
        shape = list(self.data_obj.get_shape())
        for dim in dims:
            shape[dim] = 1
        self.data_obj.set_shape(tuple(shape))
        #self.__set_shape()

    def _get_fixed_dimensions(self):
        """ Get the fixed data directions and their indices

        :returns: Fixed directions and their associated values
        :rtype: list(list(int), list(int))
        """
        fixed = []
        values = []
        if 'fixed_dimensions' in self.meta_data.get_dictionary():
            fixed = self.meta_data.get("fixed_dimensions")
            values = self.meta_data.get("fixed_dimensions_values")
        return [fixed, values]

    def _get_data_slice_list(self, plist):
        """ Convert a plugin data slice list to a slice list for the whole
        dataset, i.e. add in any missing dimensions.
        """
        nDims = len(self.get_shape())
        all_dims = self.get_core_dimensions() + self.get_slice_dimension()
        extra_dims = all_dims[nDims:]
        dlist = list(plist)
        for i in extra_dims:
            dlist.insert(i, slice(None))
        return tuple(dlist)

    def _get_max_frames_process(self):
        """ Get the number of frames to process for each run of process_frames.

        If the number of frames is not divisible by the previewing ``chunk``
        value then amend the number of frames to gcd(frames, chunk)

        :returns: Number of frames to process
        :rtype: int
        """
        if self._plugin and self._plugin.chunk > 1:
            frame_chunk = self.meta_data.get("max_frames_process")
            chunk = self.data_obj.get_preview().get_starts_stops_steps(
                key='chunks')[self.get_slice_directions()[0]]
            self.meta_data.set('max_frames_process', gcd(frame_chunk, chunk))
        return self.meta_data.get("max_frames_process")

    def _get_max_frames_transfer(self):
        """ Get the number of frames to transfer for each run of
        process_frames. """
        return self.meta_data.get('max_frames_transfer')

    def _set_no_squeeze(self):
        self.no_squeeze = True

    def _get_no_squeeze(self):
        return self.no_squeeze

    def _set_meta_data(self):
        fixed, _ = self._get_fixed_dimensions()
        sdir = \
            [s for s in self.data_obj.get_slice_dimensions() if s not in fixed]
        shape = self.data_obj.get_shape()
        shape_before_tuning = self._get_shape_before_tuning()

        diff = len(shape) - len(shape_before_tuning)
        if diff:
            shape = shape_before_tuning
            sdir = sdir[:-diff]
        
        if 'fix_total_frames' in self.meta_data.get_dictionary().keys():
            frames = self.meta_data.get('fix_total_frames')
        else:
            frames = np.prod([shape[d] for d in sdir])

        base_names = [p.__name__ for p in self._plugin.__class__.__bases__]
        processes = self.data_obj.exp.meta_data.get('processes')

        if 'GpuPlugin' in base_names:
            n_procs = len([n for n in processes if 'GPU' in n])
        else:
            n_procs = len(processes)

        f_per_p = np.ceil(frames/n_procs)
        self.meta_data.set('shape', shape)
        self.meta_data.set('sdir', sdir)
        self.meta_data.set('total_frames', frames)
        self.meta_data.set('mpi_procs', n_procs)
        self.meta_data.set('frames_per_process', f_per_p)
        frame_shape, b_per_f = self.set_bytes_per_frame()
        self.meta_data.set('bytes_per_frame', b_per_f)
        self.meta_data.set('bytes_per_process', b_per_f*f_per_p)
        self.meta_data.set('frame_shape', frame_shape)

    def __log_max_frames(self, mft, mfp, check=True):
        logging.debug("Setting max frames transfer for plugin %s to %d" %
                      (self._plugin, mft))
        logging.debug("Setting max frames process for plugin %s to %d" %
                      (self._plugin, mfp))
        self.meta_data.set('max_frames_process', mfp)
        if check:
            self.__check_distribution(mft)
        # (((total_frames/mft)/mpi_procs) % 1)

    def __check_distribution(self, mft):
        warn_threshold = 0.85
        nprocs = self.meta_data.get('mpi_procs')
        nframes = self.meta_data.get('total_frames')
        temp = (((nframes/mft)/float(nprocs)) % 1)
        if temp != 0.0 and temp < warn_threshold:
            shape = self.meta_data.get('shape')
            sdir = self.meta_data.get('sdir')
            logging.warn('UNEVEN FRAME DISTRIBUTION: shape %s, nframes %s ' +
                         'sdir %s, nprocs %s', shape, nframes, sdir, nprocs)

    def _set_padding_dict(self):
        if self.padding and not isinstance(self.padding, Padding):
            self.pad_dict = copy.deepcopy(self.padding)
            self.padding = Padding(self)
            for key in self.pad_dict.keys():
                getattr(self.padding, key)(self.pad_dict[key])

    def plugin_data_setup(self, pattern, nFrames, split=None):
        """ Setup the PluginData object.

        :param str pattern: A pattern name
        :param int nFrames: How many frames to process at a time.  Choose from\
            'single', 'multiple', 'fixed_multiple' or an integer (an integer \
            should only ever be passed in exceptional circumstances)
        """
        self.__set_pattern(pattern)
        if isinstance(nFrames, list):
            nFrames, self._frame_limit = nFrames
        self.max_frames = nFrames
        self.split = split

    def plugin_data_transfer_setup(self, copy=None, calc=None):
        """ Set up the plugin data transfer frame parameters.
        If copy=pData (another PluginData instance) then copy """
        chunks = \
            self.data_obj.get_preview().get_starts_stops_steps(key='chunks')

        if not copy and not calc:
            mft, mft_shape, mfp = self._calculate_max_frames()
        elif calc:
            max_mft = calc.meta_data.get('max_frames_transfer')             
            max_mfp = calc.meta_data.get('max_frames_process')
            max_nProc = int(np.ceil(max_mft/float(max_mfp)))
            nProc = max_nProc
            mfp = 1 if self.max_frames == 'single' else self.max_frames
            mft = nProc*mfp
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
        elif copy:
            mft = copy._get_max_frames_transfer()
            mft_shape = self._set_shape_transfer(self.__get_slice_size(mft))
            mfp = copy._get_max_frames_process()

        self.__set_max_frames(mft, mft_shape, mfp)

        if self._plugin and mft \
                and (chunks[self.data_obj.get_slice_dimensions()[0]] % mft):
            self._plugin.chunk = True
        self.__set_shape()

    def _calculate_max_frames(self):
        nFrames = self.max_frames
        self.__perform_checks(nFrames)
        td = self.data_obj._get_transport_data()
        mft, size_list = td._calc_max_frames_transfer(nFrames)
        self.meta_data.set('size_list', size_list)
        mfp = td._calc_max_frames_process(nFrames)
        if mft:
            mft_shape = self._set_shape_transfer(list(size_list))
        return mft, mft_shape, mfp

    def __set_max_frames(self, mft, mft_shape, mfp):
        self.meta_data.set('max_frames_transfer', mft)
        self.meta_data.set('transfer_shape', mft_shape)
        self.meta_data.set('max_frames_process', mfp)
        self.__log_max_frames(mft, mfp)
        # Retain the shape if the first slice dimension has length 1
        if mfp == 1 and self.max_frames == 'multiple':
            self._set_no_squeeze()

    def _get_plugin_data_size_params(self):
        nBytes = self.data_obj.get_itemsize()
        frame_shape = self.meta_data.get('frame_shape')
        total_frames = self.meta_data.get('total_frames')
        tbytes = nBytes*np.prod(frame_shape)*total_frames
        
        params = {'nBytes': nBytes, 'frame_shape': frame_shape,
                  'total_frames': total_frames, 'transfer_bytes': tbytes}
        return params

    def __perform_checks(self, nFrames):
        options = ['single', 'multiple']
        if not isinstance(nFrames, int) and nFrames not in options:
            e_str = "The value of nFrames is not recognised.  Please choose "
            "from 'single' and 'multiple' (or an integer in exceptional "
            "circumstances)."
            raise Exception(e_str)

    def get_frame_limit(self):
        return self._frame_limit

    def get_current_frame_idx(self):
        """ Returns the index of the frames currently being processed.
        """
        global_index = self._plugin.get_global_frame_index()
        count = self._plugin.get_process_frames_counter()
        mfp = self.meta_data.get('max_frames_process')
        start = global_index[count]*mfp
        index = np.arange(start, start + mfp)
        nFrames = self.get_total_frames()
        index[index >= nFrames] = nFrames - 1
        return index