def _loop_body(step, previous_step_results, accumulated_traced_results, num_steps_traced): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. ps.equal(num_transitions_per_observation, 1) | ps.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) new_step_results = _filter_one_step( step=step, previous_step_results=previous_step_results, observation=current_observation, transition_fn=transition_fn, observation_fn=observation_fn, proposal_fn=proposal_fn, resample_criterion_fn=resample_criterion_fn, resample_fn=resample_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables( step=step, current_step_results=new_step_results, accumulated_traced_results=accumulated_traced_results, trace_fn=trace_fn, step_indices_to_trace=step_indices_to_trace, num_steps_traced=num_steps_traced)
def _scan(level, elems): """Perform scan on `elems`.""" elem_length = prefer_static.shape(elems[0])[0] # Apply `fn` to reduce adjacent pairs to a single entry. a = [elem[0:-1:2] for elem in elems] b = [elem[1::2] for elem in elems] reduced_elems = lowered_fn(a, b) def handle_base_case_elem_length_two(): return [tf.concat([elem[0:1], reduced_elem], axis=0) for (reduced_elem, elem) in zip(reduced_elems, elems)] def handle_base_case_elem_length_three(): reduced_reduced_elems = lowered_fn( reduced_elems, [elem[2:3] for elem in elems]) return [ tf.concat([elem[0:1], reduced_elem, reduced_reduced_elem], axis=0) for (reduced_reduced_elem, reduced_elem, elem) in zip(reduced_reduced_elems, reduced_elems, elems)] # Base case of recursion: assumes `elem_length` is 2 or 3. at_base_case = prefer_static.logical_or( prefer_static.equal(elem_length, 2), prefer_static.equal(elem_length, 3)) base_value = lambda: prefer_static.cond( # pylint: disable=g-long-lambda prefer_static.equal(elem_length, 2), handle_base_case_elem_length_two, handle_base_case_elem_length_three) if level <= 0: return base_value() def recursive_case(): """Evaluate the next step of the recursion.""" odd_elems = _scan(level - 1, reduced_elems) def even_length_case(): return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems], [elem[2::2] for elem in elems]) def odd_length_case(): return lowered_fn([odd_elem for odd_elem in odd_elems], [elem[2::2] for elem in elems]) results = prefer_static.cond( prefer_static.equal(elem_length % 2, 0), even_length_case, odd_length_case) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [tf.concat([elem[0:1], result], axis=0) for (elem, result) in zip(elems, results)] return list(map(_interleave, even_elems, odd_elems)) return prefer_static.cond(at_base_case, base_value, recursive_case)
def _initialize_loop_variables(initial_step_results, num_timesteps, trace_fn, step_indices_to_trace): """Initialize arrays and other quantities passed through the filter loop.""" # Create arrays to store traced values (particles, likelihoods, etc). num_steps_to_trace = (num_timesteps if step_indices_to_trace is None else ps.size0(step_indices_to_trace)) traced_results = trace_fn(initial_step_results) trace_arrays = tf.nest.map_structure( lambda x: tf.TensorArray(dtype=x.dtype, size=num_steps_to_trace), traced_results) # If we are supposed to trace at step 0, write the traced values. num_steps_traced, trace_arrays = ps.cond( (True if step_indices_to_trace is None else ps.equal( step_indices_to_trace[0], 0)), lambda: ( 1, # pylint: disable=g-long-lambda tf.nest.map_structure(lambda ta, x: ta.write(0, x), trace_arrays, traced_results)), lambda: (0, trace_arrays)) return ParticleFilterLoopVariables( step=1, previous_step_results=initial_step_results, accumulated_traced_results=trace_arrays, num_steps_traced=num_steps_traced)
def _interleave(a, b, axis): """Interleaves two `Tensor`s along the given axis.""" # [a b c ...] [d e f ...] -> [a d b e c f ...] num_elems_a = ps.shape(a)[axis] num_elems_b = ps.shape(b)[axis] # Note that interleaving implies rank(a)==rank(b). axis = ps.where(axis >= 0, axis, ps.rank(a) + axis) axis = (int(axis) # Avoid ndarray values. if tf.get_static_value(axis) is not None else axis) def _interleave_with_b(a): return tf.reshape( # Work around lack of support for Tensor axes in `tf.stack` by using # `concat` and `expand_dims` instead. tf.concat([tf.expand_dims(a, axis=axis + 1), tf.expand_dims(b, axis=axis + 1)], axis=axis + 1), ps.concat( [ ps.shape(a)[:axis], [2 * num_elems_b], ps.shape(a)[axis + 1:] ], axis=0)) return ps.cond( ps.equal(num_elems_a, num_elems_b + 1), lambda: tf.concat([ # pylint: disable=g-long-lambda _interleave_with_b(_slice_along_axis(a, None, -1, axis=axis)), _slice_along_axis(a, -1, None, axis=axis)], axis=axis), lambda: _interleave_with_b(a))
def recursive_case(): """Evaluate the next step of the recursion.""" odd_elems = _scan(level - 1, reduced_elems) def even_length_case(): return lowered_fn( [slice_elem(odd_elem, 0, -1) for odd_elem in odd_elems], [slice_elem(elem, 2, None, 2) for elem in elems]) def odd_length_case(): return lowered_fn([odd_elem for odd_elem in odd_elems], [slice_elem(elem, 2, None, 2) for elem in elems]) results = ps.cond( ps.equal(elem_length % 2, 0), even_length_case, odd_length_case) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [tf.concat([slice_elem(elem, 0, 1), result], axis=axis) for (elem, result) in zip(elems, results)] return list(map(lambda a, b: _interleave(a, b, axis=axis), even_elems, odd_elems))
def _update_loop_variables(step, current_step_results, accumulated_traced_results, trace_fn, step_indices_to_trace, num_steps_traced): """Update the loop state to reflect a step of filtering.""" # Write particles, indices, and likelihoods to their respective arrays. trace_this_step = True if step_indices_to_trace is not None: trace_this_step = ps.equal( step_indices_to_trace[ps.minimum( num_steps_traced, ps.cast(ps.size0(step_indices_to_trace) - 1, dtype=np.int32))], step) num_steps_traced, accumulated_traced_results = ps.cond( trace_this_step, lambda: ( num_steps_traced + 1, # pylint: disable=g-long-lambda tf.nest.map_structure(lambda x, y: x.write(num_steps_traced, y), accumulated_traced_results, trace_fn(current_step_results))), lambda: (num_steps_traced, accumulated_traced_results)) return ParticleFilterLoopVariables( step=step + 1, previous_step_results=current_step_results, accumulated_traced_results=accumulated_traced_results, num_steps_traced=num_steps_traced)
def _is_increasing(self, **kwargs): # desc(desc)=>asc, asc(asc)=>asc, other cases=>desc. is_increasing = True for b in self._bijectors: is_increasing = ps.equal( is_increasing, b._internal_is_increasing(**kwargs.get(b.name, {}))) # pylint: disable=protected-access return is_increasing
def _parameter_control_dependencies(self, is_init): if not self.validate_args: # Avoid computing intermediates needed to construct the assertions. return [] assertions = [] if is_init != tensor_util.is_ref(self._batch_shape_unexpanded): implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1) assertions.append( assert_util.assert_rank(self._batch_shape_unexpanded, 1, message='New shape must be a vector.')) assertions.append( assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1, message='At most one dimension can be unknown.')) assertions.append( assert_util.assert_non_negative( self._batch_shape_unexpanded + 1, message='Shape elements must be >=-1.')) # Check that the old and new shapes are the same size. expanded_new_shape, original_size = self._calculate_new_shape() new_size = ps.reduce_prod(expanded_new_shape) assertions.append( assert_util.assert_equal(new_size, tf.cast(original_size, new_size.dtype), message='Shape sizes do not match.')) return assertions
def test_step_indices_to_trace(self): num_particles = 1024 (particles_1_3, log_weights_1_3, parent_indices_1_3, incremental_log_marginal_likelihood_1_3) = self.evaluate( tfp.experimental.mcmc.particle_filter( observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), initial_state_prior=tfd.Normal(0., 1.), transition_fn=lambda _, state: tfd.Normal(state, 10.), observation_fn=lambda _, state: tfd.Normal(state, 0.1), num_particles=num_particles, trace_criterion_fn=lambda s, r: ps.logical_or( # pylint: disable=g-long-lambda ps.equal(r.steps, 2), ps.equal(r.steps, 4)), static_trace_allocation_size=2, seed=test_util.test_seed())) self.assertLen(particles_1_3, 2) self.assertLen(log_weights_1_3, 2) self.assertLen(parent_indices_1_3, 2) self.assertLen(incremental_log_marginal_likelihood_1_3, 2) means = np.sum(np.exp(log_weights_1_3) * particles_1_3, axis=1) self.assertAllClose(means, [3., 7.], atol=1.) (final_particles, final_log_weights, final_cumulative_lp) = self.evaluate( tfp.experimental.mcmc.particle_filter( observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), initial_state_prior=tfd.Normal(0., 1.), transition_fn=lambda _, state: tfd.Normal(state, 10.), observation_fn=lambda _, state: tfd.Normal(state, 0.1), num_particles=num_particles, trace_fn=lambda s, r: (s.particles, # pylint: disable=g-long-lambda s.log_weights, r.accumulated_log_marginal_likelihood), trace_criterion_fn=None, seed=test_util.test_seed())) self.assertLen(final_particles, num_particles) self.assertLen(final_log_weights, num_particles) self.assertEqual(final_cumulative_lp.shape, ()) means = np.sum(np.exp(final_log_weights) * final_particles) self.assertAllClose(means, 9., atol=1.5)
def _loop_body(step, previous_step_results, accumulated_step_results, state_history): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. prefer_static.equal(num_transitions_per_observation, 1) | prefer_static.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) history_to_pass_into_fns = {} if num_steps_observation_history_to_pass: history_to_pass_into_fns[ 'observation_history'] = _gather_history( observations, observation_idx, num_steps_observation_history_to_pass) if num_steps_state_history_to_pass: history_to_pass_into_fns['state_history'] = state_history new_step_results = _filter_one_step( step=step, previous_particles=previous_step_results.particles, log_weights=previous_step_results.log_weights, observation=current_observation, transition_fn=functools.partial(transition_fn, **history_to_pass_into_fns), observation_fn=functools.partial(observation_fn, **history_to_pass_into_fns), proposal_fn=(None if proposal_fn is None else functools.partial( proposal_fn, **history_to_pass_into_fns)), resample_criterion_fn=resample_criterion_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables(step, new_step_results, accumulated_step_results, state_history)
def _compute_observation_log_weights(step, particles, observations, observation_fn, num_transitions_per_observation=1): """Computes particle importance weights from an observation step. Args: step: int `Tensor` current step. particles: Nested structure of `Tensor`s, each of shape `concat([[num_particles, b1, ..., bN], event_shape])`, where `b1, ..., bN` are optional batch dimensions and `event_shape` may differ across `Tensor`s. observations: Nested structure of `Tensor`s, each of shape `concat([[num_observations, b1, ..., bN], event_shape])` where `b1, ..., bN` are optional batch dimensions and `event_shape` may differ across `Tensor`s. observation_fn: callable with signature `observation_dist = observation_fn(step, particles)`, producing a batch of distributions over the `observation` at the given `step`, one for each particle. num_transitions_per_observation: optional int `Tensor` number of times to apply the transition model between successive observation steps. Default value: `1`. Returns: log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`. """ with tf.name_scope('compute_observation_log_weights'): step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. ps.equal(num_transitions_per_observation, 1) | ps.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) log_weights = observation_fn(step, particles).log_prob(observation) return ps.where(step_has_observation, log_weights, tf.zeros_like(log_weights))
def _interleave(a, b): """Interleaves two `Tensor`s along their first axis.""" # [a b c ...] [d e f ...] -> [a d b e c f ...] num_elems_a = prefer_static.shape(a)[0] num_elems_b = prefer_static.shape(b)[0] def _interleave_with_b(a): return tf.reshape( tf.stack([a, b], axis=1), prefer_static.concat([[2 * num_elems_b], prefer_static.shape(a)[1:]], axis=0)) return prefer_static.cond( prefer_static.equal(num_elems_a, num_elems_b + 1), lambda: tf.concat([_interleave_with_b(a[:-1]), a[-1:]], axis=0), lambda: _interleave_with_b(a))
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps): """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc.""" step_indices_to_trace = tf.convert_to_tensor( step_indices_to_trace, dtype_hint=tf.int32) # Warning: breaks gradients. traced_steps_have_rank_zero = ps.equal( ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0) # Canonicalize negative step indices as positive. step_indices_to_trace = ps.where(step_indices_to_trace < 0, num_timesteps + step_indices_to_trace, step_indices_to_trace) # Canonicalize scalars as length-one vectors. return (ps.reshape(step_indices_to_trace, [ps.size(step_indices_to_trace)]), traced_steps_have_rank_zero)
def _validate_elem_length(max_num_levels, elems_flat): """Checks that elems all have the same length, and returns that length.""" assertions = [] elem_length = prefer_static.shape(elems_flat[0])[0] # The default size limit will overflow a 32-bit int, so make sure we're # using 64-bit. size_limit = 2**(prefer_static.cast(max_num_levels, np.int64) + 1) enough_levels = prefer_static.less( prefer_static.cast(elem_length, np.int64), size_limit) enough_levels_ = tf.get_static_value(enough_levels) if enough_levels_ is None: assertions.append( tf.debugging.assert_equal( enough_levels, True, message='Input `Tensor`s must have first axis dimension less than' ' `2**(max_num_levels + 1)`' ' (saw: {} which is not less than 2**{} == {})'.format( elem_length, max_num_levels, size_limit))) elif not enough_levels_: raise ValueError( 'Input `Tensor`s must have first axis dimension less than' ' `2**(max_num_levels + 1)`' ' (saw: {} which is not less than 2**{} == {})'.format( elem_length, max_num_levels, size_limit)) is_consistent = prefer_static.reduce_all([ prefer_static.equal( prefer_static.shape(elem)[0], elem_length) for elem in elems_flat[1:]]) is_consistent_ = tf.get_static_value(is_consistent) if is_consistent_ is None: assertions.append( tf.debugging.assert_equal( is_consistent, True, message='Input `Tensor`s must have the same first dimension.' ' (saw: {})'.format([elem.shape for elem in elems_flat]))) elif not is_consistent_: raise ValueError( 'Input `Tensor`s must have the same first dimension.' ' (saw: {})'.format([elem.shape for elem in elems_flat])) return elem_length, assertions
def _matmul(self, x, adjoint=False, adjoint_arg=False): x1, x2 = self._x1_x2() if (self._num_matmul_parts is None or prefer_static.equal(self._num_matmul_parts, 1)): return tf.matmul(self._kernel().matrix(x1, x2), x, adjoint_a=adjoint, adjoint_b=adjoint_arg) if adjoint or adjoint_arg: raise NotImplementedError( '`adjoint`, `adjoint_arg` NYI when `num_matmul_parts` specified.' ) return _chunked_matmul(kernel_fn=self.kernel_fn, kernel_args=self.kernel_args, x1=x1, x2=x2, x=x, num_matmul_parts=self._num_matmul_parts, operator_shape=self.shape_tensor())
def recursive_case(): """Evaluate the next step of the recursion.""" odd_elems = _scan(level - 1, reduced_elems) def even_length_case(): return lowered_fn([odd_elem[:-1] for odd_elem in odd_elems], [elem[2::2] for elem in elems]) def odd_length_case(): return lowered_fn([odd_elem for odd_elem in odd_elems], [elem[2::2] for elem in elems]) results = prefer_static.cond( prefer_static.equal(elem_length % 2, 0), even_length_case, odd_length_case) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [tf.concat([elem[0:1], result], axis=0) for (elem, result) in zip(elems, results)] return list(map(_interleave, even_elems, odd_elems))
def _calculate_new_shape(self): # Try to get the old shape statically if available. original_shape = self._distribution.batch_shape if not tensorshape_util.is_fully_defined(original_shape): original_shape = self._distribution.batch_shape_tensor() # This is not a check for falseness, it's a check for exactly that shape. if original_shape == (): # pylint: disable=g-explicit-bool-comparison # Force the size to be an integer, not a float, when the shape contains no # dtype information. original_size = 1 else: original_size = ps.reduce_prod(original_shape) original_size = ps.cast(original_size, tf.int32) # Compute the new shape, filling in the `-1` dimension if present. new_shape = self._batch_shape_unexpanded implicit_dim_mask = ps.equal(new_shape, -1) size_implicit_dim = (original_size // ps.maximum(1, -ps.reduce_prod(new_shape))) expanded_new_shape = ps.where( # Assumes exactly one `-1`. implicit_dim_mask, size_implicit_dim, new_shape) # Return the original size on the side because one caller would otherwise # have to recompute it. return expanded_new_shape, original_size
def one_step(self, state, kernel_results, seed=None): """Takes one Sequential Monte Carlo inference step. Args: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing the current particles with (log) weights. The `log_weights` must be a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The `particles` may be any structure of `Tensor`s, each of which must have shape `concat([log_weights.shape, event_shape])` for some `event_shape`, which may vary across components. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results from a previous step. seed: Optional seed for reproducible sampling. Returns: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing new particles with (log) weights. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults`. """ with tf.name_scope(self.name): with tf.name_scope('one_step'): seed = samplers.sanitize_seed(seed) proposal_seed, resample_seed = samplers.split_seed(seed) state = WeightedParticles(*state) # Canonicalize. num_particles = ps.size0(state.log_weights) # Propose new particles and update weights for this step, unless it's # the initial step, in which case, use the user-provided initial # particles and weights. proposed_state = self.propose_and_update_log_weights_fn( # Propose state[t] from state[t - 1]. ps.maximum(0, kernel_results.steps - 1), state, seed=proposal_seed) is_initial_step = ps.equal(kernel_results.steps, 0) # TODO(davmre): this `where` assumes the state size didn't change. state = tf.nest.map_structure( lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state) normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=0) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. incremental_log_marginal_likelihood = ( state.log_weights[0] - normalized_log_weights[0]) do_resample = self.resample_criterion_fn(state) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a # dealbreaker. resampled_particles, resample_indices = weighted_resampling.resample( state.particles, state.log_weights, self.resample_fn, seed=resample_seed) uniform_weights = tf.fill( ps.shape(state.log_weights), value=-tf.math.log( tf.cast(num_particles, state.log_weights.dtype))) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: ps.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (state.particles, _dummy_indices_like(resample_indices), normalized_log_weights)) return ( WeightedParticles(particles=resampled_particles, log_weights=log_weights), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, incremental_log_marginal_likelihood=( incremental_log_marginal_likelihood), accumulated_log_marginal_likelihood=( kernel_results.accumulated_log_marginal_likelihood + incremental_log_marginal_likelihood), seed=seed))
def _has_nonzero_rank(self, override_shape): return prefer_static.logical_not( prefer_static.equal(prefer_static.rank_from_shape(override_shape), self._zero))
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.linalg.diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) with tf.control_dependencies(self._assertions(x)): # Create a vector equal to: [p, p-1, ..., 2, 1]. if tf.compat.dimension_value(x.shape[-1]) is None: p_int = tf.shape(x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = tf.compat.dimension_value(x.shape[-1]) p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze(tf.matmul( tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = ps.shape(fldj) maybe_squeeze_shape = ps.concat([ shape[:-1], distribution_util.pick_vector(ps.equal( ps.rank(x), 2), np.array([], dtype=np.int32), shape[-1:]) ], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def _is_scalar_from_shape_tensor(shape): """Returns `True` `Tensor` if `Tensor` shape implies a scalar.""" return prefer_static.equal(prefer_static.rank_from_shape(shape), 0)
def __init__(self, distribution, bijector, batch_shape=None, event_shape=None, kwargs_split_fn=_default_kwargs_split_fn, validate_args=False, parameters=None, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event()`. kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns a tuple of kwargs `dict`s for each of the `distribution` and `bijector` parameters respectively. Default value: `_default_kwargs_split_fn` (i.e., `lambda kwargs: (kwargs.get('distribution_kwargs', {}), kwargs.get('bijector_kwargs', {}))`) validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. parameters: Locals dict captured by subclass constructor, to be used for copy/slice re-instantiation operations. name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ parameters = dict(locals()) if parameters is None else parameters name = name or (("" if bijector is None else bijector.name) + (distribution.name or "")) with tf.name_scope(name) as name: self._kwargs_split_fn = (_default_kwargs_split_fn if kwargs_split_fn is None else kwargs_split_fn) # For convenience we define some handy constants. self._zero = tf.constant(0, dtype=tf.int32, name="zero") self._empty = tf.constant([], dtype=tf.int32, name="empty") # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph # execution, including possibly raising Python exceptions. self._override_batch_shape = self._maybe_validate_shape_override( batch_shape, distribution.is_scalar_batch(), validate_args, "batch_shape") self._is_batch_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_batch_shape), self._zero)) self._is_maybe_batch_override = bool( tf.get_static_value(self._override_batch_shape) is None or tf.get_static_value(self._override_batch_shape).size != 0) self._override_event_shape = self._maybe_validate_shape_override( event_shape, distribution.is_scalar_event(), validate_args, "event_shape") self._is_event_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_event_shape), self._zero)) self._is_maybe_event_override = bool( tf.get_static_value(self._override_event_shape) is None or tf.get_static_value(self._override_event_shape).size != 0) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has batch dims # and we're overriding event shape. When that case happens the event dims # will incorrectly be to the left of the batch dims. In this case we'll # cyclically permute left the new dims. self._needs_rotation = prefer_static.reduce_all([ self._is_event_override, prefer_static.logical_not(self._is_batch_override), prefer_static.logical_not(distribution.is_scalar_batch()) ]) override_event_ndims = prefer_static.rank_from_shape( self._override_event_shape) self._rotate_ndims = _pick_scalar_condition( self._needs_rotation, override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = tf.range( self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=( distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)
def _get_search_direction(state): """Computes the search direction to follow at the current state. On the `k`-th iteration of the main L-BFGS algorithm, the state has collected the most recent `m` correction pairs in position_deltas and gradient_deltas, where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`. Assuming these, the code below is an implementation of the L-BFGS two-loop recursion algorithm given by [Nocedal and Wright(2006)][1]: ```None q_direction = objective_gradient for i in reversed(range(m)): # First loop. inv_rho[i] = gradient_deltas[i]^T * position_deltas[i] alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i] q_direction = q_direction - alpha[i] * gradient_deltas[i] kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] / gradient_deltas[-1]^T * gradient_deltas[-1]) r_direction = kth_inv_hessian_factor * I * q_direction for i in range(m): # Second loop. beta = gradient_deltas[i]^T * r_direction / inv_rho[i] r_direction = r_direction + position_deltas[i] * (alpha[i] - beta) return -r_direction # Approximates - H_k * objective_gradient. ``` Args: state: A `LBfgsOptimizerResults` tuple with the current state of the search procedure. Returns: A real `Tensor` of the same shape as the `state.position`. The direction along which to perform line search. """ # The number of correction pairs that have been collected so far. num_elements = ps.minimum( state.num_iterations, # TODO(b/162733947): Change loop state -> closure. ps.shape(state.position_deltas)[0]) def _two_loop_algorithm(): """L-BFGS two-loop algorithm.""" # Correction pairs are always appended to the end, so only the latest # `num_elements` vectors have valid position/gradient deltas. Vectors # that haven't been computed yet are zero. position_deltas = state.position_deltas gradient_deltas = state.gradient_deltas # Pre-compute all `inv_rho[i]`s. inv_rhos = tf.reduce_sum( gradient_deltas * position_deltas, axis=-1) def first_loop(acc, args): _, q_direction = acc position_delta, gradient_delta, inv_rho = args alpha = tf.math.divide_no_nan( tf.reduce_sum(position_delta * q_direction, axis=-1), inv_rho) direction_delta = alpha[..., tf.newaxis] * gradient_delta return (alpha, q_direction - direction_delta) # Run first loop body computing and collecting `alpha[i]`s, while also # computing the updated `q_direction` at each step. zero = tf.zeros_like(inv_rhos[-num_elements]) alphas, q_directions = tf.scan( first_loop, [position_deltas, gradient_deltas, inv_rhos], initializer=(zero, state.objective_gradient), reverse=True) # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`. gamma_k = inv_rhos[-1] / tf.reduce_sum( gradient_deltas[-1] * gradient_deltas[-1], axis=-1) r_direction = gamma_k[..., tf.newaxis] * q_directions[-num_elements] def second_loop(r_direction, args): alpha, position_delta, gradient_delta, inv_rho = args beta = tf.math.divide_no_nan( tf.reduce_sum(gradient_delta * r_direction, axis=-1), inv_rho) direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta return r_direction + direction_delta # Finally, run second loop body computing the updated `r_direction` at each # step. r_directions = tf.scan( second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos], initializer=r_direction) return -r_directions[-1] return ps.cond(ps.equal(num_elements, 0), lambda: -state.objective_gradient, _two_loop_algorithm)
def _scan(level, elems): """Perform scan on `elems`.""" elem_length = ps.shape(elems[0])[axis] # Apply `fn` to reduce adjacent pairs to a single entry. a = [slice_elem(elem, 0, -1, step=2) for elem in elems] b = [slice_elem(elem, 1, None, step=2) for elem in elems] reduced_elems = lowered_fn(a, b) def handle_base_case_elem_length_two(): return [tf.concat([slice_elem(elem, 0, 1), reduced_elem], axis=axis) for (reduced_elem, elem) in zip(reduced_elems, elems)] def handle_base_case_elem_length_three(): reduced_reduced_elems = lowered_fn( reduced_elems, [slice_elem(elem, 2, 3) for elem in elems]) return [ tf.concat([slice_elem(elem, 0, 1), # pylint: disable=g-complex-comprehension reduced_elem, reduced_reduced_elem], axis=axis) for (reduced_reduced_elem, reduced_elem, elem) in zip(reduced_reduced_elems, reduced_elems, elems)] # Base case of recursion: assumes `elem_length` is 2 or 3. at_base_case = ps.logical_or( ps.equal(elem_length, 2), ps.equal(elem_length, 3)) base_value = lambda: ps.cond( # pylint: disable=g-long-lambda ps.equal(elem_length, 2), handle_base_case_elem_length_two, handle_base_case_elem_length_three) if level <= 0: return base_value() def recursive_case(): """Evaluate the next step of the recursion.""" odd_elems = _scan(level - 1, reduced_elems) def even_length_case(): return lowered_fn( [slice_elem(odd_elem, 0, -1) for odd_elem in odd_elems], [slice_elem(elem, 2, None, 2) for elem in elems]) def odd_length_case(): return lowered_fn([odd_elem for odd_elem in odd_elems], [slice_elem(elem, 2, None, 2) for elem in elems]) results = ps.cond( ps.equal(elem_length % 2, 0), even_length_case, odd_length_case) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [tf.concat([slice_elem(elem, 0, 1), result], axis=axis) for (elem, result) in zip(elems, results)] return list(map(lambda a, b: _interleave(a, b, axis=axis), even_elems, odd_elems)) return ps.cond(at_base_case, base_value, recursive_case)
def _is_odd_integer(x): return ps.equal(x, ps.round(x)) & ps.not_equal(2. * ps.floor(x / 2.), x)