def generate(self, ds): """Yield dataset splits. Parameters ---------- ds: Dataset Input dataset Returns ------- generator The generator yields every possible split according to the splitter configuration. All generated dataset have a boolean 'lastsplit' attribute in their dataset attribute collection indicating whether this particular dataset is the last one. """ # localbinding noslicing = self.__noslicing count = self.__count splattr = self.get_space() ignore = self.__splitattr_ignore # get attribute and source collection from dataset splattr, collection = ds.get_attr(splattr) splattr_data = splattr.value cfgs = self.__splitattr_values if cfgs is None: cfgs = splattr.unique if __debug__: debug('SPL', 'Determined %i split specifications' % len(cfgs)) if not ignore is None: # remove to be ignored bits cfgs = [c for c in cfgs if not c in ignore] if __debug__: debug('SPL', '%i split specifications left after removing ignored ones' % len(cfgs)) n_cfgs = len(cfgs) if self.__reverse: if __debug__: debug('SPL', 'Reversing split order') cfgs = cfgs[::-1] # split the data for isplit, split in enumerate(cfgs): if not count is None and isplit >= count: # number of max splits is reached if __debug__: debug('SPL', 'Discard remaining splits as maximum of %i is reached' % count) break # safeguard against 'split' being `None` -- in which case a single # boolean would be the result of the comparision below, and not # a boolean vector from element-wise comparision if split is None: split = [None] # boolean mask is 'selected' samples for this split filter_ = splattr_data == split if not noslicing: # check whether we can do slicing instead of advanced # indexing -- if we can split the dataset without causing # the data to be copied, its is quicker and leaner. # However, it only works if we have a contiguous chunk or # regular step sizes for the samples to be split filter_ = mask2slice(filter_) if collection is ds.sa: if __debug__: debug('SPL', 'Split along samples axis') split_ds = ds[filter_] elif collection is ds.fa: if __debug__: debug('SPL', 'Split along feature axis') split_ds = ds[:, filter_] else: RuntimeError("This should never happen.") # is this the last split if count is None: lastsplit = (isplit == n_cfgs - 1) else: lastsplit = (isplit == count - 1) if not split_ds.a.has_key('lastsplit'): # if not yet known -- add one split_ds.a['lastsplit'] = lastsplit else: # otherwise just assign a new value split_ds.a.lastsplit = lastsplit yield split_ds