def _shuffle_kwargs(rng: np.random.Generator, kwargs: dict) -> dict: # We must shuffle all the lists, and lists of the same size must have the same shuffling. # This way entangled lists of (shard, shard_metadata) are still in the right order. # First, let's generate the shuffled indices per list size list_sizes = set( len(value) for value in kwargs.values() if isinstance(value, list)) indices_per_size = {} for size in list_sizes: indices_per_size[size] = list(range(size)) rng.shuffle(indices_per_size[size]) # Now let's copy the kwargs and shuffle the lists based on their sizes shuffled_kwargs = dict(kwargs) for key, value in shuffled_kwargs.items(): if isinstance(value, list): shuffled_kwargs[key] = [ value[i] for i in indices_per_size[len(value)] ] return shuffled_kwargs
def random_layer_iterator(self, rng: np.random.Generator) -> Iterator[int]: layer_ins = list(self.layer_map.keys()) rng.shuffle(layer_ins) return cast(Iterator[int], layer_ins)