def _mean(self, **kwargs): params = self._params if self.n_channels == 1: component_logits, locs, scales = params else: # r ~ Logistic(loc_r, scale_r) # g ~ Logistic(coef_rg * r + loc_g, scale_g) # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b) component_logits, locs, scales, coeffs = params loc_tensors = tf.split(locs, self.n_channels, axis=-1) coef_tensors = tf.split(coeffs, self.n_coeffs, axis=-1) coef_count = 0 for i in range(self.n_channels): for j in range(i): loc_tensors[i] += loc_tensors[j] * coef_tensors[coef_count] coef_count += 1 locs = tf.concat(loc_tensors, axis=-1) ## create the distrubtion mixture_distribution = Categorical(logits=component_logits) # Convert distribution parameters for pixel values in # `[self._low, self._high]` for use with `QuantizedDistribution` locs = self.low + 0.5 * (self.high - self.low) * (locs + 1.) scales = scales * 0.5 * (self.high - self.low) logistic_dist = TransformedDistribution( distribution=Logistic(loc=locs, scale=scales), bijector=Shift(shift=tf.cast(-0.5, self.dtype))) dist = MixtureSameFamily(mixture_distribution=mixture_distribution, components_distribution=Independent( logistic_dist, reinterpreted_batch_ndims=1)) mean = Independent(dist, reinterpreted_batch_ndims=2).mean() ## normalize the data back to the input domain return _pixels_to(mean, self.inputs_domain, self.low, self.high)
def _log_prob(self, value: tf.Tensor): """ expect `value` is output from ELU function """ params = self._params transformed_value, value = _switch_domain( value, inputs_domain=self.inputs_domain, low=self.low, high=self.high) ## prepare the parameters if self.n_channels == 1: component_logits, locs, scales = params else: channel_tensors = tf.split(transformed_value, self.n_channels, axis=-1) # If there is more than one channel, we create a linear autoregressive # dependency among the location parameters of the channels of a single # pixel (the scale parameters within a pixel are independent). For a pixel # with R/G/B channels, the `r`, `g`, and `b` saturation values are # distributed as: # # r ~ Logistic(loc_r, scale_r) # g ~ Logistic(coef_rg * r + loc_g, scale_g) # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b) component_logits, locs, scales, coeffs = params loc_tensors = tf.split(locs, self.n_channels, axis=-1) coef_tensors = tf.split(coeffs, self.n_coeffs, axis=-1) coef_count = 0 for i in range(self.n_channels): channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :] for j in range(i): loc_tensors[ i] += channel_tensors[j] * coef_tensors[coef_count] coef_count += 1 locs = tf.concat(loc_tensors, axis=-1) ## create the distrubtion mixture_distribution = Categorical(logits=component_logits) # Convert distribution parameters for pixel values in # `[self._low, self._high]` for use with `QuantizedDistribution` locs = self.low + 0.5 * (self.high - self.low) * (locs + 1.) scales = scales * 0.5 * (self.high - self.low) logistic_dist = QuantizedDistribution( distribution=TransformedDistribution( distribution=Logistic(loc=locs, scale=scales), bijector=Shift(shift=tf.cast(-0.5, self.dtype))), low=self.low, high=self.high, ) dist = MixtureSameFamily(mixture_distribution=mixture_distribution, components_distribution=Independent( logistic_dist, reinterpreted_batch_ndims=1)) dist = Independent(dist, reinterpreted_batch_ndims=2) return dist.log_prob(value)
def new(params, temperature, probs_input=False, event_shape=(), validate_args=False, name='RelaxedBernoulliLayer'): """Create the distribution instance from a `params` vector.""" params = tf.convert_to_tensor(value=params, name='params') event_shape = dist_util.expand_to_vector( tf.convert_to_tensor(value=event_shape, name='event_shape', dtype_hint=tf.int32), tensor_name='event_shape', ) new_shape = tf.concat( [tf.shape(input=params)[:-1], event_shape], axis=0, ) params = tf.reshape(params, new_shape) dist = Independent( RelaxedBernoulli(temperature=temperature, logits=None if probs_input else params, probs=params if probs_input else None, validate_args=validate_args), reinterpreted_batch_ndims=tf.size(input=event_shape), name=name, ) return dist
def new(params, event_shape=(), given_logits=True, validate_args=False, name='ZIBernoulliLayer'): """Create the distribution instance from a `params` vector.""" params = tf.convert_to_tensor(value=params, name='params') event_shape = dist_util.expand_to_vector( tf.convert_to_tensor(value=event_shape, name='event_shape', dtype=tf.int32), tensor_name='event_shape', ) output_shape = tf.concat( [tf.shape(input=params)[:-1], event_shape], axis=0, ) (bernoulli_params, rate_params) = tf.split(params, 2, axis=-1) bernoulli_params = tf.reshape(bernoulli_params, output_shape) bern = Bernoulli(logits=bernoulli_params if given_logits else None, probs=bernoulli_params if not given_logits else None, validate_args=validate_args) zibern = ZeroInflated(count_distribution=bern, logits=tf.reshape(rate_params, output_shape), validate_args=validate_args) return Independent(zibern, reinterpreted_batch_ndims=tf.size(input=event_shape), name=name)
def make_gaussian_out(p: tf.Tensor, event_shape: Sequence[int]) -> Independent: loc, scale = tf.split(p, 2, -1) loc = tf.reshape(loc, (-1,) + tuple(event_shape)) scale = tf.reshape(scale, (-1,) + tuple(event_shape)) scale = tf.nn.softplus(scale) return Independent(Normal(loc=loc, scale=scale), len(event_shape))
def observation(distribution: str): if distribution == 'qlogistic': n_params = 2 obs = DistributionLambda( lambda params: QuantizedLogistic( *[ # loc p if i == 0 else # Ensure scales are positive and do not collapse to near-zero tf.nn.softplus(p) + tf.cast(tf.exp(-7.), tf.float32) for i, p in enumerate(tf.split(params, 2, -1)) ], low=0, high=255, inputs_domain='sigmoid', reinterpreted_batch_ndims=3), convert_to_tensor_fn=Distribution.sample, name='image') elif distribution == 'bernoulli': n_params = 1 obs = DistributionLambda(lambda params: Independent( Bernoulli(logits=params, dtype=tf.float32), len(IMAGE_SHAPE)), convert_to_tensor_fn=Distribution.sample, name='image') else: raise NotImplementedError return n_params, obs
def __init__(self, units, **kwargs): super().__init__(units, posterior=NormalLayer, posterior_kwargs=dict(scale_activation='softplus1'), prior=Independent( Normal(loc=tf.zeros(shape=units), scale=tf.ones(shape=units)), 1), **kwargs)
def MixtureQLogistic( locs: tf.Tensor, scales: tf.Tensor, logits: Optional[tf.Tensor] = None, probs: Optional[tf.Tensor] = None, batch_ndims: int = 0, low: int = 0, bits: int = 8, name: str = 'MixtureQuantizedLogistic') -> MixtureSameFamily: """ Mixture of quantized logistic distribution Parameters ---------- locs : tf.Tensor locs of all logistics components, shape `[batch_size, n_components, event_size]` scales : tf.Tensor scales of all logistics components, shape `[batch_size, n_components, event_size]` logits, probs : tf.Tensor probability for the mixture Categorical distribution, shape `[batch_size, n_components]` low : int, optional minimum quantized value, by default 0 bits : int, optional number of bits for quantization, the maximum will be `2^bits - 1`, by default 8 name : str, optional distribution name, by default 'MixtureQuantizedLogistic' Returns ------- MixtureSameFamily the mixture of quantized logistic distribution Example ------- ``` d = MixtureQLogistic(np.ones((12, 3, 8)).astype('float32'), np.ones((12, 3, 8)).astype('float32'), logits=np.random.rand(12, 3).astype('float32'), batch_ndims=1) ``` Reference --------- Salimans, T., Karpathy, A., Chen, X., Kingma, D.P., 2017. PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications. arXiv:1701.05517 [cs, stat]. """ cats = Categorical(probs=probs, logits=logits) dists = Logistic(loc=locs, scale=scales) dists = TransformedDistribution( distribution=dists, bijector=Shift(shift=tf.cast(-0.5, dists.dtype))) dists = QuantizedDistribution(dists, low=low, high=2**bits - 1.) dists = Independent(dists, reinterpreted_batch_ndims=batch_ndims) dists = MixtureSameFamily(mixture_distribution=cats, components_distribution=dists, name=name) return dists
def __init__(self, data, bandwidth=0.01, kernel=Normal, use_grid=False, use_fft=False, num_grid_points=1024, reparameterize=False, validate_args=False, allow_nan_stats=True, name='KernelDensityEstimation'): components_distribution_generator = lambda loc, scale: Independent( kernel(loc=loc, scale=scale)) with tf.name_scope(name) as name: self._use_fft = use_fft dtype = dtype_util.common_dtype([bandwidth, data], tf.float32) self._bandwidth = tensor_util.convert_nonref_to_tensor( bandwidth, name='bandwidth', dtype=dtype) self._data = tensor_util.convert_nonref_to_tensor(data, name='data', dtype=dtype) if (use_fft): self._grid = self._generate_grid(num_grid_points) self._grid_data = self._linear_binning() mixture_distribution = Categorical(probs=self._grid_data) components_distribution = components_distribution_generator( loc=self._grid, scale=self._bandwidth) elif (use_grid): self._grid = self._generate_grid(num_grid_points) self._grid_data = self._linear_binning() mixture_distribution = Categorical(probs=self._grid_data) components_distribution = components_distribution_generator( loc=self._grid, scale=self._bandwidth) else: self._grid = None self._grid_data = None n = self._data.shape[0] mixture_distribution = Categorical(probs=[1 / n] * n) components_distribution = components_distribution_generator( loc=self._data, scale=self._bandwidth) super(KernelDensityEstimation, self).__init__( mixture_distribution=mixture_distribution, components_distribution=components_distribution, reparameterize=reparameterize, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name)
def model_fullcov(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdims = int(np.prod(nets['latents'].event_shape)) nets['latents'] = RVconf( event_shape=zdims, projection=True, posterior='mvntril', prior=Independent(Normal(tf.zeros([zdims]), tf.ones([zdims])), 1), name='latents').create_posterior() return VariationalAutoencoder(**nets, name='FullCov')
def new(dists): q_e, q_d = dists mu_e = q_e.mean() mu_d = q_d.mean() prec_e = 1 / q_e.variance() prec_d = 1 / q_d.variance() mu = (mu_e * prec_e + mu_d * prec_d) / (prec_e + prec_d) scale = tf.math.sqrt(1 / (prec_e + prec_d)) dist = Normal(loc=mu, scale=scale) if isinstance(q_e, Independent): ndim = q_e.reinterpreted_batch_ndims dist = Independent(dist, reinterpreted_batch_ndims=ndim) return dist
def set_prior(self, loc=0., log_scale=np.log(np.expm1(1)), mixture_logits=None): r""" Set the prior for mixture density network loc : Scalar or Tensor with shape `[n_components, event_size]` log_scale : Scalar or Tensor with shape `[n_components, event_size]` for 'none' and 'diag' component, and `[n_components, event_size*(event_size +1)//2]` for 'full' component. mixture_logits : Scalar or Tensor with shape `[n_components]` """ event_size = self.event_size if self.covariance == 'diag': scale_shape = [self.n_components, event_size] fn = lambda l, s: MultivariateNormalDiag( loc=l, scale_diag=tf.nn.softplus(s)) elif self.covariance == 'none': scale_shape = [self.n_components, event_size] fn = lambda l, s: Independent( Normal(loc=l, scale=tf.math.softplus(s)), 1) elif self.covariance == 'full': scale_shape = [ self.n_components, event_size * (event_size + 1) // 2 ] fn = lambda l, s: MultivariateNormalTriL( loc=l, scale_tril=FillScaleTriL(diag_shift=1e-5)(tf.math.softplus(s))) # if isinstance(log_scale, Number) or tf.rank(log_scale) == 0: loc = tf.fill([self.n_components, self.event_size], loc) # if isinstance(log_scale, Number) or tf.rank(log_scale) == 0: log_scale = tf.fill(scale_shape, log_scale) # if mixture_logits is None: p = 1. / self.n_components mixture_logits = np.log(p / (1. - p)) if isinstance(mixture_logits, Number) or tf.rank(mixture_logits) == 0: mixture_logits = tf.fill([self.n_components], mixture_logits) # loc = tf.cast(loc, self.dtype) log_scale = tf.cast(log_scale, self.dtype) mixture_logits = tf.cast(mixture_logits, self.dtype) self._prior = MixtureSameFamily( components_distribution=fn(loc, log_scale), mixture_distribution=Categorical(logits=mixture_logits), name="prior") return self
def __init__(self, loc: Union[tf.Tensor, np.ndarray], scale: Union[tf.Tensor, np.ndarray], low: Union[None, Number] = 0, high: Union[None, Number] = 2**8 - 1, inputs_domain: Literal['sigmoid', 'tanh', 'pixel'] = 'sigmoid', reinterpreted_batch_ndims: Optional[int] = None, validate_args: bool = False, allow_nan_stats: bool = True, name: str = 'QuantizedLogistic'): parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, low, high], dtype_hint=tf.float32) self._low = low self._high = high # Convert distribution parameters for pixel values in # `[self._low, self._high]` for use with `QuantizedDistribution` if low is not None and high is not None: support = 0.5 * (high - low) loc = low + support * (loc + 1.) scale = scale * support self._logistic = Logistic(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._dist = QuantizedDistribution( distribution=TransformedDistribution( distribution=self._logistic, bijector=Shift(tf.cast(-0.5, dtype=dtype))), low=low, high=high, validate_args=validate_args, name=name) if reinterpreted_batch_ndims is not None: self._dist = Independent( self._dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) self.inputs_domain = inputs_domain super(QuantizedLogistic, self).__init__(dtype=dtype, reparameterization_type=NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, units: int, prior_loc: float = 0., prior_scale: float = 1., projection: bool = True, name: str = "Latents", **kwargs): super().__init__( event_shape=(int(units), ), posterior=NormalLayer, posterior_kwargs=dict(scale_activation='softplus'), prior=Independent(Normal(loc=tf.fill((units, ), prior_loc), scale=tf.fill((units, ), prior_scale)), reinterpreted_batch_ndims=1), projection=projection, name=name, **kwargs, )
def encode(self, inputs, library=None, training=None, mask=None, sample_shape=(), **kwargs): qZ_X = super().encode(inputs=inputs, library=library, training=training, mask=mask, sample_shape=sample_shape) if library is not None: mean, var = tf.split(tf.nest.flatten(library)[0], 2, axis=1) pL = Independent(Normal(loc=mean, scale=tf.math.sqrt(var)), 1) else: pL = None qZ_X[-1].KL_divergence.prior = pL return qZ_X
def build(self, input_shape=None): super().build(input_shape) if self._disable: return decoder_shape = self.layer.compute_output_shape(input_shape) layer, layer_t = _NDIMS_CONV[self.input_ndim] # === 1. create projection layer assert self.encoder is not None, \ 'ParallelLatents require encoder to be specified' # posterior projection (assume encoder shape and decoder shape the same) self._conv_posterior = layer(**self._network_kw, name='ConvPosterior') self._conv_posterior.build(decoder_shape) # === 2. distribution params_shape = self._conv_posterior.compute_output_shape(decoder_shape) self._dist_posterior = DistributionLambda( make_distribution_fn=partial(_create_dist, event_ndims=len(params_shape) - 1, dtype=self.dtype), name=f'{self.name}_posterior') self._dist_posterior.build(params_shape) # dynamically infer the shape latents_shape = tf.convert_to_tensor( self._dist_posterior(keras.layers.Input(params_shape[1:]))).shape self._latents_shape = latents_shape[1:] # create the prior N(0,I) self._prior = Independent(Normal(loc=tf.zeros(self.latents_shape, dtype=self.dtype), scale=tf.ones(self.latents_shape, dtype=self.dtype)), reinterpreted_batch_ndims=len( self.latents_shape), name=f'{self.name}_prior') # === 3. final output affine self._conv_out = _upsample_by_conv( layer, layer_t, input_shape=latents_shape, output_shape=decoder_shape, kernel_size=self._conv_posterior.kernel_size, padding=self._conv_posterior.padding, strides=self._conv_posterior.strides)
def __init__(self, units: int, prior_loc: float = 0., prior_scale: float = 1., projection: bool = True, name: str = "Latents", **kwargs): # prior = MultivariateNormalDiag(loc=tf.fill((units,), prior_loc), # scale_identity_multiplier=prior_scale) super().__init__( event_shape=(int(units), ), posterior=MultivariateNormalLayer, posterior_kwargs=dict(covariance='diag', scale_activation=tf.nn.softplus), prior=Independent(Normal(loc=tf.fill((units, ), prior_loc), scale=tf.fill((units, ), prior_scale)), reinterpreted_batch_ndims=1), projection=projection, name=name, **kwargs, )
def new(params, event_shape=(), dtype=None, validate_args=False, continuous=False, lims=(0.499, 0.501), name='BernoulliLayer'): """Create the distribution instance from a `params` vector.""" params = tf.convert_to_tensor(value=params, name='params') event_shape = dist_util.expand_to_vector( tf.convert_to_tensor(value=event_shape, name='event_shape', dtype_hint=tf.int32), tensor_name='event_shape', ) new_shape = tf.concat( [tf.shape(input=params)[:-1], event_shape], axis=0, ) if continuous: dist = ContinuousBernoulli(logits=tf.reshape(params, new_shape), dtype=dtype or params.dtype.base_dtype, lims=lims, validate_args=validate_args) else: dist = Bernoulli(logits=tf.reshape(params, new_shape), dtype=dtype or params.dtype.base_dtype, validate_args=validate_args) dist = Independent( dist, reinterpreted_batch_ndims=tf.size(input=event_shape), name=name, ) dist.logits = dist.distribution._logits # pylint: disable=protected-access dist.probs = dist.distribution._probs # pylint: disable=protected-access return dist
def __init__(self, count_distribution, inflated_distribution=None, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="ZeroInflated"): """Initialize a zero-inflated distribution. A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`, representing the probabilities of excess zeros) and a `Distribution` object having matching dtype, batch shape, event shape, and continuity properties (the dist). Parameters ---------- count_distribution : A `tfp.distributions.Distribution` instance. The instance must have `batch_shape` matching the zero-inflation distribution. inflated_distribution: `tfp.distributions.Bernoulli`-like instance. Manages the probability of excess zeros, the zero-inflated rate. Must have either scalar `batch_shape` or `batch_shape` matching `count_distribution.batch_shape`. logits: An N-D `Tensor` representing the log-odds of a excess zeros A zero-inflation rate, where the probability of excess zeros is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a zero event. Each entry in the `Tensor` parameterizes an independent ZeroInflated distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between pi and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). References ---------- Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:2140-2148 """ parameters = dict(locals()) self._runtime_assertions = [] with tf.compat.v1.name_scope(name) as name: if not isinstance(count_distribution, distribution.Distribution): raise TypeError("count_distribution must be a Distribution instance" " but saw: %s" % count_distribution) self._count_distribution = count_distribution if inflated_distribution is None: inflated_distribution = Bernoulli( logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name="ZeroInflatedRate") elif not isinstance(inflated_distribution, distribution.Distribution): raise TypeError("inflated_distribution must be a Distribution instance" " but saw: %s" % inflated_distribution) self._inflated_distribution = inflated_distribution if self._count_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from count_disttribution") if self._inflated_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from inflated_distribution") # create Independent Bernoulli distribution that the batch_shape # of count_distribution matching batch_shape of inflated_distribution inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims count_batch_ndims = self._count_distribution.batch_shape.ndims if count_batch_ndims < inflated_batch_ndims: self._inflated_distribution = Independent( self._inflated_distribution, reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims, name="ZeroInflatedRate") elif count_batch_ndims > inflated_batch_ndims: raise ValueError("count_distribution has %d-D batch_shape, which smaller" "than %d-D batch_shape of inflated_distribution" % (count_batch_ndims, inflated_batch_ndims)) # Ensure that all batch and event ndims are consistent. if validate_args: self._runtime_assertions.append( tf.assert_equal( self._count_distribution.batch_shape_tensor(), self._inflated_distribution.batch_shape_tensor(), message=("dist batch shape must match logits|probs batch shape")) ) # We let the zero-inflated distribution access _graph_parents since its arguably # more like a baseclass. reparameterization_type = [ self._count_distribution.reparameterization_type, self._inflated_distribution.reparameterization_type] if any(i == reparameterization.NOT_REPARAMETERIZED for i in reparameterization_type): reparameterization_type = reparameterization.NOT_REPARAMETERIZED else: reparameterization_type = reparameterization.FULLY_REPARAMETERIZED super(ZeroInflated, self).__init__( dtype=self._count_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=self._count_distribution._graph_parents + self._inflated_distribution._graph_parents, name=name)
class ZeroInflated(distribution.Distribution): """zero-inflated distribution. The `zero-inflated` object implements batched zero-inflated distributions. The zero-inflated model is defined by a zero-inflation rate and a python list of `Distribution` objects. Methods supported include `log_prob`, `prob`, `mean`, `sample`, and `entropy_lower_bound`. """ def __init__(self, count_distribution, inflated_distribution=None, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="ZeroInflated"): """Initialize a zero-inflated distribution. A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`, representing the probabilities of excess zeros) and a `Distribution` object having matching dtype, batch shape, event shape, and continuity properties (the dist). Parameters ---------- count_distribution : A `tfp.distributions.Distribution` instance. The instance must have `batch_shape` matching the zero-inflation distribution. inflated_distribution: `tfp.distributions.Bernoulli`-like instance. Manages the probability of excess zeros, the zero-inflated rate. Must have either scalar `batch_shape` or `batch_shape` matching `count_distribution.batch_shape`. logits: An N-D `Tensor` representing the log-odds of a excess zeros A zero-inflation rate, where the probability of excess zeros is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a zero event. Each entry in the `Tensor` parameterizes an independent ZeroInflated distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between pi and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). References ---------- Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:2140-2148 """ parameters = dict(locals()) self._runtime_assertions = [] with tf.compat.v1.name_scope(name) as name: if not isinstance(count_distribution, distribution.Distribution): raise TypeError("count_distribution must be a Distribution instance" " but saw: %s" % count_distribution) self._count_distribution = count_distribution if inflated_distribution is None: inflated_distribution = Bernoulli( logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name="ZeroInflatedRate") elif not isinstance(inflated_distribution, distribution.Distribution): raise TypeError("inflated_distribution must be a Distribution instance" " but saw: %s" % inflated_distribution) self._inflated_distribution = inflated_distribution if self._count_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from count_disttribution") if self._inflated_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from inflated_distribution") # create Independent Bernoulli distribution that the batch_shape # of count_distribution matching batch_shape of inflated_distribution inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims count_batch_ndims = self._count_distribution.batch_shape.ndims if count_batch_ndims < inflated_batch_ndims: self._inflated_distribution = Independent( self._inflated_distribution, reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims, name="ZeroInflatedRate") elif count_batch_ndims > inflated_batch_ndims: raise ValueError("count_distribution has %d-D batch_shape, which smaller" "than %d-D batch_shape of inflated_distribution" % (count_batch_ndims, inflated_batch_ndims)) # Ensure that all batch and event ndims are consistent. if validate_args: self._runtime_assertions.append( tf.assert_equal( self._count_distribution.batch_shape_tensor(), self._inflated_distribution.batch_shape_tensor(), message=("dist batch shape must match logits|probs batch shape")) ) # We let the zero-inflated distribution access _graph_parents since its arguably # more like a baseclass. reparameterization_type = [ self._count_distribution.reparameterization_type, self._inflated_distribution.reparameterization_type] if any(i == reparameterization.NOT_REPARAMETERIZED for i in reparameterization_type): reparameterization_type = reparameterization.NOT_REPARAMETERIZED else: reparameterization_type = reparameterization.FULLY_REPARAMETERIZED super(ZeroInflated, self).__init__( dtype=self._count_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=self._count_distribution._graph_parents + self._inflated_distribution._graph_parents, name=name) @property def logits(self): """Log-odds of a `1` outcome (vs `0`).""" if isinstance(self._inflated_distribution, Independent): return self._inflated_distribution.distribution.logits return self._inflated_distribution.logits @property def probs(self): """Probability of a `1` outcome (vs `0`).""" if isinstance(self._inflated_distribution, Independent): return self._inflated_distribution.distribution.probs return self._inflated_distribution.probs @property def count_distribution(self): return self._count_distribution @property def inflated_distribution(self): return self._inflated_distribution def _batch_shape_tensor(self): return self._count_distribution._batch_shape_tensor() def _batch_shape(self): return self._count_distribution._batch_shape() def _event_shape_tensor(self): return self._count_distribution._event_shape_tensor() def _event_shape(self): return self._count_distribution._event_shape() def _mean(self): with tf.compat.v1.control_dependencies(self._runtime_assertions): # These should all be the same shape by virtue of matching # batch_shape and event_shape. probs, d_mean = _broadcast_rate(self.probs, self._count_distribution.mean()) return (1 - probs) * d_mean def _variance(self): """ (1 - pi) * (d.var + d.mean^2) - [(1 - pi) * d.mean]^2 Note: mean(ZeroInflated) = (1 - pi) * d.mean where: - pi is zero-inflated rate - d is count distribution """ with tf.compat.v1.control_dependencies(self._runtime_assertions): # These should all be the same shape by virtue of matching # batch_shape and event_shape. d = self._count_distribution probs, d_mean, d_variance = _broadcast_rate( self.probs, d.mean(), d.variance()) return (1 - probs) * \ (d_variance + tf.square(d_mean)) - \ tf.square(self._mean()) def _log_prob(self, x): with tf.compat.v1.control_dependencies(self._runtime_assertions): x = tf.convert_to_tensor(x, name="x") d = self._count_distribution pi = self.probs d_prob = d.prob(x) d_log_prob = d.log_prob(x) # make pi and anything come out of count_distribution # broadcast-able pi, d_prob, d_log_prob = _broadcast_rate(pi, d_prob, d_log_prob) # This equation is validated # Equation (13) reference: u_{ij} = 1 - pi_{ij} y_0 = tf.log(pi + (1 - pi) * d_prob) y_1 = tf.log(1 - pi) + d_log_prob return tf.where(x == 0, y_0, y_1) def _prob(self, x): return tf.exp(self._log_prob(x)) def _sample_n(self, n, seed): with tf.compat.v1.control_dependencies(self._runtime_assertions): seed = seed_stream.SeedStream(seed, salt="ZeroInflated") mask = self.inflated_distribution.sample(n, seed()) samples = self.count_distribution.sample(n, seed()) mask, samples = _broadcast_rate(mask, samples) # mask = 1 => new_sample = 0 # mask = 0 => new_sample = sample return samples * tf.cast(1 - mask, samples.dtype) # ******************** shortcut for denoising ******************** # def denoised_mean(self): return self.count_distribution.mean() def denoised_variance(self): return self.count_distribution.variance()
class ZeroInflated(distribution.Distribution): """Zero-inflated distribution. The `zero-inflated` object implements batched zero-inflated distributions. The zero-inflated model is defined by a zero-inflation rate and a python list of `Distribution` objects. Methods supported include `log_prob`, `prob`, `mean`, `sample`, and `entropy_lower_bound`. """ def __init__(self, count_distribution, inflated_distribution=None, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="ZeroInflated"): """Initialize a zero-inflated distribution. A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`, representing the probabilities of excess zeros) and a `Distribution` object having matching dtype, batch shape, event shape, and continuity properties (the dist). Parameters ---------- count_distribution : A `tfp.distributions.Distribution` instance. The instance must have `batch_shape` matching the zero-inflation distribution. inflated_distribution: `tfp.distributions.Bernoulli`-like instance. Manages the probability of excess zeros, the zero-inflated rate. Must have either scalar `batch_shape` or `batch_shape` matching `count_distribution.batch_shape`. logits: An N-D `Tensor` representing the log-odds of a excess zeros A zero-inflation rate, where the probability of excess zeros is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a zero event. Each entry in the `Tensor` parameterizes an independent ZeroInflated distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between pi and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). References ---------- Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:2140-2148 """ parameters = dict(locals()) self._runtime_assertions = [] with tf.compat.v1.name_scope(name) as name: if not isinstance(count_distribution, distribution.Distribution): raise TypeError("count_distribution must be a Distribution instance" " but saw: %s" % count_distribution) self._count_distribution = count_distribution if inflated_distribution is None: inflated_distribution = Bernoulli(logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name="ZeroInflatedRate") elif not isinstance(inflated_distribution, distribution.Distribution): raise TypeError("inflated_distribution must be a Distribution instance" " but saw: %s" % inflated_distribution) self._inflated_distribution = inflated_distribution if self._count_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from count_disttribution") if self._inflated_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from inflated_distribution") # create Independent Bernoulli distribution that the batch_shape # of count_distribution matching batch_shape of inflated_distribution inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims count_batch_ndims = self._count_distribution.batch_shape.ndims if count_batch_ndims < inflated_batch_ndims: self._inflated_distribution = Independent( self._inflated_distribution, reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims, name="ZeroInflatedRate") elif count_batch_ndims > inflated_batch_ndims: raise ValueError( "count_distribution has %d-D batch_shape, which smaller" "than %d-D batch_shape of inflated_distribution" % (count_batch_ndims, inflated_batch_ndims)) # Ensure that all batch and event ndims are consistent. if validate_args: self._runtime_assertions.append( tf.assert_equal( self._count_distribution.batch_shape_tensor(), self._inflated_distribution.batch_shape_tensor(), message=( "dist batch shape must match logits|probs batch shape"))) # We let the zero-inflated distribution access _graph_parents since its arguably # more like a baseclass. reparameterization_type = [ self._count_distribution.reparameterization_type, self._inflated_distribution.reparameterization_type ] if any(i == reparameterization.NOT_REPARAMETERIZED for i in reparameterization_type): reparameterization_type = reparameterization.NOT_REPARAMETERIZED else: reparameterization_type = reparameterization.FULLY_REPARAMETERIZED super(ZeroInflated, self).__init__(dtype=self._count_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=self._count_distribution._graph_parents + self._inflated_distribution._graph_parents, name=name) @property def logits(self): """Log-odds of a `1` outcome (vs `0`).""" if isinstance(self._inflated_distribution, Independent): return self._inflated_distribution.distribution.logits_parameter() return self._inflated_distribution.logits_parameter() @property def probs(self): """Probability of a `1` outcome (vs `0`).""" if isinstance(self._inflated_distribution, Independent): return self._inflated_distribution.distribution.probs_parameter() return self._inflated_distribution.probs_parameter() @property def count_distribution(self): return self._count_distribution @property def inflated_distribution(self): return self._inflated_distribution def _batch_shape_tensor(self): return self._count_distribution._batch_shape_tensor() def _batch_shape(self): return self._count_distribution._batch_shape() def _event_shape_tensor(self): return self._count_distribution._event_shape_tensor() def _event_shape(self): return self._count_distribution._event_shape() def _mean(self): with tf.compat.v1.control_dependencies(self._runtime_assertions): # These should all be the same shape by virtue of matching # batch_shape and event_shape. probs, d_mean = _broadcast_rate(self.probs, self._count_distribution.mean()) return (1 - probs) * d_mean def _variance(self): """ (1 - pi) * (d.var + d.mean^2) - [(1 - pi) * d.mean]^2 Note: mean(ZeroInflated) = (1 - pi) * d.mean where: - pi is zero-inflated rate - d is count distribution """ with tf.compat.v1.control_dependencies(self._runtime_assertions): # These should all be the same shape by virtue of matching # batch_shape and event_shape. d = self._count_distribution probs, d_mean, d_variance = _broadcast_rate(self.probs, d.mean(), d.variance()) return (1 - probs) * \ (d_variance + tf.square(d_mean)) - \ tf.math.square(self._mean()) def _log_prob(self, x): with tf.compat.v1.control_dependencies(self._runtime_assertions): eps = tf.cast(1e-8, x.dtype) x = tf.convert_to_tensor(x, name="x") d = self._count_distribution pi = self.probs log_prob = d.log_prob(x) prob = tf.math.exp(log_prob) # make pi and anything come out of count_distribution # broadcast-able pi, prob, log_prob = _broadcast_rate(pi, prob, log_prob) # This equation is validated # Equation (13) reference: u_{ij} = 1 - pi_{ij} y_0 = tf.math.log(pi + (1 - pi) * prob) y_1 = tf.math.log(1 - pi) + log_prob return tf.where(x <= eps, y_0, y_1) def _prob(self, x): return tf.math.exp(self._log_prob(x)) def _sample_n(self, n, seed): with tf.compat.v1.control_dependencies(self._runtime_assertions): seed = SeedStream(seed, salt="ZeroInflated") mask = self.inflated_distribution.sample(n, seed()) samples = self.count_distribution.sample(n, seed()) mask, samples = _broadcast_rate(mask, samples) # mask = 1 => new_sample = 0 # mask = 0 => new_sample = sample return samples * tf.cast(1 - mask, samples.dtype) # ******************** shortcut for denoising ******************** # def denoised_mean(self): return self.count_distribution.mean() def denoised_variance(self): return self.count_distribution.variance()
def __init__(self, count_distribution, inflated_distribution=None, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="ZeroInflated"): """Initialize a zero-inflated distribution. A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`, representing the probabilities of excess zeros) and a `Distribution` object having matching dtype, batch shape, event shape, and continuity properties (the dist). Parameters ---------- count_distribution : A `tfp.distributions.Distribution` instance. The instance must have `batch_shape` matching the zero-inflation distribution. inflated_distribution: `tfp.distributions.Bernoulli`-like instance. Manages the probability of excess zeros, the zero-inflated rate. Must have either scalar `batch_shape` or `batch_shape` matching `count_distribution.batch_shape`. logits: An N-D `Tensor` representing the log-odds of a excess zeros A zero-inflation rate, where the probability of excess zeros is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a zero event. Each entry in the `Tensor` parameterizes an independent ZeroInflated distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between pi and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). References ---------- Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:2140-2148 """ parameters = dict(locals()) self._runtime_assertions = [] with tf.compat.v1.name_scope(name) as name: if not isinstance(count_distribution, distribution.Distribution): raise TypeError("count_distribution must be a Distribution instance" " but saw: %s" % count_distribution) self._count_distribution = count_distribution if inflated_distribution is None: inflated_distribution = Bernoulli(logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name="ZeroInflatedRate") elif not isinstance(inflated_distribution, distribution.Distribution): raise TypeError("inflated_distribution must be a Distribution instance" " but saw: %s" % inflated_distribution) self._inflated_distribution = inflated_distribution if self._count_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from count_disttribution") if self._inflated_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from inflated_distribution") # create Independent Bernoulli distribution that the batch_shape # of count_distribution matching batch_shape of inflated_distribution inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims count_batch_ndims = self._count_distribution.batch_shape.ndims if count_batch_ndims < inflated_batch_ndims: self._inflated_distribution = Independent( self._inflated_distribution, reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims, name="ZeroInflatedRate") elif count_batch_ndims > inflated_batch_ndims: raise ValueError( "count_distribution has %d-D batch_shape, which smaller" "than %d-D batch_shape of inflated_distribution" % (count_batch_ndims, inflated_batch_ndims)) # Ensure that all batch and event ndims are consistent. if validate_args: self._runtime_assertions.append( tf.assert_equal( self._count_distribution.batch_shape_tensor(), self._inflated_distribution.batch_shape_tensor(), message=( "dist batch shape must match logits|probs batch shape"))) # We let the zero-inflated distribution access _graph_parents since its arguably # more like a baseclass. reparameterization_type = [ self._count_distribution.reparameterization_type, self._inflated_distribution.reparameterization_type ] if any(i == reparameterization.NOT_REPARAMETERIZED for i in reparameterization_type): reparameterization_type = reparameterization.NOT_REPARAMETERIZED else: reparameterization_type = reparameterization.FULLY_REPARAMETERIZED super(ZeroInflated, self).__init__(dtype=self._count_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=self._count_distribution._graph_parents + self._inflated_distribution._graph_parents, name=name)
def _create_dist(params, event_ndims, dtype): loc, scale = tf.split(params, 2, axis=-1) scale = tf.nn.softplus(scale) + tf.cast(tf.exp(-7.), dtype) d = Normal(loc, scale) d = Independent(d, reinterpreted_batch_ndims=event_ndims) return d
np.isclose(p.numpy(), tf.concat((p1, p2), axis=0).numpy())) except NotImplementedError: pass shape = (8, 2) count = np.random.randint(0, 20, size=shape).astype('float32') probs = np.random.rand(*shape).astype('float32') logits = np.random.rand(*shape).astype('float32') assert_consistent_statistics(Bernoulli(probs=probs), Bernoulli(logits=logits)) assert_consistent_statistics(Bernoulli(logits=logits), Bernoulli(logits=logits)) assert_consistent_statistics( Independent(Bernoulli(probs=probs), reinterpreted_batch_ndims=1), Independent(Bernoulli(logits=logits), reinterpreted_batch_ndims=1)) assert_consistent_statistics( NegativeBinomial(total_count=count, logits=logits), NegativeBinomial(total_count=count, probs=probs)) assert_consistent_statistics( Independent(NegativeBinomial(total_count=count, logits=logits), reinterpreted_batch_ndims=1), Independent(NegativeBinomial(total_count=count, probs=probs), reinterpreted_batch_ndims=1)) assert_consistent_statistics( ZeroInflated(NegativeBinomial(total_count=count, logits=logits), logits=logits), ZeroInflated(NegativeBinomial(total_count=count, probs=probs), probs=probs))
def __init__(self, count_distribution, inflated_distribution=None, logits=None, probs=None, eps=1e-8, validate_args=False, allow_nan_stats=True, name="ZeroInflated"): r"""Initialize a zero-inflated distribution. A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`, representing the probabilities of excess zeros) and a `Distribution` object having matching dtype, batch shape, event shape, and continuity properties (the dist). Arguments: count_distribution : A `tfp.distributions.Distribution` instance. The instance must have `batch_shape` matching the zero-inflation distribution. inflated_distribution: `tfp.distributions.Bernoulli`-like instance. Manages the probability of excess zeros, the zero-inflated rate. Must have either scalar `batch_shape` or `batch_shape` matching `count_distribution.batch_shape`. logits: An N-D `Tensor` representing the log-odds of a excess zeros A zero-inflation rate, where the probability of excess zeros is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a zero event. Each entry in the `Tensor` parameterizes an independent ZeroInflated distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between pi and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). References: Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:2140-2148 """ parameters = dict(locals()) with tf.compat.v1.name_scope(name) as name: # main count distribution if not isinstance(count_distribution, distribution.Distribution): raise TypeError( "count_distribution must be a Distribution instance" " but saw: %s" % count_distribution) # Zero inflation distribution if inflated_distribution is None: inflated_distribution = Bernoulli( logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name="ZeroInflatedRate") elif not isinstance(inflated_distribution, distribution.Distribution): raise TypeError( "inflated_distribution must be a Distribution instance" " but saw: %s" % inflated_distribution) # Matching the event shape inflated_ndim = len(inflated_distribution.event_shape) count_ndim = len(count_distribution.event_shape) if inflated_ndim < count_ndim: inflated_distribution = Independent(inflated_distribution, count_ndim - inflated_ndim) self._count_distribution = count_distribution self._inflated_distribution = inflated_distribution # if self._count_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from count_distribution" ) if self._inflated_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from inflated_distribution" ) self._eps = tensor_util.convert_nonref_to_tensor( eps, dtype_hint=count_distribution.dtype, name='eps') # We let the zero-inflated distribution access _graph_parents since its arguably # more like a baseclass. super(ZeroInflated, self).__init__( dtype=self._count_distribution.dtype, reparameterization_type=self._count_distribution. reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name, )
def model_gmmvae3(args: Arguments): zdim = args.zdim prior = Independent(Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])), 1) return GMMVAE(n_components=10, prior=prior, analytic=False, **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False))
def __init__(self, loc, scale, logits=None, probs=None, covariance_type='diag', trainable=False, validate_args=False, allow_nan_stats=True, name=None): kw = dict(validate_args=validate_args, allow_nan_stats=allow_nan_stats) self._trainable = bool(trainable) self._llk_history = [] if trainable: loc = tf.Variable(loc, trainable=True, name='loc') scale = tf.Variable(scale, trainable=True, name='scale') if logits is not None: logits = tf.Variable(logits, trainable=True, name='logits') if probs is not None: probs = tf.Variable(probs, trainable=True, name='probs') ### initialize mixture Categorical mixture = Categorical(logits=logits, probs=probs, name="MixtureWeights", **kw) n_components = mixture._num_categories() ### initialize Gaussian components covariance_type = str(covariance_type).lower().strip() if name is None: name = 'Mixture%sGaussian' % \ (covariance_type.capitalize() if covariance_type != 'none' else 'Independent') ## create the components if covariance_type == 'diag': if tf.rank(scale) == 0: # scalar extra_kw = dict(scale_identity_multiplier=scale) else: # a tensor extra_kw = dict(scale_diag=scale) components = MultivariateNormalDiag(loc=loc, name=name, **kw, **extra_kw) elif covariance_type in ('tril', 'full'): if tf.rank(scale) == 1 or \ (scale.shape[-1] != scale.shape[-2]): scale_tril = FillScaleTriL(diag_shift=np.array( 1e-5, tf.convert_to_tensor(scale).dtype.as_numpy_dtype())) scale = scale_tril(scale) components = MultivariateNormalTriL(loc=loc, scale_tril=scale, name=name, **kw) elif covariance_type == 'none': components = Independent(distribution=Normal(loc=loc, scale=scale, **kw), reinterpreted_batch_ndims=1, name=name) else: raise ValueError("No support for covariance_type: '%s'" % covariance_type) ### validate the n_components assert (components.batch_shape[-1] == int(n_components)), \ "Number of components mismatch, given:%d, mixture:%d, components:%d" % \ (mixture.event_shape[-1], components.batch_shape[-1], int(n_components)) super().__init__(mixture_distribution=mixture, components_distribution=components, name=name, **kw)