def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_sample') seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32)) return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_independent_sample') seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32)) return super(ShardedIndependent, self)._sample_n(n, seed, **kwargs)
def fold_in(seed, axes): for name in axes: axis_index = get_axis_index(name) seed = samplers.fold_in(seed, tf.cast(axis_index, tf.int32)) return seed
def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_sample') for axis_name in self.experimental_shard_axis_names: axis_index = distribute_lib.get_axis_index(axis_name) seed = samplers.fold_in(seed, tf.cast(axis_index, tf.int32)) return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)