def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] assertions = [] low = None high = None if is_init != tensor_util.is_ref(self.low): low = tf.convert_to_tensor(self.low) assertions.append( assert_util.assert_finite(low, message='`low` is not finite')) if is_init != tensor_util.is_ref(self.high): high = tf.convert_to_tensor(self.high) assertions.append( assert_util.assert_finite(high, message='`high` is not finite')) if is_init != tensor_util.is_ref(self.loc): assertions.append( assert_util.assert_finite(self.loc, message='`loc` is not finite')) if is_init != tensor_util.is_ref(self.scale): scale = tf.convert_to_tensor(self.scale) assertions.extend([ assert_util.assert_positive( scale, message='`scale` must be positive'), assert_util.assert_finite(scale, message='`scale` is not finite'), ]) if (is_init != tensor_util.is_ref(self.low) or is_init != tensor_util.is_ref(self.high)): low = tf.convert_to_tensor(self.low) if low is None else low high = tf.convert_to_tensor(self.high) if high is None else high assertions.append( assert_util.assert_greater( high, low, message='TruncatedCauchy not defined when `low >= high`.')) return assertions
def _parameter_control_dependencies(self, is_init): assertions = super(Wishart, self)._parameter_control_dependencies(is_init) if not self.validate_args: assert not assertions return [] if self._scale_full is None: if is_init != tensor_util.is_ref(self._scale_tril): shape = prefer_static.shape(self._scale_tril) assertions.extend( [assert_util.assert_positive( tf.linalg.diag_part(self._scale_tril), message='`scale_tril` must be positive definite.'), assert_util.assert_equal( shape[-1], shape[-2], message='`scale_tril` must be square.')] ) else: if is_init != tensor_util.is_ref(self._scale_full): assertions.append(distribution_util.assert_symmetric(self._scale_full)) return assertions
def _parameter_control_dependencies(self, is_init): assertions = [] if is_init: try: self._batch_shape() except ValueError: raise ValueError( 'Arguments `loc` and `scale` must have compatible shapes; ' 'loc.shape={}, scale.shape={}.'.format( self.loc.shape, self.scale.shape)) # We don't bother checking the shapes in the dynamic case because # all member functions access both arguments anyway. if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(self.scale): assertions.append( assert_util.assert_positive( self.scale, message='Argument `scale` must be positive.')) return assertions
def __init__(self, df, loc, scale, validate_args=False, allow_nan_stats=True, name="MultivariateStudentTLinearOperator"): """Construct Multivariate Student's t-distribution on `R^k`. The `batch_shape` is the broadcast shape between `df`, `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` must broadcast with this. Additional leading dimensions (if any) will index batches. Args: df: A positive floating-point `Tensor`. Has shape `[B1, ..., Bb]` where `b >= 0`. loc: Floating-point `Tensor`. Has shape `[B1, ..., Bb, k]` where `k` is the event size. scale: Instance of `LinearOperator` with a floating `dtype` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/variance/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: if not `scale.dtype.is_floating`. ValueError: if not `scale.is_positive_definite`. """ parameters = dict(locals()) if not dtype_util.is_floating(scale.dtype): raise TypeError("`scale` must have floating-point dtype.") if validate_args and not scale.is_positive_definite: raise ValueError("`scale` must be positive definite.") with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([df, loc, scale], dtype_hint=tf.float32) with tf.control_dependencies([ assert_util.assert_positive(df, message="`df` must be positive.") ] if validate_args else []): self._df = tf.identity(tf.convert_to_tensor(df, dtype=dtype), name="df") self._loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype) self._scale = scale super(MultivariateStudentTLinearOperator, self).__init__( dtype=dtype, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._df, self._loc] + self._scale.graph_parents, name=name, validate_args=validate_args, allow_nan_stats=allow_nan_stats) self._parameters = parameters
def __init__(self, rate=None, log_rate=None, interpolate_nondiscrete=True, validate_args=False, allow_nan_stats=True, name="Poisson"): """Initialize a batch of Poisson distributions. Args: rate: Floating point tensor, the rate parameter. `rate` must be positive. Must specify exactly one of `rate` and `log_rate`. log_rate: Floating point tensor, the log of the rate parameter. Must specify exactly one of `rate` and `log_rate`. interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns `-inf` (and `prob` returns `0`) for non-integer inputs. When `True`, `log_prob` evaluates the continuous function `k * log_rate - lgamma(k+1) - rate`, which matches the Poisson pmf at integer arguments `k` (note that this function is not itself a normalized probability log-density). Default value: `True`. validate_args: Python `bool`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if none or both of `rate`, `log_rate` are specified. TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ parameters = dict(locals()) with tf.compat.v2.name_scope(name) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") elif log_rate is None: rate = tf.convert_to_tensor( value=rate, name="rate", dtype=dtype_util.common_dtype([rate], preferred_dtype=tf.float32)) if not rate.dtype.is_floating: raise TypeError("rate.dtype ({}) is a not a float-type.".format( rate.dtype.name)) with tf.control_dependencies( [assert_util.assert_positive(rate)] if validate_args else []): self._rate = tf.identity(rate, name="rate") self._log_rate = tf.math.log(rate, name="log_rate") else: log_rate = tf.convert_to_tensor( value=log_rate, name="log_rate", dtype=dtype_util.common_dtype([log_rate], tf.float32)) if not log_rate.dtype.is_floating: raise TypeError("log_rate.dtype ({}) is a not a float-type.".format( log_rate.dtype.name)) self._rate = tf.exp(log_rate, name="rate") self._log_rate = tf.convert_to_tensor(value=log_rate, name="log_rate") self._interpolate_nondiscrete = interpolate_nondiscrete super(Poisson, self).__init__( dtype=self._rate.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._rate], name=name)
def __init__(self, concentration, mixing_concentration, mixing_rate, validate_args=False, allow_nan_stats=True, name="GammaGamma"): """Initializes a batch of Gamma-Gamma distributions. The parameters `concentration` and `rate` must be shaped in a way that supports broadcasting (e.g. `concentration + mixing_concentration + mixing_rate` is a valid operation). Args: concentration: Floating point tensor, the concentration params of the distribution(s). Must contain only positive values. mixing_concentration: Floating point tensor, the concentration params of the mixing Gamma distribution(s). Must contain only positive values. mixing_rate: Floating point tensor, the rate params of the mixing Gamma distribution(s). Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `concentration` and `rate` are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name): dtype = dtype_util.common_dtype( [concentration, mixing_concentration, mixing_rate], preferred_dtype=tf.float32) concentration = tf.convert_to_tensor(value=concentration, name="concentration", dtype=dtype) mixing_concentration = tf.convert_to_tensor( value=mixing_concentration, name="mixing_concentration", dtype=dtype) mixing_rate = tf.convert_to_tensor(value=mixing_rate, name="mixing_rate", dtype=dtype) with tf.control_dependencies([ assert_util.assert_positive(concentration), assert_util.assert_positive(mixing_concentration), assert_util.assert_positive(mixing_rate), ] if validate_args else []): self._concentration = tf.identity(concentration, name="concentration") self._mixing_concentration = tf.identity( mixing_concentration, name="mixing_concentration") self._mixing_rate = tf.identity(mixing_rate, name="mixing_rate") tf.debugging.assert_same_float_dtype([ self._concentration, self._mixing_concentration, self._mixing_rate ]) super(GammaGamma, self).__init__( dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[ self._concentration, self._mixing_concentration, self._mixing_rate ], name=name)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32) self._temperature = tf.convert_to_tensor(value=temperature, name="temperature", dtype=dtype) if validate_args: with tf.control_dependencies( [assert_util.assert_positive(temperature)]): self._temperature = tf.identity(self._temperature) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, dtype=dtype) super(RelaxedBernoulli, self).__init__( distribution=logistic.Logistic(self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name + "/Logistic"), bijector=sigmoid_bijector.Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters
def _maybe_assert_valid_sample(self, x): if not self.validate_args: return [] return [ assert_util.assert_positive(x, message='Sample must be positive.') ]
def _maybe_assert_valid_sample(self, x): dtype_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype) if not self.validate_args: return x with tf.control_dependencies([assert_util.assert_positive(x)]): return tf.identity(x)
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name="Gumbel"): """Construct Gumbel distributions with location and scale `loc` and `scale`. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g. `loc + scale` is a valid operation). Args: loc: Floating point tensor, the means of the distribution(s). scale: Floating point tensor, the scales of the distribution(s). scale must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: `'Gumbel'`. Raises: TypeError: if loc and scale are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], preferred_dtype=tf.float32) loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype) scale = tf.convert_to_tensor(value=scale, name="scale", dtype=dtype) with tf.control_dependencies( [assert_util.assert_positive(scale)] if validate_args else []): loc = tf.identity(loc, name="loc") scale = tf.identity(scale, name="scale") tf.debugging.assert_same_float_dtype([loc, scale]) self._gumbel_bijector = gumbel_bijector.Gumbel( loc=loc, scale=scale, validate_args=validate_args) # Because the uniform sampler generates samples in `[0, 1)` this would # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny` # because it is the smallest, positive, "normal" number. super(Gumbel, self).__init__( distribution=uniform.Uniform(low=np.finfo( dtype_util.as_numpy_dtype(dtype)).tiny, high=tf.ones([], dtype=loc.dtype), allow_nan_stats=allow_nan_stats), # The Gumbel bijector encodes the quantile # function as the forward, and hence needs to # be inverted. bijector=invert_bijector.Invert(self._gumbel_bijector), batch_shape=distribution_util.get_broadcast_shape(loc, scale), parameters=parameters, name=name)
def __init__(self, total_count, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='NegativeBinomial'): """Construct NegativeBinomial distributions. Args: total_count: Non-negative floating-point `Tensor` with shape broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm` different Negative Binomial distributions. In practice, this represents the number of negative Bernoulli trials to stop at (the `total_count` of failures), but this is still a valid distribution when `total_count` is a non-integer. logits: Floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents logits for the probability of success for independent Negative Binomial distributions and must be in the open interval `(-inf, inf)`. Only one of `logits` or `probs` should be specified. probs: Positive floating-point `Tensor` with shape broadcastable to `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions. Each entry represents the probability of success for independent Negative Binomial distributions and must be in the open interval `(0, 1)`. Only one of `logits` or `probs` should be specified. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([total_count, logits, probs], dtype_hint=tf.float32) self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name, dtype=dtype) total_count = tf.convert_to_tensor(value=total_count, name='total_count', dtype=dtype) with tf.control_dependencies( [assert_util.assert_positive(total_count )] if validate_args else []): self._total_count = tf.identity(total_count, name='total_count') super(NegativeBinomial, self).__init__( dtype=self._probs.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._total_count, self._probs, self._logits], name=name)
def __init__(self, df, scale=None, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name="Wishart"): """Construct Wishart distributions. Args: df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than or equal to dimension of the scale matrix. scale: `float` or `double` `Tensor`. The symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. scale_tril: `float` or `double` `Tensor`. The Cholesky factorization of the symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if zero or both of 'scale' and 'scale_tril' are passed in. """ parameters = dict(locals()) with tf.name_scope(name) as name: with tf.name_scope("init"): if (scale is None) == (scale_tril is None): raise ValueError( "Must pass scale or scale_tril, but not both.") dtype = dtype_util.common_dtype([df, scale, scale_tril], tf.float32) df = tf.convert_to_tensor(value=df, name="df", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(value=scale, name="scale", dtype=dtype) if validate_args: scale = distribution_util.assert_symmetric(scale) scale_tril = tf.linalg.cholesky(scale) else: # scale_tril is not None scale_tril = tf.convert_to_tensor(value=scale_tril, name="scale_tril", dtype=dtype) if validate_args: scale_tril = distribution_util.with_dependencies([ assert_util.assert_positive( tf.linalg.diag_part(scale_tril), message="scale_tril must be positive definite" ), assert_util.assert_equal( tf.shape(input=scale_tril)[-1], tf.shape(input=scale_tril)[-2], message="scale_tril must be square") ], scale_tril) super(Wishart, self).__init__( df=df, scale_operator=tf.linalg.LinearOperatorLowerTriangular( tril=scale_tril, is_non_singular=True, is_positive_definite=True, is_square=True), input_output_cholesky=input_output_cholesky, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: # Y = X X.T # where X is lower-triangular. # # Observe that, # dY[i,j]/dX[a,b] # = d/dX[a,b] { X[i,:] X[j,:] } # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } # # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is # symmetric and X is lower-triangular, we need vectors of dimension: # d = p (p + 1) / 2 # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., # k = { i (i + 1) / 2 + j i>=j # { undef i<j # and assume zero-based indexes. When k is undef, the element is dropped. # Example: # j k # 0 1 2 3 / # 0 [ 0 . . . ] # i 1 [ 1 2 . . ] # 2 [ 3 4 5 . ] # 3 [ 6 7 8 9 ] # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With # slight abuse: k(i,j)=undef means the element is dropped.) # # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b. # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since: # (1) j<=i<a thus i,j!=a. # (2) i=a>j thus i,j!=a. # # Since the Jacobian is lower-triangular, we need only compute the product # of diagonal elements: # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] # = X[j,j] + I[i=j] X[i,j] # = 2 X[j,j]. # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. diag = tf.linalg.diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the # output is unchanged. diag = self._make_columnar(diag) if self.validate_args: is_matrix = assert_util.assert_rank_at_least( x, 2, message="Input must be a (batch of) matrix.") shape = tf.shape(input=x) is_square = assert_util.assert_equal( shape[-2], shape[-1], message="Input must be a (batch of) square matrix.") # Assuming lower-triangular means we only need check diag>0. is_positive_definite = assert_util.assert_positive( diag, message="Input must be positive definite.") x = distribution_util.with_dependencies( [is_matrix, is_square, is_positive_definite], x) # Create a vector equal to: [p, p-1, ..., 2, 1]. if tf.compat.dimension_value(x.shape[-1]) is None: p_int = tf.shape(input=x)[-1] p_float = tf.cast(p_int, dtype=x.dtype) else: p_int = tf.compat.dimension_value(x.shape[-1]) p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int) exponents = tf.linspace(p_float, 1., p_int) sum_weighted_log_diag = tf.squeeze(tf.matmul( tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag # We finally need to undo adding an extra column in non-scalar cases # where there is a single matrix as input. if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 2: fldj = tf.squeeze(fldj, axis=-1) return fldj shape = tf.shape(input=fldj) maybe_squeeze_shape = tf.concat([ shape[:-1], distribution_util.pick_vector(tf.equal( tf.rank(x), 2), np.array([], dtype=np.int32), shape[-1:]) ], 0) return tf.reshape(fldj, maybe_squeeze_shape)
def _assertions(self, t): if not self.validate_args: return [] return [assert_util.assert_positive( t[..., 1:] - t[..., :-1], message='Forward transformation input must be strictly increasing.')]
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1:], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(self.temperature): assertions.append(assert_util.assert_positive(self.temperature)) if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def __init__( self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='ExpRelaxedOneHotCategorical'): """Initialize ExpRelaxedOneHotCategorical using class log-probabilities. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of ExpRelaxedCategorical distributions. The temperature should be positive. logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of ExpRelaxedCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each class. Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set of ExpRelaxedCategorical distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32) self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, multidimensional=True, dtype=dtype) with tf.control_dependencies( [assert_util.assert_positive(temperature)] if validate_args else []): self._temperature = tf.convert_to_tensor( temperature, name='temperature', dtype=dtype) self._temperature_2d = tf.reshape( self._temperature, [-1, 1], name='temperature_2d') logits_shape_static = tensorshape_util.with_rank_at_least( self._logits.shape, 1) if tensorshape_util.rank(logits_shape_static) is not None: self._batch_rank = tf.convert_to_tensor( tensorshape_util.rank(logits_shape_static) - 1, dtype=tf.int32, name='batch_rank') else: with tf.name_scope('batch_rank'): self._batch_rank = tf.rank(self._logits) - 1 with tf.name_scope('event_size'): self._event_size = tf.shape(self._logits)[-1] super(ExpRelaxedOneHotCategorical, self).__init__( dtype=dtype, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[self._logits, self._probs, self._temperature], name=name)
def _parameter_control_dependencies(self, is_init): """Validate parameters.""" bw, bh, kd = None, None, None try: shape = tf.broadcast_static_shape(self.bin_widths.shape, self.bin_heights.shape) except ValueError as e: raise ValueError('`bin_widths`, `bin_heights` must broadcast: {}'.format( str(e))) bin_sizes_shape = shape try: shape = tf.broadcast_static_shape(shape[:-1], self.knot_slopes.shape[:-1]) except ValueError as e: raise ValueError( '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on ' 'batch axes: {}'.format(str(e))) assertions = [] if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:]) and tensorshape_util.is_fully_defined(self.knot_slopes.shape[-1:])): if tensorshape_util.rank(self.knot_slopes.shape) > 0: num_interior_knots = tensorshape_util.dims(bin_sizes_shape)[-1] - 1 if tensorshape_util.dims( self.knot_slopes.shape)[-1] not in (1, num_interior_knots): raise ValueError( 'Innermost axis of non-scalar `knot_slopes` must broadcast with ' '{}; got {}.'.format(num_interior_knots, self.knot_slopes.shape)) elif self.validate_args: if is_init != any( tensor_util.is_ref(t) for t in (self.bin_widths, self.bin_heights, self.knot_slopes)): bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd shape = tf.broadcast_dynamic_shape( tf.shape((bw + bh)[..., :-1]), tf.shape(kd)) assertions.append( assert_util.assert_greater( tf.shape(shape)[0], tf.zeros([], dtype=shape.dtype), message='`(bin_widths + bin_heights)[..., :-1]` must broadcast ' 'with `knot_slopes` to at least 1-D.')) if not self.validate_args: assert not assertions return assertions if (is_init != tensor_util.is_ref(self.bin_widths) or is_init != tensor_util.is_ref(self.bin_heights)): bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh assertions += [ assert_util.assert_near( tf.reduce_sum(bw, axis=-1), tf.reduce_sum(bh, axis=-1), message='`sum(bin_widths, axis=-1)` must equal ' '`sum(bin_heights, axis=-1)`.'), ] if is_init != tensor_util.is_ref(self.bin_widths): bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw assertions += [ assert_util.assert_positive( bw, message='`bin_widths` must be positive.'), ] if is_init != tensor_util.is_ref(self.bin_heights): bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh assertions += [ assert_util.assert_positive( bh, message='`bin_heights` must be positive.'), ] if is_init != tensor_util.is_ref(self.knot_slopes): kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd assertions += [ assert_util.assert_positive( kd, message='`knot_slopes` must be positive.'), ] return assertions
def _maybe_assert_valid_y(self, y): if not self.validate_args: return y is_valid = assert_util.assert_positive( y, message="Inverse transformation input must be greater than 0.") return distribution_util.with_dependencies([is_valid], y)
def _maybe_assert_valid_y(self, y): if not self.validate_args: return [] return [assert_util.assert_positive( y, message='Inverse transformation input must be greater than 0.')]