def split_args_in_chunks(args, n_chunks): """Split the args passed to ``make_transitions`` in similar batches.""" n_values = len(args[0]) chunk_size = int(numpy.ceil(n_values / n_chunks)) for start, end in similiar_chunks_indexes(n_values, n_chunks): if start + chunk_size >= n_values - 2: yield tuple(v[start:n_values] for v in args) break yield tuple(v[start:end] for v in args)
def split_kwargs_in_chunks(kwargs, n_chunks): """Split the kwargs passed to ``make_transitions`` in similar batches.""" n_values = len(next(iter( kwargs.values()))) # Assumes all data have the same len chunk_size = int(numpy.ceil(n_values / n_chunks)) for start, end in similiar_chunks_indexes(n_values, n_chunks): if start + chunk_size >= n_values - 2: # Do not allow the last chunk to have size 1 yield { k: v[start:n_values] if isinstance(v, numpy.ndarray) else v for k, v in kwargs.items() } break else: yield { k: v[start:end] if isinstance(v, numpy.ndarray) else v for k, v in kwargs.items() }
def split_states(self, n_chunks: int) -> Generator["States", None, None]: """ Return a generator for n_chunks different states, where each one \ contain only the data corresponding to one walker. """ def get_chunck_size(state, start, end): for name in state._names: attr = state[name] if isinstance(attr, numpy.ndarray): return len(attr[start:end]) return int(numpy.ceil(self.n / n_chunks)) for start, end in similiar_chunks_indexes(self.n, n_chunks): chunk_size = get_chunck_size(self, start, end) data = { k: val[start:end] if isinstance(val, numpy.ndarray) else val for k, val in self.items() } new_state = self.__class__(batch_size=chunk_size, **data) yield new_state