Beispiel #1
0
def split_args_in_chunks(args, n_chunks):
    """Split the args passed to ``make_transitions`` in similar batches."""
    n_values = len(args[0])
    chunk_size = int(numpy.ceil(n_values / n_chunks))
    for start, end in similiar_chunks_indexes(n_values, n_chunks):
        if start + chunk_size >= n_values - 2:
            yield tuple(v[start:n_values] for v in args)
            break
        yield tuple(v[start:end] for v in args)
Beispiel #2
0
def split_kwargs_in_chunks(kwargs, n_chunks):
    """Split the kwargs passed to ``make_transitions`` in similar batches."""
    n_values = len(next(iter(
        kwargs.values())))  # Assumes all data have the same len
    chunk_size = int(numpy.ceil(n_values / n_chunks))
    for start, end in similiar_chunks_indexes(n_values, n_chunks):
        if start + chunk_size >= n_values - 2:  # Do not allow the last chunk to have size 1
            yield {
                k: v[start:n_values] if isinstance(v, numpy.ndarray) else v
                for k, v in kwargs.items()
            }
            break
        else:
            yield {
                k: v[start:end] if isinstance(v, numpy.ndarray) else v
                for k, v in kwargs.items()
            }
Beispiel #3
0
    def split_states(self, n_chunks: int) -> Generator["States", None, None]:
        """
        Return a generator for n_chunks different states, where each one \
        contain only the data corresponding to one walker.
        """
        def get_chunck_size(state, start, end):
            for name in state._names:
                attr = state[name]
                if isinstance(attr, numpy.ndarray):
                    return len(attr[start:end])
            return int(numpy.ceil(self.n / n_chunks))

        for start, end in similiar_chunks_indexes(self.n, n_chunks):
            chunk_size = get_chunck_size(self, start, end)

            data = {
                k: val[start:end] if isinstance(val, numpy.ndarray) else val
                for k, val in self.items()
            }
            new_state = self.__class__(batch_size=chunk_size, **data)
            yield new_state