Esempio n. 1
0
 def _mean(self, **kwargs):
     params = self._params
     if self.n_channels == 1:
         component_logits, locs, scales = params
     else:
         # r ~ Logistic(loc_r, scale_r)
         # g ~ Logistic(coef_rg * r + loc_g, scale_g)
         # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
         component_logits, locs, scales, coeffs = params
         loc_tensors = tf.split(locs, self.n_channels, axis=-1)
         coef_tensors = tf.split(coeffs, self.n_coeffs, axis=-1)
         coef_count = 0
         for i in range(self.n_channels):
             for j in range(i):
                 loc_tensors[i] += loc_tensors[j] * coef_tensors[coef_count]
                 coef_count += 1
         locs = tf.concat(loc_tensors, axis=-1)
     ## create the distrubtion
     mixture_distribution = Categorical(logits=component_logits)
     # Convert distribution parameters for pixel values in
     # `[self._low, self._high]` for use with `QuantizedDistribution`
     locs = self.low + 0.5 * (self.high - self.low) * (locs + 1.)
     scales = scales * 0.5 * (self.high - self.low)
     logistic_dist = TransformedDistribution(
         distribution=Logistic(loc=locs, scale=scales),
         bijector=Shift(shift=tf.cast(-0.5, self.dtype)))
     dist = MixtureSameFamily(mixture_distribution=mixture_distribution,
                              components_distribution=Independent(
                                  logistic_dist,
                                  reinterpreted_batch_ndims=1))
     mean = Independent(dist, reinterpreted_batch_ndims=2).mean()
     ## normalize the data back to the input domain
     return _pixels_to(mean, self.inputs_domain, self.low, self.high)
Esempio n. 2
0
 def _log_prob(self, value: tf.Tensor):
     """ expect `value` is output from ELU function """
     params = self._params
     transformed_value, value = _switch_domain(
         value,
         inputs_domain=self.inputs_domain,
         low=self.low,
         high=self.high)
     ## prepare the parameters
     if self.n_channels == 1:
         component_logits, locs, scales = params
     else:
         channel_tensors = tf.split(transformed_value,
                                    self.n_channels,
                                    axis=-1)
         # If there is more than one channel, we create a linear autoregressive
         # dependency among the location parameters of the channels of a single
         # pixel (the scale parameters within a pixel are independent). For a pixel
         # with R/G/B channels, the `r`, `g`, and `b` saturation values are
         # distributed as:
         #
         # r ~ Logistic(loc_r, scale_r)
         # g ~ Logistic(coef_rg * r + loc_g, scale_g)
         # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
         component_logits, locs, scales, coeffs = params
         loc_tensors = tf.split(locs, self.n_channels, axis=-1)
         coef_tensors = tf.split(coeffs, self.n_coeffs, axis=-1)
         coef_count = 0
         for i in range(self.n_channels):
             channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :]
             for j in range(i):
                 loc_tensors[
                     i] += channel_tensors[j] * coef_tensors[coef_count]
                 coef_count += 1
         locs = tf.concat(loc_tensors, axis=-1)
     ## create the distrubtion
     mixture_distribution = Categorical(logits=component_logits)
     # Convert distribution parameters for pixel values in
     # `[self._low, self._high]` for use with `QuantizedDistribution`
     locs = self.low + 0.5 * (self.high - self.low) * (locs + 1.)
     scales = scales * 0.5 * (self.high - self.low)
     logistic_dist = QuantizedDistribution(
         distribution=TransformedDistribution(
             distribution=Logistic(loc=locs, scale=scales),
             bijector=Shift(shift=tf.cast(-0.5, self.dtype))),
         low=self.low,
         high=self.high,
     )
     dist = MixtureSameFamily(mixture_distribution=mixture_distribution,
                              components_distribution=Independent(
                                  logistic_dist,
                                  reinterpreted_batch_ndims=1))
     dist = Independent(dist, reinterpreted_batch_ndims=2)
     return dist.log_prob(value)
Esempio n. 3
0
 def new(params,
         temperature,
         probs_input=False,
         event_shape=(),
         validate_args=False,
         name='RelaxedBernoulliLayer'):
   """Create the distribution instance from a `params` vector."""
   params = tf.convert_to_tensor(value=params, name='params')
   event_shape = dist_util.expand_to_vector(
     tf.convert_to_tensor(value=event_shape,
                          name='event_shape',
                          dtype_hint=tf.int32),
     tensor_name='event_shape',
   )
   new_shape = tf.concat(
     [tf.shape(input=params)[:-1], event_shape],
     axis=0,
   )
   params = tf.reshape(params, new_shape)
   dist = Independent(
     RelaxedBernoulli(temperature=temperature,
                      logits=None if probs_input else params,
                      probs=params if probs_input else None,
                      validate_args=validate_args),
     reinterpreted_batch_ndims=tf.size(input=event_shape),
     name=name,
   )
   return dist
Esempio n. 4
0
 def new(params,
         event_shape=(),
         given_logits=True,
         validate_args=False,
         name='ZIBernoulliLayer'):
   """Create the distribution instance from a `params` vector."""
   params = tf.convert_to_tensor(value=params, name='params')
   event_shape = dist_util.expand_to_vector(
     tf.convert_to_tensor(value=event_shape,
                          name='event_shape',
                          dtype=tf.int32),
     tensor_name='event_shape',
   )
   output_shape = tf.concat(
     [tf.shape(input=params)[:-1], event_shape],
     axis=0,
   )
   (bernoulli_params, rate_params) = tf.split(params, 2, axis=-1)
   bernoulli_params = tf.reshape(bernoulli_params, output_shape)
   bern = Bernoulli(logits=bernoulli_params if given_logits else None,
                    probs=bernoulli_params if not given_logits else None,
                    validate_args=validate_args)
   zibern = ZeroInflated(count_distribution=bern,
                         logits=tf.reshape(rate_params, output_shape),
                         validate_args=validate_args)
   return Independent(zibern,
                      reinterpreted_batch_ndims=tf.size(input=event_shape),
                      name=name)
Esempio n. 5
0
def make_gaussian_out(p: tf.Tensor,
                      event_shape: Sequence[int]) -> Independent:
  loc, scale = tf.split(p, 2, -1)
  loc = tf.reshape(loc, (-1,) + tuple(event_shape))
  scale = tf.reshape(scale, (-1,) + tuple(event_shape))
  scale = tf.nn.softplus(scale)
  return Independent(Normal(loc=loc, scale=scale), len(event_shape))
Esempio n. 6
0
def observation(distribution: str):
    if distribution == 'qlogistic':
        n_params = 2
        obs = DistributionLambda(
            lambda params: QuantizedLogistic(
                *[
                    # loc
                    p if i == 0 else
                    # Ensure scales are positive and do not collapse to near-zero
                    tf.nn.softplus(p) + tf.cast(tf.exp(-7.), tf.float32)
                    for i, p in enumerate(tf.split(params, 2, -1))
                ],
                low=0,
                high=255,
                inputs_domain='sigmoid',
                reinterpreted_batch_ndims=3),
            convert_to_tensor_fn=Distribution.sample,
            name='image')
    elif distribution == 'bernoulli':
        n_params = 1
        obs = DistributionLambda(lambda params: Independent(
            Bernoulli(logits=params, dtype=tf.float32), len(IMAGE_SHAPE)),
                                 convert_to_tensor_fn=Distribution.sample,
                                 name='image')
    else:
        raise NotImplementedError
    return n_params, obs
Esempio n. 7
0
 def __init__(self, units, **kwargs):
     super().__init__(units,
                      posterior=NormalLayer,
                      posterior_kwargs=dict(scale_activation='softplus1'),
                      prior=Independent(
                          Normal(loc=tf.zeros(shape=units),
                                 scale=tf.ones(shape=units)), 1),
                      **kwargs)
Esempio n. 8
0
def MixtureQLogistic(
        locs: tf.Tensor,
        scales: tf.Tensor,
        logits: Optional[tf.Tensor] = None,
        probs: Optional[tf.Tensor] = None,
        batch_ndims: int = 0,
        low: int = 0,
        bits: int = 8,
        name: str = 'MixtureQuantizedLogistic') -> MixtureSameFamily:
    """ Mixture of quantized logistic distribution

  Parameters
  ----------
  locs : tf.Tensor
      locs of all logistics components, shape `[batch_size, n_components, event_size]`
  scales : tf.Tensor
      scales of all logistics components, shape `[batch_size, n_components, event_size]`
  logits, probs : tf.Tensor
      probability for the mixture Categorical distribution, shape `[batch_size, n_components]`
  low : int, optional
      minimum quantized value, by default 0
  bits : int, optional
      number of bits for quantization, the maximum will be `2^bits - 1`, by default 8
  name : str, optional
      distribution name, by default 'MixtureQuantizedLogistic'

  Returns
  -------
  MixtureSameFamily
      the mixture of quantized logistic distribution

  Example
  -------
  ```
  d = MixtureQLogistic(np.ones((12, 3, 8)).astype('float32'),
                       np.ones((12, 3, 8)).astype('float32'),
                       logits=np.random.rand(12, 3).astype('float32'),
                       batch_ndims=1)
  ```

  Reference
  ---------
  Salimans, T., Karpathy, A., Chen, X., Kingma, D.P., 2017.
    PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture
    Likelihood and Other Modifications. arXiv:1701.05517 [cs, stat].

  """
    cats = Categorical(probs=probs, logits=logits)
    dists = Logistic(loc=locs, scale=scales)
    dists = TransformedDistribution(
        distribution=dists, bijector=Shift(shift=tf.cast(-0.5, dists.dtype)))
    dists = QuantizedDistribution(dists, low=low, high=2**bits - 1.)
    dists = Independent(dists, reinterpreted_batch_ndims=batch_ndims)
    dists = MixtureSameFamily(mixture_distribution=cats,
                              components_distribution=dists,
                              name=name)
    return dists
    def __init__(self,
                 data,
                 bandwidth=0.01,
                 kernel=Normal,
                 use_grid=False,
                 use_fft=False,
                 num_grid_points=1024,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='KernelDensityEstimation'):

        components_distribution_generator = lambda loc, scale: Independent(
            kernel(loc=loc, scale=scale))

        with tf.name_scope(name) as name:
            self._use_fft = use_fft

            dtype = dtype_util.common_dtype([bandwidth, data], tf.float32)
            self._bandwidth = tensor_util.convert_nonref_to_tensor(
                bandwidth, name='bandwidth', dtype=dtype)
            self._data = tensor_util.convert_nonref_to_tensor(data,
                                                              name='data',
                                                              dtype=dtype)

            if (use_fft):
                self._grid = self._generate_grid(num_grid_points)
                self._grid_data = self._linear_binning()

                mixture_distribution = Categorical(probs=self._grid_data)
                components_distribution = components_distribution_generator(
                    loc=self._grid, scale=self._bandwidth)

            elif (use_grid):
                self._grid = self._generate_grid(num_grid_points)
                self._grid_data = self._linear_binning()

                mixture_distribution = Categorical(probs=self._grid_data)
                components_distribution = components_distribution_generator(
                    loc=self._grid, scale=self._bandwidth)

            else:
                self._grid = None
                self._grid_data = None
                n = self._data.shape[0]
                mixture_distribution = Categorical(probs=[1 / n] * n)
                components_distribution = components_distribution_generator(
                    loc=self._data, scale=self._bandwidth)

            super(KernelDensityEstimation, self).__init__(
                mixture_distribution=mixture_distribution,
                components_distribution=components_distribution,
                reparameterize=reparameterize,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name)
Esempio n. 10
0
def model_fullcov(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  zdims = int(np.prod(nets['latents'].event_shape))
  nets['latents'] = RVconf(
    event_shape=zdims,
    projection=True,
    posterior='mvntril',
    prior=Independent(Normal(tf.zeros([zdims]), tf.ones([zdims])), 1),
    name='latents').create_posterior()
  return VariationalAutoencoder(**nets, name='FullCov')
Esempio n. 11
0
 def new(dists):
     q_e, q_d = dists
     mu_e = q_e.mean()
     mu_d = q_d.mean()
     prec_e = 1 / q_e.variance()
     prec_d = 1 / q_d.variance()
     mu = (mu_e * prec_e + mu_d * prec_d) / (prec_e + prec_d)
     scale = tf.math.sqrt(1 / (prec_e + prec_d))
     dist = Normal(loc=mu, scale=scale)
     if isinstance(q_e, Independent):
         ndim = q_e.reinterpreted_batch_ndims
         dist = Independent(dist, reinterpreted_batch_ndims=ndim)
     return dist
Esempio n. 12
0
    def set_prior(self,
                  loc=0.,
                  log_scale=np.log(np.expm1(1)),
                  mixture_logits=None):
        r""" Set the prior for mixture density network

    loc : Scalar or Tensor with shape `[n_components, event_size]`
    log_scale : Scalar or Tensor with shape
      `[n_components, event_size]` for 'none' and 'diag' component, and
      `[n_components, event_size*(event_size +1)//2]` for 'full' component.
    mixture_logits : Scalar or Tensor with shape `[n_components]`
    """
        event_size = self.event_size
        if self.covariance == 'diag':
            scale_shape = [self.n_components, event_size]
            fn = lambda l, s: MultivariateNormalDiag(
                loc=l, scale_diag=tf.nn.softplus(s))
        elif self.covariance == 'none':
            scale_shape = [self.n_components, event_size]
            fn = lambda l, s: Independent(
                Normal(loc=l, scale=tf.math.softplus(s)), 1)
        elif self.covariance == 'full':
            scale_shape = [
                self.n_components, event_size * (event_size + 1) // 2
            ]
            fn = lambda l, s: MultivariateNormalTriL(
                loc=l,
                scale_tril=FillScaleTriL(diag_shift=1e-5)(tf.math.softplus(s)))
        #
        if isinstance(log_scale, Number) or tf.rank(log_scale) == 0:
            loc = tf.fill([self.n_components, self.event_size], loc)
        #
        if isinstance(log_scale, Number) or tf.rank(log_scale) == 0:
            log_scale = tf.fill(scale_shape, log_scale)
        #
        if mixture_logits is None:
            p = 1. / self.n_components
            mixture_logits = np.log(p / (1. - p))
        if isinstance(mixture_logits, Number) or tf.rank(mixture_logits) == 0:
            mixture_logits = tf.fill([self.n_components], mixture_logits)
        #
        loc = tf.cast(loc, self.dtype)
        log_scale = tf.cast(log_scale, self.dtype)
        mixture_logits = tf.cast(mixture_logits, self.dtype)
        self._prior = MixtureSameFamily(
            components_distribution=fn(loc, log_scale),
            mixture_distribution=Categorical(logits=mixture_logits),
            name="prior")
        return self
Esempio n. 13
0
 def __init__(self,
              loc: Union[tf.Tensor, np.ndarray],
              scale: Union[tf.Tensor, np.ndarray],
              low: Union[None, Number] = 0,
              high: Union[None, Number] = 2**8 - 1,
              inputs_domain: Literal['sigmoid', 'tanh',
                                     'pixel'] = 'sigmoid',
              reinterpreted_batch_ndims: Optional[int] = None,
              validate_args: bool = False,
              allow_nan_stats: bool = True,
              name: str = 'QuantizedLogistic'):
     parameters = dict(locals())
     with tf.name_scope(name) as name:
         dtype = dtype_util.common_dtype([loc, scale, low, high],
                                         dtype_hint=tf.float32)
         self._low = low
         self._high = high
         # Convert distribution parameters for pixel values in
         # `[self._low, self._high]` for use with `QuantizedDistribution`
         if low is not None and high is not None:
             support = 0.5 * (high - low)
             loc = low + support * (loc + 1.)
             scale = scale * support
         self._logistic = Logistic(loc=loc,
                                   scale=scale,
                                   validate_args=validate_args,
                                   allow_nan_stats=allow_nan_stats,
                                   name=name)
         self._dist = QuantizedDistribution(
             distribution=TransformedDistribution(
                 distribution=self._logistic,
                 bijector=Shift(tf.cast(-0.5, dtype=dtype))),
             low=low,
             high=high,
             validate_args=validate_args,
             name=name)
         if reinterpreted_batch_ndims is not None:
             self._dist = Independent(
                 self._dist,
                 reinterpreted_batch_ndims=reinterpreted_batch_ndims)
         self.inputs_domain = inputs_domain
         super(QuantizedLogistic,
               self).__init__(dtype=dtype,
                              reparameterization_type=NOT_REPARAMETERIZED,
                              validate_args=validate_args,
                              allow_nan_stats=allow_nan_stats,
                              parameters=parameters,
                              name=name)
Esempio n. 14
0
 def __init__(self,
              units: int,
              prior_loc: float = 0.,
              prior_scale: float = 1.,
              projection: bool = True,
              name: str = "Latents",
              **kwargs):
     super().__init__(
         event_shape=(int(units), ),
         posterior=NormalLayer,
         posterior_kwargs=dict(scale_activation='softplus'),
         prior=Independent(Normal(loc=tf.fill((units, ), prior_loc),
                                  scale=tf.fill((units, ), prior_scale)),
                           reinterpreted_batch_ndims=1),
         projection=projection,
         name=name,
         **kwargs,
     )
Esempio n. 15
0
 def encode(self,
            inputs,
            library=None,
            training=None,
            mask=None,
            sample_shape=(),
            **kwargs):
     qZ_X = super().encode(inputs=inputs,
                           library=library,
                           training=training,
                           mask=mask,
                           sample_shape=sample_shape)
     if library is not None:
         mean, var = tf.split(tf.nest.flatten(library)[0], 2, axis=1)
         pL = Independent(Normal(loc=mean, scale=tf.math.sqrt(var)), 1)
     else:
         pL = None
     qZ_X[-1].KL_divergence.prior = pL
     return qZ_X
Esempio n. 16
0
 def build(self, input_shape=None):
     super().build(input_shape)
     if self._disable:
         return
     decoder_shape = self.layer.compute_output_shape(input_shape)
     layer, layer_t = _NDIMS_CONV[self.input_ndim]
     # === 1. create projection layer
     assert self.encoder is not None, \
       'ParallelLatents require encoder to be specified'
     # posterior projection (assume encoder shape and decoder shape the same)
     self._conv_posterior = layer(**self._network_kw, name='ConvPosterior')
     self._conv_posterior.build(decoder_shape)
     # === 2. distribution
     params_shape = self._conv_posterior.compute_output_shape(decoder_shape)
     self._dist_posterior = DistributionLambda(
         make_distribution_fn=partial(_create_dist,
                                      event_ndims=len(params_shape) - 1,
                                      dtype=self.dtype),
         name=f'{self.name}_posterior')
     self._dist_posterior.build(params_shape)
     # dynamically infer the shape
     latents_shape = tf.convert_to_tensor(
         self._dist_posterior(keras.layers.Input(params_shape[1:]))).shape
     self._latents_shape = latents_shape[1:]
     # create the prior N(0,I)
     self._prior = Independent(Normal(loc=tf.zeros(self.latents_shape,
                                                   dtype=self.dtype),
                                      scale=tf.ones(self.latents_shape,
                                                    dtype=self.dtype)),
                               reinterpreted_batch_ndims=len(
                                   self.latents_shape),
                               name=f'{self.name}_prior')
     # === 3. final output affine
     self._conv_out = _upsample_by_conv(
         layer,
         layer_t,
         input_shape=latents_shape,
         output_shape=decoder_shape,
         kernel_size=self._conv_posterior.kernel_size,
         padding=self._conv_posterior.padding,
         strides=self._conv_posterior.strides)
Esempio n. 17
0
 def __init__(self,
              units: int,
              prior_loc: float = 0.,
              prior_scale: float = 1.,
              projection: bool = True,
              name: str = "Latents",
              **kwargs):
     # prior = MultivariateNormalDiag(loc=tf.fill((units,), prior_loc),
     #                                scale_identity_multiplier=prior_scale)
     super().__init__(
         event_shape=(int(units), ),
         posterior=MultivariateNormalLayer,
         posterior_kwargs=dict(covariance='diag',
                               scale_activation=tf.nn.softplus),
         prior=Independent(Normal(loc=tf.fill((units, ), prior_loc),
                                  scale=tf.fill((units, ), prior_scale)),
                           reinterpreted_batch_ndims=1),
         projection=projection,
         name=name,
         **kwargs,
     )
Esempio n. 18
0
 def new(params,
         event_shape=(),
         dtype=None,
         validate_args=False,
         continuous=False,
         lims=(0.499, 0.501),
         name='BernoulliLayer'):
   """Create the distribution instance from a `params` vector."""
   params = tf.convert_to_tensor(value=params, name='params')
   event_shape = dist_util.expand_to_vector(
     tf.convert_to_tensor(value=event_shape,
                          name='event_shape',
                          dtype_hint=tf.int32),
     tensor_name='event_shape',
   )
   new_shape = tf.concat(
     [tf.shape(input=params)[:-1], event_shape],
     axis=0,
   )
   if continuous:
     dist = ContinuousBernoulli(logits=tf.reshape(params, new_shape),
                                dtype=dtype or params.dtype.base_dtype,
                                lims=lims,
                                validate_args=validate_args)
   else:
     dist = Bernoulli(logits=tf.reshape(params, new_shape),
                      dtype=dtype or params.dtype.base_dtype,
                      validate_args=validate_args)
   dist = Independent(
     dist,
     reinterpreted_batch_ndims=tf.size(input=event_shape),
     name=name,
   )
   dist.logits = dist.distribution._logits  # pylint: disable=protected-access
   dist.probs = dist.distribution._probs  # pylint: disable=protected-access
   return dist
Esempio n. 19
0
  def __init__(self,
               count_distribution,
               inflated_distribution=None,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="ZeroInflated"):
    """Initialize a zero-inflated distribution.

    A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`,
    representing the probabilities of excess zeros) and a `Distribution` object
    having matching dtype, batch shape, event shape, and continuity
    properties (the dist).

    Parameters
    ----------
    count_distribution : A `tfp.distributions.Distribution` instance.
      The instance must have `batch_shape` matching the zero-inflation
      distribution.

    inflated_distribution: `tfp.distributions.Bernoulli`-like instance.
      Manages the probability of excess zeros, the zero-inflated rate.
      Must have either scalar `batch_shape` or `batch_shape` matching
      `count_distribution.batch_shape`.

    logits: An N-D `Tensor` representing the log-odds of a excess zeros
      A zero-inflation rate, where the probability of excess zeros is
      sigmoid(logits).
      Only one of `logits` or `probs` should be passed in.

    probs: An N-D `Tensor` representing the probability of a zero event.
      Each entry in the `Tensor` parameterizes an independent
      ZeroInflated distribution.
      Only one of `logits` or `probs` should be passed in.

    validate_args: Python `bool`, default `False`. If `True`, raise a runtime
      error if batch or event ranks are inconsistent between pi and any of
      the distributions. This is only checked if the ranks cannot be
      determined statically at graph construction time.

    allow_nan_stats: Boolean, default `True`. If `False`, raise an
     exception if a statistic (e.g. mean/mode/etc...) is undefined for any
      batch member. If `True`, batch members with valid parameters leading to
      undefined statistics will return NaN for this statistic.

    name: A name for this distribution (optional).

    References
    ----------
    Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings.
    Proceedings of the 34th International Conference on Machine Learning,
    in PMLR 70:2140-2148

    """
    parameters = dict(locals())
    self._runtime_assertions = []

    with tf.compat.v1.name_scope(name) as name:
      if not isinstance(count_distribution, distribution.Distribution):
        raise TypeError("count_distribution must be a Distribution instance"
                        " but saw: %s" % count_distribution)
      self._count_distribution = count_distribution

      if inflated_distribution is None:
        inflated_distribution = Bernoulli(
            logits=logits,
            probs=probs,
            dtype=tf.int32,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name="ZeroInflatedRate")
      elif not isinstance(inflated_distribution, distribution.Distribution):
        raise TypeError("inflated_distribution must be a Distribution instance"
                        " but saw: %s" % inflated_distribution)
      self._inflated_distribution = inflated_distribution

      if self._count_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from count_disttribution")
      if self._inflated_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from inflated_distribution")

      # create Independent Bernoulli distribution that the batch_shape
      # of count_distribution matching batch_shape of inflated_distribution
      inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims
      count_batch_ndims = self._count_distribution.batch_shape.ndims
      if count_batch_ndims < inflated_batch_ndims:
        self._inflated_distribution = Independent(
            self._inflated_distribution,
            reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims,
            name="ZeroInflatedRate")
      elif count_batch_ndims > inflated_batch_ndims:
        raise ValueError("count_distribution has %d-D batch_shape, which smaller"
          "than %d-D batch_shape of inflated_distribution" %
          (count_batch_ndims, inflated_batch_ndims))

      # Ensure that all batch and event ndims are consistent.
      if validate_args:
        self._runtime_assertions.append(
            tf.assert_equal(
                self._count_distribution.batch_shape_tensor(),
                self._inflated_distribution.batch_shape_tensor(),
                message=("dist batch shape must match logits|probs batch shape"))
        )

    # We let the zero-inflated distribution access _graph_parents since its arguably
    # more like a baseclass.
    reparameterization_type = [
        self._count_distribution.reparameterization_type,
        self._inflated_distribution.reparameterization_type]
    if any(i == reparameterization.NOT_REPARAMETERIZED
           for i in reparameterization_type):
      reparameterization_type = reparameterization.NOT_REPARAMETERIZED
    else:
      reparameterization_type = reparameterization.FULLY_REPARAMETERIZED

    super(ZeroInflated, self).__init__(
        dtype=self._count_distribution.dtype,
        reparameterization_type=reparameterization_type,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=self._count_distribution._graph_parents +
        self._inflated_distribution._graph_parents,
        name=name)
Esempio n. 20
0
class ZeroInflated(distribution.Distribution):
  """zero-inflated distribution.

  The `zero-inflated` object implements batched zero-inflated distributions.
  The zero-inflated model is defined by a zero-inflation rate
  and a python list of `Distribution` objects.

  Methods supported include `log_prob`, `prob`, `mean`, `sample`, and
  `entropy_lower_bound`.
  """

  def __init__(self,
               count_distribution,
               inflated_distribution=None,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="ZeroInflated"):
    """Initialize a zero-inflated distribution.

    A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`,
    representing the probabilities of excess zeros) and a `Distribution` object
    having matching dtype, batch shape, event shape, and continuity
    properties (the dist).

    Parameters
    ----------
    count_distribution : A `tfp.distributions.Distribution` instance.
      The instance must have `batch_shape` matching the zero-inflation
      distribution.

    inflated_distribution: `tfp.distributions.Bernoulli`-like instance.
      Manages the probability of excess zeros, the zero-inflated rate.
      Must have either scalar `batch_shape` or `batch_shape` matching
      `count_distribution.batch_shape`.

    logits: An N-D `Tensor` representing the log-odds of a excess zeros
      A zero-inflation rate, where the probability of excess zeros is
      sigmoid(logits).
      Only one of `logits` or `probs` should be passed in.

    probs: An N-D `Tensor` representing the probability of a zero event.
      Each entry in the `Tensor` parameterizes an independent
      ZeroInflated distribution.
      Only one of `logits` or `probs` should be passed in.

    validate_args: Python `bool`, default `False`. If `True`, raise a runtime
      error if batch or event ranks are inconsistent between pi and any of
      the distributions. This is only checked if the ranks cannot be
      determined statically at graph construction time.

    allow_nan_stats: Boolean, default `True`. If `False`, raise an
     exception if a statistic (e.g. mean/mode/etc...) is undefined for any
      batch member. If `True`, batch members with valid parameters leading to
      undefined statistics will return NaN for this statistic.

    name: A name for this distribution (optional).

    References
    ----------
    Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings.
    Proceedings of the 34th International Conference on Machine Learning,
    in PMLR 70:2140-2148

    """
    parameters = dict(locals())
    self._runtime_assertions = []

    with tf.compat.v1.name_scope(name) as name:
      if not isinstance(count_distribution, distribution.Distribution):
        raise TypeError("count_distribution must be a Distribution instance"
                        " but saw: %s" % count_distribution)
      self._count_distribution = count_distribution

      if inflated_distribution is None:
        inflated_distribution = Bernoulli(
            logits=logits,
            probs=probs,
            dtype=tf.int32,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name="ZeroInflatedRate")
      elif not isinstance(inflated_distribution, distribution.Distribution):
        raise TypeError("inflated_distribution must be a Distribution instance"
                        " but saw: %s" % inflated_distribution)
      self._inflated_distribution = inflated_distribution

      if self._count_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from count_disttribution")
      if self._inflated_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from inflated_distribution")

      # create Independent Bernoulli distribution that the batch_shape
      # of count_distribution matching batch_shape of inflated_distribution
      inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims
      count_batch_ndims = self._count_distribution.batch_shape.ndims
      if count_batch_ndims < inflated_batch_ndims:
        self._inflated_distribution = Independent(
            self._inflated_distribution,
            reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims,
            name="ZeroInflatedRate")
      elif count_batch_ndims > inflated_batch_ndims:
        raise ValueError("count_distribution has %d-D batch_shape, which smaller"
          "than %d-D batch_shape of inflated_distribution" %
          (count_batch_ndims, inflated_batch_ndims))

      # Ensure that all batch and event ndims are consistent.
      if validate_args:
        self._runtime_assertions.append(
            tf.assert_equal(
                self._count_distribution.batch_shape_tensor(),
                self._inflated_distribution.batch_shape_tensor(),
                message=("dist batch shape must match logits|probs batch shape"))
        )

    # We let the zero-inflated distribution access _graph_parents since its arguably
    # more like a baseclass.
    reparameterization_type = [
        self._count_distribution.reparameterization_type,
        self._inflated_distribution.reparameterization_type]
    if any(i == reparameterization.NOT_REPARAMETERIZED
           for i in reparameterization_type):
      reparameterization_type = reparameterization.NOT_REPARAMETERIZED
    else:
      reparameterization_type = reparameterization.FULLY_REPARAMETERIZED

    super(ZeroInflated, self).__init__(
        dtype=self._count_distribution.dtype,
        reparameterization_type=reparameterization_type,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=self._count_distribution._graph_parents +
        self._inflated_distribution._graph_parents,
        name=name)

  @property
  def logits(self):
    """Log-odds of a `1` outcome (vs `0`)."""
    if isinstance(self._inflated_distribution, Independent):
      return self._inflated_distribution.distribution.logits
    return self._inflated_distribution.logits

  @property
  def probs(self):
    """Probability of a `1` outcome (vs `0`)."""
    if isinstance(self._inflated_distribution, Independent):
      return self._inflated_distribution.distribution.probs
    return self._inflated_distribution.probs

  @property
  def count_distribution(self):
    return self._count_distribution

  @property
  def inflated_distribution(self):
    return self._inflated_distribution

  def _batch_shape_tensor(self):
    return self._count_distribution._batch_shape_tensor()

  def _batch_shape(self):
    return self._count_distribution._batch_shape()

  def _event_shape_tensor(self):
    return self._count_distribution._event_shape_tensor()

  def _event_shape(self):
    return self._count_distribution._event_shape()

  def _mean(self):
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      # These should all be the same shape by virtue of matching
      # batch_shape and event_shape.
      probs, d_mean = _broadcast_rate(self.probs, self._count_distribution.mean())
      return (1 - probs) * d_mean

  def _variance(self):
    """
    (1 - pi) * (d.var + d.mean^2) - [(1 - pi) * d.mean]^2

    Note: mean(ZeroInflated) = (1 - pi) * d.mean
    where:
     - pi is zero-inflated rate
     - d is count distribution
    """
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      # These should all be the same shape by virtue of matching
      # batch_shape and event_shape.
      d = self._count_distribution

      probs, d_mean, d_variance = _broadcast_rate(
          self.probs, d.mean(), d.variance())
      return (1 - probs) * \
      (d_variance + tf.square(d_mean)) - \
      tf.square(self._mean())

  def _log_prob(self, x):
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      x = tf.convert_to_tensor(x, name="x")
      d = self._count_distribution
      pi = self.probs

      d_prob = d.prob(x)
      d_log_prob = d.log_prob(x)

      # make pi and anything come out of count_distribution
      # broadcast-able
      pi, d_prob, d_log_prob = _broadcast_rate(pi, d_prob, d_log_prob)

      # This equation is validated
      # Equation (13) reference: u_{ij} = 1 - pi_{ij}
      y_0 = tf.log(pi + (1 - pi) * d_prob)
      y_1 = tf.log(1 - pi) + d_log_prob
      return tf.where(x == 0, y_0, y_1)

  def _prob(self, x):
    return tf.exp(self._log_prob(x))

  def _sample_n(self, n, seed):
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      seed = seed_stream.SeedStream(seed, salt="ZeroInflated")
      mask = self.inflated_distribution.sample(n, seed())
      samples = self.count_distribution.sample(n, seed())
      mask, samples = _broadcast_rate(mask, samples)
      # mask = 1 => new_sample = 0
      # mask = 0 => new_sample = sample
      return samples * tf.cast(1 - mask, samples.dtype)

  # ******************** shortcut for denoising ******************** #
  def denoised_mean(self):
    return self.count_distribution.mean()

  def denoised_variance(self):
    return self.count_distribution.variance()
Esempio n. 21
0
class ZeroInflated(distribution.Distribution):
  """Zero-inflated distribution.

  The `zero-inflated` object implements batched zero-inflated distributions.
  The zero-inflated model is defined by a zero-inflation rate
  and a python list of `Distribution` objects.

  Methods supported include `log_prob`, `prob`, `mean`, `sample`, and
  `entropy_lower_bound`.
  """

  def __init__(self,
               count_distribution,
               inflated_distribution=None,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="ZeroInflated"):
    """Initialize a zero-inflated distribution.

    A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`,
    representing the probabilities of excess zeros) and a `Distribution` object
    having matching dtype, batch shape, event shape, and continuity
    properties (the dist).

    Parameters
    ----------
    count_distribution : A `tfp.distributions.Distribution` instance.
      The instance must have `batch_shape` matching the zero-inflation
      distribution.

    inflated_distribution: `tfp.distributions.Bernoulli`-like instance.
      Manages the probability of excess zeros, the zero-inflated rate.
      Must have either scalar `batch_shape` or `batch_shape` matching
      `count_distribution.batch_shape`.

    logits: An N-D `Tensor` representing the log-odds of a excess zeros
      A zero-inflation rate, where the probability of excess zeros is
      sigmoid(logits).
      Only one of `logits` or `probs` should be passed in.

    probs: An N-D `Tensor` representing the probability of a zero event.
      Each entry in the `Tensor` parameterizes an independent
      ZeroInflated distribution.
      Only one of `logits` or `probs` should be passed in.

    validate_args: Python `bool`, default `False`. If `True`, raise a runtime
      error if batch or event ranks are inconsistent between pi and any of
      the distributions. This is only checked if the ranks cannot be
      determined statically at graph construction time.

    allow_nan_stats: Boolean, default `True`. If `False`, raise an
     exception if a statistic (e.g. mean/mode/etc...) is undefined for any
      batch member. If `True`, batch members with valid parameters leading to
      undefined statistics will return NaN for this statistic.

    name: A name for this distribution (optional).

    References
    ----------
    Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings.
    Proceedings of the 34th International Conference on Machine Learning,
    in PMLR 70:2140-2148

    """
    parameters = dict(locals())
    self._runtime_assertions = []

    with tf.compat.v1.name_scope(name) as name:
      if not isinstance(count_distribution, distribution.Distribution):
        raise TypeError("count_distribution must be a Distribution instance"
                        " but saw: %s" % count_distribution)
      self._count_distribution = count_distribution

      if inflated_distribution is None:
        inflated_distribution = Bernoulli(logits=logits,
                                          probs=probs,
                                          dtype=tf.int32,
                                          validate_args=validate_args,
                                          allow_nan_stats=allow_nan_stats,
                                          name="ZeroInflatedRate")
      elif not isinstance(inflated_distribution, distribution.Distribution):
        raise TypeError("inflated_distribution must be a Distribution instance"
                        " but saw: %s" % inflated_distribution)
      self._inflated_distribution = inflated_distribution

      if self._count_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from count_disttribution")
      if self._inflated_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from inflated_distribution")

      # create Independent Bernoulli distribution that the batch_shape
      # of count_distribution matching batch_shape of inflated_distribution
      inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims
      count_batch_ndims = self._count_distribution.batch_shape.ndims
      if count_batch_ndims < inflated_batch_ndims:
        self._inflated_distribution = Independent(
            self._inflated_distribution,
            reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims,
            name="ZeroInflatedRate")
      elif count_batch_ndims > inflated_batch_ndims:
        raise ValueError(
            "count_distribution has %d-D batch_shape, which smaller"
            "than %d-D batch_shape of inflated_distribution" %
            (count_batch_ndims, inflated_batch_ndims))

      # Ensure that all batch and event ndims are consistent.
      if validate_args:
        self._runtime_assertions.append(
            tf.assert_equal(
                self._count_distribution.batch_shape_tensor(),
                self._inflated_distribution.batch_shape_tensor(),
                message=(
                    "dist batch shape must match logits|probs batch shape")))

    # We let the zero-inflated distribution access _graph_parents since its arguably
    # more like a baseclass.
    reparameterization_type = [
        self._count_distribution.reparameterization_type,
        self._inflated_distribution.reparameterization_type
    ]
    if any(i == reparameterization.NOT_REPARAMETERIZED
           for i in reparameterization_type):
      reparameterization_type = reparameterization.NOT_REPARAMETERIZED
    else:
      reparameterization_type = reparameterization.FULLY_REPARAMETERIZED

    super(ZeroInflated,
          self).__init__(dtype=self._count_distribution.dtype,
                         reparameterization_type=reparameterization_type,
                         validate_args=validate_args,
                         allow_nan_stats=allow_nan_stats,
                         parameters=parameters,
                         graph_parents=self._count_distribution._graph_parents +
                         self._inflated_distribution._graph_parents,
                         name=name)

  @property
  def logits(self):
    """Log-odds of a `1` outcome (vs `0`)."""
    if isinstance(self._inflated_distribution, Independent):
      return self._inflated_distribution.distribution.logits_parameter()
    return self._inflated_distribution.logits_parameter()

  @property
  def probs(self):
    """Probability of a `1` outcome (vs `0`)."""
    if isinstance(self._inflated_distribution, Independent):
      return self._inflated_distribution.distribution.probs_parameter()
    return self._inflated_distribution.probs_parameter()

  @property
  def count_distribution(self):
    return self._count_distribution

  @property
  def inflated_distribution(self):
    return self._inflated_distribution

  def _batch_shape_tensor(self):
    return self._count_distribution._batch_shape_tensor()

  def _batch_shape(self):
    return self._count_distribution._batch_shape()

  def _event_shape_tensor(self):
    return self._count_distribution._event_shape_tensor()

  def _event_shape(self):
    return self._count_distribution._event_shape()

  def _mean(self):
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      # These should all be the same shape by virtue of matching
      # batch_shape and event_shape.
      probs, d_mean = _broadcast_rate(self.probs,
                                      self._count_distribution.mean())
      return (1 - probs) * d_mean

  def _variance(self):
    """
    (1 - pi) * (d.var + d.mean^2) - [(1 - pi) * d.mean]^2

    Note: mean(ZeroInflated) = (1 - pi) * d.mean
    where:
     - pi is zero-inflated rate
     - d is count distribution
    """
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      # These should all be the same shape by virtue of matching
      # batch_shape and event_shape.
      d = self._count_distribution

      probs, d_mean, d_variance = _broadcast_rate(self.probs, d.mean(),
                                                  d.variance())
      return (1 - probs) * \
      (d_variance + tf.square(d_mean)) - \
      tf.math.square(self._mean())

  def _log_prob(self, x):
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      eps = tf.cast(1e-8, x.dtype)
      x = tf.convert_to_tensor(x, name="x")
      d = self._count_distribution
      pi = self.probs

      log_prob = d.log_prob(x)
      prob = tf.math.exp(log_prob)

      # make pi and anything come out of count_distribution
      # broadcast-able
      pi, prob, log_prob = _broadcast_rate(pi, prob, log_prob)

      # This equation is validated
      # Equation (13) reference: u_{ij} = 1 - pi_{ij}
      y_0 = tf.math.log(pi + (1 - pi) * prob)
      y_1 = tf.math.log(1 - pi) + log_prob
      return tf.where(x <= eps, y_0, y_1)

  def _prob(self, x):
    return tf.math.exp(self._log_prob(x))

  def _sample_n(self, n, seed):
    with tf.compat.v1.control_dependencies(self._runtime_assertions):
      seed = SeedStream(seed, salt="ZeroInflated")
      mask = self.inflated_distribution.sample(n, seed())
      samples = self.count_distribution.sample(n, seed())
      mask, samples = _broadcast_rate(mask, samples)
      # mask = 1 => new_sample = 0
      # mask = 0 => new_sample = sample
      return samples * tf.cast(1 - mask, samples.dtype)

  # ******************** shortcut for denoising ******************** #
  def denoised_mean(self):
    return self.count_distribution.mean()

  def denoised_variance(self):
    return self.count_distribution.variance()
Esempio n. 22
0
  def __init__(self,
               count_distribution,
               inflated_distribution=None,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="ZeroInflated"):
    """Initialize a zero-inflated distribution.

    A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`,
    representing the probabilities of excess zeros) and a `Distribution` object
    having matching dtype, batch shape, event shape, and continuity
    properties (the dist).

    Parameters
    ----------
    count_distribution : A `tfp.distributions.Distribution` instance.
      The instance must have `batch_shape` matching the zero-inflation
      distribution.

    inflated_distribution: `tfp.distributions.Bernoulli`-like instance.
      Manages the probability of excess zeros, the zero-inflated rate.
      Must have either scalar `batch_shape` or `batch_shape` matching
      `count_distribution.batch_shape`.

    logits: An N-D `Tensor` representing the log-odds of a excess zeros
      A zero-inflation rate, where the probability of excess zeros is
      sigmoid(logits).
      Only one of `logits` or `probs` should be passed in.

    probs: An N-D `Tensor` representing the probability of a zero event.
      Each entry in the `Tensor` parameterizes an independent
      ZeroInflated distribution.
      Only one of `logits` or `probs` should be passed in.

    validate_args: Python `bool`, default `False`. If `True`, raise a runtime
      error if batch or event ranks are inconsistent between pi and any of
      the distributions. This is only checked if the ranks cannot be
      determined statically at graph construction time.

    allow_nan_stats: Boolean, default `True`. If `False`, raise an
     exception if a statistic (e.g. mean/mode/etc...) is undefined for any
      batch member. If `True`, batch members with valid parameters leading to
      undefined statistics will return NaN for this statistic.

    name: A name for this distribution (optional).

    References
    ----------
    Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings.
    Proceedings of the 34th International Conference on Machine Learning,
    in PMLR 70:2140-2148

    """
    parameters = dict(locals())
    self._runtime_assertions = []

    with tf.compat.v1.name_scope(name) as name:
      if not isinstance(count_distribution, distribution.Distribution):
        raise TypeError("count_distribution must be a Distribution instance"
                        " but saw: %s" % count_distribution)
      self._count_distribution = count_distribution

      if inflated_distribution is None:
        inflated_distribution = Bernoulli(logits=logits,
                                          probs=probs,
                                          dtype=tf.int32,
                                          validate_args=validate_args,
                                          allow_nan_stats=allow_nan_stats,
                                          name="ZeroInflatedRate")
      elif not isinstance(inflated_distribution, distribution.Distribution):
        raise TypeError("inflated_distribution must be a Distribution instance"
                        " but saw: %s" % inflated_distribution)
      self._inflated_distribution = inflated_distribution

      if self._count_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from count_disttribution")
      if self._inflated_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from inflated_distribution")

      # create Independent Bernoulli distribution that the batch_shape
      # of count_distribution matching batch_shape of inflated_distribution
      inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims
      count_batch_ndims = self._count_distribution.batch_shape.ndims
      if count_batch_ndims < inflated_batch_ndims:
        self._inflated_distribution = Independent(
            self._inflated_distribution,
            reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims,
            name="ZeroInflatedRate")
      elif count_batch_ndims > inflated_batch_ndims:
        raise ValueError(
            "count_distribution has %d-D batch_shape, which smaller"
            "than %d-D batch_shape of inflated_distribution" %
            (count_batch_ndims, inflated_batch_ndims))

      # Ensure that all batch and event ndims are consistent.
      if validate_args:
        self._runtime_assertions.append(
            tf.assert_equal(
                self._count_distribution.batch_shape_tensor(),
                self._inflated_distribution.batch_shape_tensor(),
                message=(
                    "dist batch shape must match logits|probs batch shape")))

    # We let the zero-inflated distribution access _graph_parents since its arguably
    # more like a baseclass.
    reparameterization_type = [
        self._count_distribution.reparameterization_type,
        self._inflated_distribution.reparameterization_type
    ]
    if any(i == reparameterization.NOT_REPARAMETERIZED
           for i in reparameterization_type):
      reparameterization_type = reparameterization.NOT_REPARAMETERIZED
    else:
      reparameterization_type = reparameterization.FULLY_REPARAMETERIZED

    super(ZeroInflated,
          self).__init__(dtype=self._count_distribution.dtype,
                         reparameterization_type=reparameterization_type,
                         validate_args=validate_args,
                         allow_nan_stats=allow_nan_stats,
                         parameters=parameters,
                         graph_parents=self._count_distribution._graph_parents +
                         self._inflated_distribution._graph_parents,
                         name=name)
Esempio n. 23
0
def _create_dist(params, event_ndims, dtype):
    loc, scale = tf.split(params, 2, axis=-1)
    scale = tf.nn.softplus(scale) + tf.cast(tf.exp(-7.), dtype)
    d = Normal(loc, scale)
    d = Independent(d, reinterpreted_batch_ndims=event_ndims)
    return d
                np.isclose(p.numpy(),
                           tf.concat((p1, p2), axis=0).numpy()))
    except NotImplementedError:
        pass


shape = (8, 2)
count = np.random.randint(0, 20, size=shape).astype('float32')
probs = np.random.rand(*shape).astype('float32')
logits = np.random.rand(*shape).astype('float32')

assert_consistent_statistics(Bernoulli(probs=probs), Bernoulli(logits=logits))
assert_consistent_statistics(Bernoulli(logits=logits),
                             Bernoulli(logits=logits))
assert_consistent_statistics(
    Independent(Bernoulli(probs=probs), reinterpreted_batch_ndims=1),
    Independent(Bernoulli(logits=logits), reinterpreted_batch_ndims=1))

assert_consistent_statistics(
    NegativeBinomial(total_count=count, logits=logits),
    NegativeBinomial(total_count=count, probs=probs))
assert_consistent_statistics(
    Independent(NegativeBinomial(total_count=count, logits=logits),
                reinterpreted_batch_ndims=1),
    Independent(NegativeBinomial(total_count=count, probs=probs),
                reinterpreted_batch_ndims=1))
assert_consistent_statistics(
    ZeroInflated(NegativeBinomial(total_count=count, logits=logits),
                 logits=logits),
    ZeroInflated(NegativeBinomial(total_count=count, probs=probs),
                 probs=probs))
Esempio n. 25
0
    def __init__(self,
                 count_distribution,
                 inflated_distribution=None,
                 logits=None,
                 probs=None,
                 eps=1e-8,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="ZeroInflated"):
        r"""Initialize a zero-inflated distribution.

    A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`,
    representing the probabilities of excess zeros) and a `Distribution` object
    having matching dtype, batch shape, event shape, and continuity
    properties (the dist).

    Arguments:
      count_distribution : A `tfp.distributions.Distribution` instance.
        The instance must have `batch_shape` matching the zero-inflation
        distribution.
      inflated_distribution: `tfp.distributions.Bernoulli`-like instance.
        Manages the probability of excess zeros, the zero-inflated rate.
        Must have either scalar `batch_shape` or `batch_shape` matching
        `count_distribution.batch_shape`.
      logits: An N-D `Tensor` representing the log-odds of a excess zeros
        A zero-inflation rate, where the probability of excess zeros is
        sigmoid(logits).
        Only one of `logits` or `probs` should be passed in.
      probs: An N-D `Tensor` representing the probability of a zero event.
        Each entry in the `Tensor` parameterizes an independent
        ZeroInflated distribution.
        Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between pi and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution (optional).

    References:
      Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings.
        Proceedings of the 34th International Conference on Machine Learning,
        in PMLR 70:2140-2148

    """
        parameters = dict(locals())
        with tf.compat.v1.name_scope(name) as name:
            # main count distribution
            if not isinstance(count_distribution, distribution.Distribution):
                raise TypeError(
                    "count_distribution must be a Distribution instance"
                    " but saw: %s" % count_distribution)
            # Zero inflation distribution
            if inflated_distribution is None:
                inflated_distribution = Bernoulli(
                    logits=logits,
                    probs=probs,
                    dtype=tf.int32,
                    validate_args=validate_args,
                    allow_nan_stats=allow_nan_stats,
                    name="ZeroInflatedRate")
            elif not isinstance(inflated_distribution,
                                distribution.Distribution):
                raise TypeError(
                    "inflated_distribution must be a Distribution instance"
                    " but saw: %s" % inflated_distribution)
            # Matching the event shape
            inflated_ndim = len(inflated_distribution.event_shape)
            count_ndim = len(count_distribution.event_shape)
            if inflated_ndim < count_ndim:
                inflated_distribution = Independent(inflated_distribution,
                                                    count_ndim - inflated_ndim)
            self._count_distribution = count_distribution
            self._inflated_distribution = inflated_distribution
            #
            if self._count_distribution.batch_shape.ndims is None:
                raise ValueError(
                    "Expected to know rank(batch_shape) from count_distribution"
                )
            if self._inflated_distribution.batch_shape.ndims is None:
                raise ValueError(
                    "Expected to know rank(batch_shape) from inflated_distribution"
                )
            self._eps = tensor_util.convert_nonref_to_tensor(
                eps, dtype_hint=count_distribution.dtype, name='eps')
        # We let the zero-inflated distribution access _graph_parents since its arguably
        # more like a baseclass.
        super(ZeroInflated, self).__init__(
            dtype=self._count_distribution.dtype,
            reparameterization_type=self._count_distribution.
            reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name,
        )
Esempio n. 26
0
def model_gmmvae3(args: Arguments):
  zdim = args.zdim
  prior = Independent(Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])), 1)
  return GMMVAE(n_components=10, prior=prior, analytic=False,
                **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                               is_semi_supervised=False))
Esempio n. 27
0
 def __init__(self,
              loc,
              scale,
              logits=None,
              probs=None,
              covariance_type='diag',
              trainable=False,
              validate_args=False,
              allow_nan_stats=True,
              name=None):
     kw = dict(validate_args=validate_args, allow_nan_stats=allow_nan_stats)
     self._trainable = bool(trainable)
     self._llk_history = []
     if trainable:
         loc = tf.Variable(loc, trainable=True, name='loc')
         scale = tf.Variable(scale, trainable=True, name='scale')
         if logits is not None:
             logits = tf.Variable(logits, trainable=True, name='logits')
         if probs is not None:
             probs = tf.Variable(probs, trainable=True, name='probs')
     ### initialize mixture Categorical
     mixture = Categorical(logits=logits,
                           probs=probs,
                           name="MixtureWeights",
                           **kw)
     n_components = mixture._num_categories()
     ### initialize Gaussian components
     covariance_type = str(covariance_type).lower().strip()
     if name is None:
         name = 'Mixture%sGaussian' % \
           (covariance_type.capitalize() if covariance_type != 'none' else
            'Independent')
     ## create the components
     if covariance_type == 'diag':
         if tf.rank(scale) == 0:  # scalar
             extra_kw = dict(scale_identity_multiplier=scale)
         else:  # a tensor
             extra_kw = dict(scale_diag=scale)
         components = MultivariateNormalDiag(loc=loc,
                                             name=name,
                                             **kw,
                                             **extra_kw)
     elif covariance_type in ('tril', 'full'):
         if tf.rank(scale) == 1 or \
           (scale.shape[-1] != scale.shape[-2]):
             scale_tril = FillScaleTriL(diag_shift=np.array(
                 1e-5,
                 tf.convert_to_tensor(scale).dtype.as_numpy_dtype()))
             scale = scale_tril(scale)
         components = MultivariateNormalTriL(loc=loc,
                                             scale_tril=scale,
                                             name=name,
                                             **kw)
     elif covariance_type == 'none':
         components = Independent(distribution=Normal(loc=loc,
                                                      scale=scale,
                                                      **kw),
                                  reinterpreted_batch_ndims=1,
                                  name=name)
     else:
         raise ValueError("No support for covariance_type: '%s'" %
                          covariance_type)
     ### validate the n_components
     assert (components.batch_shape[-1] == int(n_components)), \
       "Number of components mismatch, given:%d, mixture:%d, components:%d" % \
         (mixture.event_shape[-1], components.batch_shape[-1], int(n_components))
     super().__init__(mixture_distribution=mixture,
                      components_distribution=components,
                      name=name,
                      **kw)