def __init__(self, bijectors, block_sizes=None, validate_args=False, name=None): """Creates the bijector. Args: bijectors: A non-empty list of bijectors. block_sizes: A 1-D integer `Tensor` with each element signifying the length of the block of the input vector to pass to the corresponding bijector. The length of `block_sizes` must be be equal to the length of `bijectors`. If left as None, a vector of 1's is used. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str`, name given to ops managed by this object. Default: E.g., `Blockwise([Exp(), Softplus()]).name == 'blockwise_of_exp_and_softplus'`. Raises: NotImplementedError: If a bijector with `event_ndims` > 1 or one that reshapes events is passed. ValueError: If `bijectors` list is empty. ValueError: If size of `block_sizes` does not equal to the length of bijectors or is not a vector. """ if not name: name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors]) name = name.replace('/', '') with tf.name_scope(name) as name: super(Blockwise, self).__init__(forward_min_event_ndims=1, validate_args=validate_args, name=name) if not bijectors: raise ValueError('`bijectors` must not be empty.') for bijector in bijectors: if (bijector.forward_min_event_ndims > 1 or (bijector.inverse_min_event_ndims != bijector.forward_min_event_ndims)): # TODO(siege): In the future, it can be reasonable to support N-D # bijectors by concatenating along some specific axis, broadcasting # low-D bijectors appropriately. raise NotImplementedError( 'Only scalar and vector event-shape ' 'bijectors that do not alter the ' 'shape are supported at this time.') self._bijectors = bijectors if block_sizes is None: block_sizes = tf.ones(len(bijectors), dtype=tf.int32) self._block_sizes = tf.convert_to_tensor(block_sizes, name='block_sizes', dtype_hint=tf.int32) self._block_sizes = _validate_block_sizes(self._block_sizes, bijectors, validate_args)
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = tf.shape(input=param) insert_ones = tf.ones( [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)], dtype=param_shape.dtype) new_param_shape = tf.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = tf.where(is_broadcast, 0, start) if stop is not None: stop = tf.where(is_broadcast, 1, stop) if step is not None: step = tf.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(tf.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(param_slices)
def _mode(self): concentration1 = tf.convert_to_tensor(self.concentration1) concentration0 = tf.convert_to_tensor(self.concentration0) mode = (concentration1 - 1.) / (concentration1 + concentration0 - 2.) with tf.control_dependencies([] if self.allow_nan_stats else [ # pylint: disable=g-long-ternary assert_util. assert_less(tf.ones([], dtype=self.dtype), concentration1, message="Mode undefined for concentration1 <= 1."), assert_util. assert_less(tf.ones([], dtype=self.dtype), concentration0, message="Mode undefined for concentration0 <= 1.") ]): return tf.where((concentration1 > 1.) & (concentration0 > 1.), mode, dtype_util.as_numpy_dtype(self.dtype)(np.nan))
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name='Gumbel'): """Construct Gumbel distributions with location and scale `loc` and `scale`. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g. `loc + scale` is a valid operation). Args: loc: Floating point tensor, the means of the distribution(s). scale: Floating point tensor, the scales of the distribution(s). scale must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: `'Gumbel'`. Raises: TypeError: if loc and scale are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) loc = tensor_util.convert_nonref_to_tensor( loc, name='loc', dtype=dtype) scale = tensor_util.convert_nonref_to_tensor( scale, name='scale', dtype=dtype) dtype_util.assert_same_float_dtype([loc, scale]) # Positive scale is asserted by the incorporated Gumbel bijector. self._gumbel_bijector = gumbel_bijector.Gumbel( loc=loc, scale=scale, validate_args=validate_args) # Because the uniform sampler generates samples in `[0, 1)` this would # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny` # because it is the smallest, positive, 'normal' number. super(Gumbel, self).__init__( distribution=uniform.Uniform( low=np.finfo(dtype_util.as_numpy_dtype(dtype)).tiny, high=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats), # The Gumbel bijector encodes the quantile function as the forward, # and hence needs to be inverted. bijector=invert_bijector.Invert(self._gumbel_bijector), batch_shape=distribution_util.get_broadcast_shape(loc, scale), parameters=parameters, name=name)
def _maybe_assert_valid_sample(self, x, dtype): if not self.validate_args: return x one = tf.ones([], dtype=dtype) return distribution_util.with_dependencies([ assert_util.assert_non_negative(x), assert_util.assert_less_equal(x, one), assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])), ], x)
def _entropy(self): df = tf.convert_to_tensor(self.df) scale = tf.convert_to_tensor(self.scale) v = tf.ones(self._batch_shape_tensor(df=df, scale=scale), dtype=self.dtype)[..., tf.newaxis] u = v * df[..., tf.newaxis] beta_arg = tf.concat([u, v], -1) / 2. return (tf.math.log(tf.abs(scale)) + 0.5 * tf.math.log(df) + tf.math.lbeta(beta_arg) + 0.5 * (df + 1.) * (tf.math.digamma(0.5 * (df + 1.)) - tf.math.digamma(0.5 * df)))
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(x), self._batch_shape_tensor(low=low, high=high)) zeros = tf.zeros(broadcast_shape, dtype=self.dtype) ones = tf.ones(broadcast_shape, dtype=self.dtype) result_if_not_big = tf.where(x < low, zeros, (x - low) / self._range(low=low, high=high)) return tf.where(x >= high, ones, result_if_not_big)
def _maybe_assert_valid(self, x): if not self.validate_args: return x return distribution_util.with_dependencies([ assert_util.assert_non_negative( x, message="sample must be non-negative"), assert_util.assert_less_equal( x, tf.ones([], self.concentration0.dtype), message="sample must be no larger than `1`."), ], x)
def _maybe_assert_valid_sample(self, x): """Checks the validity of a sample.""" if not self.validate_args: return [] return [ assert_util.assert_positive(x, message='samples must be positive'), assert_util.assert_near( tf.ones([], dtype=self.dtype), tf.reduce_sum(x, axis=-1), message='sample last-dimension must sum to `1`'), ]
def softplus_inverse(x, name=None): """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). Mathematically this op is equivalent to: ```none softplus_inverse = log(exp(x) - 1.) ``` Args: x: `Tensor`. Non-negative (not enforced), floating-point. name: A name for the operation (optional). Returns: `Tensor`. Has the same type/shape as input `x`. """ with tf.name_scope(name or 'softplus_inverse'): x = tf.convert_to_tensor(x, name='x') # We begin by deriving a more numerically stable softplus_inverse: # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). # ==> exp{x} = 1 + exp{y} (1) # ==> y = Log[exp{x} - 1] (2) # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] # = Log[1 - exp{-x}] + x (3) # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. # # In addition to the numerically stable derivation above, we clamp # small/large values to be congruent with the logic in: # tensorflow/core/kernels/softplus_op.h # # Finally, we set the input to one whenever the input is too large or too # small. This ensures that no unchosen codepath is +/- inf. This is # necessary to ensure the gradient doesn't get NaNs. Recall that the # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful # to overwrite `x` with ones only when we will never actually use this # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. threshold = np.log(np.finfo(dtype_util.as_numpy_dtype( x.dtype)).eps) + 2. is_too_small = x < np.exp(threshold) is_too_large = x > -threshold too_small_value = tf.math.log(x) too_large_value = x # This `where` will ultimately be a NOP because we won't select this # codepath whenever we used the surrogate `ones_like`. x = tf.where(is_too_small | is_too_large, tf.ones([], x.dtype), x) y = x + tf.math.log(-tf.math.expm1(-x)) # == log(expm1(x)) return tf.where(is_too_small, too_small_value, tf.where(is_too_large, too_large_value, y))
def _mean(self): # Let # W = (w1,...,wk), with wj ~ iid Exponential(0, 1). # Then this distribution is # X = loc + LW, # and then E[X] = loc + L1, where 1 is the vector of ones. scale_x_ones = self.bijector.scale.matvec( tf.ones(self._mode_mean_shape(), self.dtype)) if self.loc is None: return scale_x_ones return tf.identity(self.loc) + scale_x_ones
def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) return quantiles
def _variance(self): df = tf.convert_to_tensor(self.df) scale = tf.convert_to_tensor(self.scale) # We need to put the tf.where inside the outer tf.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) # Abs(scale) superfluous. var = (tf.ones(self._batch_shape_tensor(df=df, scale=scale), dtype=self.dtype) * tf.square(scale) * df / denom) # When 1 < df <= 2, variance is infinite. result_where_defined = tf.where( df > 2., var, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where(df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: return distribution_util.with_dependencies([ assert_util.assert_less( tf.ones([], dtype=self.dtype), df, message='variance not defined for components of df <= 1'), ], result_where_defined)
def _mode(self): df = tf.convert_to_tensor(self.df) mode = df - 2. if self.allow_nan_stats: assertions = [] else: assertions = [ assert_util.assert_less( 2. * tf.ones([], self.dtype), df, message='Mode not defined when df <= 2.') ] with tf.control_dependencies(assertions): return tf.where(df > 2., mode, dtype_util.as_numpy_dtype(self.dtype)(np.nan))
def _expand_mix_distribution_probs(self): p = self.mixture_distribution.probs_parameter() # [B, deg] deg = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(p.shape, 1)[-1]) if deg is None: deg = tf.shape(p)[-1] event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_ndims = tf.shape(self.event_shape_tensor())[0] expand_shape = tf.concat([ self.mixture_distribution.batch_shape_tensor(), tf.ones([event_ndims], dtype=tf.int32), [deg], ], axis=0) return tf.reshape(p, shape=expand_shape)
def _mean(self): concentration = tf.convert_to_tensor(self.concentration) lim = tf.ones([], dtype=self.dtype) valid = concentration < lim safe_conc = tf.where(valid, concentration, tf.constant(.5, self.dtype)) result = lambda: self.loc + self.scale / (1 - safe_conc) if self.allow_nan_stats: return tf.where(valid, result(), tf.constant(float('nan'), self.dtype)) with tf.control_dependencies([ assert_util.assert_less( concentration, lim, message='`mean` is undefined when `concentration >= 1`') ]): return result()
def _broadcast_event_and_samples(event, samples, event_ndims): """Broadcasts the event or samples.""" # This is the shape of self.samples, without the samples axis, i.e. the shape # of the result of a call to dist.sample(). This way we can broadcast it with # event to get a properly-sized event, then add the singleton dim back at # -event_ndims - 1. samples_shape = tf.concat([ tf.shape(samples)[:-event_ndims - 1], tf.shape(samples)[tf.rank(samples) - event_ndims:] ], axis=0) event = event * tf.ones(samples_shape, dtype=event.dtype) event = tf.expand_dims(event, axis=-event_ndims - 1) samples = samples * tf.ones_like(event, dtype=samples.dtype) return event, samples
def _mode(self): concentration = tf.convert_to_tensor(self.concentration) rate = tf.convert_to_tensor(self.rate) mode = (concentration - 1.) / rate if self.allow_nan_stats: assertions = [] else: assertions = [ assert_util.assert_less( tf.ones([], self.dtype), concentration, message='Mode not defined when any concentration <= 1.') ] with tf.control_dependencies(assertions): return tf.where(concentration > 1., mode, dtype_util.as_numpy_dtype(self.dtype)(np.nan))
def _mean(self): concentration = tf.convert_to_tensor(self.concentration) scale = tf.convert_to_tensor(self.scale) mean = scale / (concentration - 1.) if self.allow_nan_stats: assertions = [] else: assertions = [ assert_util.assert_less( tf.ones([], self.dtype), concentration, message='mean undefined when any concentration <= 1') ] with tf.control_dependencies(assertions): return tf.where(concentration > 1., mean, dtype_util.as_numpy_dtype(self.dtype)(np.nan))
def _mean(self): concentration = tf.convert_to_tensor(self.concentration) mixing_concentration = tf.convert_to_tensor(self.mixing_concentration) mixing_rate = tf.convert_to_tensor(self.mixing_rate) mean = concentration * mixing_rate / (mixing_concentration - 1.) if self.allow_nan_stats: return tf.where(mixing_concentration > 1., mean, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.ones([], self.dtype), mixing_concentration, message= 'mean undefined when `mixing_concentration` <= 1'), ]): return tf.identity(mean)
def _mode(self): concentration = tf.convert_to_tensor(self.concentration) k = tf.cast(tf.shape(concentration)[-1], self.dtype) total_concentration = tf.reduce_sum(concentration, axis=-1) mode = (concentration - 1.) / (total_concentration[..., tf.newaxis] - k) if self.allow_nan_stats: return tf.where( tf.reduce_all(concentration > 1., axis=-1, keepdims=True), mode, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) assertions = [ assert_util.assert_less( tf.ones([], self.dtype), concentration, message='Mode undefined when any concentration <= 1') ] with tf.control_dependencies(assertions): return tf.identity(mode)
def _variance(self): concentration = tf.convert_to_tensor(self.concentration) mixing_concentration = tf.convert_to_tensor(self.mixing_concentration) mixing_rate = tf.convert_to_tensor(self.mixing_rate) variance = (tf.square(concentration * mixing_rate / (mixing_concentration - 1.)) / (mixing_concentration - 2.)) if self.allow_nan_stats: return tf.where(mixing_concentration > 2., variance, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.ones([], self.dtype) * 2., mixing_concentration, message= 'variance undefined when `mixing_concentration` <= 2') ]): return tf.identity(variance)
def __init__(self, scale, validate_args=False, allow_nan_stats=True, name='Horseshoe'): """Construct a Horseshoe distribution with `scale`. Args: scale: Floating point tensor; the scales of the distribution(s). Must contain only positive values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False` (i.e., do not validate args). allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: 'Horseshoe'. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([scale], dtype_hint=tf.float32) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) self._half_cauchy = half_cauchy.HalfCauchy( loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=True) super(Horseshoe, self).__init__(dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _forward(self, x): x = tf.convert_to_tensor(x, name='x') batch_shape = prefer_static.shape(x)[:-1] # Pad zeros on the top row and right column. y = fill_triangular.FillTriangular().forward(x) rank = prefer_static.rank(y) paddings = tf.concat([ tf.zeros(shape=(rank - 2, 2), dtype=tf.int32), tf.constant([[1, 0], [0, 1]], dtype=tf.int32) ], axis=0) y = tf.pad(y, paddings) # Set diagonal to 1s. n = prefer_static.shape(y)[-1] diag = tf.ones(tf.concat([batch_shape, [n]], axis=-1), dtype=x.dtype) y = tf.linalg.set_diag(y, diag) # Normalize each row to have Euclidean (L2) norm 1. y /= tf.norm(y, axis=-1)[..., tf.newaxis] return y
def _sample_n(self, n, seed=None): # The sampling method comes from the fact that if: # X ~ Normal(0, 1) # Z ~ Chi2(df) # Y = X / sqrt(Z / df) # then: # Y ~ StudentT(df). df = tf.convert_to_tensor(self.df) loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) batch_shape = self._batch_shape_tensor(df=df, loc=loc, scale=scale) shape = tf.concat([[n], batch_shape], 0) seed = SeedStream(seed, 'student_t') normal_sample = tf.random.normal(shape, dtype=self.dtype, seed=seed()) df = df * tf.ones(batch_shape, dtype=self.dtype) gamma_sample = tf.random.gamma([n], 0.5 * df, beta=0.5, dtype=self.dtype, seed=seed()) samples = normal_sample * tf.math.rsqrt(gamma_sample / df) return samples * scale + loc # Abs(scale) not wanted.
def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" if dtype_util.is_integer(event.dtype): pass elif dtype_util.is_floating(event.dtype): # When `validate_args=True` we've already ensured int/float casting # is closed. event = tf.cast(event, dtype=tf.int32) else: raise TypeError('`value` should have integer `dtype` or ' '`self.dtype` ({})'.format(base_dtype)) shape_known_statically = ( tensorshape_util.rank(params.shape) is not None and tensorshape_util.is_fully_defined(params.shape[:-1]) and tensorshape_util.is_fully_defined(event.shape)) if not shape_known_statically or params.shape[:-1] != event.shape: params = params * tf.ones_like(event[..., tf.newaxis], dtype=params.dtype) params_shape = tf.shape(params)[:-1] event = event * tf.ones(params_shape, dtype=event.dtype) if tensorshape_util.rank(params.shape) is not None: tensorshape_util.set_shape(event, params.shape[:-1]) return event, params
def _replicate(n, tensor): """Replicate the input tensor n times along a new (major) dimension.""" # TODO(axch) Does this already exist somewhere? Should it get contributed? multiples = tf.concat([[n], tf.ones([tf.rank(tensor)], dtype=n.dtype)], axis=0) return tf.tile(tensor[tf.newaxis], multiples)
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) with tf.control_dependencies(assertions): param = tf.identity(param) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def _ones_like(input, dtype=None, name=None): # pylint: disable=redefined-builtin s = _shape(input) s_ = tf.get_static_value(s) if s_ is not None: return np.ones(s_, dtype_util.as_numpy_dtype(dtype or input.dtype)) return tf.ones(s, dtype or s.dtype, name)
def _mean_of_covariance_given_quadrature_component(self, diag_only): p = self.mixture_distribution.probs_parameter() # To compute E[Cov(Z|V)], we'll add matrices within three categories: # scaled-identity, diagonal, and full. Then we'll combine these at the end. scale_identity_multiplier = None diag = None full = None for k, aff in enumerate(self.interpolated_affine): s = aff.scale # Just in case aff.scale has side-effects, we'll call once. if (s is None or isinstance(s, tf.linalg.LinearOperatorIdentity)): scale_identity_multiplier = add(scale_identity_multiplier, p[..., k, tf.newaxis]) elif isinstance(s, tf.linalg.LinearOperatorScaledIdentity): scale_identity_multiplier = add( scale_identity_multiplier, (p[..., k, tf.newaxis] * tf.square(s.multiplier))) elif isinstance(s, tf.linalg.LinearOperatorDiag): diag = add(diag, (p[..., k, tf.newaxis] * tf.square(s.diag_part()))) else: x = (p[..., k, tf.newaxis, tf.newaxis] * s.matmul(s.to_dense(), adjoint_arg=True)) if diag_only: x = tf.linalg.diag_part(x) full = add(full, x) # We must now account for the fact that the base distribution might have a # non-unity variance. Recall that, since X ~ iid Law(X_0), # `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`. # We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid # samples from a scalar-event distribution. v = self.distribution.variance() if scale_identity_multiplier is not None: scale_identity_multiplier = scale_identity_multiplier * v if diag is not None: diag = diag * v[..., tf.newaxis] if full is not None: full = full * v[..., tf.newaxis] if diag_only: # Apparently we don't need the full matrix, just the diagonal. r = add(diag, full) if r is None and scale_identity_multiplier is not None: ones = tf.ones(self.event_shape_tensor(), dtype=self.dtype) return scale_identity_multiplier[..., tf.newaxis] * ones return add(r, scale_identity_multiplier) # `None` indicates we don't know if the result is positive-definite. is_positive_definite = (True if all( aff.scale.is_positive_definite for aff in self.endpoint_affine) else None) to_add = [] if diag is not None: to_add.append( tf.linalg.LinearOperatorDiag( diag=diag, is_positive_definite=is_positive_definite)) if full is not None: to_add.append( tf.linalg.LinearOperatorFullMatrix( matrix=full, is_positive_definite=is_positive_definite)) if scale_identity_multiplier is not None: to_add.append( tf.linalg.LinearOperatorScaledIdentity( num_rows=self.event_shape_tensor()[0], multiplier=scale_identity_multiplier, is_positive_definite=is_positive_definite)) return (linop_add_lib.add_operators(to_add)[0].to_dense() if to_add else None)