def get_mean_field_elbo(model, target, num_mc_samples, model_args, model_obs_kwargs, vi_kwargs): if FLAGS.reparameterise_variational and 'cVIP' in FLAGS.method: combined_kwargs = model_obs_kwargs.copy() combined_kwargs.update(vi_kwargs) variational_model, variational_parameters = make_variational_model_special( model, *model_args, **combined_kwargs) else: variational_model, variational_parameters = program_transformations.make_variational_model( model, *model_args, **model_obs_kwargs) log_joint_q = make_log_joint_fn(variational_model) def target_q(**parameters): return log_joint_q(*model_args, **parameters) #beta = tf.get_variable("beta", trainable=False, initializer=0.) #beta_incr = tf.assign(beta, tf.clip_by_value(beta + 0.1*beta + 0.0000001, 0., 1.)) #with tf.control_dependencies([beta_incr]): def loop_body(mc_sample): with tape() as variational_tape: _ = variational_model(*model_args) params = variational_tape.values() energy = target(*params) entropy = tf.negative(target_q(**variational_tape)) return energy + entropy elbo = tf.reduce_sum(pfor(loop_body, num_mc_samples)) / num_mc_samples tf.summary.scalar('elbo', elbo) return elbo, variational_parameters
def network_builder(x): """Wraps the function 'network' to compute per-example.""" def loop_fn(i): x_i = tf.expand_dims(tf.gather(x, i), 0) features = network(x_i) jac = pfor.jacobian(features, params, use_pfor=use_pfor) return features, jac if use_pfor: features, jac = pfor.pfor(loop_fn, x.shape[0]) else: loop_fn_dtypes = [tf.float32, [tf.float32] * len(params)] features, jac = pfor.for_loop(loop_fn, loop_fn_dtypes, data.shape[0]) raise NotImplementedError( 'use_pfor=False + per_example=True is not yet working.' ) features = _collapse_first_dim(features) features.set_shape(network(x).shape) jac = [_collapse_first_dim(y) for y in jac] for p, j in zip(params, jac): j.set_shape(features.shape.as_list() + p.shape.as_list()) # Note: setting rank=2 so that we use matmul for covariance below # instead of batch_matmul. return features, jac
def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" with tf.name_scope('iid_sample_fn'): seed = kwargs.pop('seed', None) if samplers.is_stateful_seed(seed): kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) def pfor_loop_body(_): with tf.name_scope('iid_sample_fn_stateful_body'): return sample_fn(*args, **kwargs) else: # If a stateless seed arg is passed, split it into `n` different # stateless seeds, so that we don't just get a bunch of copies of the # same sample. if not JAX_MODE: warnings.warn( 'Saw Tensor seed {}, implying stateless sampling. Autovectorized ' 'functions that use stateless sampling may be quite slow because ' 'the current implementation falls back to an explicit loop. This ' 'will be fixed in the future. For now, you will likely see ' 'better performance from stateful sampling, which you can invoke ' 'by passing a Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') def pfor_loop_body(i): with tf.name_scope('iid_sample_fn_stateless_body'): return sample_fn(*args, seed=tf.gather(seed, i), **kwargs) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True)
def iid_sample_fn(*args, **kwargs): """Draws iid samples from `fn`.""" pfor_loop_body = lambda _: sample_fn(*args, **kwargs) seed = kwargs.pop('seed', None) try: # Assume that `seed` is a valid stateful seed (Python `int`). kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')()) pfor_loop_body = lambda _: sample_fn(*args, **kwargs) except TypeError as e: # If a stateless seed arg is passed, split it into `n` different stateless # seeds, so that we don't just get a bunch of copies of the same sample. if TENSOR_SEED_MSG_PREFIX not in str(e): raise warnings.warn( 'Saw non-`int` seed {}, implying stateless sampling. ' 'Autovectorized functions that use stateless sampling ' 'may be quite slow because the current implementation ' 'falls back to an explicit loop. This will be fixed in the ' 'future. For now, you will likely see better performance ' 'from stateful sampling, which you can invoke by passing a' 'traditional Python `int` seed.'.format(seed)) seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless') pfor_loop_body = ( lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs)) draws = parallel_for.pfor(pfor_loop_body, n) return tf.nest.map_structure(unflatten, draws, expand_composites=True)
def transform_mcmc_states(states, transform_fn, num_chains=1): """Apply a joint transformation to each of a set of MCMC samples. Args: states: list of `Tensors`, such as returned from `tfp.mcmc.sample_chain`, where the `i`th element has shape `concat([[num_results], rv_shapes[i]])`, or shape `concat([[num_results], [num_chains], rv_shapes[i]])` if num_chains > 1. transform_fn: callable that takes as argument a single state of the chain, i.e., a list of `Tensors` where the `i`th element has shape `rv_shapes[i]` representing a single rv value, and returns a transformed state, i.e., a list of `Tensors` where the `i`th element has shape `transformed_rv_shapes[i]`. Returns: transformed_states: list of `Tensors` representing samples from a transformed model, where the `i`th element has shape `concat([[num_results], transformed_rv_shapes[i]])`. """ num_samples = states[0].shape[0].value #transformed_states = zip(*[ # transform_fn([rv_states[sample_idx, Ellipsis] # for rv_states in states]) # for sample_idx in range(num_samples) #]) def loop_body(sample_idx): return transform_fn( [tf.gather(rv_states, sample_idx) for rv_states in states]) return pfor(loop_body, num_samples) # transformed_states =
def test_sampled_scale_follows_correct_distribution(self): strm = test_util.test_seed_stream() prior = tfd.InverseGamma(concentration=0.1, scale=0.1) num_timesteps = 100 observed_samples = tf.random.normal([2, num_timesteps], seed=strm()) * 3. is_missing = tf.random.uniform([2, num_timesteps], seed=strm()) > 0.9 # Check that posterior variance samples have the moments of the correct # InverseGamma distribution. posterior_scale_samples = parallel_for.pfor( lambda i: gibbs_sampler._resample_scale( # pylint: disable=g-long-lambda prior=prior, observed_residuals=observed_samples, is_missing=is_missing, seed=strm()), 10000) concentration = prior.concentration + tf.reduce_sum( 1 - tf.cast(is_missing, tf.float32), axis=-1)/2. scale = prior.scale + tf.reduce_sum( (observed_samples * tf.cast(~is_missing, tf.float32))**2, axis=-1)/2. posterior_scale_samples_, concentration_, scale_ = self.evaluate( (posterior_scale_samples, concentration, scale)) self.assertAllClose(np.mean(posterior_scale_samples_**2, axis=0), scale_ / (concentration_ - 1), atol=0.05) self.assertAllClose( np.std(posterior_scale_samples_**2, axis=0), scale_ / ((concentration_ - 1) * np.sqrt(concentration_ - 2)), atol=0.05)
def transform_mcmc_states(states, transform_fn): """Transforms all states using the provided transform function.""" num_samples = FLAGS.num_samples num_chains = FLAGS.num_chains def loop_body(sample_idx): def loop_body_chain(chain_idx): print('\nNested pfor!\n') return transform_fn([ tf.gather(tf.gather(rv_states, sample_idx), chain_idx) for rv_states in states ]) if num_chains == 1: return tf.nest.map_structure(lambda x: tf.expand_dims(x, 0), loop_body_chain(0)) return pfor(loop_body_chain, num_chains) if num_samples == 1: return tf.nest.map_structure(lambda x: tf.expand_dims(x, 0), loop_body(0)) return pfor(loop_body, num_samples)
def loop_body(sample_idx): def loop_body_chain(chain_idx): print('\nNested pfor!\n') return transform_fn([ tf.gather(tf.gather(rv_states, sample_idx), chain_idx) for rv_states in states ]) return pfor(loop_body_chain, num_chains)
def test_pfor_with_closure_multi_out(self): val = np.arange(7.)[:, np.newaxis] tf_val = tf.constant(val) def tf_fn(x): return tf.gather(tf_val, x)**2, tf.gather(tf_val, x) def np_fn(x): return nptf.gather(val, x)**2, nptf.gather(val, x) self.assertAllEqual( self.evaluate(tf_pfor.pfor(tf_fn, 7)), np_pfor.pfor(np_fn, 7))
def test_sampled_weights_follow_correct_distribution(self): seed = test_util.test_seed(sampler_type='stateless') design_seed, true_weights_seed, sampled_weights_seed = samplers.split_seed( seed, 3, 'test_sampled_weights_follow_correct_distribution') num_timesteps = 10 num_features = 2 batch_shape = [3, 1] design_matrix = samplers.normal( batch_shape + [num_timesteps, num_features], seed=design_seed) true_weights = samplers.normal( batch_shape + [num_features, 1], seed=true_weights_seed) * 10.0 targets = tf.matmul(design_matrix, true_weights) is_missing = tf.convert_to_tensor([False, False, False, True, True, False, False, True, False, False], dtype=tf.bool) prior_scale = tf.convert_to_tensor(5.) likelihood_scale = tf.convert_to_tensor(0.1) # Analytically compute the true posterior distribution on weights. valid_design_matrix = tf.boolean_mask(design_matrix, ~is_missing, axis=-2) valid_targets = tf.boolean_mask(targets, ~is_missing, axis=-2) num_valid_observations = tf.shape(valid_design_matrix)[-2] weights_posterior_mean, weights_posterior_cov, _ = linear_gaussian_update( prior_mean=tf.zeros([num_features, 1]), prior_cov=tf.eye(num_features) * prior_scale**2, observation_matrix=tfl.LinearOperatorFullMatrix(valid_design_matrix), observation_noise=tfd.MultivariateNormalDiag( loc=tf.zeros([num_valid_observations]), scale_diag=likelihood_scale * tf.ones([num_valid_observations])), x_observed=valid_targets) # Check that the empirical moments of sampled weights match the true values. sampled_weights = parallel_for.pfor( lambda i: gibbs_sampler._resample_weights( # pylint: disable=g-long-lambda design_matrix=design_matrix, target_residuals=targets[..., 0], observation_noise_scale=likelihood_scale, weights_prior_scale=prior_scale, is_missing=is_missing, seed=sampled_weights_seed), 10000) sampled_weights_mean = tf.reduce_mean(sampled_weights, axis=0) centered_weights = sampled_weights - weights_posterior_mean[..., 0] sampled_weights_cov = tf.reduce_mean(centered_weights[..., :, tf.newaxis] * centered_weights[..., tf.newaxis, :], axis=0) (sampled_weights_mean_, weights_posterior_mean_, sampled_weights_cov_, weights_posterior_cov_) = self.evaluate(( sampled_weights_mean, weights_posterior_mean[..., 0], sampled_weights_cov, weights_posterior_cov)) self.assertAllClose(sampled_weights_mean_, weights_posterior_mean_, atol=0.01, rtol=0.05) self.assertAllClose(sampled_weights_cov_, weights_posterior_cov_, atol=0.01, rtol=0.05)
def vectorized_sample(model, model_args, num_samples): """Draw multiple joint samples from an Ed2 model.""" def loop_body(i): # trace the model to draw a single joint sample with ed.tape() as model_tape: model(*model_args) # pfor works with Tensors only, so extract RV values values = collections.OrderedDict( (k, rv.value) for k, rv in model_tape.items()) return values return pfor(loop_body, num_samples)
def loop_body(sample_idx): def loop_body_chain(chain_idx): print('\nNested pfor!\n') return transform_fn([ tf.gather(tf.gather(rv_states, sample_idx), chain_idx) for rv_states in states ]) if num_chains == 1: return tf.nest.map_structure(lambda x: tf.expand_dims(x, 0), loop_body_chain(0)) return pfor(loop_body_chain, num_chains)
def vectorized_log_joint_fn(*args, **kwargs): x1 = args[0] if len(args) > 0 else kwargs.values()[0] num_inputs = x1.shape[0] if not x1.shape.is_fully_defined(): num_inputs = tf.shape(x1)[0] def loop_body(i): sliced_args = [tf.gather(v, i) for v in args] sliced_kwargs = {k: tf.gather(v, i) for k, v in kwargs.items()} return log_joint_fn(*sliced_args, **sliced_kwargs) result = pfor(loop_body, num_inputs) result.set_shape([num_inputs]) return result
def transform_mcmc_states(states, transform_fn): """Transforms all states using the provided transform function.""" num_samples = FLAGS.num_samples num_chains = FLAGS.num_chains def loop_body(sample_idx): def loop_body_chain(chain_idx): print('\nNested pfor!\n') return transform_fn([ tf.gather(tf.gather(rv_states, sample_idx), chain_idx) for rv_states in states ]) return pfor(loop_body_chain, num_chains) return pfor(loop_body, num_samples)
def test_pfor(self): self.assertAllEqual( self.evaluate(tf_pfor.pfor(lambda x: tf.ones([]), 7)), np_pfor.pfor(lambda x: nptf.ones([]), 7))
def vtransf(many_chains_sample): def loop_body(c): return transform( [tf.gather(rv_states, c) for rv_states in many_chains_sample]) return pfor(loop_body, FLAGS.num_chains)