def poisson_and_mixture_distributions(self): """Returns the Poisson and Mixture distribution parameterized by the quadrature grid and weights.""" loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) quadrature_grid, quadrature_probs = tuple(self._quadrature_fn( loc, scale, self.quadrature_size, self.validate_args)) dt = quadrature_grid.dtype if not dtype_util.base_equal(dt, quadrature_probs.dtype): raise TypeError('Quadrature grid dtype ({}) does not match quadrature ' 'probs dtype ({}).'.format( dtype_util.name(dt), dtype_util.name(quadrature_probs.dtype))) dist = poisson.Poisson( log_rate=quadrature_grid, validate_args=self.validate_args, allow_nan_stats=self.allow_nan_stats) mixture_dist = categorical.Categorical( logits=tf.math.log(quadrature_probs), validate_args=self.validate_args, allow_nan_stats=self.allow_nan_stats) return dist, mixture_dist
def __init__(self, loc, scale, quadrature_size=8, quadrature_fn=quadrature_scheme_lognormal_quantiles, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): """Constructs the PoissonLogNormalQuadratureCompound`. Note: `probs` returned by (optional) `quadrature_fn` are presumed to be either a length-`quadrature_size` vector or a batch of vectors in 1-to-1 correspondence with the returned `grid`. (I.e., broadcasting is only partially supported.) Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. quadrature_size: Python `int` scalar representing the number of quadrature points. quadrature_fn: Python callable taking `loc`, `scale`, `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` representing the LogNormal grid and corresponding normalized weight. normalized) weight. Default value: `quadrature_scheme_lognormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ parameters = dict(locals()) with tf.name_scope(name, values=[loc, scale]) as name: dtype = dtype_util.common_dtype([loc, scale], tf.float32) if loc is not None: loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(scale, dtype=dtype, name="scale") self._quadrature_grid, self._quadrature_probs = tuple( quadrature_fn(loc, scale, quadrature_size, validate_args)) dt = self._quadrature_grid.dtype if dt.base_dtype != self._quadrature_probs.dtype.base_dtype: raise TypeError( "Quadrature grid dtype ({}) does not match quadrature " "probs dtype ({}).".format( dt.name, self._quadrature_probs.dtype.name)) self._distribution = poisson.Poisson( log_rate=self._quadrature_grid, validate_args=validate_args, allow_nan_stats=allow_nan_stats) self._mixture_distribution = categorical.Categorical( logits=tf.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) self._loc = loc self._scale = scale self._quadrature_size = quadrature_size super(PoissonLogNormalQuadratureCompound, self).__init__( dtype=dt, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[loc, scale], name=name)