def filter_max_length(x, y, max_length=max_len): """ filter by new max length """ return tf.logical_and( tf.size(x) <= max_length, tf.size(y) <= max_length)
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() x_ndims = tf.rank(x_sqrt) num_singleton_axes_to_prepend = ( tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = tf.concat([ tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32), tf.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = tf.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - tf.size(batch_shape) - 2 sample_shape = tf.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = tf.concat([tf.range(sample_ndims, ndims), tf.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = tf.concat( [x_with_prepended_singletons_shape[sample_ndims:-2], [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve(scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat( [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = tf.concat([ tf.range(ndims - sample_ndims, ndims), tf.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum( tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum( tf.math.log(tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalization()) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank(x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob
def retry_init(proposal_fn, target_fn, *args, max_trials=50, seed=None, name=None, **kwargs): """Tries an MCMC initialization proposal until it gets a valid state. In this case, "valid" is defined as the value of `target_fn` is finite. This corresponds to an MCMC workflow where `target_fn` compute the log-probability one wants to sample from, in which case "finite `target_fn`" means "finite and positive probability state". If `target_fn` returns a Tensor of size greater than 1, the results are assumed to be independent of each other, so that different batch members can be accepted individually. The method is bounded rejection sampling. The bound serves to avoid wasting computation on hopeless initialization procedures. In interactive MCMC, one would presumably rather come up with a better initialization proposal than wait for an unbounded number of attempts with a bad one. If unbounded re-trials are desired, set `max_trials` to `None`. Note: XLA and @jax.jit do not support assertions, so this function can return invalid states on those platforms without raising an error (unless `max_trials` is set to `None`). Args: proposal_fn: A function accepting a `seed` keyword argument and no other required arguments which generates proposed initial states. target_fn: A function accepting the return value of `proposal_fn` and returning a floating-point Tensor. *args: Additional arguments passed to `proposal_fn`. max_trials: Size-1 integer `Tensor` or None. Maximum number of calls to `proposal_fn` to attempt. If acceptable states are not found in this many trials, `retry_init` signals an error. If `None`, there is no limit, and `retry_init` skips the control flow cost of checking for success. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_sample_chain'). **kwargs: Additional keyword arguments passed to `proposal_fn`. Returns: states: An acceptable result from `proposal_fn`. #### Example One popular MCMC initialization scheme is to start the chains near 0 in unconstrained space. There are models where the unconstraining transformation cannot exactly capture the space of valid states, such that this initialization has some material but not overwhelming chance of failure. In this case, we can use `retry_init` to compensate. ```python @tfp.distributions.JointDistributionCoroutine def model(): ... raw_init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model) init_states = tfp.experimental.mcmc.retry_init( proposal_fn=raw_init_dist.sample, target_fn=model.log_prob, sample_shape=[100], seed=[4, 8]) states = tfp.mcmc.sample_chain( current_state=init_states, ...) ``` """ def trial(seed): values = proposal_fn(*args, seed=seed, **kwargs) log_probs = target_fn(values) success = tf.math.is_finite(log_probs) return values, success with tf.name_scope(name or 'mcmc_retry_init'): values, successes, _ = brs.batched_las_vegas_algorithm( trial, max_trials=max_trials, seed=seed) if max_trials is None: # We were authorized to compute until success, so no need to # check for failure return values else: num_states = tf.size(successes) num_successes = tf.reduce_sum(tf.cast(successes, tf.int32)) msg = ( 'Failed to find acceptable initial states after {} trials;\n' '{} of {} states have non-finite log probability').format( max_trials, num_states - num_successes, num_states) with tf.control_dependencies([ tf.debugging.assert_equal(successes, tf.ones_like(successes), message=msg) ]): return tf.nest.map_structure(tf.identity, values)
def _inverse_event_shape_tensor(self, output_shape): perm = self._make_perm(tf.size(output_shape), tf.argsort(self.perm)) return tf.gather(output_shape, perm)
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init and not dtype_util.is_integer( self.mixture_distribution.dtype): raise ValueError( '`mixture_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.mixture_distribution.dtype))) if tensorshape_util.rank( self.mixture_distribution.event_shape) is not None: if tensorshape_util.rank( self.mixture_distribution.event_shape) != 0: raise ValueError( '`mixture_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( tf.size(self.mixture_distribution.event_shape_tensor()), 0, message= '`mixture_distribution` must have scalar `event_dim`s'), ] # pylint: disable=protected-access mixture_dist_param = (self.mixture_distribution._probs if self.mixture_distribution._logits is None else self.mixture_distribution._logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[-1]) component_bst = None if km is not None and kc is not None: if km != kc: raise ValueError( '`mixture_distribution` components ({}) does not ' 'equal `components_distribution.batch_shape[-1]` ' '({})'.format(km, kc)) elif self.validate_args: if km is None: mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) km = tf.shape(mixture_dist_param)[-1] if kc is None: component_bst = self.components_distribution.batch_shape_tensor( ) kc = component_bst[-1] assertions += [ assert_util.assert_equal( km, kc, message=( '`mixture_distribution` components does not equal ' '`components_distribution.batch_shape[-1]`')), ] mdbs = self.mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( self.components_distribution.batch_shape, 1)[:-1] if (tensorshape_util.is_fully_defined(mdbs) and tensorshape_util.is_fully_defined(cdbs)): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( '`mixture_distribution.batch_shape` (`{}`) is not ' 'compatible with `components_distribution.batch_shape` ' '(`{}`)'.format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif self.validate_args: if not tensorshape_util.is_fully_defined(mdbs): mixture_dist_param = tf.convert_to_tensor(mixture_dist_param) mdbs = tf.shape(mixture_dist_param)[:-1] if not tensorshape_util.is_fully_defined(cdbs): if component_bst is None: component_bst = self.components_distribution.batch_shape_tensor( ) cdbs = component_bst[:-1] assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs), cdbs, message=( '`mixture_distribution.batch_shape` is not ' 'compatible with `components_distribution.batch_shape`' )) ] return assertions
def _bates_cdf(total_count, low, high, dtype, value): """Compute the Bates cdf. Internally, the (standard, unnormalized) cdf is computed by the formula ```none pdf = sum_{k=0}^j (-1)^k (n choose k) (nx - k)^n ``` where * `n = total_count`, * `x = value` the value to compute the cumulative probability of, and * `j = floor(nx)`. This is shifted to `[low, high]` and normalized. Since the pdf is symmetric, we have `cdf(x) = 1 - cdf(1 - x)` for `x > .5`, hence we only compute the left half, which keeps the number of terms lower. Computation is batched, using `tf.math.segment_sum()`. For this reason this is not compatible with `tf.vectorized_map()`. All input parameters should have compatible dtypes and shapes. Args: total_count: `Tensor` with integer values, as given to the `Bates` constructor. low: Float `Tensor`, as given to the `Bates` constructor. high: Float `Tensor`, as given to the `Bates` constructor. dtype: The dtype of the output. value: Float `Tensor`. Input value to `cdf()`. Returns: cdf: Float `Tensor`. See above formula. """ total_count = tf.cast(total_count, dtype) low = tf.convert_to_tensor(low) high = tf.convert_to_tensor(high) # Warn the user if they try to compute a pdf with high `total_count`. This # warning is here instead of `_parameter_control_dependencies()` because # nested calls to `_name_and_control_scope` (e.g. `log_survival_function`) can # result in multiple warnings being added and multiple tensor # conversions. Also `sample()` does not have the same numerical issues. with tf.control_dependencies([_stability_limit_tensor(total_count, dtype)]): # Center and adjust `value` using limits and symmetry. value_centered = (value - low) / (high - low) value_adj = tf.clip_by_value(value_centered, 0., 1.) value_adj = tf.where(value_adj < .5, value_adj, 1. - value_adj) value_adj = tf.where(tf.math.is_finite(value_adj), value_adj, 0.) # Flatten to make segments; need to broadcast before flattening. shape = ps.broadcast_shape(ps.shape(value_adj), ps.shape(total_count)) total_count_b = ps.broadcast_to(total_count, shape) total_count_x_value_adj_b = total_count * value_adj total_count_f = tf.reshape(total_count_b, [-1]) total_count_x_value_adj_f = tf.reshape(total_count_x_value_adj_b, [-1]) # Create segmented terms of summation. num_terms_f = tf.cast(tf.math.floor(total_count_x_value_adj_f + 1), dtype=tf.int32) term_idx_s = tf.cast(_segmented_range(num_terms_f), dtype) # aka `k` total_count_s = tf.repeat(total_count_f, num_terms_f) total_count_x_value_adj_s = tf.repeat(total_count_x_value_adj_f, num_terms_f) terms = (tf.cast(-1., dtype) ** term_idx_s * (1. / ((total_count_s + 1.) * tf.math.exp( tfp_math.lbeta(total_count_s - term_idx_s + 1., term_idx_s + 1.)))) * (total_count_x_value_adj_s - term_idx_s) ** total_count_s) # Segment sum. segment_ids = tf.repeat(tf.range(tf.size(num_terms_f)), num_terms_f) cdf_s = tf.math.segment_sum(terms, segment_ids) # Reshape back. cdf = tf.reshape(cdf_s, shape) # Normalize. cdf = cdf / tf.math.exp(tf.math.lgamma(total_count_b + tf.cast(1., dtype))) # cdf symmetry adjustment: cdf(x) = 1 - cdf(1 - x) for x > 0.5 cdf = tf.where(value_centered > .5, 1. - cdf, cdf) # Fix out-of-support queries. cdf = tf.where(value_centered < 0., tf.cast(0., dtype), cdf) cdf = tf.where(value_centered > 1., tf.cast(1., dtype), cdf) cdf = tf.where(tf.math.is_finite(value_centered), cdf, np.nan) return cdf
def _might_have_nonzero_size(sample_shape): static_size = tf.get_static_value(tf.size(sample_shape)) return (static_size is None) or static_size >= 1
def _observation_log_probs(self, observations, mask): """Compute and shape tensor of log probs associated with observations..""" # Let E be the underlying event shape # M the number of steps in the HMM # N the number of states of the HMM # # Then the incoming observations have shape # # observations : batch_o [M] E # # and the mask (if present) has shape # # mask : batch_m [M] # # Let this HMM distribution have batch shape batch_d # We need to broadcast all three of these batch shapes together # into the shape batch. # # We need to move the step dimension to the first dimension to make # them suitable for folding or scanning over. # # When we call `log_prob` for our observations we need to # do this for each state the observation could correspond to. # We do this by expanding the dimensions by 1 so we end up with: # # observations : [M] batch [1] [E] # # After calling `log_prob` we get # # observation_log_probs : [M] batch [N] # # We wish to use `mask` to select from this so we also # reshape and broadcast it up to shape # # mask : [M] batch [N] observation_distribution = self.observation_distribution underlying_event_rank = tf.size( observation_distribution.event_shape_tensor()) observation_tensor_shape = tf.shape(observations) observation_batch_shape = observation_tensor_shape[:-1 - underlying_event_rank] observation_event_shape = observation_tensor_shape[ -1 - underlying_event_rank:] if mask is not None: mask_tensor_shape = tf.shape(mask) mask_batch_shape = mask_tensor_shape[:-1] batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) if mask is not None: batch_shape = tf.broadcast_dynamic_shape(batch_shape, mask_batch_shape) observations = tf.broadcast_to( observations, tf.concat([batch_shape, observation_event_shape], axis=0)) observation_rank = tf.rank(observations) observations = distribution_util.move_dimension( observations, observation_rank - underlying_event_rank - 1, 0) observations = tf.expand_dims(observations, observation_rank - underlying_event_rank) observation_log_probs = observation_distribution.log_prob(observations) if mask is not None: mask = tf.broadcast_to( mask, tf.concat([batch_shape, [self._num_steps]], axis=0)) mask = distribution_util.move_dimension(mask, -1, 0) observation_log_probs = tf.where( mask[..., tf.newaxis], tf.zeros_like(observation_log_probs), observation_log_probs) return observation_log_probs
def posterior_marginals(self, observations, mask=None, name='posterior_marginals'): """Compute marginal posterior distribution for each state. This function computes, for each time step, the marginal conditional probability that the hidden Markov model was in each possible state given the observations that were made at each time step. So if the hidden states are `z[0],...,z[num_steps - 1]` and the observations are `x[0], ..., x[num_steps - 1]`, then this function computes `P(z[i] | x[0], ..., x[num_steps - 1])` for all `i` from `0` to `num_steps - 1`. This operation is sometimes called smoothing. It uses a form of the forward-backward algorithm. Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimension of this tensor gives the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_marginal: A `Categorical` distribution object representing the marginal probability of the hidden Markov model being in each state at each step. The rightmost dimension of the `Categorical` distributions batch will equal the `num_steps` parameter providing one marginal distribution for each step. The other dimensions are the dimensions corresponding to the batch of observations. Raises: ValueError: if rightmost dimension of `observations` does not have size `num_steps`. """ with self._name_and_control_scope(name): observation_tensor_shape = tf.shape(observations) observation_distribution = self.observation_distribution underlying_event_rank = tf.size( observation_distribution.event_shape_tensor()) mask_tensor_shape = tf.shape(mask) if mask is not None else None num_states = self.transition_distribution.batch_shape_tensor()[-1] with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape, underlying_event_rank): observation_log_probs = self._observation_log_probs( observations, mask) log_init = _extract_log_probs(num_states, self.initial_distribution) log_prob = log_init + observation_log_probs[0] log_transition = _extract_log_probs( num_states, self.transition_distribution) log_adjoint_prob = tf.zeros_like(log_prob) def _scan_multiple_steps_forwards(): def forward_step(log_previous_step, log_prob_observation): return _log_vector_matrix( log_previous_step, log_transition) + log_prob_observation forward_log_probs = tf.scan(forward_step, observation_log_probs[1:], initializer=log_prob, name='forward_log_probs') return tf.concat([[log_prob], forward_log_probs], axis=0) forward_log_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_forwards, lambda: tf.convert_to_tensor([log_prob])) total_log_prob = tf.reduce_logsumexp(forward_log_probs[-1], axis=-1) def _scan_multiple_steps_backwards(): """Perform `scan` operation when `num_steps` > 1.""" def backward_step(log_previous_step, log_prob_observation): return _log_matrix_vector( log_transition, log_prob_observation + log_previous_step) backward_log_adjoint_probs = tf.scan( backward_step, observation_log_probs[1:], initializer=log_adjoint_prob, reverse=True, name='backward_log_adjoint_probs') return tf.concat( [backward_log_adjoint_probs, [log_adjoint_prob]], axis=0) backward_log_adjoint_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps_backwards, lambda: tf.convert_to_tensor([log_adjoint_prob])) log_likelihoods = forward_log_probs + backward_log_adjoint_probs marginal_log_probs = distribution_util.move_dimension( log_likelihoods - total_log_prob[..., tf.newaxis], 0, -2) return categorical.Categorical(logits=marginal_log_probs)
def filter_max_length(x, y, max_length=max_len): return tf.logical_and(tf.size(x) <= max_length, tf.size(y) <= max_length)
def _sample_n(self, n, seed=None): strm = SeedStream(seed, salt='HiddenMarkovModel') transition_batch_shape = self.transition_distribution.batch_shape_tensor( ) num_states = transition_batch_shape[-1] batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(batch_shape) // tf.reduce_prod(self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = (tf.reduce_prod(batch_shape) // tf.reduce_prod(transition_batch_shape[:-1])) init_shape = init_state.shape def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states result = tf.reduce_sum(old_states_one_hot * new_states, axis=-1) # We know that `generate_step` must preserve the shape of the # tensor of states of each state. This is because # the transition matrix must be square. But TensorFlow might # not know this so we explicitly tell it that the result has the # same shape. result.set_shape(init_shape) return result def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond(self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot(hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape_tensor() # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat( [[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def segment_diff(x, segment_ids, order=1, exclusive=False, dtype=None, name=None): """Computes difference of successive elements in a segment. For a complete description of segment_* ops see documentation of `tf.segment_max`. This op extends the `diff` functionality to segmented inputs. The behaviour of this op is the same as that of the op `diff` within each segment. The result is effectively a concatenation of the results of `diff` applied to each segment. #### Example ```python x = tf.constant([2, 5, 1, 7, 9] + [32, 10, 12, 3] + [4, 8, 5]) segments = tf.constant([0, 0, 0, 0, 0] + [1, 1, 1, 1] + [2, 2, 2]) # First order diff. Expected result: [3, -4, 6, 2, -22, 2, -9, 4, -3] dx1 = segment_diff( x, segment_ids=segments, order=1, exclusive=True) # Non-exclusive, second order diff. # Expected result: [2, 5, -1, 2, 8, 32, 10, -20, -7, 4, 8, 1] dx2 = segment_diff( x, segment_ids=segments, order=2, exclusive=False) ``` Args: x: A rank 1 `Tensor` of any dtype for which arithmetic operations are permitted. segment_ids: A `Tensor`. Must be one of the following types: int32, int64. A 1-D tensor whose size is equal to the size of `x`. Values should be sorted and can be repeated. order: Positive Python int. The order of the difference to compute. `order = 1` corresponds to the difference between successive elements. Default value: 1 exclusive: Python bool. See description above. Default value: False dtype: Optional `tf.Dtype`. If supplied, the dtype for `x` to use when converting to `Tensor`. Default value: None which maps to the default dtype inferred by TF. name: Python `str` name prefixed to Ops created by this class. Default value: None which is mapped to the default name 'segment_diff'. Returns: diffs: A `Tensor` of the same dtype as `x`. Assuming that each segment is of length greater than or equal to order, if `exclusive` is True, then the size is `n-order*k` where `n` is the size of x, `k` is the number of different segment ids supplied if `segment_ids` is not None or 1 if `segment_ids` is None. If any of the segments is of length less than the order, then the size is: `n-sum(min(order, length(segment_j)), j)` where the sum is over segments. If `exclusive` is False, then the size is `n`. """ with tf.compat.v1.name_scope(name, default_name='segment_diff', values=[x]): x = tf.convert_to_tensor(x, dtype=dtype) raw_diffs = diff_ops.diff(x, order=order, exclusive=exclusive) if segment_ids is None: return raw_diffs # If segment ids are supplied, raw_diffs are incorrect at locations: # p, p+1, ... min(p+order-1, m_p-1) where p is the index of the first # element of a segment other than the very first segment (which is # already correct). m_p is the segment length. # Find positions where the segments begin. has_segment_changed = tf.concat( [[False], tf.not_equal(segment_ids[1:] - segment_ids[:-1], 0)], axis=0) # Shape [k, 1] segment_start_index = tf.cast(tf.where(has_segment_changed), dtype=tf.int32) segment_end_index = tf.concat( [tf.reshape(segment_start_index, [-1])[1:], [tf.size(segment_ids)]], axis=0) segment_end_index = tf.reshape(segment_end_index, [-1, 1]) # The indices of locations that need to be adjusted. This needs to be # constructed in steps. First we generate p, p+1, ... p+order-1. # Shape [num_segments-1, order] fix_indices = ( segment_start_index + tf.range(order, dtype=segment_start_index.dtype)) in_bounds = tf.where(fix_indices < segment_end_index) # Keep only the ones in bounds. fix_indices = tf.reshape(tf.gather_nd(fix_indices, in_bounds), [-1, 1]) needs_fix = tf.scatter_nd( fix_indices, # Unfortunately, scatter_nd doesn't support bool on GPUs so we need to # do ints here and then convert to bool. tf.reshape(tf.ones_like(fix_indices, dtype=tf.int32), [-1]), shape=tf.shape(x)) # If exclusive is False, then needs_fix means we need to replace the values # in raw_diffs at those locations with the values in x. needs_fix = tf.cast(needs_fix, dtype=tf.bool) if not exclusive: return tf.where(needs_fix, x, raw_diffs) # If exclusive is True, we have to be more careful. The raw_diffs # computation has removed the first 'order' elements. After removing the # corresponding elements from needs_fix, we use it to remove the elements # from raw_diffs. return tf.boolean_mask(raw_diffs, tf.logical_not(needs_fix[order:]))
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in, event_shape_out, validate_args): """Replaces the rightmost dims in a `Tensor` representing a shape. Args: input_shape: a rank-1 `Tensor` of integers event_shape_in: the event shape expected to be present in rightmost dims of `shape_in`. event_shape_out: the event shape with which to replace `event_shape_in` in the rightmost dims of `input_shape`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Returns: output_shape: A rank-1 integer `Tensor` with the same contents as `input_shape` except for the event dims, which are replaced with `event_shape_out`. """ output_tensorshape, is_validated = _replace_event_shape_in_tensorshape( tensorshape_util.constant_value_as_shape(input_shape), event_shape_in, event_shape_out) # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function # correctly supports control_dependencies. validation_dependencies = (map(tf.identity, (event_shape_in, event_shape_out)) if validate_args else ()) if (tensorshape_util.is_fully_defined(output_tensorshape) and (is_validated or not validate_args)): with tf.control_dependencies(validation_dependencies): output_shape = tf.convert_to_tensor( tensorshape_util.as_list(output_tensorshape), name='output_shape', dtype_hint=tf.int32) return output_shape, output_tensorshape with tf.control_dependencies(validation_dependencies): event_shape_in_ndims = ( tf.size(event_shape_in) if tensorshape_util.num_elements(event_shape_in.shape) is None else tensorshape_util.num_elements(event_shape_in.shape)) input_non_event_shape, input_event_shape = tf.split( input_shape, num_or_size_splits=[-1, event_shape_in_ndims]) additional_assertions = [] if is_validated: pass elif validate_args: # Check that `input_event_shape` and `event_shape_in` are compatible in the # sense that they have equal entries in any position that isn't a `-1` in # `event_shape_in`. Note that our validations at construction time ensure # there is at most one such entry in `event_shape_in`. mask = event_shape_in >= 0 explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask) explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask) additional_assertions.append( assert_util.assert_equal( explicit_input_event_shape, explicit_event_shape_in, message='Input `event_shape` does not match `event_shape_in`.') ) # We don't explicitly additionally verify # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split` # already makes this assertion. with tf.control_dependencies(additional_assertions): output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0, name='output_shape') return output_shape, output_tensorshape
def _map_payoff_to_sim_times(indices, payoff, num_samples): """Maps the swaption payoffs to short rate simulation times. Swaption payoffs are calculated on bermudan swaption's expiries. However, for the LSM/PDE algorithms, we need quantities such as short rate simulations and/or swaption payoffs at the union of all exercise times in the batch of swaptions. This function takes the payoff of individual swaption at their respective exercise times and maps it to all simulation times. This is done by setting the payoff to -1 whenever the simulation time is not equal to the swaption exercise time. Args: indices: A `Tensor` of shape `batch_shape + num_exercise_times` containing the index of exercise time in the vector of simulation times. payoff: A real tensor of shape `[num_samples] + batch_shape + num_exercise_times` containing the exercise value of the underlying swap on each exercise time. num_samples: A scalar `Tensor` specifying the number of samples on which swaption payoff is computed. Returns: A tuple of `Tensors`. The first tensor is a integer `Tensor` of shape `[num_samples] + batch_shape + [num_simulation_times]` and contains `1` if the corresponding simulation time is one of the exercise times for the swaption. The second `Tensor` is a real `Tensor` of same shape and contains the exercise value of the swaption if the corresponding simulation time is an exercise time for the swaption or -1 otherwise. """ indices = tf.expand_dims(indices, axis=0) indices = tf.repeat(indices, num_samples, axis=0) index_list = [] tensor_shape = tf.shape(indices) tensor_rank = indices.shape.rank output_shape = tf.concat( [tf.shape(indices)[:-1], [tf.math.reduce_max(indices) + 1]], axis=0) num_elements = tf.size(indices) # Construct `index_list` which contains the indicies at which swaption # payoff would be needed. for dim in range(tensor_rank - 1): idx = tf.range(0, tensor_shape[dim], dtype=indices.dtype) idx = tf.tile( tf.repeat(idx, tf.math.reduce_prod(tensor_shape[dim + 1:])), [tf.math.reduce_prod(tensor_shape[:dim])]) index_list.append(idx) index_list.append(tf.reshape(indices, [-1])) # We need to transform `payoff` from the initial shape of # [num_samples, batch_shape, num_exercise_times] to a new `Tensor` with # shape = [num_samples, batch_shape, num_exercise_times] such that # payoff_new[..., indices] = payoff # We achieve this by first creating a `payoff_new` as a SparseTensor with # nonzero values at appropriate indices based on the payoff_new.shape and # then converting the sparse tensor to dense tensor. sparse_indices = tf.cast(tf.stack(index_list, axis=-1), dtype=tf.int64) is_exercise_time = tf.sparse.to_dense(tf.sparse.SparseTensor( sparse_indices, tf.ones(shape=num_elements, dtype=tf.int64), tf.cast(output_shape, dtype=tf.int64)), validate_indices=False) payoff = tf.sparse.to_dense(tf.sparse.SparseTensor( sparse_indices, tf.reshape(payoff, [-1]), tf.cast(output_shape, dtype=tf.int64)), validate_indices=False) return is_exercise_time, payoff
def _num_elements(losses): """Computes the number of elements in `losses` tensor.""" with backend.name_scope('num_elements') as scope: return tf.cast(tf.size(losses, name=scope), dtype=losses.dtype)
def posterior_mode(self, observations, mask=None, name='posterior_mode'): """Compute maximum likelihood sequence of hidden states. When this function is provided with a sequence of observations `x[0], ..., x[num_steps - 1]`, it returns the sequence of hidden states `z[0], ..., z[num_steps - 1]`, drawn from the underlying Markov chain, that is most likely to yield those observations. It uses the [Viterbi algorithm]( https://en.wikipedia.org/wiki/Viterbi_algorithm). Note: the behavior of this function is undefined if the `observations` argument represents impossible observations from the model. Note: if there isn't a unique most likely sequence then one of the equally most likely sequences is chosen. Args: observations: A tensor representing a batch of observations made on the hidden Markov model. The rightmost dimensions of this tensor correspond to the dimensions of the observation distributions of the underlying Markov chain. The next dimension from the right indexes the steps in a sequence of observations from a single sample from the hidden Markov model. The size of this dimension should match the `num_steps` parameter of the hidden Markov model object. The other dimensions are the dimensions of the batch and these are broadcast with the hidden Markov model's parameters. mask: optional bool-type `tensor` with rightmost dimension matching `num_steps` indicating which observations the result of this function should be conditioned on. When the mask has value `True` the corresponding observations aren't used. if `mask` is `None` then all of the observations are used. the `mask` dimensions left of the last are broadcast with the hmm batch as well as with the observations. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Returns: posterior_mode: A `Tensor` representing the most likely sequence of hidden states. The rightmost dimension of this tensor will equal the `num_steps` parameter providing one hidden state for each step. The other dimensions are those of the batch. Raises: ValueError: if the `observations` tensor does not consist of sequences of `num_steps` observations. #### Examples ```python tfd = tfp.distributions # A simple weather model. # Represent a cold day with 0 and a hot day with 1. # Suppose the first day of a sequence has a 0.8 chance of being cold. initial_distribution = tfd.Categorical(probs=[0.8, 0.2]) # Suppose a cold day has a 30% chance of being followed by a hot day # and a hot day has a 20% chance of being followed by a cold day. transition_distribution = tfd.Categorical(probs=[[0.7, 0.3], [0.2, 0.8]]) # Suppose additionally that on each day the temperature is # normally distributed with mean and standard deviation 0 and 5 on # a cold day and mean and standard deviation 15 and 10 on a hot day. observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.]) # This gives the hidden Markov model: model = tfd.HiddenMarkovModel( initial_distribution=initial_distribution, transition_distribution=transition_distribution, observation_distribution=observation_distribution, num_steps=7) # Suppose we observe gradually rising temperatures over a week: temps = [-2., 0., 2., 4., 6., 8., 10.] # We can now compute the most probable sequence of hidden states: model.posterior_mode(temps) # The result is [0 0 0 0 0 1 1] telling us that the transition # from "cold" to "hot" most likely happened between the # 5th and 6th days. ``` """ with self._name_and_control_scope(name): observations = tf.convert_to_tensor(observations, name='observations') if mask is not None: mask = tf.convert_to_tensor(mask, name='mask', dtype_hint=tf.bool) num_states = self.transition_distribution.batch_shape_tensor()[-1] observation_distribution = self.observation_distribution underlying_event_rank = tf.size( observation_distribution.event_shape_tensor()) observation_tensor_shape = tf.shape(observations) mask_tensor_shape = tf.shape(mask) if mask is not None else None with self._observation_mask_shape_preconditions( observation_tensor_shape, mask_tensor_shape, underlying_event_rank): observation_log_probs = self._observation_log_probs( observations, mask) log_init = _extract_log_probs(num_states, self.initial_distribution) log_trans = _extract_log_probs(num_states, self.transition_distribution) log_prob = log_init + observation_log_probs[0] def _reduce_multiple_steps(): """Perform `reduce_max` operation when `num_steps` > 1.""" def forward_step(previous_step_pair, log_prob_observation): log_prob_previous = previous_step_pair[0] log_prob = (log_prob_previous[..., tf.newaxis] + log_trans + log_prob_observation[..., tf.newaxis, :]) most_likely_given_successor = tf.argmax(log_prob, axis=-2) max_log_p_given_successor = tf.reduce_max(log_prob, axis=-2) return (max_log_p_given_successor, most_likely_given_successor) forward_log_probs, all_most_likely_given_successor = tf.scan( forward_step, observation_log_probs[1:], initializer=(log_prob, tf.zeros(tf.shape(log_prob), dtype=tf.int64)), name='forward_log_probs') most_likely_end = tf.argmax(forward_log_probs[-1], axis=-1) # We require the operation that gives C from A and B where # C[i...j] = A[i...j, B[i...j]] # and A = most_likely_given_successor # B = most_likely_successor. # tf.gather requires indices of known shape so instead we use # reduction with tf.one_hot(B) to pick out elements from B def backward_step(most_likely_successor, most_likely_given_successor): return tf.reduce_sum((most_likely_given_successor * tf.one_hot(most_likely_successor, num_states, dtype=tf.int64)), axis=-1) backward_scan = tf.scan(backward_step, all_most_likely_given_successor, most_likely_end, reverse=True) most_likely_sequences = tf.concat( [backward_scan, [most_likely_end]], axis=0) return distribution_util.move_dimension( most_likely_sequences, 0, -1) return prefer_static.cond( self.num_steps > 1, _reduce_multiple_steps, lambda: tf.argmax(log_prob, axis=-1)[..., tf.newaxis])
def _prefer_static_event_ndims(distribution): if distribution.event_shape.ndims is not None: return distribution.event_shape.ndims else: return tf.size(distribution.event_shape_tensor())
def _parameter_control_dependencies(self, is_init): assertions = [] # Check num_steps is a scalar that's at least 1. if is_init != tensor_util.is_ref(self.num_steps): num_steps = tf.convert_to_tensor(self.num_steps) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: if np.ndim(num_steps_) != 0: raise ValueError( '`num_steps` must be a scalar but it has rank {}'. format(np.ndim(num_steps_))) if num_steps_ < 1: raise ValueError('`num_steps` must be at least 1.') elif self.validate_args: message = '`num_steps` must be a scalar' assertions.append( assert_util.assert_rank_at_most(self.num_steps, 0, message=message)) assertions.append( assert_util.assert_greater_equal( num_steps, 1, message='`num_steps` must be at least 1.')) # Check that the initial distribution has scalar events over the # integers. if is_init and not dtype_util.is_integer( self.initial_distribution.dtype): raise ValueError( '`initial_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.initial_distribution.dtype))) if tensorshape_util.rank( self.initial_distribution.event_shape) is not None: if tensorshape_util.rank( self.initial_distribution.event_shape) != 0: raise ValueError( '`initial_distribution` must have scalar `event_dim`s') elif self.validate_args: assertions += [ assert_util.assert_equal( tf.size(self.initial_distribution.event_shape_tensor()), 0, message= '`initial_distribution` must have scalar `event_dim`s'), ] # Check that the transition distribution is over the integers. if (is_init and not dtype_util.is_integer(self.transition_distribution.dtype)): raise ValueError( '`transition_distribution.dtype` ({}) is not over integers'. format(dtype_util.name(self.transition_distribution.dtype))) # Check observations have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank( self.observation_distribution.batch_shape) == 0: raise ValueError( "`observation_distribution` can't have scalar batches") # Check transitions have non-scalar batches. # The graph version of this assertion is incorporated as # a control dependency of the transition/observation # compatibility test. if tensorshape_util.rank( self.transition_distribution.batch_shape) == 0: raise ValueError( "`transition_distribution` can't have scalar batches") # Check compatibility of transition distribution and observation # distribution. tdbs = self.transition_distribution.batch_shape odbs = self.observation_distribution.batch_shape if (tensorshape_util.dims(tdbs) is not None and tf.compat.dimension_value(odbs[-1]) is not None): if (tf.compat.dimension_value(tdbs[-1]) != tf.compat.dimension_value(odbs[-1])): raise ValueError( '`transition_distribution` and `observation_distribution` ' 'must agree on last dimension of batch size') elif self.validate_args: tdbs = self.transition_distribution.batch_shape_tensor() odbs = self.observation_distribution.batch_shape_tensor() transition_precondition = assert_util.assert_greater( tf.size(tdbs), 0, message=('`transition_distribution` can\'t have scalar ' 'batches')) observation_precondition = assert_util.assert_greater( tf.size(odbs), 0, message=('`observation_distribution` can\'t have scalar ' 'batches')) with tf.control_dependencies( [transition_precondition, observation_precondition]): assertions += [ assert_util.assert_equal( tdbs[-1], odbs[-1], message=('`transition_distribution` and ' '`observation_distribution` ' 'must agree on last dimension of batch size')) ] return assertions
def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: """Internal function to map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ keys = list(feature_lengths) partial = empty_example.copy() first_key, *_ = keys dynamic_batch_size = tf.shape(x[first_key])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[feature_lengths[k]]) outputs[k + "_positions"] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[feature_lengths[k]]) for i in tf.range(0, dynamic_batch_size): tf.autograph.experimental.set_loop_options( shape_invariants=[ (partial, {k: tf.TensorShape([None]) for k in keys_etc}), (outputs, {k: tf.TensorShape(None) for k in keys_etc})] ) can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), feature_lengths[k])) if not can_append: partial, outputs = _write_packed_example(partial, outputs) new_partial = {} for k in keys: new_seq = one_example[k][:feature_lengths[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + "_positions"] = tf.concat( [partial[k + "_positions"], tf.range(new_seq_len, dtype=tf.int32)], 0) partial = new_partial partial, outputs = _write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + "_segment_ids"] = ( tf.cumsum( tf.cast(tf.equal(packed[k + "_positions"], 0), tf.int32), axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)) return packed
def __init__(self, perm=None, rightmost_transposed_ndims=None, validate_args=False, name='transpose'): """Instantiates the `Transpose` bijector. Args: perm: Positive `int32` vector-shaped `Tensor` representing permutation of rightmost dims (for forward transformation). Note that the `0`th index represents the first of the rightmost dims and the largest value must be `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`. rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor` representing the number of rightmost dimensions to permute. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.size(perm)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are specified. NotImplementedError: if `rightmost_transposed_ndims` is not known prior to graph execution. """ with tf.name_scope(name) as name: if (rightmost_transposed_ndims is None) == (perm is None): raise ValueError('Must specify exactly one of ' '`rightmost_transposed_ndims` and `perm`.') if rightmost_transposed_ndims is not None: rightmost_transposed_ndims = tf.convert_to_tensor( rightmost_transposed_ndims, dtype=np.int32, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) assertions = _maybe_validate_rightmost_transposed_ndims( rightmost_transposed_ndims, validate_args) if assertions: with tf.control_dependencies(assertions): rightmost_transposed_ndims = tf.identity(rightmost_transposed_ndims) perm_start = ( distribution_util.prefer_static_value(rightmost_transposed_ndims) - 1) perm = tf.range(start=perm_start, limit=-1, delta=-1, name='perm') else: # perm is not None: perm = tf.convert_to_tensor(perm, dtype=np.int32, name='perm') rightmost_transposed_ndims = tf.size( perm, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) assertions = _maybe_validate_perm(perm, validate_args) if assertions: with tf.control_dependencies(assertions): perm = tf.identity(perm) # TODO(b/110828604): If bijector base class ever supports dynamic # `min_event_ndims`, then this class already works dynamically and the # following five lines can be removed. if rightmost_transposed_ndims_ is None: raise NotImplementedError('`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') else: rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_) self._perm = perm self._rightmost_transposed_ndims = rightmost_transposed_ndims super(Transpose, self).__init__( forward_min_event_ndims=rightmost_transposed_ndims_, graph_parents=[perm, rightmost_transposed_ndims], is_constant_jacobian=True, validate_args=validate_args, name=name)
def _forward_event_shape_tensor(self, input_shape): perm = self._make_perm(tf.size(input_shape), self.perm) return tf.gather(input_shape, perm)
def _mvnormal_quasi(sample_shape, mean, random_type, seed, covariance_matrix=None, scale_matrix=None, validate_args=False, dtype=None, **kwargs): """Returns normal draws using low-discrepancy sequences.""" (mean, scale_matrix, batch_shape, dim, dtype) = _process_mean_scale(mean, scale_matrix, covariance_matrix, dtype) # Reverse elements of the batch shape batch_shape_reverse = tf.reverse(batch_shape, [0]) # Transposed shape of the output output_shape_t = tf.concat([batch_shape_reverse, sample_shape], -1) # Number of quasi random samples num_samples = tf.reduce_prod(output_shape_t) // dim # Number of initial low discrepancy sequence numbers to skip if 'skip' in kwargs: skip = kwargs['skip'] else: skip = 0 if random_type == RandomType.SOBOL: # TODO(b/182621549): For Sobol sequences, dimension should be known at graph # construction time. dim = tf.get_static_value(dim) if dim is None: raise ValueError( 'For Sobol sequences, dimension should be known at graph' ' construction time.') # Shape [num_samples, dim] of the Sobol samples low_discrepancy_seq = sobol.sample(dim=dim, num_results=num_samples, skip=skip, dtype=dtype) else: # HALTON or HALTON_RANDOMIZED random_dtype if 'randomization_params' in kwargs: randomization_params = kwargs['randomization_params'] else: randomization_params = None randomized = random_type == RandomType.HALTON_RANDOMIZED # Shape [num_samples, dim] of the Sobol samples low_discrepancy_seq, _ = halton.sample( dim=dim, sequence_indices=tf.range(skip, skip + num_samples), randomized=randomized, randomization_params=randomization_params, seed=seed, validate_args=validate_args, dtype=dtype) # Transpose to the shape [dim, num_samples] low_discrepancy_seq = tf.transpose(low_discrepancy_seq) size_sample = tf.size(sample_shape) size_batch = tf.size(batch_shape) # Permutation for `output_shape_t` to the output shape permutation = tf.concat([ tf.range(size_batch, size_batch + size_sample), tf.range(size_batch - 1, -1, -1) ], -1) # Reshape Sobol samples to the correct output shape low_discrepancy_seq = tf.transpose( tf.reshape(low_discrepancy_seq, output_shape_t), permutation) # Apply inverse Normal CDF to Sobol samples to obtain the corresponding # Normal samples samples = tf.math.erfinv((low_discrepancy_seq - 0.5) * 2) * _SQRT_2 if scale_matrix is None: return mean + samples else: return mean + tf.linalg.matvec(scale_matrix, samples)
def __init__(self, perm=None, rightmost_transposed_ndims=None, validate_args=False, name='transpose'): """Instantiates the `Transpose` bijector. Args: perm: Positive `int32` vector-shaped `Tensor` representing permutation of rightmost dims (for forward transformation). Note that the `0`th index represents the first of the rightmost dims and the largest value must be `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. The number of elements in a permutation must have a value that can be determined statically. Default value: `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`. rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor` representing the number of rightmost dimensions to permute. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. If `rightmost_transposed_ndims` is specified, the rightmost dims are reversed. This argument must have a value that can be determined statically. Default value: `tf.size(perm)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are specified. NotImplementedError: if `rightmost_transposed_ndims` is not known prior to graph execution. """ parameters = dict(locals()) with tf.name_scope(name) as name: # We need to determine `forward_min_event_ndims` statically, which # requires that we know `rightmost_transposed_ndims` statically. # So the corresponding assertions go here rather than in # `_parameter_control_dependencies` if (rightmost_transposed_ndims is None) == (perm is None): raise ValueError('Must specify exactly one of ' '`rightmost_transposed_ndims` and `perm`.') if rightmost_transposed_ndims is not None: rightmost_transposed_ndims = tensor_util.convert_nonref_to_tensor( rightmost_transposed_ndims, dtype_hint=np.int32) if not dtype_util.is_integer(rightmost_transposed_ndims.dtype): raise TypeError( '`rightmost_transposed_ndims` must be integer type.') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) if rightmost_transposed_ndims_ is None: raise NotImplementedError( '`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') msg = '`rightmost_transposed_ndims` must be non-negative.' if rightmost_transposed_ndims_ < 0: raise ValueError( msg[:-1] + ', saw: {}.'.format(rightmost_transposed_ndims_)) perm_start = (distribution_util.prefer_static_value( rightmost_transposed_ndims) - 1) perm = tf.range(start=perm_start, limit=-1, delta=-1, name='perm') else: # perm is not None: perm = tensor_util.convert_nonref_to_tensor( perm, dtype_hint=np.int32, name='perm') rightmost_transposed_ndims = tf.size( perm, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) # TODO(b/110828604): If bijector base class ever supports dynamic # `min_event_ndims`, then this class already works dynamically and the # following five lines can be removed. if rightmost_transposed_ndims_ is None: raise NotImplementedError( '`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') else: rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_) self._perm = perm self._rightmost_transposed_ndims = rightmost_transposed_ndims self._initial_rightmost_transposed_ndims = rightmost_transposed_ndims_ super(Transpose, self).__init__( forward_min_event_ndims=rightmost_transposed_ndims_, is_constant_jacobian=True, validate_args=validate_args, parameters=parameters, name=name)
def grad_fn(*dresults, **kwargs): """Adjoint sensitivity method to compute gradients.""" dresults = tf.nest.pack_sequence_as(results, dresults) dstates = dresults.states # The signature grad_fn(*dresults, variables=None) is not valid Python 2 # so use kwargs instead. variables = kwargs.pop('variables', []) assert not kwargs # This assert should never fail. # TODO(b/138304303): Support complex types. with tf.name_scope('{}Gradients'.format(self._name)): get_dtype = lambda x: x.dtype def error_if_complex(dtype): if dtype.is_complex: raise NotImplementedError( 'The adjoint sensitivity method does ' 'not support complex dtypes.') state_dtypes = tf.nest.map_structure( get_dtype, initial_state) tf.nest.map_structure(error_if_complex, state_dtypes) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # We add initial_time to ensure that we know where to stop. result_times = tf.concat( [[tf.cast(initial_time, real_dtype)], results.times], 0) num_result_times = tf.size(result_times) # First two components correspond to reverse and adjoint states. # the last component is adjoint state for variables. terminal_augmented_state = tuple([ rk_util.nest_constant(initial_state, 0.0), rk_util.nest_constant(initial_state, 0.0), tuple( rk_util.nest_constant(variable, 0.0) for variable in variables) ]) # The XLA compiler does not compile code which slices/indexes using # integer `Tensor`s. `TensorArray`s are used to get around this. result_time_array = tf.TensorArray( results.times.dtype, clear_after_read=False, size=num_result_times, element_shape=[]).unstack(result_times) # TensorArray shape should not include time dimension, hence shape[1:] result_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(results.states) ] result_state_arrays = tf.nest.pack_sequence_as( results.states, result_state_arrays) dresult_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(dstates) ] dresult_state_arrays = tf.nest.pack_sequence_as( results.states, dresult_state_arrays) def augmented_ode_fn(backward_time, augmented_state): """Dynamics function for the augmented system. Describes a differential equation that evolves the augmented state backwards in time to compute gradients using the adjoint method. Augmented state consists of 3 components `(state, adjoint_state, vars)` all evaluated at time `backward_time`: state: represents the solution of user provided `ode_fn`. The structure coincides with the `initial_state`. adjoint_state: represents the solution of adjoint sensitivity differential equation as discussed below. Has the same structure and shape as `state`. vars: represent the solution of the adjoint equation for variable gradients. Represented as a `Tuple(Tensor, ...)` with as many tensors as there are `variables`. Adjoint sensitivity equation describes the gradient of the solution with respect to the value of the solution at previous time t. Its dynamics are given by d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z) Which is computed as: d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)] d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] where in the last line we moved adj(t)_j under derivative by removing gradient from it. Adjoint equation for the gradient with respect to every `tf.Variable` theta follows: d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta) = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] Args: backward_time: Floating `Tensor` representing current time. augmented_state: `Tuple(state, adjoint_state, variable_grads)` Returns: negative_derivatives: Structure of `Tensor`s equal to backwards time derivative of the `state` componnent. adjoint_ode: Structure of `Tensor`s equal to backwards time derivative of the `adjoint_state` component. adjoint_variables_ode: Structure of `Tensor`s equal to backwards time derivative of the `vars` component. """ # The negative signs disappears after the change of variables. # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state, adjoint_state, _ = augmented_state with tf.GradientTape() as tape: tape.watch(variables) tape.watch(state) derivatives = ode_fn(time, state) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n( tf.nest.flatten(adjoint_dot_derivatives))) adjoint_ode, adjoint_variables_ode = tape.gradient( adjoint_dot_derivatives, (state, tuple(variables)), unconnected_gradients=tf.UnconnectedGradients.ZERO) return negative_derivatives, adjoint_ode, adjoint_variables_ode def reverse_to_result_time(n, augmented_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) _, adjoint_state, adjoint_variable_state = augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) initial_augmented_state = (initial_state, initial_adjoint_state, adjoint_variable_state) # TODO(b/143624114). augmented_results = self._solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return n - 1, result_state, status _, augmented_state, _ = tf.while_loop( lambda n, _, status: (n >= 1) & tf.equal(status, 0), reverse_to_result_time, (num_result_times - 1, terminal_augmented_state, 0), back_prop=False) _, adjoint_state, adjoint_variables = augmented_state return adjoint_state, list(adjoint_variables)
def _sample_paths(self, times, time_step, num_samples, random_type, skip, seed): """Returns a sample of paths from the process.""" # Note: all the notations below are the same as in [2]. times, keep_mask = _prepare_grid(times, time_step) # Add zeros as a starting location dt = times[1:] - times[:-1] if dt.shape.is_fully_defined(): steps_num = dt.shape.as_list()[-1] else: steps_num = tf.shape(dt)[-1] # In order to use low-discrepancy random_type we need to generate the # sequence of independent random normals upfront. We also precompute random # numbers for stateless random type in order to ensure independent samples # for multiple function calls whith different seeds. if random_type in (random.RandomType.SOBOL, random.RandomType.HALTON, random.RandomType.HALTON_RANDOMIZED, random.RandomType.STATELESS, random.RandomType.STATELESS_ANTITHETIC): normal_draws = utils.generate_mc_normal_draws( num_normal_draws=self._dim, num_time_steps=steps_num, num_sample_paths=num_samples, random_type=random_type, seed=seed, dtype=self._dtype, skip=skip) else: normal_draws = None cond_fn = lambda i, *args: i < tf.size(dt) def body_fn(i, written_count, current_x, current_y, x_paths, y_paths): """Simulate qG-HJM process to the next time point.""" if normal_draws is None: normals = random.mv_normal_sample( (num_samples,), mean=tf.zeros((self._dim,), dtype=self._dtype), random_type=random_type, seed=seed) else: normals = normal_draws[i] if self._sqrt_rho is not None: normals = tf.linalg.matvec(self._sqrt_rho, normals) vol = self._volatility(times[i + 1], current_x) next_x = (current_x + (current_y - self._mean_reversion * current_x) * dt[i] + vol * normals * tf.math.sqrt(dt[i])) next_y = current_y + (vol**2 - 2.0 * self._mean_reversion * current_y) * dt[i] # Update `x_paths` and `y_paths` x_paths = utils.maybe_update_along_axis( tensor=x_paths, do_update=True, ind=written_count + 1, axis=1, new_tensor=tf.expand_dims(next_x, axis=1)) y_paths = utils.maybe_update_along_axis( tensor=y_paths, do_update=True, ind=written_count + 1, axis=1, new_tensor=tf.expand_dims(next_y, axis=1)) written_count += 1 return (i + 1, written_count, next_x, next_y, x_paths, y_paths) x_paths = tf.zeros((num_samples, times.shape.as_list()[0], self._factors), dtype=self._dtype) y_paths = tf.zeros((num_samples, times.shape.as_list()[0], self._factors), dtype=self._dtype) initial_x = tf.zeros((num_samples, self._factors), dtype=self._dtype) initial_y = tf.zeros((num_samples, self._factors), dtype=self._dtype) _, _, _, _, x_paths, y_paths = tf.while_loop( cond_fn, body_fn, (0, 0, initial_x, initial_y, x_paths, y_paths)) f_0_t = self._instant_forward_rate_fn(times) # shape=(num_times,) rate_paths = tf.math.reduce_sum( x_paths, axis=-1) + f_0_t # shape=(num_samples, num_times) discount_factor_paths = tf.math.exp(-rate_paths[:, :-1] * dt) discount_factor_paths = tf.concat( [tf.ones((num_samples, 1), dtype=self._dtype), discount_factor_paths], axis=1) # shape=(num_samples, num_times) discount_factor_paths = utils.cumprod_using_matvec(discount_factor_paths) return ( tf.boolean_mask(rate_paths, keep_mask, axis=1), tf.boolean_mask(discount_factor_paths, keep_mask, axis=1), tf.boolean_mask(x_paths, keep_mask, axis=1), tf.boolean_mask(y_paths, keep_mask, axis=1) )
def _solve( self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None, ): # Static assertions del jacobian_fn, jacobian_sparsity # not used by DormandPrince if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'For homogeneous batching use `batch_ndims=0`.') solution_times_by_solver = isinstance(solution_times, base.ChosenBySolver) with tf.name_scope(self._name): # (2) Convert to tensors, determined dtypes. get_dtype = lambda x: x.dtype error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_dtypes = tf.nest.map_structure(get_dtype, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) initial_time = tf.cast(initial_time, real_dtype) max_num_steps = self._max_num_steps max_ode_fn_evals = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_ode_fn_evals = max_num_steps * self.ODE_FN_EVALS_PER_STEP step_size = tf.convert_to_tensor(self._first_step_size, dtype=real_dtype) rtol = tf.convert_to_tensor(tf.cast(self._rtol, real_dtype)) atol = tf.convert_to_tensor(tf.cast(self._atol, real_dtype)) safety = tf.convert_to_tensor(self._safety_factor, dtype=real_dtype) # Use i(d)factor notation for increasing and decreasing factors. ifactor, dfactor = self._max_step_size_factor, self._min_step_size_factor ifactor = tf.convert_to_tensor(ifactor, dtype=real_dtype) dfactor = tf.convert_to_tensor(dfactor, dtype=real_dtype) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: initial_derivative = ode_fn(initial_time, initial_state) initial_derivative = tf.nest.map_structure( tf.convert_to_tensor, initial_derivative) solver_internal_state = _RungeKuttaSolverInternalState( current_state=initial_state, current_derivative=initial_derivative, last_step_start=initial_time, current_time=initial_time, step_size=step_size, interpolating_coefficients=[initial_state] * self.ORDER) num_solution_times = 0 if solution_times_by_solver: final_time = tf.cast(solution_times.final_time, real_dtype) times_array = tf.TensorArray(real_dtype, size=num_solution_times, dynamic_size=True, element_shape=tf.TensorShape([])) else: solution_times = tf.cast(solution_times, real_dtype) util.error_if_not_vector(solution_times, 'solution_times') num_solution_times = tf.size(solution_times) times_array = tf.TensorArray( real_dtype, size=num_solution_times, dynamic_size=False, element_shape=[]).unstack(solution_times) solutions_arrays = [ tf.TensorArray(dtype=component_dtype, size=num_solution_times, dynamic_size=solution_times_by_solver) for component_dtype in tf.nest.flatten(state_dtypes) ] solutions_arrays = tf.nest.pack_sequence_as( initial_state, solutions_arrays) rk_step = functools.partial(self._step, max_ode_fn_evals=max_ode_fn_evals, ode_fn=ode_fn, atol=atol, rtol=rtol, safety=safety, ifactor=ifactor, dfactor=dfactor) advance_to_solution_time = functools.partial( _advance_to_solution_time, times_array=solution_times, step_fn=rk_step, validate_args=self._validate_args) assert_ops = self._assert_ops( ode_fn=ode_fn, initial_time=initial_time, initial_state=initial_state, solution_times=solution_times, previous_solver_state=previous_solver_internal_state, rtol=rtol, atol=atol, first_step_size=step_size, safety_factor=safety, min_step_size_factor=ifactor, max_step_size_factor=dfactor, max_num_steps=max_num_steps, solution_times_by_solver=solution_times_by_solver) with tf.control_dependencies(assert_ops): ode_evals_by_now = 1 if self._validate_args else 0 ode_evals_by_now += 1 if solver_internal_state is None else 0 diagnostics = _DopriDiagnostics( num_ode_fn_evaluations=ode_evals_by_now, num_jacobian_evaluations=0, num_matrix_factorizations=0, status=0) if solution_times_by_solver: r = _dense_solutions_to_final_time( final_time=final_time, solver_state=solver_internal_state, diagnostics=diagnostics, step_fn=rk_step, ode_fn=ode_fn, times_array=times_array, solutions_arrays=solutions_arrays, validate_args=self._validate_args) solver_internal_state, diagnostics, times_array, solutions_arrays = r else: def iterate_cond(time_id, *_): return time_id < num_solution_times [_, solver_internal_state, diagnostics, solutions_arrays ] = tf.while_loop(iterate_cond, advance_to_solution_time, [ 0, solver_internal_state, diagnostics, solutions_arrays ], back_prop=False) times = times_array.stack() stack_components = lambda x: x.stack() states = tf.nest.map_structure(stack_components, solutions_arrays) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def filter_max_len(x, y, max_len=max_len): """filters max length of corpus""" return tf.logical_and(tf.size(x) <= max_len, tf.size(y) <= max_len)
def _binary_crossover(population, population_size, mutants, crossover_prob, seed): """Performs recombination by binary crossover for the current population. Let v_i denote the i'th component of the member v and m_i the corresponding component of the mutant vector corresponding to v. Then the crossed over vector w_i is determined by setting w_i = (m_i with probability=crossover_prob else v_i). In addition, DE requires that at least one of the components is crossed over (otherwise we end up with no change). This is done by choosing on index say k randomly where a force crossover is performed (i.e. w_k = m_k). This is the scheme implemented in this function. Args: population: A Python list of `Tensor`s where each `Tensor` in the list must be of rank at least 1 and all the elements must have a common first dimension. The base population to cross over. population_size: A scalar integer `Tensor`. The number of elements in the population (i.e. size of the first dimension of any member of `population`). mutants: A Python list of `Tensor`s with the same structure as `population`. The mutated population. crossover_prob: A positive real scalar `Tensor` bounded above by 1.0. The probability of a crossover being performed for each axis. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. Returns: A list of `Tensor`s of the same structure, dtype and shape as `population`. The recombined population. """ sizes = [tf.cast(tf.size(x), dtype=tf.float64) for x in population] seed_stream = tfp_util.SeedStream(seed, salt='binary_crossover') force_crossover_group = distributions.Categorical(sizes).sample( [population_size, 1], seed=seed_stream()) recombinants = [] for i, population_part in enumerate(population): pop_part_flat = tf.reshape(population_part, [population_size, -1]) mutant_part_flat = tf.reshape(mutants[i], [population_size, -1]) part_size = tf.size(population_part) // population_size force_crossovers = tf.one_hot( tf.random.uniform([population_size], minval=0, maxval=part_size, dtype=tf.int32, seed=seed_stream()), part_size, on_value=True, off_value=False, dtype=tf.bool) # Tensor of shape [population_size, size] group_mask = tf.math.equal(force_crossover_group, i) force_crossovers &= group_mask do_binary_crossover = tf.random.uniform( [population_size, part_size], dtype=crossover_prob.dtype.base_dtype, seed=seed_stream()) < crossover_prob do_binary_crossover |= force_crossovers recombinant_flat = tf1.where(do_binary_crossover, mutant_part_flat, pop_part_flat) recombinant = tf.reshape(recombinant_flat, tf.shape(population_part)) recombinants.append(recombinant) return recombinants
def _axis_size(x, axis=None): """Get number of elements of `x` in `axis`, as type `x.dtype`.""" if axis is None: return tf.cast(tf.size(x), x.dtype) return tf.cast(tf.reduce_prod(tf.gather(tf.shape(x), axis)), x.dtype)
def _sample_control_dependencies(self, x): """Helper which validates sample arg, e.g., input to `log_prob`.""" x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else tensorshape_util.rank(x.shape)) event_ndims = (tf.size(self.event_shape_tensor()) if tensorshape_util.rank(self.event_shape) is None else tensorshape_util.rank(self.event_shape)) batch_ndims = (tf.size(self._batch_shape_unexpanded) if tensorshape_util.rank(self.batch_shape) is None else tensorshape_util.rank(self.batch_shape)) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and isinstance(expected_batch_event_ndims, int)): if x_ndims < expected_batch_event_ndims: raise NotImplementedError( 'Broadcasting is not supported; too few batch and event dims ' '(expected at least {}, saw {}).'.format( expected_batch_event_ndims, x_ndims)) ndims_assertion = [] elif self.validate_args: ndims_assertion = [ assert_util.assert_greater_equal( x_ndims, expected_batch_event_ndims, message=('Broadcasting is not supported; too few ' 'batch and event dims.'), name='assert_batch_and_event_ndims_large_enough'), ] if (tensorshape_util.is_fully_defined(self.batch_shape) and tensorshape_util.is_fully_defined(self.event_shape)): expected_batch_event_shape = np.int32( tensorshape_util.concatenate(self.batch_shape, self.event_shape)) else: expected_batch_event_shape = tf.concat([ self.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) sample_ndims = x_ndims - expected_batch_event_ndims if isinstance(sample_ndims, int): sample_ndims = max(sample_ndims, 0) if (isinstance(sample_ndims, int) and tensorshape_util.is_fully_defined(x.shape[sample_ndims:])): actual_batch_event_shape = np.int32(x.shape[sample_ndims:]) else: sample_ndims = tf.maximum(sample_ndims, 0) actual_batch_event_shape = tf.shape(x)[sample_ndims:] assertions = [] if (isinstance(expected_batch_event_shape, np.ndarray) and isinstance(actual_batch_event_shape, np.ndarray)): if any(expected_batch_event_shape != actual_batch_event_shape): raise NotImplementedError('Broadcasting is not supported; ' 'unexpected batch and event shape ' '(expected {}, saw {}).'.format( expected_batch_event_shape, actual_batch_event_shape)) # We need to set the final runtime-assertions to `ndims_assertion` since # its possible this assertion was created. We could add a condition to # only do so if `self.validate_args == True`, however this is redundant # as `ndims_assertion` already encodes this information. assertions.extend(ndims_assertion) elif self.validate_args: # We need to make the `ndims_assertion` a control dep because otherwise # TF itself might raise an exception owing to this assertion being # ill-defined, ie, one cannot even compare different rank Tensors. with tf.control_dependencies(ndims_assertion): shape_assertion = assert_util.assert_equal( expected_batch_event_shape, actual_batch_event_shape, message=('Broadcasting is not supported; ' 'unexpected batch and event shape.'), name='assert_batch_and_event_shape_same') assertions.append(shape_assertion) return assertions