def testNonEmptyConstantTensor(self): x = tf.zeros((2, 3, 4)) shape = distribution_util.prefer_static_shape(x) self.assertIsInstance(shape, np.ndarray) self.assertAllEqual([2, 3, 4], shape)
def test_increment_log_prob(self): root = tfd.JointDistributionCoroutine.Root prior_mean = 3. x_size = 100 def custom_ll(w, x): return tf.reduce_sum(tfd.Normal(w, 1.).log_prob(x)) def ulp_grad(w, x): @joint_density_coroutine.JointDensityCoroutine def sharded_model(): w = yield root(tfd.Normal(prior_mean, 1.)) yield root( sharded.Sharded(increment_log_prob.IncrementLogProb( custom_ll(w, x)), shard_axis_name=self.axis_name)) def ulp_fn(w): zeros = tf.zeros([x_size, 0]) return sharded_model.unnormalized_log_prob(w, zeros) ulp, g = tfp.math.value_and_gradient(ulp_fn, (w, )) return ulp, g def true_ulp_grad(w, x): @joint_density_coroutine.JointDensityCoroutine def model(): w = yield root(tfd.Normal(prior_mean, 1.)) yield root(increment_log_prob.IncrementLogProb(custom_ll(w, x))) def ulp_fn(w): zeros = tf.zeros([x_size, 0]) return model.unnormalized_log_prob(w, zeros) ulp, g = tfp.math.value_and_gradient(ulp_fn, (w, )) return ulp, g def test_w_x(w, x): sharded_x = self.shard_values( tf.reshape(x, [test_lib.NUM_DEVICES, -1])) lp, g = self.evaluate( self.per_replica_to_tensor( self.strategy_run(ulp_grad, ( w, sharded_x, ), in_axes=(None, 0)))) true_lp, true_g = self.evaluate(true_ulp_grad(w, x)) self.assertAllClose(true_lp, lp[0]) self.assertAllClose(true_g[0], g[0][0]) w = tf.constant(4.) zeros = tf.zeros([x_size]) test_w_x(w, zeros) random_x = self.evaluate( tfd.Normal(loc=tf.zeros([x_size]), scale=tf.ones([x_size])).sample(seed=self.key)) test_w_x(w, random_x)
def run(key): return tfp_dist.Sharded( tfd.Independent(tfd.Normal(tf.zeros(1), tf.ones(1)), 1), shard_axis_name=self.axis_name).sample(seed=key)
def _mean(self): return tf.zeros(self.batch_shape_tensor())
def _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): (direction_seed, subtree_seed, acceptance_seed, next_seed) = samplers.split_seed(seed, n=4) batch_shape = ps.shape(current_step_meta_info.init_energy) direction = tf.cast(samplers.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=direction_seed), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: bu.where_left_justified_mask(direction, v[1], v[0]), initial_step_state) directions_expanded = [ bu.left_justified_expand_dims_like(direction, state) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_subtree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory, seed=subtree_seed) last_candidate_state = initial_step_metastate.candidate_state energy_diff_sum = (energy_diff_tree_sum + initial_step_metastate.energy_diff_sum) if MULTINOMIAL_SAMPLE: tree_weight = tf.where( continue_tree_final, candidate_tree_state.weight, tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype)) weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: tree_weight = tf.where(continue_tree_final, candidate_tree_state.weight, tf.zeros([], dtype=TREE_COUNT_DTYPE)) weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-samplers.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=acceptance_seed)) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ bu.where_left_justified_mask(choose_new_state, s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=bu.where_left_justified_mask( choose_new_state, candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ bu.where_left_justified_mask(choose_new_state, grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=bu.where_left_justified_mask( choose_new_state, candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) for new_candidate_state_temp, old_candidate_state_temp in zip( new_candidate_state.state, last_candidate_state.state): tensorshape_util.set_shape(new_candidate_state_temp, old_candidate_state_temp.shape) for new_candidate_grad_temp, old_candidate_grad_temp in zip( new_candidate_state.target_grad_parts, last_candidate_state.target_grad_parts): tensorshape_util.set_shape(new_candidate_grad_temp, old_candidate_grad_temp.shape) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: bu.where_left_justified_mask(direction, v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension bu.where_left_justified_mask( direction, right, left), bu.where_left_justified_mask( direction, left, right), ], axis=0) for left, right in zip( tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) momentum_tree_cumsum = [] for p0, p1 in zip(initial_step_metastate.momentum_sum, momentum_subtree_cumsum): momentum_part_temp = p0 + p1 tensorshape_util.set_shape(momentum_part_temp, p0.shape) momentum_tree_cumsum.append(momentum_part_temp) for new_state_temp, old_state_temp in zip( tf.nest.flatten(new_step_state), tf.nest.flatten(initial_step_state)): tensorshape_util.set_shape(new_state_temp, old_state_temp.shape) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=ps.rank_from_shape(batch_shape), shard_axis_names=self.experimental_shard_axis_names) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, momentum_sum=momentum_tree_cumsum, energy_diff_sum=energy_diff_sum, continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, next_seed, new_step_state, new_step_metastate
def null_input(self): return tf.zeros([1, self._num_tokens], dtype=tf.float32)
def sample_annealed_importance_chain(num_steps, proposal_log_prob_fn, target_log_prob_fn, current_state, make_kernel_fn, parallel_iterations=10, seed=None, name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'proposal' distribution: `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` and the target distribution: `exp(target_log_prob_fn(x) - target_log_normalizer)`, accumulating importance weights along the way. The product of these importance weights gives an unbiased estimate of the ratio of the normalizing constants of the initial distribution and the target distribution: `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. Note: When running in graph mode, `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three times (although this may be reduced to two times in the future). Args: num_steps: Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator. proposal_log_prob_fn: Python callable that returns the log density of the initial distribution. target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_annealed_importance_chain` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'sample_annealed_importance_chain'). Returns: next_state: `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input `current_state`. ais_weights: Tensor with the estimated weight(s). Has shape matching `target_log_prob_fn(current_state)`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### Examples ##### Estimate the normalizing constant of a log-gamma distribution. ```python tfd = tfp.distributions # Run 100 AIS chains in parallel num_chains = 100 dims = 20 dtype = np.float32 proposal = tfd.MultivariateNormalDiag( loc=tf.zeros([dims], dtype=dtype)) target = tfd.TransformedDistribution( distribution=tfd.Sample( tfd.Gamma(concentration=dtype(2), rate=dtype(3)), sample_shape=[dims]) bijector=tfp.bijectors.Invert(tfp.bijectors.Exp())) chains_state, ais_weights, kernels_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target.log_prob, current_state=proposal.sample(num_chains), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.2, num_leapfrog_steps=2))) log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) ``` ##### Estimate marginal likelihood of a Bayesian regression model. ```python tfd = tfp.distributions def make_prior(dims, dtype): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) def make_likelihood(weights, x): return tfd.MultivariateNormalDiag( loc=tf.tensordot(weights, x, axes=[[0], [-1]])) # Run 100 AIS chains in parallel num_chains = 100 dims = 10 dtype = np.float32 # Make training data. x = np.random.randn(num_chains, dims).astype(dtype) true_weights = np.random.randn(dims).astype(dtype) y = np.dot(x, true_weights) + np.random.randn(num_chains) # Setup model. prior = make_prior(dims, dtype) def target_log_prob_fn(weights): return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) proposal = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) weight_samples, ais_weights, kernel_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target_log_prob_fn current_state=tf.zeros([num_chains, dims], dtype), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.1, num_leapfrog_steps=2))) log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) ``` """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed, salt='mcmc.sample_ais_chain') with tf.name_scope(name or 'sample_annealed_importance_chain'): num_steps = tf.convert_to_tensor(value=num_steps, dtype=tf.int32, name='num_steps') if mcmc_util.is_list_like(current_state): current_state = [ tf.convert_to_tensor(s, name='current_state') for s in current_state ] else: current_state = tf.convert_to_tensor(value=current_state, name='current_state') def _make_convex_combined_log_prob_fn(iter_): def _fn(*args): p = tf.identity(proposal_log_prob_fn(*args), name='proposal_log_prob') t = tf.identity(target_log_prob_fn(*args), name='target_log_prob') dtype = dtype_util.base_dtype(p.dtype) beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype) return tf.identity(beta * t + (1. - beta) * p, name='convex_combined_log_prob') return _fn def _loop_body(iter_, seed, ais_weights, current_state, kernel_results): """Closure which implements `tf.while_loop` body.""" iter_seed, next_seed = samplers.split_seed( seed, salt='ais_chain.seeded_one_step') if is_seeded else (seed, seed) x = (current_state if mcmc_util.is_list_like(current_state) else [current_state]) proposal_log_prob = proposal_log_prob_fn(*x) target_log_prob = target_log_prob_fn(*x) ais_weights += ((target_log_prob - proposal_log_prob) / tf.cast(num_steps, ais_weights.dtype)) kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_)) # TODO(b/147676843): Should we warn if the kernel is not calibrated? one_step_kwargs = dict(seed=iter_seed) if is_seeded else {} next_state, inner_results = kernel.one_step( current_state, kernel_results.inner_results, **one_step_kwargs) kernel_results = AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) return [ iter_ + 1, next_seed, ais_weights, next_state, kernel_results ] def _bootstrap_results(init_state): """Creates first version of `previous_kernel_results`.""" kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0)) inner_results = kernel.bootstrap_results(init_state) mh_results = _find_inner_mh_results(inner_results) convex_combined_log_prob = mh_results.accepted_results.target_log_prob dtype = dtype_util.as_numpy_dtype(convex_combined_log_prob.dtype) shape = tf.shape(convex_combined_log_prob) proposal_log_prob = tf.fill(shape, dtype(np.nan), name='bootstrap_proposal_log_prob') target_log_prob = tf.fill(shape, dtype(np.nan), name='target_target_log_prob') return AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) previous_kernel_results = _bootstrap_results(current_state) inner_results = previous_kernel_results.inner_results mh_results = _find_inner_mh_results(inner_results) ais_weights = tf.zeros( shape=tf.broadcast_dynamic_shape( tf.shape(mh_results.proposed_results.target_log_prob), tf.shape(mh_results.accepted_results.target_log_prob)), dtype=mh_results.proposed_results.target_log_prob.dtype) [_, _, ais_weights, current_state, kernel_results] = tf.while_loop( cond=lambda iter_, *args: iter_ < num_steps, body=_loop_body, loop_vars=[ np.int32(0), # iter_ seed, ais_weights, current_state, previous_kernel_results, ], parallel_iterations=parallel_iterations) return [current_state, ais_weights, kernel_results]
def _forward_log_det_jacobian(self, x): return tf.zeros([], dtype=x.dtype)
def sample_sequential_monte_carlo( prior_log_prob_fn, likelihood_log_prob_fn, current_state, min_num_steps=2, max_num_steps=25, max_stage=100, make_kernel_fn=make_rwmh_kernel_fn, tuning_fn=simple_heuristic_tuning, make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn, ess_threshold_ratio=0.5, parallel_iterations=10, seed=None, name=None): """Runs Sequential Monte Carlo to sample from the posterior distribution. This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'prior' distribution: `exp(prior_log_prob_fn(x))` and the target 'posterior' distribution: `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`, by mutating a collection of MC samples (i.e., particles). The approach is also known as Particle Filter in some literature. The current implemenetation is largely based on Del Moral et al [1], which adapts the tempering sequence adaptively (base on the effective sample size) and the scaling of the mutation kernel (base on the sample covariance of the particles) at each stage. Args: prior_log_prob_fn: Python callable that returns the log density of the prior distribution. likelihood_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the likelihood distribution. current_state: Nested structure of `Tensor`s, each of shape `concat([[num_particles, b1, ..., bN], latent_part_event_shape])`, where `b1, ..., bN` are optional batch dimensions. Each batch represents an independent SMC run. min_num_steps: The minimal number of kernel transition steps in one mutation of the MC samples. max_num_steps: The maximum number of kernel transition steps in one mutation of the MC samples. Note that the actual number of steps in one mutation is tuned during sampling and likely lower than the max_num_step. max_stage: Integer number of the stage for increasing the temperature from 0 to 1. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_sequential_monte_carlo` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. tuning_fn: Python `callable` which takes the number of steps, the log scaling, and the log acceptance ratio from the last mutation and output the number of steps and log scaling for the next mutation. make_tempered_target_log_prob_fn: Python `callable` that takes the `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures` and creates a `target_log_prob_fn` `callable` that pass to `make_kernel_fn`. ess_threshold_ratio: Target ratio for effective sample size. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Python integer or TFP seedstream to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'sample_sequential_monte_carlo'). Returns: n_stage: Number of the mutation stage SMC ran. final_state: `Tensor` or Python `list` of `Tensor`s representing the final state(s) of the Markov chain(s). The output are the posterior samples. final_kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### References [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential Monte Carlo method for approximate Bayesian computation. _Statistics and Computing_, 22.5(1009-1020), 2012. """ with tf.name_scope(name or 'sample_sequential_monte_carlo'): seed_stream = SeedStream(seed, salt='smc_seed') unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_state = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in current_state ] # Initial preprocessing at Stage 0 likelihood_log_prob = likelihood_log_prob_fn(*current_state) likelihood_rank = ps.rank(likelihood_log_prob) dimension = ps.reduce_sum([ ps.reduce_prod(ps.shape(x)[likelihood_rank:]) for x in current_state ]) # We infer the particle shapes from the resulting likelihood: # [num_particles, b1, ..., bN] particle_shape = ps.shape(likelihood_log_prob) num_particles, batch_shape = particle_shape[0], particle_shape[1:] effective_sample_size_threshold = tf.cast( num_particles * ess_threshold_ratio, tf.int32) # TODO(b/152412213): Revisit this default parameter. # Default to the optimal scaling of a random walk kernel for a d-dimensional # normal distributed targets: 2.38 ** 2 / d. # For more detail see: # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of # random walk Metropolis algorithms. _The annals of applied probability_. # 1997;7(1):110-20. scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) / tf.constant(dimension, dtype=likelihood_log_prob.dtype)) inverse_temperature = tf.zeros(batch_shape, dtype=likelihood_log_prob.dtype) scalings = ps.ones_like(likelihood_log_prob) * ps.minimum( scale_start, 1.) kernel = make_kernel_fn(make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings, seed=seed_stream) pkr = kernel.bootstrap_results(current_state) _, kernel_target_log_prob = gather_mh_like_result(pkr) particle_info = ParticleInfo( log_accept_prob=ps.zeros_like(likelihood_log_prob), log_scalings=tf.math.log(scalings), tempered_log_prob=kernel_target_log_prob, likelihood_log_prob=likelihood_log_prob, ) current_pkr = SMCResults( num_steps=tf.convert_to_tensor(max_num_steps, dtype=tf.int32, name='num_steps'), inverse_temperature=inverse_temperature, log_marginal_likelihood=tf.zeros_like(inverse_temperature), particle_info=particle_info) def update_weights_temperature(inverse_temperature, likelihood_log_prob): """Calculate the next inverse temperature and update weights.""" likelihood_diff = likelihood_log_prob - tf.reduce_max( likelihood_log_prob, axis=0) def _body_fn(new_beta, upper_beta, lower_beta, eff_size, log_weights): """One iteration of the temperature and weight update.""" new_beta = (lower_beta + upper_beta) / 2.0 log_weights = (new_beta - inverse_temperature) * likelihood_diff log_weights_norm = tf.math.log_softmax(log_weights, axis=0) eff_size = tf.cast( tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm, axis=0)), tf.int32) upper_beta = tf.where( eff_size < effective_sample_size_threshold, new_beta, upper_beta) lower_beta = tf.where( eff_size < effective_sample_size_threshold, lower_beta, new_beta) return new_beta, upper_beta, lower_beta, eff_size, log_weights def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_): # pylint: disable=unused-argument # TODO(junpenglao): revisit threshold below to be dtype specific. threshold = 1e-6 return (tf.math.reduce_any(upper_beta - lower_beta > threshold) & tf.math.reduce_any( eff_size != effective_sample_size_threshold)) (new_beta, upper_beta, lower_beta, eff_size, log_weights) = tf.while_loop( # pylint: disable=unused-variable cond=_cond_fn, body=_body_fn, loop_vars=(tf.zeros_like(inverse_temperature), tf.fill(ps.shape(inverse_temperature), tf.constant(2, inverse_temperature.dtype)), inverse_temperature, tf.zeros_like(inverse_temperature, dtype=tf.int32), tf.zeros_like(likelihood_diff)), parallel_iterations=parallel_iterations) log_weights = tf.where(new_beta < 1., log_weights, (1. - inverse_temperature) * likelihood_diff) marginal_loglike_ = reduce_logmeanexp( (new_beta - inverse_temperature) * likelihood_log_prob, axis=0) new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.) return marginal_loglike_, new_inverse_temperature, log_weights def mutate(current_state, log_scalings, num_steps, inverse_temperature): """Mutate the state using a Transition kernel.""" with tf.name_scope('mutate_states'): scalings = tf.exp(log_scalings) kernel = make_kernel_fn(make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings, seed=seed_stream) pkr = kernel.bootstrap_results(current_state) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) def mutate_onestep(i, state, pkr, log_accept_prob_sum): next_state, next_kernel_results = kernel.one_step( state, pkr) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.) log_accept_prob_sum = log_add_exp(log_accept_prob_sum, log_accept_prob) return i + 1, next_state, next_kernel_results, log_accept_prob_sum ( _, next_state, next_kernel_results, log_accept_prob_sum ) = tf.while_loop( cond=lambda i, *args: i < num_steps, body=mutate_onestep, loop_vars=( tf.zeros([], dtype=tf.int32), current_state, pkr, # we accumulate the acceptance probability in log space. tf.fill( ps.shape(kernel_log_accept_ratio), tf.constant(-np.inf, kernel_log_accept_ratio.dtype))), parallel_iterations=parallel_iterations) _, kernel_target_log_prob = gather_mh_like_result( next_kernel_results) avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log( tf.cast(num_steps + 1, log_accept_prob_sum.dtype)) return (next_state, avg_log_accept_prob_per_particle, kernel_target_log_prob) # One SMC steps. def smc_body_fn(stage, state, smc_kernel_result): """Run one stage of SMC with constant temperature.""" (new_marginal, new_inv_temperature, log_weights) = update_weights_temperature( smc_kernel_result.inverse_temperature, smc_kernel_result.particle_info.likelihood_log_prob) # TODO(b/152412213) Use a tf.scan to better collect debug info. if PRINT_DEBUG: tf.print( 'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:', smc_kernel_result.num_steps, 'accept:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_accept_prob, axis=0)), 'scaling:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_scalings, axis=0))) (resampled_state, resampled_particle_info), _ = resample_particle_and_info( (state, smc_kernel_result.particle_info), log_weights, seed=seed_stream) next_num_steps, next_log_scalings = tuning_fn( smc_kernel_result.num_steps, resampled_particle_info.log_scalings, resampled_particle_info.log_accept_prob) # Skip tuning at stage 0. next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps, next_num_steps) next_log_scalings = tf.where(stage == 0, resampled_particle_info.log_scalings, next_log_scalings) next_num_steps = tf.clip_by_value(next_num_steps, min_num_steps, max_num_steps) next_state, log_accept_prob, tempered_log_prob = mutate( resampled_state, next_log_scalings, next_num_steps, new_inv_temperature) next_pkr = SMCResults( num_steps=next_num_steps, inverse_temperature=new_inv_temperature, log_marginal_likelihood=( new_marginal + smc_kernel_result.log_marginal_likelihood), particle_info=ParticleInfo( log_accept_prob=log_accept_prob, log_scalings=next_log_scalings, tempered_log_prob=tempered_log_prob, likelihood_log_prob=likelihood_log_prob_fn(*next_state), )) return stage + 1, next_state, next_pkr (n_stage, final_state, final_kernel_results) = tf.while_loop( cond=lambda i, state, pkr: ( # pylint: disable=g-long-lambda (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)), body=smc_body_fn, loop_vars=(tf.zeros([], dtype=tf.int32), current_state, current_pkr), parallel_iterations=parallel_iterations) if unwrap_state_list: final_state = final_state[0] return n_stage, final_state, final_kernel_results
def loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) direction = tf.cast(tf.random.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[1], v[0]), initial_step_state) directions_expanded = [ _rightmost_expand_to_rank(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight if MULTINOMIAL_SAMPLE: weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), r, l), tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), l, r), ], axis=0) for l, r in zip(tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=len(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, energy_diff_sum=(energy_diff_tree_sum + initial_step_metastate.energy_diff_sum), continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def _inverse_log_det_jacobian(self, y): return tf.zeros([], dtype=y.dtype)
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(self.name + '.one_step'): unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob) def _copy(v): return v * prefer_static.ones(prefer_static.pad( [2], paddings=[[0, prefer_static.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=2**(self.max_tree_depth - 1), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray( tf.int32, size=2**(self.max_tree_depth - 1), clear_after_read=False).unstack(self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, state, metastate: ( # pylint: disable=g-long-lambda ((iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree))), body=lambda iter_, state, metastate: self.loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), initial_step_state, initial_step_metastate), parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), # TODO(junpenglao): return non-cumulated leapfrogs_taken once # benchmarking is done. leapfrogs_taken=(previous_kernel_results.leapfrogs_taken + new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy) result_state = new_step_metastate.candidate_state.state if unwrap_state_list: result_state = result_state[0] return result_state, kernel_results
def transport_implicit_gradients(derivative_cost, transport_matrix, eps, b, d_p): """Application of the transpose of the Jacobians dP/dx and dP/db. This is applied to a perturbation of the size of the transport matrix. Required to back-propagate through Sinkhorn's output. Args: derivative_cost: the derivative of the cost function. transport_matrix: the obtained transport matrix tensor. eps: the value of the entropic regualarization parameter. b: the target weights. d_p: the perturbation of the transport matrix. Returns: A list of two tensor that correspond to the application of the transpose of dP/dx and dP/db on dP. """ batch_size = tf.shape(b)[0] m = tf.shape(b)[1] invmargin1 = tf.math.reciprocal(tf.reduce_sum(transport_matrix, axis=2)) m1 = invmargin1[:, 1:, tf.newaxis] * transport_matrix[:, 1:, :] m1 = tf.concat( [tf.zeros([tf.shape(m1)[0], 1, tf.shape(m1)[2]]), m1], axis=1) invmargin2 = tf.math.reciprocal(tf.reduce_sum(transport_matrix, axis=1)) m2 = invmargin2[:, :, tf.newaxis] * tf.transpose(transport_matrix, [0, 2, 1]) eye_m = tf.eye(m, batch_shape=[batch_size]) schur = eye_m - tf.linalg.matmul(m2, m1) def jac_b_p_transpose(d_p): """Transposed of the jacobian of the transport w.r.t the target weights.""" d_p_p = d_p * transport_matrix u_f = tf.reduce_sum(d_p_p, axis=2) / eps u_g = tf.reduce_sum(d_p_p, axis=1) / eps m1_tranpose_u_f = tf.linalg.matvec(m1, u_f, transpose_a=True) to_invert = tf.concat( [m1_tranpose_u_f[:, :, tf.newaxis], u_g[:, :, tf.newaxis]], axis=2) inverses = tf.linalg.solve(tf.transpose(schur, [0, 2, 1]), to_invert) inv_m1_tranpose_u_f, inv_u_g = inverses[:, :, 0], inverses[:, :, 1] jac_2 = -inv_m1_tranpose_u_f + inv_u_g return eps * jac_2 / b def jac_x_p_transpose(d_p): """Transposed of the jacobian of the transport w.r.t the inputs.""" d_p_p = d_p * transport_matrix c_x = -tf.reduce_sum(derivative_cost * d_p_p, axis=2) / eps u_f = tf.math.reduce_sum(d_p_p, axis=2) / eps u_g = tf.math.reduce_sum(d_p_p, axis=1) / eps m1_tranpose_u_f = tf.linalg.matvec(m1, u_f, transpose_a=True) to_invert = tf.concat( [m1_tranpose_u_f[:, :, tf.newaxis], u_g[:, :, tf.newaxis]], axis=2) inverses = tf.linalg.solve(tf.transpose(schur, [0, 2, 1]), to_invert) inv_m1_tranpose_u_f, inv_u_g = inverses[:, :, 0], inverses[:, :, 1] jac_1 = u_f + tf.linalg.matvec( m2, inv_m1_tranpose_u_f - inv_u_g, transpose_a=True) jac_2 = -inv_m1_tranpose_u_f + inv_u_g jac_1 = jac_1 * tf.reduce_sum(m1 * derivative_cost, axis=2) jac_2 = tf.linalg.matvec( tf.transpose(m2, [0, 2, 1]) * derivative_cost, jac_2) return c_x + jac_1 + jac_2 return [jac_x_p_transpose(d_p), jac_b_p_transpose(d_p)]
def testNonEmptyConstantTensor(self): x = tf.zeros((2, 3, 4)) value = distribution_util.prefer_static_value(x) self.assertIsInstance(value, np.ndarray) self.assertAllEqual(np.zeros((2, 3, 4)), value)
def _batch_of_zeros_with_rightmost_singletons(n_singletons): """Return Tensor of zeros with some singletons on the rightmost dims.""" ones = tf.ones(shape=[n_singletons], dtype=tf.int32) return tf.zeros(shape=tf.concat([batch_shape, ones], axis=0), dtype=dtype)
def _entropy(self): return tf.zeros(self.batch_shape_tensor(), dtype=self.dtype)
def _batch_interp_with_gather_nd(x, x_ref_min, x_ref_max, y_ref, nd, fill_value, batch_dims): """N-D interpolation that works with leading batch dims.""" dtype = x.dtype # In this function, # x.shape = [A1, ..., An, D, nd], where n = batch_dims # and # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM] # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with the value # at index [i1,...,ind] in the interpolation table. # and x_ref_max have shapes [A1, ..., An, nd]. # ny[k] is number of y reference points in interp dim k. ny = tf.cast(tf.shape(y_ref)[batch_dims:batch_dims + nd], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of # interpolation table for the dth x value. x_ref_min_expanded = tf.expand_dims(x_ref_min, axis=-2) x_ref_max_expanded = tf.expand_dims(x_ref_max, axis=-2) x_idx_unclipped = (ny - 1) * (x - x_ref_min_expanded) / ( x_ref_max_expanded - x_ref_min_expanded) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) x_idx_unclipped = tf.where(nan_idx, 0., x_idx_unclipped) # x_idx.shape = [A1, ..., An, D, nd] x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype), ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we 'jitter' one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. # idx_below_int32.shape = x.shape[:-1] + [nd] idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) # idx_below_list is a length nd list of shape x.shape[:-1] int32 tensors. idx_below_list = tf.unstack(idx_below_int32, axis=-1) idx_above_list = tf.unstack(idx_above_int32, axis=-1) # Use t to get a convex combination of the below/above values. # t.shape = [A1, ..., An, D, nd] t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. def _expand_x_fn(tensor): # Reshape tensor to tensor.shape + [1] * M. extended_shape = tf.concat([ tf.shape(tensor), tf.ones_like(tf.shape(y_ref)[batch_dims + nd:]) ], axis=0) return tf.reshape(tensor, extended_shape) # Now, t.shape = [A1, ..., An, D, nd] + [1] * (rank(y_ref) - nd - batch_dims) t = _expand_x_fn(t) s = 1 - t # Re-insert NaN wherever x was NaN. nan_idx = _expand_x_fn(nan_idx) t = tf.where(nan_idx, tf.constant(np.nan, dtype), t) terms = [] # Our work above has located x's fractional index inside a cube of above/below # indices. The distance to the below indices is t, and to the above indices # is s. # Drawing lines from x to the cube walls, we get 2**nd smaller cubes. Each # term in the result is a product of a reference point, gathered from y_ref, # multiplied by a volume. The volume is that of the cube opposite to the # reference point. E.g. if the reference point is below x in every axis, the # volume is that of the cube with corner above x in every axis, s[0]*...*s[nd] # We could probably do this with one massive gather, but that would be very # unreadable and un-debuggable. It also would create a large Tensor. for zero_ones_list in _binary_count(nd): gather_from_y_ref_idx = [] opposite_volume_t_idx = [] opposite_volume_s_idx = [] for k, zero_or_one in enumerate(zero_ones_list): if zero_or_one == 0: # If the kth iterate has zero_or_one = 0, # Will gather from the 'below' reference point along axis k. gather_from_y_ref_idx.append(idx_below_list[k]) # Now append the index to gather for computing opposite_volume. # This could be done by initializing opposite_volume to 1, then here: # opposite_volume *= tf.gather(s, indices=k, axis=tf.rank(x) - 1) # but that puts a gather in the 'inner loop.' Better to append the # index and do one larger gather down below. opposite_volume_s_idx.append(k) else: gather_from_y_ref_idx.append(idx_above_list[k]) # Append an index to gather, having the same effect as # opposite_volume *= tf.gather(t, indices=k, axis=tf.rank(x) - 1) opposite_volume_t_idx.append(k) # Compute opposite_volume (volume of cube opposite the ref point): # Recall t.shape = s.shape = [D, nd] + [1, ..., 1] # Gather from t and s along the 'nd' axis, which is rank(x) - 1. ov_axis = tf.rank(x) - 1 opposite_volume = (tf.reduce_prod( tf.gather(t, indices=tf.cast(opposite_volume_t_idx, dtype=tf.int32), axis=ov_axis), axis=ov_axis) * tf.reduce_prod(tf.gather( s, indices=tf.cast(opposite_volume_s_idx, dtype=tf.int32), axis=ov_axis), axis=ov_axis)) # pyformat: disable y_ref_pt = tf.gather_nd(y_ref, tf.stack(gather_from_y_ref_idx, axis=-1), batch_dims=batch_dims) terms.append(y_ref_pt * opposite_volume) y = tf.math.add_n(terms) if tf.debugging.is_numeric_tensor(fill_value): # Recall x_idx_unclipped.shape = [D, nd], # so here we check if it was out of bounds in any of the nd dims. # Thus, oob_idx.shape = [D]. oob_idx = tf.reduce_any( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), axis=-1) # Now, y.shape = [D, B1,...,BM], so we'll have to broadcast oob_idx. oob_idx = _expand_x_fn(oob_idx) # Shape [D, 1,...,1] oob_idx |= tf.fill(tf.shape(y), False) y = tf.where(oob_idx, fill_value, y) return y
def __init__(self, num_timesteps, coefficients, level_scale, initial_state_prior, observation_noise_scale=0., name=None, **linear_gaussian_ssm_kwargs): """Build a state space model implementing an autoregressive process. Args: num_timesteps: Scalar `int` `Tensor` number of timesteps to model with this distribution. coefficients: `float` `Tensor` of shape `concat(batch_shape, [order])` defining the autoregressive coefficients. The coefficients are defined backwards in time: `coefficients[0] * level[t] + coefficients[1] * level[t-1] + ... + coefficients[order-1] * level[t-order+1]`. level_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the transition noise at each step. initial_state_prior: instance of `tfd.MultivariateNormal` representing the prior distribution on latent states. Must have event shape `[order]`. observation_noise_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the observation noise. Default value: 0. name: Python `str` name prefixed to ops created by this class. Default value: "AutoregressiveStateSpaceModel". **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to to the base `tfd.LinearGaussianStateSpaceModel` constructor. """ parameters = dict(locals()) parameters.update(linear_gaussian_ssm_kwargs) del parameters['linear_gaussian_ssm_kwargs'] with tf.name_scope(name or 'AutoregressiveStateSpaceModel') as name: # The initial state prior determines the dtype of sampled values. # Other model parameters must have the same dtype. dtype = initial_state_prior.dtype coefficients = tf.convert_to_tensor(value=coefficients, name='coefficients', dtype=dtype) level_scale = tf.convert_to_tensor(value=level_scale, name='level_scale', dtype=dtype) observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) order = tf.compat.dimension_value(coefficients.shape[-1]) if order is None: raise ValueError( 'Autoregressive coefficients must have static shape.') self._order = order self._coefficients = coefficients self._level_scale = level_scale super(AutoregressiveStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=make_ar_transition_matrix(coefficients), transition_noise=tfd.MultivariateNormalDiag( scale_diag=tf.stack([level_scale] + [tf.zeros_like(level_scale)] * (self.order - 1), axis=-1)), observation_matrix=tf.concat([ tf.ones([1, 1], dtype=dtype), tf.zeros([1, self.order - 1], dtype=dtype) ], axis=-1), observation_noise=tfd.MultivariateNormalDiag( scale_diag=observation_noise_scale[..., tf.newaxis]), initial_state_prior=initial_state_prior, name=name, **linear_gaussian_ssm_kwargs) self._parameters = parameters
def normal_cdf(r): r = tf.convert_to_tensor(value=r, name='r') n = tfd.Normal(loc=tf.zeros([], r.dtype.base_dtype), scale=tf.ones([], r.dtype.base_dtype)) return n.cdf(r)
def params_model_fn(out_channels, size, in_channels, dtype): yield Root( tfd.LogNormal(tf.zeros( list(size) + [in_channels, out_channels], dtype), 1., name='kernel'))
def reduce_audio_in_batch(tensor, hparams=None, is_training=True): instrument_count = hparams.timbre_training_max_instruments note_croppping_list = [] instrument_family_list = [] samples_list = [] max_length = 0 for i in range(instrument_count): pitch = tensor['pitch'][i] # Move the audio so there are different attack times. start_idx = tf.random.uniform((), minval=0, maxval=hparams.timbre_max_start_offset, dtype='int64') samples = K.concatenate( [tf.zeros(start_idx), tf.sparse.to_dense(tensor['audio'])[i]]) end_idx = ( start_idx + tf.py_function(_get_approx_note_length, [tf.sparse.to_dense(tensor['audio'])[i]], tf.int64)) if hparams.timbre_max_len and end_idx > hparams.timbre_max_len: samples = tf.slice(samples, begin=[0], size=[hparams.timbre_max_len]) end_idx = hparams.timbre_max_len if len(samples) > max_length: max_length = len(samples) samples_list.append(samples) instrument_family = tensor['instrument_family'][i] note_croppping_list.append( timbre_dataset_util.NoteCropping(pitch=pitch, start_idx=start_idx, end_idx=end_idx)) instrument_family_list.append( tf.one_hot(tf.cast(instrument_family, tf.int32), hparams.timbre_num_classes)) # Pad the end of the shorter audio clips. samples_list = list( map(lambda x: tf.pad(x, [[0, max_length - len(x)]]), samples_list)) combined_samples = ( tf.reduce_sum(tf.convert_to_tensor(samples_list), axis=0) / instrument_count) # Ensure all audios in batches are the same length. if hparams.timbre_max_len: pad_length = hparams.timbre_max_len else: pad_length = hparams.timbre_max_start_offset + 5 * hparams.sample_rate combined_samples = tf.pad( combined_samples, [[0, pad_length - tf.shape(combined_samples)[0]]]) note_croppings = tf.convert_to_tensor(note_croppping_list, dtype=tf.int32) instrument_families = tf.convert_to_tensor(instrument_family_list, dtype=tf.int32) wav_data = tf.py_function( lambda x: audio_io.samples_to_wav_data( x.numpy(), sample_rate=hparams.sample_rate), [combined_samples], tf.string) return dict( audio=wav_data, note_croppings=note_croppings, instrument_families=instrument_families, )
def soft_multivariate_quantiles(x, quantiles, quantile_width=None, **kwargs): """Computes soft multivariate quantiles via optimal transport. Transport multivariate input values in x onto 2^d + 1 weighted points, {0,1}^d + [0.5, ..., 0.5]. Target weights are adjusted so that those values in x that are transported to the middle value in the target vector correspond to those concentrating around the quantile of interest. Args: x: Tensor<float> of shape [batch, N, d] quantiles: Tensor<float> of shape [r, d], r targeted quantiles of dimension d quantile_width: (float) mass given to the bucket supposed to attract points whose value concentrate around the desired quantile value. Bigger width means that we allow the soft quantile to be a mixture of more points further away from the quantile. If None, the width is set at 1/n where n is the number of values considered (the size along the 'axis'). **kwargs: see sinkhorn.autodiff_sinkhorn for possible extra parameters. Returns: A Tensor<float> [N,r,d] of multivariate quantiles per batch. """ quantiles = tf.constant(quantiles, tf.float32) batch_size = x.shape[0] n = tf.cast(x.shape[1], tf.float32) d = x.shape[2] if quantile_width is None: quantile_width = 2 / n num_quantiles = tf.shape(quantiles)[0] hypercube_vertices = tf.constant( list(itertools.product([-1, 1], repeat=d)), tf.float32) # weights attached to vertices for each quantile. this is n_quantiles x 2^r weights = quantiles[:, tf.newaxis, :]**(0.5 * (1 - hypercube_vertices))[tf.newaxis, Ellipsis] weights *= (1 - quantiles)[:, tf.newaxis, :]**( 0.5 * (1 + hypercube_vertices))[tf.newaxis, Ellipsis] weights = (1 - quantile_width) * tf.reduce_prod(weights, axis=2) # adding weights for quantile itself (in position 0). weights = tf.concat((quantile_width * tf.ones( (num_quantiles, 1)), weights), axis=1) # augmenting and formating as batch_size * 2^r +1 * num_quantiles weights = tf.reshape(tf.tile(tf.transpose(weights), [batch_size, 1]), [batch_size, 2**d + 1, num_quantiles]) # set target locations, by adding the point at 0 that will absorb the quantile # augment it with batch_size y = tf.concat((tf.zeros((1, d), dtype=tf.float32), hypercube_vertices), axis=0) y = tf.reshape(tf.tile(y, [batch_size, 1]), [batch_size, 2**d + 1, d]) # center x x_mean = tf.reduce_mean(x, axis=1) x = x - x_mean[:, tf.newaxis, :] transports = sinkhorn.autodiff_sinkhorn( x, y, tf.ones([batch_size, n, num_quantiles], dtype=tf.float32) / n, weights, **kwargs) # recover convex combinations resulting from transporting to central point in # in all batches and quantile variations. transports = 1 / quantile_width * tf.reshape(transports[:, :, 0, :], [batch_size, n, -1]) # apply these convex combinations to data points + recenter. all_soft_quantiles = tf.reduce_sum( transports[:, :, :, tf.newaxis] * x[:, :, tf.newaxis, :], axis=1) + x_mean[:, tf.newaxis, :] # reshape those quantiles after having applied convex combinations. return tf.reshape(all_soft_quantiles, [batch_size, num_quantiles, d])
def one_step(self, current_state, previous_kernel_results, seed=None): seed = samplers.sanitize_seed(seed) # Retain for diagnostics. start_trajectory_seed, loop_seed = samplers.split_seed(seed) with tf.name_scope(self.name + '.one_step'): state_structure = current_state current_state = tf.nest.flatten(current_state) if (tf.nest.is_nested(state_structure) and (not mcmc_util.is_list_like(state_structure) or len(current_state) != len(state_structure))): # TODO(b/170865194): Support dictionaries and other non-list-like state. raise TypeError( 'NUTS does not currently support nested or ' 'non-list-like state structures (saw: {}).'.format( state_structure)) current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob, seed=start_trajectory_seed) def _copy(v): return v * ps.ones(ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray(tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack( self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree)), body=lambda iter_, seed, state, metastate: self. _loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate, seed), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), loop_seed, initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=(new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy, seed=seed, ) result_state = tf.nest.pack_sequence_as( state_structure, new_step_metastate.candidate_state.state) return result_state, kernel_results
def _sample_n(self, n, seed=None): dim0_seed, otherdims_seed = samplers.split_seed( seed, salt='von_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). mean_direction = tf.convert_to_tensor(self.mean_direction) concentration = tf.convert_to_tensor(self.concentration) event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor(mean_direction=mean_direction)[0]) sample_batch_shape = ps.concat( [[n], self._batch_shape_tensor(mean_direction=mean_direction, concentration=concentration)], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, mean_direction=mean_direction, concentration=concentration, seed=dim0_seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * concentration + tf.sqrt(4 * concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue, seed): del w, seed return tf.reduce_any(should_continue) def body_fn(w, should_continue, seed): """While loop body for sampling the angle `w`.""" beta_seed, unif_seed, next_seed = samplers.split_seed(seed, n=3) z = beta.sample(sample_shape=sample_batch_shape, seed=beta_seed) # set_shape needed here because of b/139013403 tensorshape_util.set_shape(z, w.shape) w = tf.where(should_continue, (1. - (1. + b) * z) / (1. - (1. - b) * z), w) if not self.allow_nan_stats: w = tf.debugging.check_numerics(w, 'w') unif = samplers.uniform(sample_batch_shape, seed=unif_seed, dtype=self.dtype) # set_shape needed here because of b/139013403 tensorshape_util.set_shape(unif, w.shape) should_continue = should_continue & ( concentration * w + dim * tf.math.log1p(-x * w) - c < # Use log1p(-unif) to prevent log(0) and ensure that log(1) is # possible. tf.math.log1p(-unif)) return w, should_continue, next_seed w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0, _, _ = tf.while_loop(cond=cond_fn, body=body_fn, loop_vars=(w, should_continue, dim0_seed)) samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01)), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01)), ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = ps.concat( [sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.math.l2_normalize(samplers.normal( samples_otherdims_shape, seed=otherdims_seed, dtype=self.dtype), axis=-1) samples = tf.concat( [ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.math.l2_normalize(samples, axis=-1) if not self.allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self.allow_nan_stats: worst, _ = tf.math.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near(dtype_util.as_numpy_dtype( self.dtype)(0), worst, atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self.allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast( tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm(self._rotate( basis, mean_direction=mean_direction) - mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples, mean_direction=mean_direction) return self._rotate(samples, mean_direction=mean_direction)
def _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name=None): with tf.name_scope('build_sub_tree'): batch_shape = ps.shape(current_step_meta_info.init_energy) # We never want to select the inital state if MULTINOMIAL_SAMPLE: init_weight = tf.fill( batch_shape, tf.constant( -np.inf, dtype=current_step_meta_info.init_energy.dtype)) else: init_weight = tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE) init_momentum_cumsum = [ tf.zeros_like(x) for x in initial_state.momentum ] initial_state_candidate = TreeDoublingStateCandidate( state=initial_state.state, target=initial_state.target, target_grad_parts=initial_state.target_grad_parts, energy=initial_state.target, weight=init_weight) energy_diff_sum = tf.zeros_like(current_step_meta_info.init_energy, name='energy_diff_sum') [ _, _, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken, final_state, candidate_tree_state, final_continue_tree, final_not_divergence, momentum_state_memory, ] = tf.while_loop( cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory: ( (iter_ < nsteps) & tf.reduce_any(continue_tree)), body=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory: (self._loop_build_sub_tree( directions, integrator, current_step_meta_info, iter_, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory, seed)), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), seed, energy_diff_sum, init_momentum_cumsum, tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE), initial_state, initial_state_candidate, continue_tree, not_divergence, momentum_state_memory, ), parallel_iterations=self.parallel_iterations) return ( candidate_tree_state, final_state, final_not_divergence, final_continue_tree, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken, )
def __init__(self, kernel, index_points=None, mean_fn=None, observation_noise_variance=0., jitter=1e-6, validate_args=False, allow_nan_stats=False, name='GaussianProcess'): """Instantiate a GaussianProcess Distribution. Args: kernel: `PositiveSemidefiniteKernel`-like instance representing the GP's covariance function. index_points: `float` `Tensor` representing finite (batch of) vector(s) of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to a `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. mean_fn: Python `callable` that acts on `index_points` to produce a (batch of) vector(s) of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies constant zero function. observation_noise_variance: `float` `Tensor` representing (batch of) scalar variance(s) of the noise in the Normal likelihood distribution of the model. If batched, the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: "GaussianProcess". Raises: ValueError: if `mean_fn` is not `None` and is not callable. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [index_points, observation_noise_variance, jitter], tf.float32) if index_points is not None: index_points = tf.convert_to_tensor(index_points, dtype=dtype, name='index_points') jitter = tf.convert_to_tensor(jitter, dtype=dtype, name='jitter') observation_noise_variance = tf.convert_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') self._kernel = kernel self._index_points = index_points # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._mean_fn = mean_fn self._observation_noise_variance = observation_noise_variance self._jitter = jitter graph_parents = [observation_noise_variance, jitter] if index_points is not None: graph_parents.append(index_points) with tf.name_scope('init'): super(GaussianProcess, self).__init__( dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=graph_parents, name=name)
def ulp_fn(w): zeros = tf.zeros([x_size, 0]) return model.unnormalized_log_prob(w, zeros)
def _interp_regular_1d_grid_impl(x, x_ref_min, x_ref_max, y_ref, axis=-1, batch_y_ref=False, fill_value='constant_extension', fill_value_below=None, fill_value_above=None, grid_regularizing_transform=None, name=None): """1-D interpolation that works with/without batching.""" # Note: we do *not* make the no-batch version a special case of the batch # version, because that would an inefficient use of batch_gather with # unnecessarily broadcast args. with tf.name_scope(name or 'interp_regular_1d_grid_impl'): # Arg checking. allowed_fv_st = ('constant_extension', 'extrapolate') for fv in (fill_value, fill_value_below, fill_value_above): if isinstance(fv, str) and fv not in allowed_fv_st: raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fv, allowed_fv_st)) # Separate value fills for below/above incurs extra cost, so keep track of # whether this is needed. need_separate_fills = ( fill_value_above is not None or fill_value_below is not None or fill_value == 'extrapolate' # always requries separate below/above ) if need_separate_fills and fill_value_above is None: fill_value_above = fill_value if need_separate_fills and fill_value_below is None: fill_value_below = fill_value dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, name='x', dtype=dtype) x_ref_min = tf.convert_to_tensor(x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(x_ref_max, name='x_ref_max', dtype=dtype) if not batch_y_ref: _assert_ndims_statically(x_ref_min, expect_ndims=0) _assert_ndims_statically(x_ref_max, expect_ndims=0) y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) if batch_y_ref: # If we're batching, # x.shape ~ [A1,...,AN, D], x_ref_min/max.shape ~ [A1,...,AN] # So to add together we'll append a singleton. # If not batching, x_ref_min/max are scalar, so this isn't an issue, # moreover, if not batching, x can be scalar, and expanding x_ref_min/max # would cause a bad expansion of x when added to x (confused yet?). x_ref_min = x_ref_min[..., tf.newaxis] x_ref_max = x_ref_max[..., tf.newaxis] axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32) axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref)) _assert_ndims_statically(axis, expect_ndims=0) ny = tf.cast(tf.shape(y_ref)[axis], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. if grid_regularizing_transform is None: g = lambda x: x else: g = grid_regularizing_transform fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min))) x_idx_unclipped = fractional_idx * (ny - 1) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) zero = tf.zeros((), dtype=dtype) x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped) x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we 'jitter' one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) if batch_y_ref: # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN], # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D] # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN] y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32, axis) y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32, axis) else: # Here, y_ref_below.shape = # y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:] y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis) y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis) # Use t to get a convex combination of the below/above values. t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. # Make functions appropriate for batch/no-batch. if batch_y_ref: # In the non-batch case, the output shape is going to be # y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_batch_interpolation( y_ref, axis) else: # In the batch case, the output shape is going to be # Broadcast(y_ref.shape[:axis], x.shape[:-1]) + # x.shape[-1:] + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation( y_ref, axis) t = expand_x_fn(t) nan_idx = expand_x_fn(nan_idx, broadcast=True) x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True) y = t * y_ref_above + (1 - t) * y_ref_below # Now begins a long excursion to fill values outside [x_min, x_max]. # Re-insert NaN wherever x was NaN. y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y) if not need_separate_fills: if fill_value == 'constant_extension': pass # Already handled by clipping x_idx_unclipped. else: y = tf.where( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), fill_value, y) else: # Fill values below x_ref_min <==> x_idx_unclipped < 0. if fill_value_below == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_below == 'extrapolate': if batch_y_ref: # For every batch member, gather the first two elements of y across # `axis`. y_0 = tf.gather(y_ref, [0], axis=axis) y_1 = tf.gather(y_ref, [1], axis=axis) else: # If not batching, we want to gather the first two elements, just like # above. However, these results need to be replicated for every # member of x. An easy way to do that is to gather using # indices = zeros/ones(x.shape). y_0 = tf.gather(y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis) y_1 = tf.gather(y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_min) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y) else: y = tf.where(x_idx_unclipped < 0, fill_value_below, y) # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1. if fill_value_above == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_above == 'extrapolate': ny_int32 = tf.shape(y_ref)[axis] if batch_y_ref: y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1], axis=axis) y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2], axis=axis) else: y_n1 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis) y_n2 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_max) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped > ny - 1, y_n1 + x_factor * (y_n1 - y_n2), y) else: y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y) return y
def set_model(self, model): """Sets Keras model and creates summary ops.""" self.model = model self._init_writer(model) # histogram summaries only enabled in graph mode if not tf.executing_eagerly(): self._make_histogram_ops(model) self.merged = tf.compat.v1.summary.merge_all() # If both embedding_freq and embeddings_data are available, we will # visualize embeddings. if self.embeddings_freq and self.embeddings_data is not None: # Avoid circular dependency. from keras.engine import ( training_utils_v1, ) # pylint: disable=g-import-not-at-top self.embeddings_data = training_utils_v1.standardize_input_data( self.embeddings_data, model.input_names) # If embedding_layer_names are not provided, get all of the embedding # layers from the model. embeddings_layer_names = self.embeddings_layer_names if not embeddings_layer_names: embeddings_layer_names = [ layer.name for layer in self.model.layers if type(layer).__name__ == "Embedding" ] self.assign_embeddings = [] embeddings_vars = {} self.batch_id = batch_id = tf.compat.v1.placeholder(tf.int32) self.step = step = tf.compat.v1.placeholder(tf.int32) for layer in self.model.layers: if layer.name in embeddings_layer_names: embedding_input = self.model.get_layer(layer.name).output embedding_size = np.prod(embedding_input.shape[1:]) embedding_input = tf.reshape(embedding_input, (step, int(embedding_size))) shape = ( self.embeddings_data[0].shape[0], int(embedding_size), ) embedding = tf.Variable(tf.zeros(shape), name=layer.name + "_embedding") embeddings_vars[layer.name] = embedding batch = tf.compat.v1.assign( embedding[batch_id:batch_id + step], embedding_input) self.assign_embeddings.append(batch) self.saver = tf.compat.v1.train.Saver( list(embeddings_vars.values())) # Create embeddings_metadata dictionary if isinstance(self.embeddings_metadata, str): embeddings_metadata = { layer_name: self.embeddings_metadata for layer_name in embeddings_vars.keys() } else: # If embedding_metadata is already a dictionary embeddings_metadata = self.embeddings_metadata try: from tensorboard.plugins import projector except ImportError: raise ImportError( "Failed to import TensorBoard. Please make sure that " 'TensorBoard integration is complete."') # TODO(psv): Add integration tests to test embedding visualization # with TensorBoard callback. We are unable to write a unit test for this # because TensorBoard dependency assumes TensorFlow package is installed. config = projector.ProjectorConfig() for layer_name, tensor in embeddings_vars.items(): embedding = config.embeddings.add() embedding.tensor_name = tensor.name if (embeddings_metadata is not None and layer_name in embeddings_metadata): embedding.metadata_path = embeddings_metadata[layer_name] projector.visualize_embeddings(self.writer, config)
def testNonEmptyConstantTensor(self): x = tf.zeros([2, 3, 4]) rank = distribution_util.prefer_static_rank(x) if not tf.executing_eagerly(): self.assertIsInstance(rank, np.ndarray) self.assertEqual(3, rank)