def _inverse_log_det_jacobian(self, y): # If event_ndims = 2, # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. with tf.control_dependencies(self._assertions(y)): zero = tf.zeros([], dtype=dtype_util.base_dtype(y.dtype)) return zero, zero
def _extend_support(self, x, scale, f, alt): """Returns `f(x)` if x is in the support, and `alt` otherwise. Given `f` which is defined on the support of this distribution (e.g. x > scale), extend the function definition to the real line by defining `f(x) = alt` for `x < scale`. Args: x: Floating-point Tensor to evaluate `f` at. scale: Floating-point Tensor by which to verify `x` validity. f: Lambda that takes in a tensor and returns a tensor. This represents the function who we want to extend the domain of definition. alt: Python or numpy literal representing the value to use for extending the domain. Returns: Tensor representing an extension of `f(x)`. """ if self.validate_args: return f(x) scale = tf.convert_to_tensor(self.scale) if scale is None else scale is_invalid = x < scale # We need to do this to ensure gradients are sound. y = f(tf.where(is_invalid, scale, x)) if alt == 0.: alt = tf.zeros([], dtype=y.dtype) elif alt == 1.: alt = tf.ones([], dtype=y.dtype) else: alt = dtype_util.as_numpy_dtype(self.dtype)(alt) return tf.where(is_invalid, alt, y)
def _mode(self): scale_x_zeros = self.bijector.scale.matvec( tf.zeros(self._mode_mean_shape(), self.dtype)) if self.loc is None: return scale_x_zeros return tf.identity(self.loc) + scale_x_zeros
def _assert_valid_sample(self, x): if not self.validate_args: return x return distribution_util.with_dependencies([ assert_util.assert_non_positive(x), assert_util.assert_near( tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])), ], x)
def _rotate(self, samples): """Applies a Householder rotation to `samples`.""" event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) basis = tf.concat([[1.], tf.zeros([event_dim - 1], dtype=self.dtype)], axis=0), u = tf.math.l2_normalize(basis - self.mean_direction, axis=-1) return samples - 2 * tf.reduce_sum(samples * u, axis=-1, keepdims=True) * u
def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag
def _mean(self): shape = tensorshape_util.concatenate(self.batch_shape, self.event_shape) has_static_shape = tensorshape_util.is_fully_defined(shape) if not has_static_shape: shape = tf.concat([ self.batch_shape_tensor(), self.event_shape_tensor(), ], 0) if self.loc is None: return tf.zeros(shape, self.dtype) if has_static_shape and shape == self.loc.shape: return tf.identity(self.loc) # Add dummy tensor of zeros to broadcast. This is only necessary if shape # != self.loc.shape, but we could not determine if this is the case. return tf.identity(self.loc) + tf.zeros(shape, self.dtype)
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(x), self._batch_shape_tensor(low=low, high=high)) zeros = tf.zeros(broadcast_shape, dtype=self.dtype) ones = tf.ones(broadcast_shape, dtype=self.dtype) result_if_not_big = tf.where(x < low, zeros, (x - low) / self._range(low=low, high=high)) return tf.where(x >= high, ones, result_if_not_big)
def _forward_log_det_jacobian(self, x): if self.log_scale is not None: return self.log_scale elif self.scale is not None: return tf.math.log(tf.abs(self.scale)) else: # is_constant_jacobian = True for this bijector, hence the # `log_det_jacobian` need only be specified for a single input, as this # will be tiled to match `event_ndims`. return tf.zeros([], dtype=x.dtype)
def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] if (self.scale is not None and is_init != tensor_util.is_ref(self.scale)): assertions.append( assert_util.assert_none_equal( self.scale, tf.zeros([], dtype=self._scale.dtype), message='Argument `scale` must be non-zero.')) return assertions
def _reduce_multiple_steps(): """Perform `reduce_max` operation when `num_steps` > 1.""" def forward_step(previous_step_pair, log_prob_observation): log_prob_previous = previous_step_pair[0] log_prob = ( log_prob_previous[..., tf.newaxis] + self._log_trans + log_prob_observation[..., tf.newaxis, :]) most_likely_given_successor = tf.argmax(log_prob, axis=-2) max_log_p_given_successor = tf.reduce_max( input_tensor=log_prob, axis=-2) return (max_log_p_given_successor, most_likely_given_successor) forward_log_probs, all_most_likely_given_successor = tf.scan( forward_step, observation_log_probs[1:], initializer=(log_prob, tf.zeros(tf.shape(log_prob), dtype=tf.int64)), name="forward_log_probs") most_likely_end = tf.argmax(forward_log_probs[-1], axis=-1) # We require the operation that gives C from A and B where # C[i...j] = A[i...j, B[i...j]] # and A = most_likely_given_successor # B = most_likely_successor. # tf.gather requires indices of known shape so instead we use # reduction with tf.one_hot(B) to pick out elements from B def backward_step(most_likely_successor, most_likely_given_successor): return tf.reduce_sum( input_tensor=(most_likely_given_successor * tf.one_hot(most_likely_successor, self._num_states, dtype=tf.int64)), axis=-1) backward_scan = tf.scan( backward_step, all_most_likely_given_successor, most_likely_end, reverse=True) most_likely_sequences = tf.concat( [backward_scan, [most_likely_end]], axis=0) return distribution_util.move_dimension( most_likely_sequences, 0, -1)
def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) return quantiles
def _inverse(self, y): ndims = prefer_static.rank(y) shifted_y = tf.pad( tf.slice( y, tf.zeros(ndims, dtype=tf.int32), prefer_static.shape(y) - tf.one_hot(ndims + self.axis, ndims, dtype=tf.int32) ), # Remove the last entry of y in the chosen dimension. paddings=tf.one_hot( tf.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1), 2, dtype=tf.int32 ) # Insert zeros at the beginning of the chosen dimension. ) return y - shifted_y
def _scan_multiple_steps(): """Perform `scan` operation when `num_steps` > 1.""" transition_log_probs = self._log_trans def forward_step(log_probs, _): return _log_vector_matrix(log_probs, transition_log_probs) dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) forward_log_probs = tf.scan(forward_step, dummy_index, initializer=initial_log_probs, name="forward_log_probs") return tf.concat([[initial_log_probs], forward_log_probs], axis=0)
def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0)
def _inverse(self, y): n = prefer_static.shape(y)[-1] batch_shape = prefer_static.shape(y)[:-2] # Extract the reciprocal of the row norms from the diagonal. diag = tf.linalg.diag_part(y)[..., tf.newaxis] # Set the diagonal to 0s. y = tf.linalg.set_diag( y, tf.zeros(tf.concat([batch_shape, [n]], axis=-1), dtype=y.dtype)) # Multiply with the norm (or divide by its reciprocal) to recover the # unconstrained reals in the (strictly) lower triangular part. x = y / diag # Remove the first row and last column before inverting the FillTriangular # transformation. return fill_triangular.FillTriangular().inverse(x[..., 1:, :-1])
def __getitem__(self, slices): # Because slicing is parameterization-dependent, we only implement slicing # for instances of TD, not subclasses thereof. if type(self) is not TransformedDistribution: # pylint: disable=unidiomatic-typecheck return super(TransformedDistribution, self).__getitem__(slices) if tensorshape_util.rank(self.distribution.batch_shape) is None: raise NotImplementedError( "Slicing TransformedDistribution with underlying distribution of " "unknown rank is not yet implemented") overrides = {} if (tensorshape_util.rank(self.distribution.batch_shape) == 0 and self.parameters.get("batch_shape", None) is not None): overrides["batch_shape"] = tf.shape( tf.zeros(self.parameters["batch_shape"])[slices]) elif self.parameters.get("distribution", None) is not None: overrides["distribution"] = self.distribution[slices] return self.copy(**overrides)
def __init__(self, scale, validate_args=False, allow_nan_stats=True, name='Horseshoe'): """Construct a Horseshoe distribution with `scale`. Args: scale: Floating point tensor; the scales of the distribution(s). Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False` (i.e., do not validate args). allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: 'Horseshoe'. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([scale], dtype_hint=tf.float32) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) self._half_cauchy = half_cauchy.HalfCauchy( loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=True) super(Horseshoe, self).__init__(dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _forward(self, x): x = tf.convert_to_tensor(x, name='x') batch_shape = prefer_static.shape(x)[:-1] # Pad zeros on the top row and right column. y = fill_triangular.FillTriangular().forward(x) rank = prefer_static.rank(y) paddings = tf.concat([ tf.zeros(shape=(rank - 2, 2), dtype=tf.int32), tf.constant([[1, 0], [0, 1]], dtype=tf.int32) ], axis=0) y = tf.pad(y, paddings) # Set diagonal to 1s. n = prefer_static.shape(y)[-1] diag = tf.ones(tf.concat([batch_shape, [n]], axis=-1), dtype=x.dtype) y = tf.linalg.set_diag(y, diag) # Normalize each row to have Euclidean (L2) norm 1. y /= tf.norm(y, axis=-1)[..., tf.newaxis] return y
def _assertions(self, x): if not self.validate_args: return [] shape = tf.shape(x) is_matrix = assert_util.assert_rank_at_least( x, 2, message="Input must have rank at least 2.") is_square = assert_util.assert_equal( shape[-2], shape[-1], message="Input must be a square matrix.") above_diagonal = tf.linalg.band_part( tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0, -1) is_lower_triangular = assert_util.assert_equal( above_diagonal, tf.zeros_like(above_diagonal), message="Input must be lower triangular.") # A lower triangular matrix is nonsingular iff all its diagonal entries are # nonzero. diag_part = tf.linalg.diag_part(x) is_nonsingular = assert_util.assert_none_equal( diag_part, tf.zeros_like(diag_part), message="Input must have all diagonal entries nonzero.") return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
def _sample_n(self, n, seed=None): seed = SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) # set_shape needed here because of b/139013403 z.set_shape(w.shape) w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.debugging.check_numerics(w, 'w') unif = tf.random.uniform( sample_batch_shape, seed=seed(), dtype=self.dtype) # set_shape needed here because of b/139013403 unif.set_shape(w.shape) should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.math.log1p(-x * w) - c < tf.math.log(unif)) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01), data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01), data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]]) ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.math.l2_normalize( tf.random.normal( samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat([ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.math.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.math.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, data=[ worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx) ], atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm( self._rotate(basis) - self.mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)
def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin s = _shape(input) s_ = tf.get_static_value(s) if s_ is not None: return np.zeros(s, _numpy_dtype(dtype or input.dtype)) return tf.zeros(s, dtype or s.dtype, name)
def _entropy(self): return tf.zeros(self.batch_shape_tensor(), dtype=self.dtype)
def _mode(self): return tf.zeros(self.batch_shape_tensor(), dtype=self.dtype)
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) super(MultivariateNormalLinearOperator, self).__init__( distribution=normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype)), bijector=affine_linear_operator_bijector.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def _inverse_log_det_jacobian(self, y): return tf.zeros([], dtype=y.dtype)
def _forward_log_det_jacobian(self, x): return tf.zeros([], dtype=x.dtype)
def _pad(x): """Prepends and appends a zero to every vector in a batch of vectors.""" shape = tf.concat([tf.shape(x)[:-1], [1]], axis=0) z = tf.zeros(shape, dtype=x.dtype) return tf.concat([z, x, z], axis=-1)
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="SinhArcsinh"): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Default is `tfd.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) self._loc = tensor_util.convert_nonref_to_tensor(loc, name="loc", dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name="scale", dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if has_default_skewness else skewness self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name="tailweight", dtype=dtype) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name="skewness", dtype=dtype) batch_shape = distribution_util.get_broadcast_shape( self._loc, self._scale, self._tailweight, self._skewness) # Recall, with Z a random variable, # Y := loc + scale * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C # C := 2 / F_0(2) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats, validate_args=validate_args) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: self._loc = distribution_util.with_dependencies( asserts, self._loc) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh(skewness=self._skewness, tailweight=self._tailweight, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) affine = affine_scalar_bijector.AffineScalar( shift=self._loc, scale=self._scale, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, validate_args=validate_args, name=name) self._parameters = parameters
def _create_scale_operator(self, identity_multiplier, diag, tril, perturb_diag, perturb_factor, shift, validate_args, dtype): """Construct `scale` from various components. Args: identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. diag: Floating-point `Tensor` representing the diagonal matrix.`diag` has shape `[N1, N2, ... k]`, which represents a k x k diagonal matrix. tril: Floating-point `Tensor` representing the lower triangular matrix. `tril` has shape `[N1, N2, ... k, k]`, which represents a k x k lower triangular matrix. perturb_diag: Floating-point `Tensor` representing the diagonal matrix of the low rank update. perturb_factor: Floating-point `Tensor` representing factor matrix. shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. dtype: `DType` for arg `Tensor` conversions. Returns: scale. In the case of scaling by a constant, scale is a floating point `Tensor`. Otherwise, scale is a `LinearOperator`. Raises: ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. """ identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier", dtype) diag = _as_tensor(diag, "diag", dtype) tril = _as_tensor(tril, "tril", dtype) perturb_diag = _as_tensor(perturb_diag, "perturb_diag", dtype) perturb_factor = _as_tensor(perturb_factor, "perturb_factor", dtype) # If possible, use the low rank update to infer the shape of # the identity matrix, when scale represents a scaled identity matrix # with a low rank update. shape_hint = None if perturb_factor is not None: shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) if self._is_only_identity_multiplier: if validate_args: return distribution_util.with_dependencies([ assert_util.assert_none_equal( identity_multiplier, tf.zeros([], identity_multiplier.dtype), ["identity_multiplier should be non-zero."]) ], identity_multiplier) return identity_multiplier scale = distribution_util.make_tril_scale( loc=shift, scale_tril=tril, scale_diag=diag, scale_identity_multiplier=identity_multiplier, validate_args=validate_args, assert_positive=False, shape_hint=shape_hint) if perturb_factor is not None: return tf.linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, is_diag_update_positive=perturb_diag is None, is_non_singular=True, # Implied by is_positive_definite=True. is_self_adjoint=True, is_positive_definite=True, is_square=True) return scale