Exemple #1
0
def get_mean_field_elbo(model, target, num_mc_samples, model_args,
                        model_obs_kwargs, vi_kwargs):
  if FLAGS.reparameterise_variational and 'cVIP' in FLAGS.method:
    combined_kwargs = model_obs_kwargs.copy()
    combined_kwargs.update(vi_kwargs)
    variational_model, variational_parameters = make_variational_model_special(
        model, *model_args, **combined_kwargs)
  else:
    variational_model, variational_parameters = program_transformations.make_variational_model(
        model, *model_args, **model_obs_kwargs)

  log_joint_q = make_log_joint_fn(variational_model)

  def target_q(**parameters):
    return log_joint_q(*model_args, **parameters)

  #beta = tf.get_variable("beta", trainable=False, initializer=0.)
  #beta_incr = tf.assign(beta, tf.clip_by_value(beta + 0.1*beta + 0.0000001, 0., 1.))

  #with tf.control_dependencies([beta_incr]):

  def loop_body(mc_sample):
    with tape() as variational_tape:
      _ = variational_model(*model_args)

      params = variational_tape.values()

      energy = target(*params)
      entropy = tf.negative(target_q(**variational_tape))
      return energy + entropy

  elbo = tf.reduce_sum(pfor(loop_body, num_mc_samples)) / num_mc_samples
  tf.summary.scalar('elbo', elbo)
  return elbo, variational_parameters
            def network_builder(x):
                """Wraps the function 'network' to compute per-example."""
                def loop_fn(i):
                    x_i = tf.expand_dims(tf.gather(x, i), 0)
                    features = network(x_i)
                    jac = pfor.jacobian(features, params, use_pfor=use_pfor)
                    return features, jac

                if use_pfor:
                    features, jac = pfor.pfor(loop_fn, x.shape[0])
                else:
                    loop_fn_dtypes = [tf.float32, [tf.float32] * len(params)]
                    features, jac = pfor.for_loop(loop_fn, loop_fn_dtypes,
                                                  data.shape[0])
                    raise NotImplementedError(
                        'use_pfor=False + per_example=True is not yet working.'
                    )
                features = _collapse_first_dim(features)
                features.set_shape(network(x).shape)
                jac = [_collapse_first_dim(y) for y in jac]
                for p, j in zip(params, jac):
                    j.set_shape(features.shape.as_list() + p.shape.as_list())
                # Note: setting rank=2 so that we use matmul for covariance below
                # instead of batch_matmul.
                return features, jac
  def iid_sample_fn(*args, **kwargs):
    """Draws iid samples from `fn`."""

    with tf.name_scope('iid_sample_fn'):

      seed = kwargs.pop('seed', None)
      if samplers.is_stateful_seed(seed):
        kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
        def pfor_loop_body(_):
          with tf.name_scope('iid_sample_fn_stateful_body'):
            return sample_fn(*args, **kwargs)
      else:
        # If a stateless seed arg is passed, split it into `n` different
        # stateless seeds, so that we don't just get a bunch of copies of the
        # same sample.
        if not JAX_MODE:
          warnings.warn(
              'Saw Tensor seed {}, implying stateless sampling. Autovectorized '
              'functions that use stateless sampling may be quite slow because '
              'the current implementation falls back to an explicit loop. This '
              'will be fixed in the future. For now, you will likely see '
              'better performance from stateful sampling, which you can invoke '
              'by passing a Python `int` seed.'.format(seed))
        seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
        def pfor_loop_body(i):
          with tf.name_scope('iid_sample_fn_stateless_body'):
            return sample_fn(*args, seed=tf.gather(seed, i), **kwargs)

      draws = parallel_for.pfor(pfor_loop_body, n)
      return tf.nest.map_structure(unflatten, draws, expand_composites=True)
Exemple #4
0
    def iid_sample_fn(*args, **kwargs):
        """Draws iid samples from `fn`."""

        pfor_loop_body = lambda _: sample_fn(*args, **kwargs)

        seed = kwargs.pop('seed', None)
        try:  # Assume that `seed` is a valid stateful seed (Python `int`).
            kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
            pfor_loop_body = lambda _: sample_fn(*args, **kwargs)
        except TypeError as e:
            # If a stateless seed arg is passed, split it into `n` different stateless
            # seeds, so that we don't just get a bunch of copies of the same sample.
            if TENSOR_SEED_MSG_PREFIX not in str(e):
                raise
            warnings.warn(
                'Saw non-`int` seed {}, implying stateless sampling. '
                'Autovectorized functions that use stateless sampling '
                'may be quite slow because the current implementation '
                'falls back to an explicit loop. This will be fixed in the '
                'future. For now, you will likely see better performance '
                'from stateful sampling, which you can invoke by passing a'
                'traditional Python `int` seed.'.format(seed))
            seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
            pfor_loop_body = (
                lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs))

        draws = parallel_for.pfor(pfor_loop_body, n)
        return tf.nest.map_structure(unflatten, draws, expand_composites=True)
def transform_mcmc_states(states, transform_fn, num_chains=1):
    """Apply a joint transformation to each of a set of MCMC samples.

	Args:
		states: list of `Tensors`, such as returned from `tfp.mcmc.sample_chain`,
			where the `i`th element has shape `concat([[num_results], rv_shapes[i]])`, 
			or shape `concat([[num_results], [num_chains], rv_shapes[i]])` if num_chains > 1.
		transform_fn: callable that takes as argument a single state of the chain,
			i.e., a list of `Tensors` where the `i`th element has shape `rv_shapes[i]`
			representing a single rv value, and returns a transformed state, i.e., a
			list of `Tensors` where the `i`th element has shape
			`transformed_rv_shapes[i]`.

	Returns:
		transformed_states: list of `Tensors` representing samples from a
			transformed model, where the `i`th element has shape
			`concat([[num_results], transformed_rv_shapes[i]])`.
	"""

    num_samples = states[0].shape[0].value

    #transformed_states = zip(*[
    #		transform_fn([rv_states[sample_idx, Ellipsis]
    #									for rv_states in states])
    #		for sample_idx in range(num_samples)
    #])

    def loop_body(sample_idx):
        return transform_fn(
            [tf.gather(rv_states, sample_idx) for rv_states in states])

    return pfor(loop_body, num_samples)  # transformed_states =
  def test_sampled_scale_follows_correct_distribution(self):
    strm = test_util.test_seed_stream()
    prior = tfd.InverseGamma(concentration=0.1, scale=0.1)

    num_timesteps = 100
    observed_samples = tf.random.normal([2, num_timesteps], seed=strm()) * 3.
    is_missing = tf.random.uniform([2, num_timesteps], seed=strm()) > 0.9

    # Check that posterior variance samples have the moments of the correct
    # InverseGamma distribution.
    posterior_scale_samples = parallel_for.pfor(
        lambda i: gibbs_sampler._resample_scale(  # pylint: disable=g-long-lambda
            prior=prior,
            observed_residuals=observed_samples,
            is_missing=is_missing,
            seed=strm()), 10000)

    concentration = prior.concentration + tf.reduce_sum(
        1 - tf.cast(is_missing, tf.float32), axis=-1)/2.
    scale = prior.scale + tf.reduce_sum(
        (observed_samples * tf.cast(~is_missing, tf.float32))**2, axis=-1)/2.
    posterior_scale_samples_, concentration_, scale_ = self.evaluate(
        (posterior_scale_samples, concentration, scale))
    self.assertAllClose(np.mean(posterior_scale_samples_**2, axis=0),
                        scale_ / (concentration_ - 1), atol=0.05)
    self.assertAllClose(
        np.std(posterior_scale_samples_**2, axis=0),
        scale_ / ((concentration_ - 1) * np.sqrt(concentration_ - 2)),
        atol=0.05)
Exemple #7
0
def transform_mcmc_states(states, transform_fn):
    """Transforms all states using the provided transform function."""

    num_samples = FLAGS.num_samples
    num_chains = FLAGS.num_chains

    def loop_body(sample_idx):
        def loop_body_chain(chain_idx):
            print('\nNested pfor!\n')
            return transform_fn([
                tf.gather(tf.gather(rv_states, sample_idx), chain_idx)
                for rv_states in states
            ])

        if num_chains == 1:
            return tf.nest.map_structure(lambda x: tf.expand_dims(x, 0),
                                         loop_body_chain(0))

        return pfor(loop_body_chain, num_chains)

    if num_samples == 1:
        return tf.nest.map_structure(lambda x: tf.expand_dims(x, 0),
                                     loop_body(0))

    return pfor(loop_body, num_samples)
    def loop_body(sample_idx):
        def loop_body_chain(chain_idx):
            print('\nNested pfor!\n')
            return transform_fn([
                tf.gather(tf.gather(rv_states, sample_idx), chain_idx)
                for rv_states in states
            ])

        return pfor(loop_body_chain, num_chains)
Exemple #9
0
 def test_pfor_with_closure_multi_out(self):
   val = np.arange(7.)[:, np.newaxis]
   tf_val = tf.constant(val)
   def tf_fn(x):
     return tf.gather(tf_val, x)**2, tf.gather(tf_val, x)
   def np_fn(x):
     return nptf.gather(val, x)**2, nptf.gather(val, x)
   self.assertAllEqual(
       self.evaluate(tf_pfor.pfor(tf_fn, 7)),
       np_pfor.pfor(np_fn, 7))
  def test_sampled_weights_follow_correct_distribution(self):
    seed = test_util.test_seed(sampler_type='stateless')
    design_seed, true_weights_seed, sampled_weights_seed = samplers.split_seed(
        seed, 3, 'test_sampled_weights_follow_correct_distribution')
    num_timesteps = 10
    num_features = 2
    batch_shape = [3, 1]
    design_matrix = samplers.normal(
        batch_shape + [num_timesteps, num_features], seed=design_seed)
    true_weights = samplers.normal(
        batch_shape + [num_features, 1], seed=true_weights_seed) * 10.0
    targets = tf.matmul(design_matrix, true_weights)
    is_missing = tf.convert_to_tensor([False, False, False, True, True,
                                       False, False, True, False, False],
                                      dtype=tf.bool)
    prior_scale = tf.convert_to_tensor(5.)
    likelihood_scale = tf.convert_to_tensor(0.1)

    # Analytically compute the true posterior distribution on weights.
    valid_design_matrix = tf.boolean_mask(design_matrix, ~is_missing, axis=-2)
    valid_targets = tf.boolean_mask(targets, ~is_missing, axis=-2)
    num_valid_observations = tf.shape(valid_design_matrix)[-2]
    weights_posterior_mean, weights_posterior_cov, _ = linear_gaussian_update(
        prior_mean=tf.zeros([num_features, 1]),
        prior_cov=tf.eye(num_features) * prior_scale**2,
        observation_matrix=tfl.LinearOperatorFullMatrix(valid_design_matrix),
        observation_noise=tfd.MultivariateNormalDiag(
            loc=tf.zeros([num_valid_observations]),
            scale_diag=likelihood_scale * tf.ones([num_valid_observations])),
        x_observed=valid_targets)

    # Check that the empirical moments of sampled weights match the true values.
    sampled_weights = parallel_for.pfor(
        lambda i: gibbs_sampler._resample_weights(  # pylint: disable=g-long-lambda
            design_matrix=design_matrix,
            target_residuals=targets[..., 0],
            observation_noise_scale=likelihood_scale,
            weights_prior_scale=prior_scale,
            is_missing=is_missing,
            seed=sampled_weights_seed),
        10000)
    sampled_weights_mean = tf.reduce_mean(sampled_weights, axis=0)
    centered_weights = sampled_weights - weights_posterior_mean[..., 0]
    sampled_weights_cov = tf.reduce_mean(centered_weights[..., :, tf.newaxis] *
                                         centered_weights[..., tf.newaxis, :],
                                         axis=0)

    (sampled_weights_mean_, weights_posterior_mean_,
     sampled_weights_cov_, weights_posterior_cov_) = self.evaluate((
         sampled_weights_mean, weights_posterior_mean[..., 0],
         sampled_weights_cov, weights_posterior_cov))
    self.assertAllClose(sampled_weights_mean_, weights_posterior_mean_,
                        atol=0.01, rtol=0.05)
    self.assertAllClose(sampled_weights_cov_, weights_posterior_cov_,
                        atol=0.01, rtol=0.05)
Exemple #11
0
def vectorized_sample(model, model_args, num_samples):
    """Draw multiple joint samples from an Ed2 model."""
    def loop_body(i):  # trace the model to draw a single joint sample
        with ed.tape() as model_tape:
            model(*model_args)
        # pfor works with Tensors only, so extract RV values
        values = collections.OrderedDict(
            (k, rv.value) for k, rv in model_tape.items())
        return values

    return pfor(loop_body, num_samples)
Exemple #12
0
    def loop_body(sample_idx):
        def loop_body_chain(chain_idx):
            print('\nNested pfor!\n')
            return transform_fn([
                tf.gather(tf.gather(rv_states, sample_idx), chain_idx)
                for rv_states in states
            ])

        if num_chains == 1:
            return tf.nest.map_structure(lambda x: tf.expand_dims(x, 0),
                                         loop_body_chain(0))

        return pfor(loop_body_chain, num_chains)
    def vectorized_log_joint_fn(*args, **kwargs):
        x1 = args[0] if len(args) > 0 else kwargs.values()[0]

        num_inputs = x1.shape[0]
        if not x1.shape.is_fully_defined():
            num_inputs = tf.shape(x1)[0]

        def loop_body(i):
            sliced_args = [tf.gather(v, i) for v in args]
            sliced_kwargs = {k: tf.gather(v, i) for k, v in kwargs.items()}
            return log_joint_fn(*sliced_args, **sliced_kwargs)

        result = pfor(loop_body, num_inputs)
        result.set_shape([num_inputs])
        return result
def transform_mcmc_states(states, transform_fn):
    """Transforms all states using the provided transform function."""

    num_samples = FLAGS.num_samples
    num_chains = FLAGS.num_chains

    def loop_body(sample_idx):
        def loop_body_chain(chain_idx):
            print('\nNested pfor!\n')
            return transform_fn([
                tf.gather(tf.gather(rv_states, sample_idx), chain_idx)
                for rv_states in states
            ])

        return pfor(loop_body_chain, num_chains)

    return pfor(loop_body, num_samples)
 def test_pfor(self):
     self.assertAllEqual(
         self.evaluate(tf_pfor.pfor(lambda x: tf.ones([]), 7)),
         np_pfor.pfor(lambda x: nptf.ones([]), 7))
Exemple #16
0
    def vtransf(many_chains_sample):
        def loop_body(c):
            return transform(
                [tf.gather(rv_states, c) for rv_states in many_chains_sample])

        return pfor(loop_body, FLAGS.num_chains)