Esempio n. 1
0
def SanitizedAutoCorrelationMean(x,
                                 axis,
                                 reduce_axis,
                                 max_lags=None,
                                 **kwargs):
    shape_arr = np.array(list(x.shape))
    axes = list(sorted(set(range(len(shape_arr))) - set([reduce_axis])))
    mean_shape = shape_arr[axes]
    if max_lags is not None:
        mean_shape[axis] = max_lags + 1
    mean_state = fun_mc.running_mean_init(mean_shape, x.dtype)
    new_order = list(range(len(shape_arr)))
    new_order[0] = new_order[reduce_axis]
    new_order[reduce_axis] = 0
    x = tf.transpose(x, new_order)
    x_arr = tf.TensorArray(x.dtype, x.shape[0]).unstack(x)
    mean_state, _ = fun_mc.trace(
        state=mean_state,
        fn=lambda state: fun_mc.running_mean_step(  # pylint: disable=g-long-lambda
            state,
            SanitizedAutoCorrelation(x_arr.read(state.num_points),
                                     axis,
                                     max_lags=max_lags,
                                     **kwargs)),
        num_steps=x.shape[0],
        trace_fn=lambda *_: ())
    return mean_state.mean
    def testBasic(self):
        def fun(x):
            return x + 1., 2 * x

        x, _ = fun_mc.trace(state=0., fn=fun, num_steps=5)

        self.assertIsInstance(x, tf.Tensor)
Esempio n. 3
0
def train_p(q, u, x_pos, step_size, opt_p):
    """Train P using the standard CD objective.

  Args:
    q: `ModelQ`.
    u: A callable representing the energy function.
    x_pos: A batch of positive examples.
    step_size: Step size to use for HMC.
    opt_p: A `tf.optimizer.Optimizer`.

  Returns:
    x_neg_q: Negative samples sampled from `q`.
    x_neg_p: Negative samples used to train `p`, possibly generated via HMC.
    p_accept: Acceptance rate of HMC.
    step_size: The new step size, possibly adapted to adjust the acceptance
      rate.
    pos_e: Mean energy of the positive samples across the batch.
    pos_e: Mean energy of the positive samples across the batch, after the
      parameter update.
    neg_e_q: Mean energy of `x_neg_q` across the batch.
    neg_e_p: Mean energy of `x_neg_p` across the batch.
    neg_e_p_updated: Mean energy of `x_neg_p` across the batch, after the
      parameter update.
  """
    def create_momentum_sample_fn(state):
        sample_fn = lambda seed: tf.random.normal(  # pylint: disable=g-long-lambda
            tf.shape(state),
            stddev=FLAGS.mcmc_momentum_stddev)
        return sample_fn

    _, x_neg_q, _ = q.sample_with_log_prob(FLAGS.batch_size,
                                           temp=FLAGS.q_temperature)
    neg_e_q = tf.reduce_mean(u(x_neg_q))

    def p_log_prob(x):
        return -u(x)

    if FLAGS.use_mcmc:

        def log_prob_non_transformed(x):
            p_log_p = p_log_prob(x)

            return p_log_p, (x, )

        # TODO(siege): Why aren't we actually using NeuTra?
        # def log_prob_transformed(z):
        #   x, logdet = q.reverse(z)
        #   p_log_p = p_log_prob(x)

        #   return p_log_p + logdet, (x,)

        def kernel(hmc_state, step_size, step):
            """HMC kernel."""
            hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step(
                hmc_state,
                step_size=step_size,
                num_integrator_steps=FLAGS.mcmc_leapfrog_steps,
                momentum_sample_fn=create_momentum_sample_fn(hmc_state.state),
                target_log_prob_fn=log_prob_non_transformed)

            mean_p_accept = tf.reduce_mean(
                tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

            if FLAGS.mcmc_adapt_step_size:
                step_size = fun_mc.sign_adaptation(step_size,
                                                   output=mean_p_accept,
                                                   set_point=0.9)

            return (hmc_state, step_size, step + 1), hmc_extra

        hmc_state, is_accepted = fun_mc.trace(
            state=(fun_mc.hamiltonian_monte_carlo_init(
                x_neg_q, log_prob_non_transformed), step_size, 0),
            fn=kernel,
            num_steps=FLAGS.mcmc_num_steps,
            trace_fn=lambda _, hmc_extra: hmc_extra.is_accepted)

        x_neg_p = hmc_state[0].state_extra[0]
        step_size = hmc_state[1]

        p_accept = tf.reduce_mean(tf.cast(is_accepted, tf.float32))
    else:
        x_neg_p = x_neg_q
        p_accept = 0.0
        step_size = 0.0

    with tf.GradientTape() as tape:
        tape.watch(u.trainable_variables)
        pos_e = tf.reduce_mean(u(x_pos))
        neg_e_p = tf.reduce_mean(u(x_neg_p))
        loss = pos_e - neg_e_p + tf.square(pos_e) * FLAGS.p_center_regularizer

    variables = u.trainable_variables
    grads = tape.gradient(loss, variables)
    grads_and_vars = list(zip(grads, variables))
    opt_p.apply_gradients(grads_and_vars)

    pos_e_updated = tf.reduce_mean(u(x_pos))
    neg_e_p_updated = tf.reduce_mean(u(x_neg_p))

    return (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated,
            neg_e_q, neg_e_p, neg_e_p_updated)