def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` ' 'argument. `TransitionKernel` instances now receive seeds via ' '`one_step`.') seed = samplers.sanitize_seed(seed) # Retain for diagnostics. inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3) # Step the inner TransitionKernel. [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, seed=inner_seed) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = bu.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = bu.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs for use in the swap acceptance ratio. if self.tempered_log_prob_fn is None: # Efficient way of re-evaluating target_log_prob_fn on the # pre_swap_replica_states. untempered_negative_energy_ignoring_ulp = ( # Since untempered_log_prob_fn is None, we may assume # inverse_temperatures > 0 (else the target is improper). pre_swap_replica_target_log_prob / inverse_temperatures) else: # The untempered_log_prob_fn does not factor into the acceptance ratio. # Proof: Suppose the tempered target is # p_k(x) = f(x)^{beta_k} g(x), # So f(x) is tempered, and g(x) is not. Then, the acceptance ratio for # a 1 <--> 2 swap is... # (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2)) # which depends only on f(x), since terms involving g(x) cancel. untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn( *pre_swap_replica_states) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. # Note: The untempered_log_prob_fn (if provided) is not included in # untempered_pre_swap_replica_target_log_prob, and hence does not factor # into energy_diff. Why? Because, it cancels out in the acceptance ratio. energy_diff = (untempered_negative_energy_ignoring_ulp - mcmc_util.index_remapping_gather( untempered_negative_energy_ignoring_ulp, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * bu.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = bu.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, bu.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _set_swapped_fields_to_nan( _swap_log_prob_and_maybe_grads(pre_swap_replica_results, post_swap_replica_states, inner_kernel)) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=seed, potential_energy=-untempered_negative_energy_ignoring_ulp, ) return states, post_swap_kernel_results
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(self.name + '.one_step'): unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob) def _copy(v): return v * prefer_static.ones(prefer_static.pad( [2], paddings=[[0, prefer_static.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=2**(self.max_tree_depth - 1), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray( tf.int32, size=2**(self.max_tree_depth - 1), clear_after_read=False).unstack(self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, state, metastate: ( # pylint: disable=g-long-lambda ((iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree))), body=lambda iter_, state, metastate: self.loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), initial_step_state, initial_step_metastate), parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), # TODO(junpenglao): return non-cumulated leapfrogs_taken once # benchmarking is done. leapfrogs_taken=(previous_kernel_results.leapfrogs_taken + new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy) result_state = new_step_metastate.candidate_state.state if unwrap_state_list: result_state = result_state[0] return result_state, kernel_results
def _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, name=None): with tf.name_scope('build_sub_tree'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) # We never want to select the inital state if MULTINOMIAL_SAMPLE: init_weight = tf.fill( batch_shape, tf.constant( -np.inf, dtype=current_step_meta_info.init_energy.dtype)) else: init_weight = tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE) init_momentum_cumsum = [ tf.zeros_like(x) for x in initial_state.momentum ] initial_state_candidate = TreeDoublingStateCandidate( state=initial_state.state, target=initial_state.target, target_grad_parts=initial_state.target_grad_parts, energy=initial_state.target, weight=init_weight) energy_diff_sum = tf.zeros_like(current_step_meta_info.init_energy, name='energy_diff_sum') [ _, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken, final_state, candidate_tree_state, final_continue_tree, final_not_divergence, momentum_state_memory, ] = tf.while_loop( cond=lambda iter_, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory: ( (iter_ < nsteps) & tf.reduce_any(continue_tree)), body=lambda iter_, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory: (self._loop_build_sub_tree( directions, integrator, current_step_meta_info, iter_, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), energy_diff_sum, init_momentum_cumsum, tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE), initial_state, initial_state_candidate, continue_tree, not_divergence, momentum_state_memory, ), parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS, ) return ( candidate_tree_state, final_state, final_not_divergence, final_continue_tree, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken, )
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name='Moyal'): """Construct Moyal distributions with location and scale `loc` and `scale`. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g. `loc + scale` is a valid operation). Args: loc: Floating point tensor, the means of the distribution(s). scale: Floating point tensor, the scales of the distribution(s). scale must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: `'Moyal'`. Raises: TypeError: if loc and scale are different dtypes. #### References [1] J.E. Moyal, "XXX. Theory of ionization fluctuations", The London, Edinburgh, and Dublin Philosophical Magazine and Journal of Science. https://www.tandfonline.com/doi/abs/10.1080/14786440308521076 [2] G. Cordeiro, J. Nobre, R. Pescim, E. Ortega, "The beta Moyal: a useful skew distribution", https://www.arpapress.com/Volumes/Vol10Issue2/IJRRAS_10_2_02.pdf """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) dtype_util.assert_same_float_dtype([loc, scale]) # Positive scale is asserted by the incorporated Moyal bijector. self._moyal_bijector = moyal_cdf_bijector.MoyalCDF( loc=loc, scale=scale, validate_args=validate_args) # Because the uniform sampler generates samples in `[0, 1)` this would # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny` # because it is the smallest, positive, 'normal' number. super(Moyal, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the # batch shape instead of tf.ones. distribution=uniform.Uniform(low=np.finfo( dtype_util.as_numpy_dtype(dtype)).tiny, high=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats), # The Moyal bijector encodes the CDF function as the forward, # and hence needs to be inverted. bijector=invert_bijector.Invert(self._moyal_bijector, validate_args=validate_args), parameters=parameters, name=name)
def sample_lkj(num_samples, dimension, concentration, cholesky_space=False, seed=None, name=None): """Returns a Tensor of samples from an LKJ distribution. Args: num_samples: Python `int`. The number of samples to draw. dimension: Python `int`. The dimension of correlation matrices. concentration: `Tensor` representing the concentration of the LKJ distribution. cholesky_space: Python `bool`. Whether to take samples from LKJ or Chol(LKJ). seed: Python integer seed for RNG name: Python `str` name prefixed to Ops created by this function. Returns: samples: A Tensor of correlation matrices (or Cholesky factors of correlation matrices if `cholesky_space = True`) with shape `[n] + B + [D, D]`, where `B` is the shape of the `concentration` parameter, and `D` is the `dimension`. Raises: ValueError: If `dimension` is negative. """ if dimension < 0: raise ValueError( 'Cannot sample negative-dimension correlation matrices.') # Notation below: B is the batch shape, i.e., tf.shape(concentration) # We need 1 seed for beta corr12, and 2 per loop iter. num_seeds = 1 + 2 * max(0, dimension - 2) seeds = list(samplers.split_seed(seed, n=num_seeds, salt='sample_lkj')) with tf.name_scope('sample_lkj' or name): concentration = tf.convert_to_tensor(concentration) if not dtype_util.is_floating(concentration.dtype): raise TypeError( 'The concentration argument should have floating type, not ' '{}'.format(dtype_util.name(concentration.dtype))) concentration = _replicate(num_samples, concentration) concentration_shape = ps.shape(concentration) if dimension <= 1: # For any dimension <= 1, there is only one possible correlation matrix. shape = ps.concat([concentration_shape, [dimension, dimension]], axis=0) return tf.ones(shape=shape, dtype=concentration.dtype) beta_conc = concentration + (dimension - 2.) / 2. beta_dist = beta.Beta(concentration1=beta_conc, concentration0=beta_conc) # Note that the sampler below deviates from [1], by doing the sampling in # cholesky space. This does not change the fundamental logic of the # sampler, but does speed up the sampling. # This is the correlation coefficient between the first two dimensions. # This is also `r` in reference [1]. corr12 = 2. * beta_dist.sample(seed=seeds.pop()) - 1. # Below we construct the Cholesky of the initial 2x2 correlation matrix, # which is of the form: # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the # first two dimensions. # This is the top-left corner of the cholesky of the final sample. first_row = tf.concat([ tf.ones_like(corr12)[..., tf.newaxis], tf.zeros_like(corr12)[..., tf.newaxis] ], axis=-1) second_row = tf.concat( [corr12[..., tf.newaxis], tf.sqrt(1 - corr12**2)[..., tf.newaxis]], axis=-1) chol_result = tf.concat( [first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]], axis=-2) for n in range(2, dimension): # Loop invariant: on entry, result has shape B + [n, n] beta_conc = beta_conc - 0.5 # norm is y in reference [1]. norm = beta.Beta(concentration1=n / 2., concentration0=beta_conc).sample(seed=seeds.pop()) # distance shape: B + [1] for broadcast distance = tf.sqrt(norm)[..., tf.newaxis] # direction is u in reference [1]. # direction shape: B + [n] direction = _uniform_unit_norm(n, concentration_shape, concentration.dtype, seed=seeds.pop()) # raw_correlation is w in reference [1]. raw_correlation = distance * direction # shape: B + [n] # This is the next row in the cholesky of the result, # which differs from the construction in reference [1]. # In the reference, the new row `z` = chol_result @ raw_correlation^T # = C @ raw_correlation^T (where as short hand we use C = chol_result). # We prove that the below equation is the right row to add to the # cholesky, by showing equality with reference [1]. # Let S be the sample constructed so far, and let `z` be as in # reference [1]. Then at this iteration, the new sample S' will be # [[S z^T] # [z 1]] # In our case we have the cholesky decomposition factor C, so # we want our new row x (same size as z) to satisfy: # [[S z^T] [[C 0] [[C^T x^T] [[CC^T Cx^T] # [z 1]] = [x k]] [0 k]] = [xC^t xx^T + k**2]] # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible, # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 - # distance**2). new_row = tf.concat( [raw_correlation, tf.sqrt(1. - norm[..., tf.newaxis])], axis=-1) # Finally add this new row, by growing the cholesky of the result. chol_result = tf.concat([ chol_result, tf.zeros_like(chol_result[..., 0][..., tf.newaxis]) ], axis=-1) chol_result = tf.concat([chol_result, new_row[..., tf.newaxis, :]], axis=-2) assert not seeds, 'Did not use all seeds: ' + len(seeds) if cholesky_space: return chol_result result = tf.matmul(chol_result, chol_result, transpose_b=True) # The diagonal for a correlation matrix should always be ones. Due to # numerical instability the matmul might not achieve that, so manually set # these to ones. result = tf.linalg.set_diag( result, tf.ones(shape=ps.shape(result)[:-1], dtype=result.dtype)) # This sampling algorithm can produce near-PSD matrices on which standard # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals` # fail. Specifically, as documented in b/116828694, around 2% of trials # of 900,000 5x5 matrices (distributed according to 9 different # concentration parameter values) contained at least one matrix on which # the Cholesky decomposition failed. return result
def __init__(self, distribution, bijector, batch_shape=None, event_shape=None, kwargs_split_fn=_default_kwargs_split_fn, validate_args=False, parameters=None, name=None): """Construct a Transformed Distribution. Args: distribution: The base distribution instance to transform. Typically an instance of `Distribution`. bijector: The object responsible for calculating the transformation. Typically an instance of `Bijector`. batch_shape: `integer` vector `Tensor` which overrides `distribution` `batch_shape`; valid only if `distribution.is_scalar_batch()`. event_shape: `integer` vector `Tensor` which overrides `distribution` `event_shape`; valid only if `distribution.is_scalar_event()`. kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns a tuple of kwargs `dict`s for each of the `distribution` and `bijector` parameters respectively. Default value: `_default_kwargs_split_fn` (i.e., `lambda kwargs: (kwargs.get('distribution_kwargs', {}), kwargs.get('bijector_kwargs', {}))`) validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. parameters: Locals dict captured by subclass constructor, to be used for copy/slice re-instantiation operations. name: Python `str` name prefixed to Ops created by this class. Default: `bijector.name + distribution.name`. """ parameters = dict(locals()) if parameters is None else parameters name = name or (("" if bijector is None else bijector.name) + (distribution.name or "")) with tf.name_scope(name) as name: self._kwargs_split_fn = (_default_kwargs_split_fn if kwargs_split_fn is None else kwargs_split_fn) # For convenience we define some handy constants. self._zero = tf.constant(0, dtype=tf.int32, name="zero") self._empty = tf.constant([], dtype=tf.int32, name="empty") # We will keep track of a static and dynamic version of # self._is_{batch,event}_override. This way we can do more prior to graph # execution, including possibly raising Python exceptions. self._override_batch_shape = self._maybe_validate_shape_override( batch_shape, distribution.is_scalar_batch(), validate_args, "batch_shape") self._is_batch_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_batch_shape), self._zero)) self._is_maybe_batch_override = bool( tf.get_static_value(self._override_batch_shape) is None or tf.get_static_value(self._override_batch_shape).size != 0) self._override_event_shape = self._maybe_validate_shape_override( event_shape, distribution.is_scalar_event(), validate_args, "event_shape") self._is_event_override = prefer_static.logical_not( prefer_static.equal( prefer_static.rank_from_shape(self._override_event_shape), self._zero)) self._is_maybe_event_override = bool( tf.get_static_value(self._override_event_shape) is None or tf.get_static_value(self._override_event_shape).size != 0) # To convert a scalar distribution into a multivariate distribution we # will draw dims from the sample dims, which are otherwise iid. This is # easy to do except in the case that the base distribution has batch dims # and we're overriding event shape. When that case happens the event dims # will incorrectly be to the left of the batch dims. In this case we'll # cyclically permute left the new dims. self._needs_rotation = prefer_static.reduce_all([ self._is_event_override, prefer_static.logical_not(self._is_batch_override), prefer_static.logical_not(distribution.is_scalar_batch()) ]) override_event_ndims = prefer_static.rank_from_shape( self._override_event_shape) self._rotate_ndims = _pick_scalar_condition( self._needs_rotation, override_event_ndims, 0) # We'll be reducing the head dims (if at all), i.e., this will be [] # if we don't need to reduce. self._reduce_event_indices = tf.range( self._rotate_ndims - override_event_ndims, self._rotate_ndims) self._distribution = distribution self._bijector = bijector super(TransformedDistribution, self).__init__( dtype=self._distribution.dtype, reparameterization_type=self._distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=self._distribution.allow_nan_stats, parameters=parameters, # We let TransformedDistribution access _graph_parents since this class # is more like a baseclass than derived. graph_parents=( distribution._graph_parents + # pylint: disable=protected-access bijector.graph_parents), name=name)
def _z(self, x): """Standardize input `x`.""" with tf.name_scope("standardize"): return (x - self.loc) / self.scale
def maybe_update_along_axis(*, tensor, new_tensor, axis, ind, do_update, dtype=None, name=None): """Replace `tensor` entries with `new_tensor` along a given axis. This updates elements of `tensor` that correspond to the elements returned by `numpy.take(updated, ind, axis)` with the corresponding elements of `new_tensor`. # Example ```python tensor = tf.ones([5, 4, 3, 2]) new_tensor = tf.zeros([5, 4, 3, 2]) updated_tensor = maybe_update_along_axis(tensor=tensor, new_tensor=new_tensor, axis=1, ind=2, do_update=True) # Returns a `Tensor` of ones where # `updated_tensor[:, 2, :, :].numpy() == 0` ``` If the `do_update` is set to `False`, then the update does not happen unless the number of dimensions along the `axis` is equal to 1. This functionality is useful when, for example, aggregating samples of an Ito process. Args: tensor: A `Tensor` of any shape and `dtype`. new_tensor: A `Tensor` of the same `dtype` as `tensor` and of shape broadcastable with `tensor`. axis: A Python integer. The axis of `tensor` along which the elements have to be updated. ind: An int32 scalar `Tensor` that denotes an index on the `axis` which defines the updated slice of `tensor` (see example above). do_update: A bool scalar `Tensor`. If `False`, the output is the same as `tensor`, unless the dimension of the `tensor` along the `axis` is equal to 1. dtype: The `dtype` of the input `Tensor`s. Default value: `None` which means that default dtypes inferred by TensorFlow are used. name: Python string. The name to give this op. Default value: `None` which maps to `maybe_update_along_axis`. Returns: A `Tensor` of the same shape and `dtype` as `tensor`. """ name = name or 'maybe_update_along_axis' with tf.name_scope(name): tensor = tf.convert_to_tensor(tensor, dtype=dtype, name='tensor') dtype = tensor.dtype new_tensor = tf.convert_to_tensor(new_tensor, dtype=dtype, name='new_tensor') ind = tf.convert_to_tensor(ind, name='ind') do_update = tf.convert_to_tensor(do_update, name='do_update') size_along_axis = tensor.shape.as_list()[axis] def _write_update_to_result(): size_along_axis_dynamic = tf.shape(tensor)[axis] one_hot = tf.one_hot(ind, depth=size_along_axis_dynamic) mask_size = tensor.shape.rank mask_shape = tf.pad([size_along_axis_dynamic], paddings=[[axis, mask_size - axis - 1]], constant_values=1) mask = tf.reshape(one_hot > 0, mask_shape) return tf.where(mask, new_tensor, tensor) # Update only if size_along_axis > 1 or if the shape is dynamic if size_along_axis is None or size_along_axis > 1: return tf.cond(do_update, _write_update_to_result, lambda: tensor) else: return new_tensor
def generate_mc_normal_draws(num_normal_draws, num_time_steps, num_sample_paths, random_type, skip=0, seed=None, dtype=None, name=None): """Generates normal random samples to be consumed by a Monte Carlo algorithm. Many of Monte Carlo (MC) algorithms can be re-written so that all necessary random (or quasi-random) variables are drawn in advance as a `Tensor` of shape `[num_time_steps, num_samples, num_normal_draws]`, where `num_time_steps` is the number of time steps Monte Carlo algorithm performs, `num_sample_paths` is a number of sample paths of the Monte Carlo algorithm and `num_normal_draws` is a number of independent normal draws per sample paths. For example, in order to use quasi-random numbers in a Monte Carlo algorithm, the samples have to be drawn in advance. The function generates a `Tensor`, say, `x` in a format such that for a quasi-`random_type` `x[i]` is correspond to different dimensions of the quasi-random sequence, so that it can be used in a Monte Carlo algorithm Args: num_normal_draws: A scalar int32 `Tensor`. The number of independent normal draws at each time step for each sample path. Should be a graph compilation constant. num_time_steps: A scalar int32 `Tensor`. The number of time steps at which to draw the independent normal samples. Should be a graph compilation constant. num_sample_paths: A scalar int32 `Tensor`. The number of trajectories (e.g., Monte Carlo paths) for which to draw the independent normal samples. Should be a graph compilation constant. random_type: Enum value of `tff.math.random.RandomType`. The type of (quasi)-random number generator to use to generate the paths. skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or Halton sequence to skip. Used only when `random_type` is 'SOBOL', 'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored. Default value: `0`. seed: Seed for the random number generator. The seed is only relevant if `random_type` is one of `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC, STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and `HALTON_RANDOMIZED` the seed should be an Python integer. For `STATELESS` and `STATELESS_ANTITHETIC `must be supplied as an integer `Tensor` of shape `[2]`. Default value: `None` which means no seed is set. dtype: The `dtype` of the output `Tensor`. Default value: `None` which maps to `float32`. name: Python string. The name to give this op. Default value: `None` which maps to `generate_mc_normal_draws`. Returns: A `Tensor` of shape `[num_time_steps, num_sample_paths, num_normal_draws]`. """ if name is None: name = 'generate_mc_normal_draws' if skip is None: skip = 0 with tf.name_scope(name): if dtype is None: dtype = tf.float32 # In case of quasi-random draws, the total dimension of the draws should be # `num_time_steps * dim` total_dimension = tf.zeros([num_time_steps * num_normal_draws], dtype=dtype, name='total_dimension') normal_draws = random.mv_normal_sample([num_sample_paths], mean=total_dimension, random_type=random_type, seed=seed, skip=skip) # Reshape and transpose normal_draws = tf.reshape( normal_draws, [num_sample_paths, num_time_steps, num_normal_draws]) # Shape [steps_num, num_samples, dim] normal_draws = tf.transpose(normal_draws, [1, 0, 2]) return normal_draws
def __init__(self, df, scale_operator, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name=None): """Construct Wishart distributions. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must be greater than or equal to `k`. scale_operator: `float` or `double` instance of `LinearOperator`. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype ValueError: if df < k, where scale operator event shape is `(k, k)` """ parameters = dict(locals()) self._input_output_cholesky = input_output_cholesky with tf.name_scope(name) as name: with tf.name_scope("init"): if not dtype_util.is_floating(scale_operator.dtype): raise TypeError( "scale_operator.dtype=%s is not a floating-point type" % scale_operator.dtype) if not scale_operator.is_square: print(scale_operator.to_dense().eval()) raise ValueError("scale_operator must be square.") self._scale_operator = scale_operator self._df = tf.convert_to_tensor(df, dtype=scale_operator.dtype, name="df") dtype_util.assert_same_float_dtype( [self._df, self._scale_operator]) if tf.compat.dimension_value( self._scale_operator.shape[-1]) is None: self._dimension = tf.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = tf.convert_to_tensor( tf.compat.dimension_value( self._scale_operator.shape[-1]), dtype=self._scale_operator.dtype, name="dimension") df_val = tf.get_static_value(self._df) dim_val = tf.get_static_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) if not df_val.shape: df_val = [df_val] if np.any(df_val < dim_val): raise ValueError( "Degrees of freedom (df = %s) cannot be less than " "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = assert_util.assert_less_equal( self._dimension, self._df, message=("Degrees of freedom (df = %s) cannot be " "less than dimension of scale matrix " "(scale.dimension = %s)" % (self._dimension, self._df))) self._df = distribution_util.with_dependencies( [assertions], self._df) super(_WishartLinearOperator, self).__init__( dtype=self._scale_operator.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=([self._df, self._dimension] + self._scale_operator.graph_parents), name=name)
def __init__(self, distributions, dtype_override=None, validate_args=False, allow_nan_stats=False, name='Blockwise'): """Construct the `Blockwise` distribution. Args: distributions: Python `list` of `tfp.distributions.Distribution` instances. All distribution instances must have the same `batch_shape` and all must have `event_ndims==1`, i.e., be vector-variate distributions. dtype_override: samples of `distributions` will be cast to this `dtype`. If unspecified, all `distributions` must have the same `dtype`. Default value: `None` (i.e., do not cast). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._assertions = _maybe_validate_distributions( distributions, dtype_override, validate_args) if dtype_override is not None: dtype = dtype_override else: dtype = set( dtype_util.base_dtype(d.dtype) for d in distributions if d.dtype is not None) if len(dtype) == 0: # pylint: disable=g-explicit-length-test dtype = tf.float32 elif len(dtype) == 1: dtype = dtype.pop() else: # Shouldn't be here: we already threw an exception in # `_maybe_validate_distributions`. raise ValueError('Internal Error: unable to resolve `dtype`.') reparameterization_type = set(d.reparameterization_type for d in distributions) reparameterization_type = (reparameterization_type.pop() if len(reparameterization_type) == 1 else reparameterization.NOT_REPARAMETERIZED) self._distributions = distributions super(Blockwise, self).__init__( dtype=dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization_type, parameters=parameters, graph_parents=_model_flatten(d._graph_parents for d in distributions), # pylint: disable=protected-access name=name)
def __init__(self, df, scale=None, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name="Wishart"): """Construct Wishart distributions. Args: df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than or equal to dimension of the scale matrix. scale: `float` or `double` `Tensor`. The symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. scale_tril: `float` or `double` `Tensor`. The Cholesky factorization of the symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if zero or both of 'scale' and 'scale_tril' are passed in. """ parameters = dict(locals()) with tf.name_scope(name) as name: with tf.name_scope("init"): if (scale is None) == (scale_tril is None): raise ValueError( "Must pass scale or scale_tril, but not both.") dtype = dtype_util.common_dtype([df, scale, scale_tril], tf.float32) df = tf.convert_to_tensor(df, name="df", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype) if validate_args: scale = distribution_util.assert_symmetric(scale) scale_tril = tf.linalg.cholesky(scale) else: # scale_tril is not None scale_tril = tf.convert_to_tensor(scale_tril, name="scale_tril", dtype=dtype) if validate_args: scale_tril = distribution_util.with_dependencies([ assert_util.assert_positive( tf.linalg.diag_part(scale_tril), message="scale_tril must be positive definite" ), assert_util.assert_equal( tf.shape(scale_tril)[-1], tf.shape(scale_tril)[-2], message="scale_tril must be square") ], scale_tril) super(Wishart, self).__init__( df=df, scale_operator=tf.linalg.LinearOperatorLowerTriangular( tril=scale_tril, is_non_singular=True, is_positive_definite=True, is_square=True), input_output_cholesky=input_output_cholesky, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, shift=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, adjoint=False, validate_args=False, name="affine", dtype=None): """Instantiates the `Affine` bijector. This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, giving the forward operation: ```none Y = g(X) = scale @ X + shift ``` where the `scale` term is logically equivalent to: ```python scale = ( scale_identity_multiplier * tf.diag(tf.ones(d)) + tf.diag(scale_diag) + scale_tril + scale_perturb_factor @ diag(scale_perturb_diag) @ tf.transpose([scale_perturb_factor]) ) ``` If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are specified then `scale += IdentityMatrix`. Otherwise specifying a `scale` argument has the semantics of `scale += Expand(arg)`, i.e., `scale_diag != None` means `scale += tf.diag(scale_diag)`. Args: shift: Floating-point `Tensor`. If this is set to `None`, no shift is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag = scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape `[N1, N2, ... k]`, which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the lower triangular matrix. `scale_tril` has shape `[N1, N2, ... k, k]`, which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape `[N1, N2, ... r]`, which represents an `r x r` diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. adjoint: Python `bool` indicating whether to use the `scale` matrix as specified or its adjoint. Default value: `False`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. dtype: `tf.DType` to prefer when converting args to `Tensor`s. Else, we fall back to a common dtype inferred from the args, finally falling back to float32. Raises: ValueError: if `perturb_diag` is specified but not `perturb_factor`. TypeError: if `shift` has different `dtype` from `scale` arguments. """ # Ambiguous definition of low rank update. if scale_perturb_diag is not None and scale_perturb_factor is None: raise ValueError("When scale_perturb_diag is specified, " "scale_perturb_factor must be specified.") # Special case, only handling a scaled identity matrix. We don't know its # dimensions, so this is special cased. # We don't check identity_multiplier, since below we set it to 1. if all # other scale args are None. self._is_only_identity_multiplier = (scale_tril is None and scale_diag is None and scale_perturb_factor is None) with tf.name_scope(name) as name: self._name = name self._validate_args = validate_args if dtype is None: dtype = dtype_util.common_dtype([ shift, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_diag, scale_perturb_factor ], tf.float32) if shift is not None: shift = tf.convert_to_tensor(shift, name="shift", dtype=dtype) self._shift = shift # When no args are specified, pretend the scale matrix is the identity # matrix. if (self._is_only_identity_multiplier and scale_identity_multiplier is None): scale_identity_multiplier = tf.convert_to_tensor(1., dtype=dtype) # self._create_scale_operator returns a LinearOperator in all cases # except if self._is_only_identity_multiplier; in which case it # returns a scalar Tensor. scale = self._create_scale_operator( identity_multiplier=scale_identity_multiplier, diag=scale_diag, tril=scale_tril, perturb_diag=scale_perturb_diag, perturb_factor=scale_perturb_factor, shift=shift, validate_args=validate_args, dtype=dtype) if scale is not None and not self._is_only_identity_multiplier: if (shift is not None and not dtype_util.base_equal(shift.dtype, scale.dtype)): raise TypeError( "shift.dtype({}) is incompatible with scale.dtype({})." .format(shift.dtype, scale.dtype)) self._scale = scale self._adjoint = adjoint super(Affine, self).__init__(forward_min_event_ndims=1, is_constant_jacobian=True, dtype=dtype, validate_args=validate_args, name=name)
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') if self._state_includes_replicas: it_n_replica = inverse_temperatures.shape[0] state_n_replica = init_state[0].shape[0] if ((it_n_replica is not None) and (state_n_replica is not None) and (it_n_replica != state_n_replica)): raise ValueError( 'Number of replicas implied by initial state ({}) must equal ' 'number of replicas implied by inverse_temperatures ({}), but ' 'did not'.format(state_n_replica, it_n_replica)) # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = ps.size0(inverse_temperatures) replica_shape = ps.convert_to_shape_tensor([num_replica]) if self._state_includes_replicas: replica_states = init_state else: replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, ps.concat([replica_shape, ps.shape(x)], axis=0), name='replica_states') for x in init_state ] target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second ' '(`seed`) argument. `TransitionKernel` instances now receive seeds ' 'via `one_step`.') replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = bu.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = bu.left_justified_broadcast_to(tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=_set_swapped_fields_to_nan( replica_results), is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, step_count=tf.zeros(shape=(), dtype=tf.int32), seed=samplers.zeros_seed(), potential_energy=tf.zeros_like( pre_swap_replica_target_log_prob), )
def state_y(self, t: types.RealTensor, name: str = None) -> types.RealTensor: """Computes the state variable `y(t)` for tha Gaussian HJM Model. For Gaussian HJM model, the state parameter y(t), can be analytically computed as follows: y_ij(t) = exp(-k_i * t) * exp(-k_j * t) * ( int_0^t rho_ij * sigma_i(u) * sigma_j(u) * du) Args: t: A rank 1 real `Tensor` of shape `[num_times]` specifying the time `t`. name: Python string. The name to give to the ops created by this function. Default value: `None` which maps to the default name `state_y`. Returns: A real `Tensor` of shape [self._factors, self._factors, num_times] containing the computed y_ij(t). """ name = name or 'state_y' with tf.name_scope(name): t = tf.convert_to_tensor(t, dtype=self._dtype) t_shape = tf.shape(t) t = tf.broadcast_to(t, tf.concat([[self._dim], t_shape], axis=0)) time_index = tf.searchsorted(self._jump_locations, t) # create a matrix k2(i,j) = k(i) + k(j) mr2 = tf.expand_dims(self._mean_reversion, axis=-1) # Add a dimension corresponding to `num_times` mr2 = tf.expand_dims(mr2 + tf.transpose(mr2), axis=-1) def _integrate_volatility_squared(vol, l_limit, u_limit): # create sigma2_ij = sigma_i * sigma_j vol = tf.expand_dims(vol, axis=-2) vol_squared = tf.expand_dims(self._rho, axis=-1) * ( vol * tf.transpose(vol, perm=[1, 0, 2])) return vol_squared / mr2 * (tf.math.exp(mr2 * u_limit) - tf.math.exp(mr2 * l_limit)) is_constant_vol = tf.math.equal( tf.shape(self._jump_values_vol)[-1], 0) v_squared_between_vol_knots = tf.cond( is_constant_vol, lambda: tf.zeros(shape=(self._dim, self._dim, 0), dtype=self._dtype), lambda: _integrate_volatility_squared( # pylint: disable=g-long-lambda self._jump_values_vol, self._padded_knots, self. _jump_locations)) v_squared_at_vol_knots = tf.concat([ tf.zeros((self._dim, self._dim, 1), dtype=self._dtype), utils.cumsum_using_matvec(v_squared_between_vol_knots) ], axis=-1) vn = tf.concat([self._zero_padding, self._jump_locations], axis=1) v_squared_t = _integrate_volatility_squared( self._volatility(t), tf.gather(vn, time_index, batch_dims=1), t) v_squared_t += tf.gather(v_squared_at_vol_knots, time_index, batch_dims=-1) return tf.math.exp(-mr2 * t) * v_squared_t
def __init__(self, df, kernel, index_points=None, mean_fn=None, observation_noise_variance=0., marginal_fn=None, cholesky_fn=None, jitter=1e-6, validate_args=False, allow_nan_stats=False, name='StudentTProcess'): """Instantiate a StudentTProcess Distribution. Args: df: Positive Floating-point `Tensor` representing the degrees of freedom. Must be greater than 2. kernel: `PositiveSemidefiniteKernel`-like instance representing the TP's covariance function. index_points: `float` `Tensor` representing finite (batch of) vector(s) of points in the index set over which the TP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to a `e`-dimensional multivariate Student's T. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. mean_fn: Python `callable` that acts on `index_points` to produce a (batch of) vector(s) of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies constant zero function. observation_noise_variance: `float` `Tensor` representing (batch of) scalar variance(s) of the noise in the Normal likelihood distribution of the model. If batched, the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` marginal_fn: A Python callable that takes a location, covariance matrix, optional `validate_args`, `allow_nan_stats` and `name` arguments, and returns a multivariate normal subclass of `tfd.Distribution`. Default value: `None`, in which case a Cholesky-factorizing function is created using `make_cholesky_factored_marginal_fn` and the `jitter` argument. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn` is used with the `jitter` parameter. At most one of `cholesky_fn` and `marginal_fn` should be set. jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. This argument is ignored if `cholesky_fn` is set. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: "StudentTProcess". Raises: ValueError: if `mean_fn` is not `None` and is not callable. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [df, kernel, index_points, observation_noise_variance, jitter], tf.float32) df = tensor_util.convert_nonref_to_tensor(df, dtype=dtype, name='df') observation_noise_variance = tensor_util.convert_nonref_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') jitter = tensor_util.convert_nonref_to_tensor(jitter, dtype=dtype, name='jitter') self._kernel = kernel self._index_points = index_points # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._df = df self._observation_noise_variance = observation_noise_variance self._mean_fn = mean_fn self._jitter = jitter self._cholesky_fn = cholesky_fn if marginal_fn is not None and cholesky_fn is not None: raise ValueError( 'At most one of `marginal_fn` and `cholesky_fn` should be set.' ) if marginal_fn is None: if self._cholesky_fn is None: self._cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn( jitter) self._marginal_fn = make_cholesky_factored_marginal_fn( self._cholesky_fn) else: self._marginal_fn = marginal_fn with tf.name_scope('init'): super(StudentTProcess, self).__init__( dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, short_position: types.BoolTensor, currency: Union[types.CurrencyProtoType, List[types.CurrencyProtoType]], expiry_date: types.DateTensor, equity: List[str], contract_amount: types.FloatTensor, strike: types.FloatTensor, is_call_option: List[bool], business_day_convention: types.BusinessDayConventionProtoType, calendar: types.BankHolidaysProtoType, settlement_days: Optional[types.IntTensor] = 0, discount_curve_type: curve_types_lib.CurveType = None, discount_curve_mask: types.IntTensor = None, equity_mask: types.IntTensor = None, config: Union[AmericanOptionConfig, Dict[str, Any]] = None, batch_names: Optional[types.StringTensor] = None, dtype: Optional[types.Dtype] = None, name: Optional[str] = None): """Initializes the batch of American Equity Options. Args: short_position: Whether the price is computed for the contract holder. Default value: `True` which means that the price is for the contract holder. currency: The denominated currency. expiry_date: A `DateTensor` specifying the dates on which the options expire. equity: A string name of the underlyings. contract_amount: A `Tensor` of real dtype and shape compatible with with `short_position`. strike: `Tensor` of real dtype and shape compatible with with `short_position`. Option strikes. is_call_option: A bool `Tensor` of shape compatible with with `short_position`. Indicates which options are of call type. business_day_convention: A business count convention. calendar: A calendar to specify the weekend mask and bank holidays. settlement_days: An integer `Tensor` of the shape broadcastable with the shape of `fixing_date`. discount_curve_type: An optional instance of `CurveType` or a list of those. If supplied as a list and `discount_curve_mask` is not supplied, the size of the list should be the same as the number of priced instruments. Defines discount curves for the instruments. Default value: `None`, meaning that discount curves are inferred from `currency` and `config`. discount_curve_mask: An optional integer `Tensor` of values ranging from `0` to `len(discount_curve_type) - 1` and of shape `batch_shape`. Identifies a mapping between `discount_curve_type` list and the underlying instruments. Default value: `None`. equity_mask: An optional integer `Tensor` of values ranging from `0` to `len(equity) - 1` and of shape `batch_shape`. Identifies a mapping between `equity` list and the underlying instruments. Default value: `None`. config: Optional `AmericanOptionConfig` or a dictionary. If dictionary, then the keys should be the same as the field names of `AmericanOptionConfig`. batch_names: A string `Tensor` of instrument names. Should be of shape `batch_shape + [2]` specying name and instrument type. This is useful when the `from_protos` method is used and the user needs to identify which instruments got batched together. dtype: `tf.Dtype` of the input and output real `Tensor`s. Default value: `None` which maps to `float64`. name: Python str. The name to give to the ops created by this class. Default value: `None` which maps to 'AmericanOption'. """ self._name = name or "AmericanOption" with tf.name_scope(self._name): if batch_names is not None: self._names = tf.convert_to_tensor(batch_names, name="batch_names") else: self._names = None self._dtype = dtype or tf.float64 ones = tf.constant(1, dtype=self._dtype) self._short_position = tf.where( short_position, ones, -ones, name="short_position") self._contract_amount = tf.convert_to_tensor( contract_amount, dtype=self._dtype, name="contract_amount") self._strike = tf.convert_to_tensor(strike, dtype=self._dtype, name="strike") self._is_call_option = tf.convert_to_tensor( is_call_option, dtype=tf.bool, name="strike") settlement_days = tf.convert_to_tensor(settlement_days) # Business day roll convention and the end of month flag roll_convention, eom = market_data_utils.get_business_day_convention( business_day_convention) # TODO(b/160446193): Calendar is ignored at the moment calendar = dateslib.create_holiday_calendar( weekend_mask=dateslib.WeekendMask.SATURDAY_SUNDAY) if isinstance(expiry_date, types.IntTensor): self._expiry_date = dateslib.dates_from_tensor(expiry_date) else: self._expiry_date = dateslib.convert_to_date_tensor(expiry_date) self._settlement_days = settlement_days self._roll_convention = roll_convention # Get discount and reference curves self._currency = cashflow_streams.to_list(currency) self._equity = cashflow_streams.to_list(equity) if len(self._currency) != len(self._equity): if len(self._currency) > 1 and len(self._equity) > 1: raise ValueError( "Number of currencies and equities should be the same " "but it is {0} and {1}".format(len(self._currency), len(self._equity))) config = _process_config(config) [ self._model, self._num_samples, self._seed, self._num_exercise_times, self._num_calibration_samples ] = _get_config_values(config) if discount_curve_type is None: discount_curve_type = [] for currency in self._currency: if currency in config.discounting_curve: curve_type = config.discounting_curve[currency] else: # Default discounting curve curve_type = curve_types_lib.RiskFreeCurve( currency=currency) discount_curve_type.append(curve_type) # Get masks for discount curves and vol surfaces [ self._discount_curve_type, self._discount_curve_mask ] = cashflow_streams.process_curve_types(discount_curve_type, discount_curve_mask) [ self._equity, self._equity_mask, ] = equity_utils.process_equities(self._equity, equity_mask) # Get batch shape self._batch_shape = tf.shape(strike)
def __init__( self, ndims=2, curvature=0.03, name='banana', pretty_name='Banana', ): """Construct the banana model. Args: ndims: Python integer. Dimensionality of the distribution. Must be at least 2. curvature: Python float. Controls the strength of the curvature of the distribution. name: Python `str` name prefixed to Ops created by this class. pretty_name: A Python `str`. The pretty name of this model. Raises: ValueError: If ndims < 2. """ if ndims < 2: raise ValueError('ndims must be at least 2, saw: {}'.format(ndims)) with tf.name_scope(name): def bijector_fn(x): """Banana transform.""" batch_shape = tf.shape(x)[:-1] shift = tf.concat( [ tf.zeros(tf.concat([batch_shape, [1]], axis=0)), curvature * (tf.square(x[..., :1]) - 100), tf.zeros(tf.concat([batch_shape, [ndims - 2]], axis=0)), ], axis=-1, ) return tfb.Shift(shift) mg = tfd.MultivariateNormalDiag(loc=tf.zeros(ndims), scale_diag=[10.] + [1.] * (ndims - 1)) banana = tfd.TransformedDistribution( mg, bijector=tfb.MaskedAutoregressiveFlow(bijector_fn=bijector_fn)) sample_transformations = { 'identity': model.Model.SampleTransformation( fn=lambda params: params, pretty_name='Identity', # The second dimension is a sum of scaled Chi2 and normal # distribution. # Mean of Chi2 with one degree of freedom is 1, but since the # first element has variance of 100, it cancels with the shift # (which is why the shift is there). ground_truth_mean=onp.zeros(ndims), # Variance of Chi2 with one degree of freedom is 2. ground_truth_standard_deviation=onp.array( [10.] + [onp.sqrt(1. + 2 * curvature**2 * 10.**4)] + [1.] * (ndims - 2)), ) } self._banana = banana super(Banana, self).__init__( default_event_space_bijector=tfb.Identity(), event_shape=banana.event_shape, dtype=banana.dtype, name=name, pretty_name=pretty_name, sample_transformations=sample_transformations, )
def _z(self, x, loc=None, scale=None): """Standardize input `x`.""" loc = tf.convert_to_tensor(self.loc if loc is None else loc) scale = tf.convert_to_tensor(self.scale if scale is None else scale) with tf.name_scope('standardize'): return (x - loc) / scale
def _convolution_batch_nhwbc(x, kernel, rank, strides, padding, dilations, name): """Specialization of batch conv to NHWBC data format.""" with tf.name_scope(name or 'conv2d_nhwbc'): # Prepare arguments. [ rank, _, # strides padding, dilations, _, # data_format ] = prepare_conv_args(rank, strides, padding, dilations) strides = prepare_strides(strides, rank + 2, arg_name='strides') dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel') # Step 1: Transpose and double flatten kernel. # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c'] kernel_shape = prefer_static.shape(kernel) kernel_batch_shape, kernel_event_shape = prefer_static.split( kernel_shape, num_or_size_splits=[-1, rank + 2]) kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape) kernel_ndims = prefer_static.rank(kernel) kernel_batch_ndims = kernel_ndims - rank - 2 perm = prefer_static.concat([ prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank), prefer_static.range(0, kernel_batch_ndims), prefer_static.range(kernel_batch_ndims + rank, kernel_ndims), ], axis=0) # Eg, [1, 2, 0, 3, 4] kernel = tf.transpose(kernel, perm=perm) # F + B + [c, c'] kernel = tf.reshape(kernel, shape=prefer_static.concat([ kernel_event_shape[:rank], [ kernel_batch_size * kernel_event_shape[-2], kernel_event_shape[-1] ], ], axis=0)) # F + [bc, c'] # Step 2: Double flatten x. # x.shape = N + D + B + [c] x_shape = prefer_static.shape(x) [ x_sample_shape, x_rank_shape, x_batch_shape, x_channel_shape, ] = prefer_static.split( x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1]) x = tf.reshape( x, # N + D + B + [c] shape=prefer_static.concat([ [prefer_static.reduce_prod(x_sample_shape)], x_rank_shape, [ prefer_static.reduce_prod(x_batch_shape) * prefer_static.reduce_prod(x_channel_shape) ], ], axis=0)) # [n] + D + [bc] # Step 3: Apply convolution. y = tf.nn.depthwise_conv2d(x, kernel, strides=strides, padding=padding, data_format='NHWC', dilations=dilations) # SAME: y.shape = [n, h, w, bcc'] # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc'] # Step 4: Reshape/reduce for output. y_shape = prefer_static.shape(y) y = tf.reshape(y, shape=prefer_static.concat( [ x_sample_shape, y_shape[1:-1], kernel_batch_shape, kernel_event_shape[-2:], ], axis=0)) # N + D' + B + [c, c'] y = tf.reduce_sum(y, axis=-2) # N + D' + B + [c'] return y
def _inv_z(self, z): """Reconstruct input `x` from a its normalized version.""" with tf.name_scope("reconstruct"): return z * self.scale + self.loc
def _interp_regular_1d_grid_impl(x, x_ref_min, x_ref_max, y_ref, axis=-1, batch_y_ref=False, fill_value='constant_extension', fill_value_below=None, fill_value_above=None, grid_regularizing_transform=None, name=None): """1-D interpolation that works with/without batching.""" # Note: we do *not* make the no-batch version a special case of the batch # version, because that would an inefficient use of batch_gather with # unnecessarily broadcast args. with tf.name_scope(name or 'interp_regular_1d_grid_impl'): # Arg checking. allowed_fv_st = ('constant_extension', 'extrapolate') for fv in (fill_value, fill_value_below, fill_value_above): if isinstance(fv, str) and fv not in allowed_fv_st: raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fv, allowed_fv_st)) # Separate value fills for below/above incurs extra cost, so keep track of # whether this is needed. need_separate_fills = ( fill_value_above is not None or fill_value_below is not None or fill_value == 'extrapolate' # always requries separate below/above ) if need_separate_fills and fill_value_above is None: fill_value_above = fill_value if need_separate_fills and fill_value_below is None: fill_value_below = fill_value dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, name='x', dtype=dtype) x_ref_min = tf.convert_to_tensor(x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(x_ref_max, name='x_ref_max', dtype=dtype) if not batch_y_ref: _assert_ndims_statically(x_ref_min, expect_ndims=0) _assert_ndims_statically(x_ref_max, expect_ndims=0) y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) if batch_y_ref: # If we're batching, # x.shape ~ [A1,...,AN, D], x_ref_min/max.shape ~ [A1,...,AN] # So to add together we'll append a singleton. # If not batching, x_ref_min/max are scalar, so this isn't an issue, # moreover, if not batching, x can be scalar, and expanding x_ref_min/max # would cause a bad expansion of x when added to x (confused yet?). x_ref_min = x_ref_min[..., tf.newaxis] x_ref_max = x_ref_max[..., tf.newaxis] axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32) axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref)) _assert_ndims_statically(axis, expect_ndims=0) ny = tf.cast(tf.shape(y_ref)[axis], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. if grid_regularizing_transform is None: g = lambda x: x else: g = grid_regularizing_transform fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min))) x_idx_unclipped = fractional_idx * (ny - 1) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) zero = tf.zeros((), dtype=dtype) x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped) x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we 'jitter' one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) if batch_y_ref: # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN], # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D] # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN] y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32, axis) y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32, axis) else: # Here, y_ref_below.shape = # y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:] y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis) y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis) # Use t to get a convex combination of the below/above values. t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. # Make functions appropriate for batch/no-batch. if batch_y_ref: # In the non-batch case, the output shape is going to be # y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_batch_interpolation( y_ref, axis) else: # In the batch case, the output shape is going to be # Broadcast(y_ref.shape[:axis], x.shape[:-1]) + # x.shape[-1:] + y_ref.shape[axis+1:] expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation( y_ref, axis) t = expand_x_fn(t) nan_idx = expand_x_fn(nan_idx, broadcast=True) x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True) y = t * y_ref_above + (1 - t) * y_ref_below # Now begins a long excursion to fill values outside [x_min, x_max]. # Re-insert NaN wherever x was NaN. y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y) if not need_separate_fills: if fill_value == 'constant_extension': pass # Already handled by clipping x_idx_unclipped. else: y = tf.where( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), fill_value, y) else: # Fill values below x_ref_min <==> x_idx_unclipped < 0. if fill_value_below == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_below == 'extrapolate': if batch_y_ref: # For every batch member, gather the first two elements of y across # `axis`. y_0 = tf.gather(y_ref, [0], axis=axis) y_1 = tf.gather(y_ref, [1], axis=axis) else: # If not batching, we want to gather the first two elements, just like # above. However, these results need to be replicated for every # member of x. An easy way to do that is to gather using # indices = zeros/ones(x.shape). y_0 = tf.gather(y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis) y_1 = tf.gather(y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_min) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y) else: y = tf.where(x_idx_unclipped < 0, fill_value_below, y) # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1. if fill_value_above == 'constant_extension': pass # Already handled by the clipping that created x_idx_unclipped. elif fill_value_above == 'extrapolate': ny_int32 = tf.shape(y_ref)[axis] if batch_y_ref: y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1], axis=axis) y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2], axis=axis) else: y_n1 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis) y_n2 = tf.gather(y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis) x_delta = (x_ref_max - x_ref_min) / (ny - 1) x_factor = expand_x_fn((x - x_ref_max) / x_delta, broadcast=True) y = tf.where(x_idx_unclipped > ny - 1, y_n1 + x_factor * (y_n1 - y_n2), y) else: y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y) return y
def __init__(self, validate_args=False, name="tanh"): with tf.name_scope(name) as name: super(Tanh, self).__init__(forward_min_event_ndims=0, validate_args=validate_args, name=name)
def batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis, fill_value='constant_extension', name=None): """Multi-linear interpolation on a regular (constant spacing) grid. Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new `x` values. The interpolant is built from reference values indexed by `nd` dimensions of `y_ref`, starting at `axis`. For example, take the case of a `2-D` scalar valued function and no leading batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]` is the reference value corresponding to grid point ``` [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1), x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)] ``` In the general case, dimensions to the left of `axis` in `y_ref` are broadcast with leading dimensions in `x`, `x_ref_min`, `x_ref_max`. Args: x: Numeric `Tensor` The x-coordinates of the interpolated output values for each batch. Shape `[..., D, nd]`, designating [a batch of] `D` coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim. x_ref_min: `Tensor` of same `dtype` as `x`. The minimum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. x_ref_max: `Tensor` of same `dtype` as `x`. The maximum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`. y_ref: `Tensor` of same `dtype` as `x`. The reference output values. Shape `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued function (for `M >= 0`). axis: Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref` index the interpolation table. E.g. `3-D` interpolation of a scalar valued function requires `axis=-3` and a `3-D` matrix valued function requires `axis=-5`. fill_value: Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or 'constant_extension' ==> Extend as constant function. Default value: `'constant_extension'` name: A name to prepend to created ops. Default value: `'batch_interp_regular_nd_grid'`. Returns: y_interp: Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].` Raises: ValueError: If `rank(x) < 2` is determined statically. ValueError: If `axis` is not a scalar is determined statically. ValueError: If `axis + nd > rank(y_ref)` is determined statically. #### Examples Interpolate a function of one variable. ```python y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20)) tfp.math.batch_interp_regular_nd_grid( # x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`. x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref, axis=0) ==> approx [exp(6.0), exp(0.5), exp(3.3)] ``` Interpolate a scalar function of two variables. ```python x_ref_min = [0., 0.] x_ref_max = [2 * np.pi, 2 * np.pi] # Build y_ref. x0s, x1s = tf.meshgrid( tf.linspace(x_ref_min[0], x_ref_max[0], num=100), tf.linspace(x_ref_min[1], x_ref_max[1], num=100), indexing='ij') def func(x0, x1): return tf.sin(x0) * tf.cos(x1) y_ref = func(x0s, x1s) x = np.pi * tf.random_uniform(shape=(10, 2)) tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2) ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1]) ``` """ with tf.name_scope(name or 'interp_regular_nd_grid'): dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref], dtype_hint=tf.float32) # Arg checking. if isinstance(fill_value, str): if fill_value != 'constant_extension': raise ValueError( 'A fill value ({}) was not an allowed string ({})'.format( fill_value, 'constant_extension')) else: fill_value = tf.convert_to_tensor(fill_value, name='fill_value', dtype=dtype) _assert_ndims_statically(fill_value, expect_ndims=0) # x.shape = [..., nd]. x = tf.convert_to_tensor(x, name='x', dtype=dtype) _assert_ndims_statically(x, expect_ndims_at_least=2) # y_ref.shape = [..., C1,...,Cnd, B1,...,BM] y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype) # x_ref_min.shape = [nd] x_ref_min = tf.convert_to_tensor(x_ref_min, name='x_ref_min', dtype=dtype) x_ref_max = tf.convert_to_tensor(x_ref_max, name='x_ref_max', dtype=dtype) _assert_ndims_statically(x_ref_min, expect_ndims_at_least=1, expect_static=True) _assert_ndims_statically(x_ref_max, expect_ndims_at_least=1, expect_static=True) # nd is the number of dimensions indexing the interpolation table, it's the # 'nd' in the function name. nd = tf.compat.dimension_value(x_ref_min.shape[-1]) if nd is None: raise ValueError('`x_ref_min.shape[-1]` must be known statically.') tensorshape_util.assert_is_compatible_with(x_ref_max.shape[-1:], x_ref_min.shape[-1:]) # Convert axis and check it statically. axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis') axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref)) tensorshape_util.assert_has_rank(axis.shape, 0) axis_ = tf.get_static_value(axis) y_ref_rank_ = tf.get_static_value(tf.rank(y_ref)) if axis_ is not None and y_ref_rank_ is not None: if axis_ + nd > y_ref_rank_: raise ValueError( 'Since dims `[axis, axis + nd)` index the interpolation table, we ' 'must have `axis + nd <= rank(y_ref)`. Found: ' '`axis`: {}, rank(y_ref): {}, and inferred `nd` from trailing ' 'dimensions of `x_ref_min` to be {}.'.format( axis_, y_ref_rank_, nd)) x_batch_shape = tf.shape(x)[:-2] x_ref_min_batch_shape = tf.shape(x_ref_min)[:-1] x_ref_max_batch_shape = tf.shape(x_ref_max)[:-1] y_ref_batch_shape = tf.shape(y_ref)[:axis] # Do a brute-force broadcast of batch dims (add zeros). batch_shape = y_ref_batch_shape for tensor in [ x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape ]: batch_shape = tf.broadcast_dynamic_shape(batch_shape, tensor) def _batch_of_zeros_with_rightmost_singletons(n_singletons): """Return Tensor of zeros with some singletons on the rightmost dims.""" ones = tf.ones(shape=[n_singletons], dtype=tf.int32) return tf.zeros(shape=tf.concat([batch_shape, ones], axis=0), dtype=dtype) x += _batch_of_zeros_with_rightmost_singletons(n_singletons=2) x_ref_min += _batch_of_zeros_with_rightmost_singletons(n_singletons=1) x_ref_max += _batch_of_zeros_with_rightmost_singletons(n_singletons=1) y_ref += _batch_of_zeros_with_rightmost_singletons( n_singletons=tf.rank(y_ref) - axis) return _batch_interp_with_gather_nd( x=x, x_ref_min=x_ref_min, x_ref_max=x_ref_max, y_ref=y_ref, nd=nd, fill_value=fill_value, batch_dims=tf.get_static_value(tf.rank(x)) - 2)
def __init__(self, target_log_prob_fn, step_size, max_tree_depth=10, max_energy_diff=1000., unrolled_leapfrog_steps=1, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e. the number of nodes in a binary tree `max_tree_depth` nodes deep. The default setting of 10 takes up to 1024 leapfrog steps. max_energy_diff: Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'nuts_kernel'). """ with tf.name_scope(name or 'NoUTurnSampler') as name: # Process `max_tree_depth` argument. max_tree_depth = tf.get_static_value(max_tree_depth) if max_tree_depth is None or max_tree_depth < 1: raise ValueError( 'max_tree_depth must be known statically and >= 1 but was ' '{}'.format(max_tree_depth)) self._max_tree_depth = max_tree_depth # Compute parameters derived from `max_tree_depth`. instruction_array = build_tree_uturn_instruction(max_tree_depth, init_memory=-1) [write_instruction_numpy, read_instruction_numpy ] = generate_efficient_write_read_instruction(instruction_array) # TensorArray version of the read/write instruction need to be created # within the function call to be compatible with XLA. Here we store the # numpy version of the instruction and convert it to TensorArray later. self._write_instruction = write_instruction_numpy self._read_instruction = read_instruction_numpy # Process all other arguments. self._target_log_prob_fn = target_log_prob_fn if not tf.nest.is_nested(step_size): step_size = [step_size] step_size = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in step_size ] self._step_size = step_size self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, max_tree_depth=max_tree_depth, max_energy_diff=max_energy_diff, unrolled_leapfrog_steps=unrolled_leapfrog_steps, seed=seed, name=name, ) self._seed_stream = SeedStream(seed, salt='nuts_one_step') self._unrolled_leapfrog_steps = unrolled_leapfrog_steps self._name = name self._max_energy_diff = max_energy_diff
def __init__(self, dim: int, mean_reversion: types.RealTensor, volatility: Union[types.RealTensor, Callable[..., types.RealTensor]], initial_discount_rate_fn, corr_matrix: types.RealTensor = None, dtype: tf.DType = None, name: str = None): """Initializes the HJM model. Args: dim: A Python scalar which corresponds to the number of factors comprising the model. mean_reversion: A real positive `Tensor` of shape `[dim]`. Corresponds to the mean reversion rate of each factor. volatility: A real positive `Tensor` of the same `dtype` and shape as `mean_reversion` or a callable with the following properties: (a) The callable should accept a scalar `Tensor` `t` and returns a 1-D `Tensor` of shape `[dim]`. The function returns instantaneous volatility `sigma(t)`. When `volatility` is specified is a real `Tensor`, each factor is assumed to have a constant instantaneous volatility. Corresponds to the instantaneous volatility of each factor. initial_discount_rate_fn: A Python callable that accepts expiry time as a real `Tensor` of the same `dtype` as `mean_reversion` and returns a `Tensor` of shape `input_shape`. Corresponds to the zero coupon bond yield at the present time for the input expiry time. corr_matrix: A `Tensor` of shape `[dim, dim]` and the same `dtype` as `mean_reversion`. Corresponds to the correlation matrix `Rho`. dtype: The default dtype to use when converting values to `Tensor`s. Default value: `None` which maps to `tf.float32`. name: Python string. The name to give to the ops created by this class. Default value: `None` which maps to the default name `gaussian_hjm_model`. """ self._name = name or 'gaussian_hjm_model' with tf.name_scope(self._name): self._dtype = dtype or tf.float32 self._dim = dim self._factors = dim def _instant_forward_rate_fn(t): t = tf.convert_to_tensor(t, dtype=self._dtype) def _log_zero_coupon_bond(x): r = tf.convert_to_tensor(initial_discount_rate_fn(x), dtype=self._dtype) return -r * x rate = -gradient.fwd_gradient( _log_zero_coupon_bond, t, use_gradient_tape=True, unconnected_gradients=tf.UnconnectedGradients.ZERO) return rate def _initial_discount_rate_fn(t): return tf.convert_to_tensor(initial_discount_rate_fn(t), dtype=self._dtype) self._instant_forward_rate_fn = _instant_forward_rate_fn self._initial_discount_rate_fn = _initial_discount_rate_fn self._mean_reversion = tf.convert_to_tensor(mean_reversion, dtype=dtype, name='mean_reversion') self._batch_shape = [] self._batch_rank = 0 # Setup volatility if callable(volatility): self._volatility = volatility else: volatility = tf.convert_to_tensor(volatility, dtype=dtype) jump_locations = [[]] * dim volatility = tf.expand_dims(volatility, axis=-1) self._volatility = piecewise.PiecewiseConstantFunc( jump_locations=jump_locations, values=volatility, dtype=dtype) if corr_matrix is None: corr_matrix = tf.eye(dim, dim, dtype=self._dtype) self._rho = tf.convert_to_tensor(corr_matrix, dtype=dtype, name='rho') self._sqrt_rho = tf.linalg.cholesky(self._rho) # Volatility function def _vol_fn(t, state): """Volatility function of Gaussian-HJM.""" del state volatility = self._volatility(tf.expand_dims( t, -1)) # shape=(dim, 1) return self._sqrt_rho * volatility # Drift function def _drift_fn(t, state): """Drift function of Gaussian-HJM.""" x = state # shape = [self._factors, self._factors] y = self.state_y(tf.expand_dims(t, axis=-1))[..., 0] drift = tf.math.reduce_sum(y, axis=-1) - self._mean_reversion * x return drift self._exact_discretization_setup(dim) super(quasi_gaussian_hjm.QuasiGaussianHJM, self).__init__(dim, _drift_fn, _vol_fn, self._dtype, self._name)
def loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) direction = tf.cast(tf.random.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[1], v[0]), initial_step_state) directions_expanded = [ _rightmost_expand_to_rank(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight if MULTINOMIAL_SAMPLE: weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), r, l), tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), l, r), ], axis=0) for l, r in zip(tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=len(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, energy_diff_sum=(energy_diff_tree_sum + initial_step_metastate.energy_diff_sum), continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def sample_paths(self, times: types.RealTensor, num_samples: types.IntTensor, time_step: types.RealTensor = None, num_time_steps: types.IntTensor = None, random_type: random.RandomType = None, seed: types.IntTensor = None, skip: types.IntTensor = 0, name: str = None) -> types.RealTensor: """Returns a sample of short rate paths from the HJM process. Uses Euler sampling for simulating the short rate paths. Args: times: A real positive `Tensor` of shape `(num_times,)`. The times at which the path points are to be evaluated. num_samples: Positive scalar `int32` `Tensor`. The number of paths to draw. time_step: Scalar real `Tensor`. Maximal distance between time grid points in Euler scheme. Used only when Euler scheme is applied. Default value: `None`. num_time_steps: An optional Scalar integer `Tensor` - a total number of time steps performed by the algorithm. The maximal distance between points in grid is bounded by `times[-1] / (num_time_steps - times.shape[0])`. Either this or `time_step` should be supplied. Default value: `None`. random_type: Enum value of `RandomType`. The type of (quasi)-random number generator to use to generate the paths. Default value: `None` which maps to the standard pseudo-random numbers. seed: Seed for the random number generator. The seed is only relevant if `random_type` is one of `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC, STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and `HALTON_RANDOMIZED` the seed should be an Python integer. For `STATELESS` and `STATELESS_ANTITHETIC `must be supplied as an integer `Tensor` of shape `[2]`. Default value: `None` which means no seed is set. skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or Halton sequence to skip. Used only when `random_type` is 'SOBOL', 'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored. Default value: `0`. name: Python string. The name to give this op. Default value: `sample_paths`. Returns: A tuple containing four elements. * The first element is a `Tensor` of shape `[num_samples, num_times]` containing the simulated short rate paths. * The second element is a `Tensor` of shape `[num_samples, num_times]` containing the simulated discount factor paths. * The third element is a `Tensor` of shape `[num_samples, num_times, dim]` conating the simulated values of the state variable `x` * The fourth element is a `Tensor` of shape `[num_samples, num_times, dim^2]` conating the simulated values of the state variable `y`. Raises: ValueError: (a) If `times` has rank different from `1`. (b) If Euler scheme is used by times is not supplied. """ name = name or self._name + '_sample_path' with tf.name_scope(name): times = tf.convert_to_tensor(times, self._dtype) if times.shape.rank != 1: raise ValueError('`times` should be a rank 1 Tensor. ' 'Rank is {} instead.'.format( times.shape.rank)) return self._sample_paths(times, time_step, num_time_steps, num_samples, random_type, skip, seed)
def _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) momentum_cumsum = [ p0 + p1 for p0, p1 in zip(momentum_cumsum_previous, next_momentum_parts) ] # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) write_instruction = current_step_meta_info.write_instruction read_instruction = current_step_meta_info.read_instruction init_energy = current_step_meta_info.init_energy # Save state and momentum at odd step, check U turn at even step. # Note that here we also write to a Placeholder at even step write_index = tf.where(tf.equal(iter_ % 2, 0), write_instruction.gather([iter_ // 2]), self.max_tree_depth) if GENERALIZED_UTURN: state_to_write = momentum_cumsum else: state_to_write = next_state_parts momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.state_swap, state_to_write) ]) batch_shape = prefer_static.shape(next_target) has_not_u_turn_at_even_step = tf.ones(batch_shape, dtype=tf.bool) read_index = read_instruction.gather([iter_ // 2])[0] no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda read_index, directions, momentum_state_memory, next_momentum_parts, state_to_write, has_not_u_turn_at_even_step, log_prob_rank=prefer_static.rank(next_target))) energy = compute_hamiltonian(next_target, next_momentum_parts) current_energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = current_energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: log_slice_sample = current_step_meta_info.log_slice_sample not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where( _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(next_target)), next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(next_target)), current_energy, init_energy), weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, tf.ones(batch_shape, dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, energy_diff_sum, momentum_cumsum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def __init__(self, df, kernel, index_points, mean_fn=None, jitter=1e-6, validate_args=False, allow_nan_stats=False, name='StudentTProcess'): """Instantiate a StudentTProcess Distribution. Args: df: Positive Floating-point `Tensor` representing the degrees of freedom. Must be greater than 2. kernel: `PositiveSemidefiniteKernel`-like instance representing the TP's covariance function. index_points: `float` `Tensor` representing finite (batch of) vector(s) of points in the index set over which the TP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to a `e`-dimensional multivariate Student's T. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. mean_fn: Python `callable` that acts on `index_points` to produce a (batch of) vector(s) of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies constant zero function. jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: "StudentTProcess". Raises: ValueError: if `mean_fn` is not `None` and is not callable. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [df, index_points, jitter], tf.float32) df = tf.convert_to_tensor(df, dtype=dtype, name='df') index_points = tf.convert_to_tensor( index_points, dtype=dtype, name='index_points') jitter = tf.convert_to_tensor(jitter, dtype=dtype, name='jitter') with tf.control_dependencies([ assert_util.assert_greater( df, tf.cast(2., df.dtype), message='`df` must be greater than 2.') ] if validate_args else []): self._df = tf.identity(df) self._kernel = kernel self._index_points = index_points # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._mean_fn = mean_fn self._jitter = jitter with tf.name_scope('init'): kernel_matrix = _add_diagonal_shift( kernel.matrix(self.index_points, self.index_points), jitter) self._covariance_matrix = kernel_matrix scale = tf.linalg.LinearOperatorLowerTriangular( tf.linalg.cholesky( ((self.df - 2) / self.df)[..., tf.newaxis, tf.newaxis] * kernel_matrix), is_non_singular=True, name='StudentTProcessScaleLinearOperator') super(StudentTProcess, self).__init__( df=df, loc=mean_fn(index_points), scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters self._graph_parents = [index_points, jitter]