def make_windowed_adapt_kernel(*, kind, proposal_kernel_kwargs, dual_averaging_kwargs, initial_running_variance, chain_axis_names, shard_axis_names): """Constructs a windowed adaptation kernel.""" kernel = WindowedAdaptation( make_slow_adapt_kernel( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, initial_running_variance=initial_running_variance), num_adaptation_steps=dual_averaging_kwargs['num_adaptation_steps']) if chain_axis_names: kernel = sharded.Sharded(kernel, chain_axis_names) if shard_axis_names: kernel = kernel.experimental_with_shard_axes(shard_axis_names) return kernel
def run(seed): kernel = sharded.Sharded(RandomWalk(), self.axis_name) state = tf.convert_to_tensor(0.) kr = kernel.bootstrap_results(state) state, _ = kernel.one_step(state, kr, seed=seed) return state