Ejemplo n.º 1
0
def SanitizedAutoCorrelationMean(x,
                                 axis,
                                 reduce_axis,
                                 max_lags=None,
                                 **kwargs):
    shape_arr = np.array(list(x.shape))
    axes = list(sorted(set(range(len(shape_arr))) - set([reduce_axis])))
    mean_shape = shape_arr[axes]
    if max_lags is not None:
        mean_shape[axis] = max_lags + 1
    mean_state = fun_mcmc.running_mean_init(mean_shape, x.dtype)
    new_order = list(range(len(shape_arr)))
    new_order[0] = new_order[reduce_axis]
    new_order[reduce_axis] = 0
    x = tf.transpose(x, new_order)
    x_arr = tf.TensorArray(x.dtype, x.shape[0]).unstack(x)
    mean_state, _ = fun_mcmc.trace(
        state=mean_state,
        fn=lambda state: fun_mcmc.running_mean_step(  # pylint: disable=g-long-lambda
            state,
            SanitizedAutoCorrelation(x_arr.read(state.num_points),
                                     axis,
                                     max_lags=max_lags,
                                     **kwargs)),
        num_steps=x.shape[0],
        trace_fn=lambda *_: ())
    return mean_state.mean
Ejemplo n.º 2
0
 def kernel(rms, idx):
   rms, _ = fun_mcmc.running_mean_step(
       rms, data[idx], window_size=window_size)
   return (rms, idx + 1), rms.mean
Ejemplo n.º 3
0
 def kernel(rms, idx):
   rms, _ = fun_mcmc.running_mean_step(rms, data[idx], axis=aggregation)
   return (rms, idx + 1), ()
Ejemplo n.º 4
0
 def kernel(sda_state, rms_state):
   sda_state, _ = fun_mcmc.simple_dual_averages_step(sda_state, loss_fn, 1.)
   rms_state, _ = fun_mcmc.running_mean_step(rms_state, sda_state.state)
   return (sda_state, rms_state), rms_state.mean