def sparse_or_dense_matmul(sparse_or_dense_a, dense_b, validate_args=False, name=None, **kwargs): """Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor. Args: sparse_or_dense_a: `SparseTensor` or `Tensor` representing a (batch of) matrices. dense_b: `Tensor` representing a (batch of) matrices, with the same batch shape as `sparse_or_dense_a`. The shape must be compatible with the shape of `sparse_or_dense_a` and kwargs. validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'sparse_or_dense_matmul'. **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul` or `tf.matmul`. Returns: product: A dense (batch of) matrix-shaped Tensor of the same batch shape and dtype as `sparse_or_dense_a` and `dense_b`. If `sparse_or_dense_a` or `dense_b` is adjointed through `kwargs` then the shape is adjusted accordingly. """ with tf.name_scope(name or 'sparse_or_dense_matmul'): dense_b = tf.convert_to_tensor(dense_b, dtype_hint=tf.float32, name='dense_b') if validate_args: assert_a_rank_at_least_2 = assert_util.assert_rank_at_least( sparse_or_dense_a, rank=2, message= 'Input `sparse_or_dense_a` must have at least 2 dimensions.') assert_b_rank_at_least_2 = assert_util.assert_rank_at_least( dense_b, rank=2, message='Input `dense_b` must have at least 2 dimensions.') with tf.control_dependencies( [assert_a_rank_at_least_2, assert_b_rank_at_least_2]): sparse_or_dense_a = tf.identity(sparse_or_dense_a) dense_b = tf.identity(dense_b) if isinstance(sparse_or_dense_a, (tf.SparseTensor, tf1.SparseTensorValue)): return _sparse_tensor_dense_matmul(sparse_or_dense_a, dense_b, **kwargs) else: return tf.matmul(sparse_or_dense_a, dense_b, **kwargs)
def _lu_reconstruct_assertions(lower_upper, perm, validate_args): """Returns list of assertions related to `lu_reconstruct` assumptions.""" assertions = [] message = 'Input `lower_upper` must have at least 2 dimensions.' if tensorshape_util.rank(lower_upper.shape) is not None: if tensorshape_util.rank(lower_upper.shape) < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(lower_upper, rank=2, message=message)) message = '`rank(lower_upper)` must equal `rank(perm) + 1`' if (tensorshape_util.rank(lower_upper.shape) is not None and tensorshape_util.rank(perm.shape) is not None): if (tensorshape_util.rank(lower_upper.shape) != tensorshape_util.rank(perm.shape) + 1): raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(lower_upper, rank=tf.rank(perm) + 1, message=message)) message = '`lower_upper` must be square.' if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]): if lower_upper.shape[-2] != lower_upper.shape[-1]: raise ValueError(message) elif validate_args: m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2) assertions.append(assert_util.assert_equal(m, n, message=message)) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(self.cutpoints.dtype): raise TypeError('Argument `cutpoints` must having floating type.') if not dtype_util.is_floating(self.loc.dtype): raise TypeError('Argument `loc` must having floating type.') cutpoint_dims = tensorshape_util.rank(self.cutpoints.shape) msg = 'Argument `cutpoints` must have rank at least 1.' if cutpoint_dims is not None: if cutpoint_dims < 1: raise ValueError(msg) elif self.validate_args: cutpoints = tf.convert_to_tensor(self.cutpoints) assertions.append( assert_util.assert_rank_at_least(cutpoints, 1, message=msg)) if not self.validate_args: return [] if is_init != tensor_util.is_ref(self.cutpoints): cutpoints = tf.convert_to_tensor(self.cutpoints) assertions.append(distribution_util.assert_nondecreasing( cutpoints, message='Argument `cutpoints` must be non-decreasing.')) return assertions
def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with tf.name_scope("check_" + name): assertions = [] if tensorshape_util.rank(param.shape) is not None: if tensorshape_util.rank(param.shape) == 0: raise ValueError("Mixing params must be a (batch of) vector; " "{}.rank={} is not at least one.".format( name, tensorshape_util.rank(param.shape))) elif validate_args: assertions.append( assert_util.assert_rank_at_least( param, 1, message=("Mixing params must be a (batch of) vector; " "{}.rank is not at least one.".format(name)))) # TODO(jvdillon): Remove once we support k-mixtures. if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None: if tf.compat.dimension_value(param.shape[-1]) != 1: raise NotImplementedError( "Currently only bimixtures are supported; " "{}.shape[-1]={} is not 1.".format( name, tf.compat.dimension_value(param.shape[-1]))) elif validate_args: assertions.append( assert_util.assert_equal( tf.shape(input=param)[-1], 1, message=("Currently only bimixtures are supported; " "{}.shape[-1] is not 1.".format(name)))) if assertions: return distribution_util.with_dependencies(assertions, param) return param
def _parameter_control_dependencies(self, is_init): assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init and self._is_vector: msg = "Argument `loc` must be at least rank 1." if tensorshape_util.rank(self.loc.shape) is not None: if tensorshape_util.rank(self.loc.shape) < 1: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(self.loc, 1, message=msg)) if not self.validate_args: assert not assertions # Should never happen return [] if is_init != tensor_util.is_mutable(self.atol): assertions.append( assert_util.assert_non_negative( self.atol, message="Argument 'atol' must be non-negative")) if is_init != tensor_util.is_mutable(self.rtol): assertions.append( assert_util.assert_non_negative( self.rtol, message="Argument 'rtol' must be non-negative")) return assertions
def _maybe_assert_float_matrix(logu, validate_args): """Assertion check for the scores matrix to be float type.""" logu = tf.convert_to_tensor(logu, dtype_hint=tf.float32, name='logu') if not dtype_util.is_floating(logu.dtype): raise TypeError('Input argument must be `float` type.') assertions = [] # Check scores is a matrix. msg = 'Input argument must be a (batch of) matrix.' rank = tensorshape_util.rank(logu.shape) if rank is not None: if rank < 2: raise ValueError(msg) elif validate_args: assertions.append(assert_util.assert_rank_at_least(logu, 2, msg)) # Check scores has the shape [..., N, M], M >= N msg = 'Input argument must be a (batch of) matrix of the shape [N, M], M > N.' if (rank is not None and tensorshape_util.is_fully_defined(logu.shape[-2:])): if logu.shape[-2] > logu.shape[-1]: raise ValueError(msg) elif validate_args: n, m = tf.unstack(logu.shape[-2:]) assertions.append(assert_util.assert_greater_equal(m, n, message=msg)) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError('Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, dtype_util.max(tf.int32)) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > dtype_util.max(tf.int32): raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append(assert_util.assert_greater_equal( tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def split_and_reshape(x): assertions = [] message = 'Input must have at least one dimension.' if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 0: raise ValueError(message) else: assertions.append( assert_util.assert_rank_at_least(x, 1, message=message)) with tf.control_dependencies(assertions): x = tf.nest.pack_sequence_as( free_rv_event_shape, tf.split(x, flat_event_splits, axis=-1)) def _reshape_map_part(part, event_shape): static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) x = tf.nest.map_structure(_reshape_map_part, x, free_rv_event_shape) return x
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args): """Returns list of assertions related to `lu_solve` assumptions.""" assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args) message = 'Input `rhs` must have at least 2 dimensions.' if rhs.shape.ndims is not None: if rhs.shape.ndims < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(rhs, rank=2, message=message)) message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.' if (tf.compat.dimension_value(lower_upper.shape[-1]) is not None and tf.compat.dimension_value(rhs.shape[-2]) is not None): if lower_upper.shape[-1] != rhs.shape[-2]: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_equal( tf.shape(lower_upper)[-1], tf.shape(rhs)[-2], message=message)) return assertions
def _assertions(self, t): if not self.validate_args: return [] is_matrix = assert_util.assert_rank_at_least(t, 2) is_square = assert_util.assert_equal(tf.shape(t)[-2], tf.shape(t)[-1]) is_positive_definite = assert_util.assert_positive( tf.linalg.diag_part(t), message="Input must be positive definite.") return [is_matrix, is_square, is_positive_definite]
def _sample_control_dependencies(self, x): assertions = [] message = 'Input must have at least one dimension.' if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 0: raise ValueError(message) elif self.validate_args: assertions.append(assert_util.assert_rank_at_least(x, 1, message=message)) return assertions
def _forward(self, x): if self.validate_args: is_matrix = assert_util.assert_rank_at_least(x, 2) shape = tf.shape(input=x) is_square = assert_util.assert_equal(shape[-2], shape[-1]) x = distribution_util.with_dependencies([is_matrix, is_square], x) # For safety, explicitly zero-out the upper triangular part. x = tf.linalg.band_part(x, -1, 0) return tf.matmul(x, x, adjoint_b=True)
def _sample_control_dependencies(self, x): assertions = [] if not self.validate_args: return assertions assertions.append(assert_util.assert_rank_at_least(x, 1)) assertions.append(assert_util.assert_equal( self.event_shape_tensor(), tf.gather(tf.shape(x), tf.rank(x) - 1), message=('Argument `x` not defined in the same space ' 'R**k as this distribution'))) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] scores = self._scores param, name = (scores, 'scores') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(scores): scores = tf.convert_to_tensor(scores) assertions.extend([ assert_util.assert_positive(scores), ]) return assertions
def __init__(self, samples, event_ndims=0, validate_args=False, allow_nan_stats=True, name='Empirical'): """Initialize `Empirical` distributions. Args: samples: Numeric `Tensor` of shape [B1, ..., Bk, S, E1, ..., En]`, `k, n >= 0`. Samples or batches of samples on which the distribution is based. The first `k` dimensions index into a batch of independent distributions. Length of `S` dimension determines number of samples in each multiset. The last `n` dimension represents samples for each distribution. n is specified by argument event_ndims. event_ndims: Python `int32`, default `0`. number of dimensions for each event. When `0` this distribution has scalar samples. When `1` this distribution has vector-like samples. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if the rank of `samples` < event_ndims + 1. """ parameters = locals() with tf.name_scope(name): self._samples = tf.convert_to_tensor(value=samples, name='samples') self._event_ndims = event_ndims self._samples_axis = ((tensorshape_util.rank(self.samples.shape) or tf.rank(self.samples)) - self._event_ndims - 1) with tf.control_dependencies([ assert_util.assert_rank_at_least(self._samples, event_ndims + 1) ]): samples_shape = distribution_util.prefer_static_shape( self._samples) self._num_samples = samples_shape[self._samples_axis] super(Empirical, self).__init__( dtype=self._samples.dtype, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._samples], name=name)
def _log_prob(self, x): additional_assertions = [] message = 'Input must have at least one dimension.' if x.shape.ndims is not None: if x.shape.ndims == 0: raise ValueError(message) elif self.validate_args: additional_assertions.append( assert_util.assert_rank_at_least(x, 1, message=message)) with tf.control_dependencies(self._assertions + additional_assertions): event_sizes = [d.event_shape_tensor()[0] for d in self.distributions] xs = tf.split(x, event_sizes, axis=-1) return sum(tf.cast(d.log_prob(tf.cast(x, d.dtype)), self.dtype) for d, x in zip(self.distributions, xs))
def _maybe_validate_matrix(a, validate_args): """Checks that input is a `float` matrix.""" assertions = [] if not dtype_util.is_floating(a.dtype): raise TypeError('Input `a` must have `float`-like `dtype` ' '(saw {}).'.format(a.dtype.name)) if a.shape.ndims is not None: if a.shape.ndims < 2: raise ValueError('Input `a` must have at least 2 dimensions ' '(saw: {}).'.format(a.shape.ndims)) elif validate_args: assertions.append(assert_util.assert_rank_at_least( a, rank=2, message='Input `a` must have at least 2 dimensions.')) return assertions
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) ndims = tensorshape_util.rank(x.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif validate_args: x = tf.convert_to_tensor(x) probs = x if logits is None else None # Retain tensor conversion. logits = x if probs is None else None assertions.append( assert_util.assert_rank_at_least(x, 1, message=msg)) if not validate_args: assert not assertions # Should never happen. return [] if logits is not None: if is_init != tensor_util.is_mutable(logits): logits = tf.convert_to_tensor(logits) assertions.extend( distribution_util.assert_categorical_event_shape(logits)) if probs is not None: if is_init != tensor_util.is_mutable(probs): probs = tf.convert_to_tensor(probs) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)), message='Argument `probs` must sum to 1.') ]) assertions.extend( distribution_util.assert_categorical_event_shape(probs)) return assertions
def _prob(self, x): if self.validate_args: is_vector_check = assert_util.assert_rank_at_least(x, 1) right_vec_space_check = assert_util.assert_equal( self.event_shape_tensor(), tf.gather(tf.shape(input=x), tf.rank(x) - 1), message= "Argument 'x' not defined in the same space R^k as this distribution" ) with tf.control_dependencies([is_vector_check]): with tf.control_dependencies([right_vec_space_check]): x = tf.identity(x) return tf.cast(tf.reduce_all( input_tensor=tf.abs(x - self.loc) <= self._slack, axis=-1), dtype=self.dtype)
def _assertions(self, x): if not self.validate_args: return [] x_shape = tf.shape(x) is_matrix = assert_util.assert_rank_at_least( x, 2, message='Input must have rank at least 2.') is_square = assert_util.assert_equal( x_shape[-2], x_shape[-1], message='Input must be a square matrix.') diag_part_x = tf.linalg.diag_part(x) is_lower_triangular = assert_util.assert_equal( tf.linalg.band_part(x, 0, -1), # Preserves triu, zeros rest. tf.linalg.diag(diag_part_x), message='Input must be lower triangular.') is_positive_diag = assert_util.assert_positive( diag_part_x, message='Input must have all positive diagonal entries.') return [is_matrix, is_square, is_lower_triangular, is_positive_diag]
def _maybe_assert_valid_concentration(self, concentration, validate_args): """Checks the validity of the concentration parameter.""" if not validate_args: return concentration return distribution_util.with_dependencies([ assert_util.assert_positive( concentration, message="Concentration parameter must be positive."), assert_util.assert_rank_at_least( concentration, 1, message="Concentration parameter must have >=1 dimensions."), assert_util.assert_less( 1, tf.shape(input=concentration)[-1], message="Concentration parameter must have event_size >= 2."), ], concentration)
def _parameter_control_dependencies(self, is_init): """Checks the validity of the concentration parameter.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(self.concentration.dtype): raise TypeError('Argument `concentration` must be float type.') msg = 'Argument `concentration` must have rank at least 1.' ndims = tensorshape_util.rank(self.concentration.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(self.concentration, 1, message=msg)) msg = 'Argument `concentration` must have `event_size` at least 2.' event_size = tf.compat.dimension_value( self.concentration.shape[-1]) if event_size is not None: if event_size < 2: raise ValueError(msg) elif self.validate_args: assertions.append( assert_util.assert_less(1, tf.shape(self.concentration)[-1], message=msg)) if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(self.concentration): assertions.append( assert_util.assert_positive( self.concentration, message='Argument `concentration` must be positive.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] message = 'Rank of `samples` must be at least `event_ndims + 1`.' if is_init: samples_rank = self.samples.shape.rank if samples_rank is not None: if self.samples.shape.rank < self._event_ndims + 1: raise ValueError(message) elif self._validate_args: assertions.append( assert_util.assert_rank_at_least( self._samples, self._event_ndims + 1, message=message)) if not self._validate_args: assert not assertions # Should never happen. return [] return assertions
def _log_prob(self, x): assertions = [] message = 'Input must have at least one dimension.' if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 0: raise ValueError(message) elif self.validate_args: assertions.append( assert_util.assert_rank_at_least(x, 1, message=message)) with tf.control_dependencies(assertions): event_tensors = self._distribution.event_shape_tensor() splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(event_tensors) ] x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape) else: x = tf.nest.map_structure( _reshape_part, x, self._distribution.dtype, self._distribution.event_shape_tensor()) return self._distribution.log_prob(x)
def _assertions(self, x): if not self.validate_args: return [] shape = tf.shape(x) is_matrix = assert_util.assert_rank_at_least( x, 2, message="Input must have rank at least 2.") is_square = assert_util.assert_equal( shape[-2], shape[-1], message="Input must be a square matrix.") above_diagonal = tf.linalg.band_part( tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0, -1) is_lower_triangular = assert_util.assert_equal( above_diagonal, tf.zeros_like(above_diagonal), message="Input must be lower triangular.") # A lower triangular matrix is nonsingular iff all its diagonal entries are # nonzero. diag_part = tf.linalg.diag_part(x) is_nonsingular = assert_util.assert_none_equal( diag_part, tf.zeros_like(diag_part), message="Input must have all diagonal entries nonzero.") return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.linalg.diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) if self.validate_args: is_matrix = assert_util.assert_rank_at_least( x, 2, message="Input must be a (batch of) matrix.") shape = tf.shape(input=x) is_square = assert_util.assert_equal( shape[-2], shape[-1], message="Input must be a (batch of) square matrix.") # Assuming lower-triangular means we only need check diag>0. is_positive_definite = assert_util.assert_positive( diag, message="Input must be positive definite.") x = distribution_util.with_dependencies( [is_matrix, is_square, is_positive_definite], x) # Create a vector equal to: [p, p-1, ..., 2, 1]. if tf.compat.dimension_value(x.shape[-1]) is None: p_int = tf.shape(input=x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = tf.compat.dimension_value(x.shape[-1]) p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze(tf.matmul( tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(input=fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector(tf.equal( tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:]) ], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def __init__(self, loc, atol=None, rtol=None, is_vector=False, validate_args=False, allow_nan_stats=True, parameters=None, name="_BaseDeterministic"): """Initialize a batch of `_BaseDeterministic` distributions. The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` computations, e.g. due to floating-point error. ``` pmf(x; loc) = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), = 0, otherwise. ``` Args: loc: Numeric `Tensor`. The point (or batch of points) on which this distribution is supported. atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The absolute tolerance for comparing closeness to `loc`. Default is `0`. rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable shape. The relative tolerance for comparing closeness to `loc`. Default is `0`. is_vector: Python `bool`. If `True`, this is for `VectorDeterministic`, else `Deterministic`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. parameters: Dict of locals to facilitate copy construction. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If `loc` is a scalar. """ with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, atol, rtol], preferred_dtype=tf.float32) loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype) if is_vector and validate_args: msg = "Argument loc must be at least rank 1." if tensorshape_util.rank(loc.shape) is not None: if tensorshape_util.rank(loc.shape) < 1: raise ValueError(msg) else: loc = distribution_util.with_dependencies([ assert_util.assert_rank_at_least(loc, 1, message=msg) ], loc) self._loc = loc self._atol = _get_tol(atol, self._loc.dtype, validate_args) self._rtol = _get_tol(rtol, self._loc.dtype, validate_args) super(_BaseDeterministic, self).__init__( dtype=self._loc.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._loc, self._atol, self._rtol], name=name) # Avoid using the large broadcast with self.loc if possible. if rtol is None: self._slack = self.atol else: self._slack = self.atol + self.rtol * tf.abs(self.loc)