Пример #1
0
def make_windowed_adapt_kernel(*, kind, proposal_kernel_kwargs,
                               dual_averaging_kwargs, initial_running_variance,
                               chain_axis_names, shard_axis_names):
    """Constructs a windowed adaptation kernel."""
    kernel = WindowedAdaptation(
        make_slow_adapt_kernel(
            kind=kind,
            proposal_kernel_kwargs=proposal_kernel_kwargs,
            dual_averaging_kwargs=dual_averaging_kwargs,
            initial_running_variance=initial_running_variance),
        num_adaptation_steps=dual_averaging_kwargs['num_adaptation_steps'])
    if chain_axis_names:
        kernel = sharded.Sharded(kernel, chain_axis_names)
    if shard_axis_names:
        kernel = kernel.experimental_with_shard_axes(shard_axis_names)
    return kernel
Пример #2
0
 def run(seed):
     kernel = sharded.Sharded(RandomWalk(), self.axis_name)
     state = tf.convert_to_tensor(0.)
     kr = kernel.bootstrap_results(state)
     state, _ = kernel.one_step(state, kr, seed=seed)
     return state