コード例 #1
0
class DistArrayTransport(TransportControl):

    def __init__(self):
        self.targets = None
        self.context = None
        self.n_processes = None
        self.count = None
        self.history = []

    def _transport_initialise(self, options):
        # self.exp is not available here
        MPI_setup(options)  # change this?
        with closing(Context()) as context:
            self.targets = context.targets  # set mpi logging here?

    def _transport_pre_plugin_list_run(self):
        self.n_processes = \
            self.exp.meta_data.plugin_list._get_n_processing_plugins()
        self.context = Context(targets=self.targets)
        closing(self.context).__enter__()

    def _transport_pre_plugin(self):
        # store all datasets and associated patterns
        self.__update_history(self.exp.index)
        self.__distribute_arrays(self.exp.index)

    def _transport_post_plugin(self):
        # if you wish to output datasets that have been removed from the index
        # then do that here (data.remove is True)
        pass

    def _transport_post_plugin_list_run(self):
        # convert distarrays to hdf5
        for data in self.exp.index['in_data'].values():
            name = data.get_name()
            fname = self.exp.meta_data.get('filename')[name]
            gname = self.exp.meta_data.get('group_name')[name]
            data.data.context.save_hdf5(fname, data.data, gname, mode='w')
            self.exp._get_experiment_collection()['saver_plugin']\
                ._open_read_only(data, fname, gname)
        closing(self.context).__exit__()

    def __update_history(self, data_index):
        for dtype, data_dict in data_index.iteritems():
            for name, dobj in data_dict.iteritems():
                pattern = dobj._get_plugin_data().get_pattern()
                self.history.append({name: pattern})

    def __distribute_arrays(self, data_index):
        if not self.history:
            self.__load_data_from_hdf5(data_index['in_data'])  # expand this later for other types (or first set should always be treated as hdf5 dataset?)
            # - i.e. get data as before directly from file and output to distributed array
        else:
            self.__redistribute_data(data_index['in_data'])
        self.__create_out_data(data_index['out_data'])

    def __redistribute_data(self, data_list):
        """ Calculate the pattern distributions and if they are not the same\
        redistribute.
        """
        for data in data_list.values():
            patterns = self.__get_distribution_history(data.get_name())

        if patterns[0] != patterns[1]:
            temp = data.data.toarray()
            # *** temporarily creating ndarray
            # distarray (create empty dist array and populate?)
            distribution = \
                Distribution(self.context, data.get_shape(), patterns[-1])  # currently redundant
            data.data = self.context.fromarray(temp, patterns[-1])

    def __load_data_from_hdf5(self, data_list):
        ''' Create a distarray from the specified section of the HDF5 file. '''

        for data in data_list:
            input_file = data.backing_file.filename
            dist = self.__calculate_distribution(
                data._get_plugin_data().get_pattern())
            distribution = \
                Distribution(self.context, data.get_shape(), dist=dist)
            data.data = self.context.load_hdf5(
                input_file, distribution=distribution, key=data.name)

    def __create_out_data(self, out_data):
        for data in out_data.values():
            dist = self.__calculate_distribution(
                data._get_plugin_data().get_pattern())
            dist = Distribution(self.context, data.get_shape(), dist)
            data.data = self.context.zeros(dist, dtype=np.int32)

    def __get_distribution_history(self, name):
        hist = [i for i in range(len(self.history)) if
                self.history[i].keys()[0] == name][-2:]
        return [self.__calculate_distribution(
            self.history[p].values()[0]) for p in hist]

    def __calculate_distribution(self, pattern):
        core_dirs = pattern.values()[0]['core_dir']
        slice_dirs = pattern.values()[0]['slice_dir']
        nDims = len(core_dirs + slice_dirs)
        dist = ['n']*nDims
        for sl in slice_dirs:
            dist[sl] = 'b'
        return ''.join(dist)

    def _transport_process(self, plugin):
        #self.distributed_process(self.process, plugin)
        print self.testing
        pickler.dump(self)
        self.distributed_process()

    def distributed_process(self, kernel):
        self.context.register(kernel)
        iters_key = \
            self.context.apply(self.local_process, (), {'kernel': kernel})
        return iters_key

    def local_process(frames, output, params, kernel):
        from distarray.localapi import LocalArray
        recon = kernel(frames, output, params)
        res = LocalArray(output.distribution, buf=recon)
        return proxyize(res)  # noqa

    def testing(self):
        print "running the testing function"