예제 #1
0
 def new(params,
         event_shape=(),
         given_logits=True,
         validate_args=False,
         name='ZIBernoulliLayer'):
   """Create the distribution instance from a `params` vector."""
   params = tf.convert_to_tensor(value=params, name='params')
   event_shape = dist_util.expand_to_vector(
     tf.convert_to_tensor(value=event_shape,
                          name='event_shape',
                          dtype=tf.int32),
     tensor_name='event_shape',
   )
   output_shape = tf.concat(
     [tf.shape(input=params)[:-1], event_shape],
     axis=0,
   )
   (bernoulli_params, rate_params) = tf.split(params, 2, axis=-1)
   bernoulli_params = tf.reshape(bernoulli_params, output_shape)
   bern = Bernoulli(logits=bernoulli_params if given_logits else None,
                    probs=bernoulli_params if not given_logits else None,
                    validate_args=validate_args)
   zibern = ZeroInflated(count_distribution=bern,
                         logits=tf.reshape(rate_params, output_shape),
                         validate_args=validate_args)
   return Independent(zibern,
                      reinterpreted_batch_ndims=tf.size(input=event_shape),
                      name=name)
예제 #2
0
def observation(distribution: str):
    if distribution == 'qlogistic':
        n_params = 2
        obs = DistributionLambda(
            lambda params: QuantizedLogistic(
                *[
                    # loc
                    p if i == 0 else
                    # Ensure scales are positive and do not collapse to near-zero
                    tf.nn.softplus(p) + tf.cast(tf.exp(-7.), tf.float32)
                    for i, p in enumerate(tf.split(params, 2, -1))
                ],
                low=0,
                high=255,
                inputs_domain='sigmoid',
                reinterpreted_batch_ndims=3),
            convert_to_tensor_fn=Distribution.sample,
            name='image')
    elif distribution == 'bernoulli':
        n_params = 1
        obs = DistributionLambda(lambda params: Independent(
            Bernoulli(logits=params, dtype=tf.float32), len(IMAGE_SHAPE)),
                                 convert_to_tensor_fn=Distribution.sample,
                                 name='image')
    else:
        raise NotImplementedError
    return n_params, obs
예제 #3
0
파일: utils.py 프로젝트: trungnt13/odin-ai
 def llk_pixels(model: VariationalModel, valid_ds: tf.data.Dataset):
     llk = []
     for x, y in valid_ds.take(5):
         px, _ = _call(model, x, y, decode=True)
         px = as_tuple(px)[0]
         if hasattr(px, 'distribution'):
             px = px.distribution
         if isinstance(px, Bernoulli):
             px = Bernoulli(logits=px.logits)
         elif isinstance(px, Normal):
             px = Normal(loc=px.loc, scale=px.scale)
         elif isinstance(px, QuantizedLogistic):
             px = QuantizedLogistic(loc=px.loc,
                                    scale=px.scale,
                                    low=px.low,
                                    high=px.high,
                                    inputs_domain=px.inputs_domain,
                                    reinterpreted_batch_ndims=None)
         else:
             return  # nothing to do
         llk.append(px.log_prob(x))
     # average over all channels
     llk_image = tf.reduce_mean(tf.reduce_mean(tf.concat(llk, 0), axis=0),
                                axis=-1)
     llk = tf.reshape(llk_image, -1)
     tf.summary.histogram('valid/llk_pixels', llk, step=model.step)
     # show the image heatmap of llk pixels
     fig = plt.figure(figsize=(3, 3))
     ax = plt.gca()
     im = ax.pcolormesh(llk_image.numpy(),
                        cmap='Spectral',
                        vmin=np.min(llk),
                        vmax=np.max(llk))
     ax.axis('off')
     ax.margins(0.)
     # color bar
     ticks = np.linspace(np.min(llk), np.max(llk), 5)
     cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02, ticks=ticks)
     cbar.ax.set_yticklabels([f'{i:.2f}' for i in ticks])
     cbar.ax.tick_params(labelsize=6)
     plt.tight_layout()
     tf.summary.image('llk_heatmap', vs.plot_to_image(fig, dpi=100))
예제 #4
0
 def elbo_components(self, inputs, training=None, mask=None, **kwargs):
   px, qz = self(inputs, training=training, mask=mask)
   # === 1. reconstructed information
   llk = {}
   for p, x in zip(as_tuple(px), as_tuple(inputs)):
     name = p.name.split('_')[1]
     if hasattr(p, 'distribution'):
       p = p.distribution
     if isinstance(p, Bernoulli):
       p = Bernoulli(logits=p.logits)
     elif isinstance(px, Normal):
       p = Normal(loc=p.loc, scale=p.scale)
     elif isinstance(p, QuantizedLogistic):
       p = QuantizedLogistic(loc=p.loc, scale=p.scale,
                             low=p.low, high=p.high,
                             inputs_domain=p.inputs_domain,
                             reinterpreted_batch_ndims=None)
     lk = p.log_prob(x)
     if self.R != 0.:
       lk = tf.minimum(lk, self.R)
     lk = tf.reduce_sum(lk, tf.range(1, x.shape.rank))
     llk[f'llk_{name}'] = lk
   # === 2. latent capacity
   kl = {}
   for q in as_tuple(qz):
     name = q.name.split('_')[1]
     kl_q = q.KL_divergence(analytic=self.analytic)
     if self.C > 0:
       zdim = int(np.prod(q.event_shape))
       C = tf.constant(self.C * zdim, dtype=self.dtype)
       if self.random_capacity:
         C = C * tf.random.uniform(shape=[], minval=0., maxval=1.,
                                   dtype=self.dtype)
       kl_q = tf.math.abs(kl_q - C)
     kl_q = self.beta * kl_q
     kl[f'kl_{name}'] = kl_q
   return llk, kl
예제 #5
0
 def new(params,
         event_shape=(),
         dtype=None,
         validate_args=False,
         continuous=False,
         lims=(0.499, 0.501),
         name='BernoulliLayer'):
   """Create the distribution instance from a `params` vector."""
   params = tf.convert_to_tensor(value=params, name='params')
   event_shape = dist_util.expand_to_vector(
     tf.convert_to_tensor(value=event_shape,
                          name='event_shape',
                          dtype_hint=tf.int32),
     tensor_name='event_shape',
   )
   new_shape = tf.concat(
     [tf.shape(input=params)[:-1], event_shape],
     axis=0,
   )
   if continuous:
     dist = ContinuousBernoulli(logits=tf.reshape(params, new_shape),
                                dtype=dtype or params.dtype.base_dtype,
                                lims=lims,
                                validate_args=validate_args)
   else:
     dist = Bernoulli(logits=tf.reshape(params, new_shape),
                      dtype=dtype or params.dtype.base_dtype,
                      validate_args=validate_args)
   dist = Independent(
     dist,
     reinterpreted_batch_ndims=tf.size(input=event_shape),
     name=name,
   )
   dist.logits = dist.distribution._logits  # pylint: disable=protected-access
   dist.probs = dist.distribution._probs  # pylint: disable=protected-access
   return dist
예제 #6
0
    def __init__(self,
                 count_distribution,
                 inflated_distribution=None,
                 logits=None,
                 probs=None,
                 eps=1e-8,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="ZeroInflated"):
        r"""Initialize a zero-inflated distribution.

    A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`,
    representing the probabilities of excess zeros) and a `Distribution` object
    having matching dtype, batch shape, event shape, and continuity
    properties (the dist).

    Arguments:
      count_distribution : A `tfp.distributions.Distribution` instance.
        The instance must have `batch_shape` matching the zero-inflation
        distribution.
      inflated_distribution: `tfp.distributions.Bernoulli`-like instance.
        Manages the probability of excess zeros, the zero-inflated rate.
        Must have either scalar `batch_shape` or `batch_shape` matching
        `count_distribution.batch_shape`.
      logits: An N-D `Tensor` representing the log-odds of a excess zeros
        A zero-inflation rate, where the probability of excess zeros is
        sigmoid(logits).
        Only one of `logits` or `probs` should be passed in.
      probs: An N-D `Tensor` representing the probability of a zero event.
        Each entry in the `Tensor` parameterizes an independent
        ZeroInflated distribution.
        Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between pi and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution (optional).

    References:
      Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings.
        Proceedings of the 34th International Conference on Machine Learning,
        in PMLR 70:2140-2148

    """
        parameters = dict(locals())
        with tf.compat.v1.name_scope(name) as name:
            # main count distribution
            if not isinstance(count_distribution, distribution.Distribution):
                raise TypeError(
                    "count_distribution must be a Distribution instance"
                    " but saw: %s" % count_distribution)
            # Zero inflation distribution
            if inflated_distribution is None:
                inflated_distribution = Bernoulli(
                    logits=logits,
                    probs=probs,
                    dtype=tf.int32,
                    validate_args=validate_args,
                    allow_nan_stats=allow_nan_stats,
                    name="ZeroInflatedRate")
            elif not isinstance(inflated_distribution,
                                distribution.Distribution):
                raise TypeError(
                    "inflated_distribution must be a Distribution instance"
                    " but saw: %s" % inflated_distribution)
            # Matching the event shape
            inflated_ndim = len(inflated_distribution.event_shape)
            count_ndim = len(count_distribution.event_shape)
            if inflated_ndim < count_ndim:
                inflated_distribution = Independent(inflated_distribution,
                                                    count_ndim - inflated_ndim)
            self._count_distribution = count_distribution
            self._inflated_distribution = inflated_distribution
            #
            if self._count_distribution.batch_shape.ndims is None:
                raise ValueError(
                    "Expected to know rank(batch_shape) from count_distribution"
                )
            if self._inflated_distribution.batch_shape.ndims is None:
                raise ValueError(
                    "Expected to know rank(batch_shape) from inflated_distribution"
                )
            self._eps = tensor_util.convert_nonref_to_tensor(
                eps, dtype_hint=count_distribution.dtype, name='eps')
        # We let the zero-inflated distribution access _graph_parents since its arguably
        # more like a baseclass.
        super(ZeroInflated, self).__init__(
            dtype=self._count_distribution.dtype,
            reparameterization_type=self._count_distribution.
            reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            name=name,
        )
예제 #7
0
  def __init__(self,
               count_distribution,
               inflated_distribution=None,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="ZeroInflated"):
    """Initialize a zero-inflated distribution.

    A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`,
    representing the probabilities of excess zeros) and a `Distribution` object
    having matching dtype, batch shape, event shape, and continuity
    properties (the dist).

    Parameters
    ----------
    count_distribution : A `tfp.distributions.Distribution` instance.
      The instance must have `batch_shape` matching the zero-inflation
      distribution.

    inflated_distribution: `tfp.distributions.Bernoulli`-like instance.
      Manages the probability of excess zeros, the zero-inflated rate.
      Must have either scalar `batch_shape` or `batch_shape` matching
      `count_distribution.batch_shape`.

    logits: An N-D `Tensor` representing the log-odds of a excess zeros
      A zero-inflation rate, where the probability of excess zeros is
      sigmoid(logits).
      Only one of `logits` or `probs` should be passed in.

    probs: An N-D `Tensor` representing the probability of a zero event.
      Each entry in the `Tensor` parameterizes an independent
      ZeroInflated distribution.
      Only one of `logits` or `probs` should be passed in.

    validate_args: Python `bool`, default `False`. If `True`, raise a runtime
      error if batch or event ranks are inconsistent between pi and any of
      the distributions. This is only checked if the ranks cannot be
      determined statically at graph construction time.

    allow_nan_stats: Boolean, default `True`. If `False`, raise an
     exception if a statistic (e.g. mean/mode/etc...) is undefined for any
      batch member. If `True`, batch members with valid parameters leading to
      undefined statistics will return NaN for this statistic.

    name: A name for this distribution (optional).

    References
    ----------
    Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings.
    Proceedings of the 34th International Conference on Machine Learning,
    in PMLR 70:2140-2148

    """
    parameters = dict(locals())
    self._runtime_assertions = []

    with tf.compat.v1.name_scope(name) as name:
      if not isinstance(count_distribution, distribution.Distribution):
        raise TypeError("count_distribution must be a Distribution instance"
                        " but saw: %s" % count_distribution)
      self._count_distribution = count_distribution

      if inflated_distribution is None:
        inflated_distribution = Bernoulli(logits=logits,
                                          probs=probs,
                                          dtype=tf.int32,
                                          validate_args=validate_args,
                                          allow_nan_stats=allow_nan_stats,
                                          name="ZeroInflatedRate")
      elif not isinstance(inflated_distribution, distribution.Distribution):
        raise TypeError("inflated_distribution must be a Distribution instance"
                        " but saw: %s" % inflated_distribution)
      self._inflated_distribution = inflated_distribution

      if self._count_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from count_disttribution")
      if self._inflated_distribution.batch_shape.ndims is None:
        raise ValueError(
            "Expected to know rank(batch_shape) from inflated_distribution")

      # create Independent Bernoulli distribution that the batch_shape
      # of count_distribution matching batch_shape of inflated_distribution
      inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims
      count_batch_ndims = self._count_distribution.batch_shape.ndims
      if count_batch_ndims < inflated_batch_ndims:
        self._inflated_distribution = Independent(
            self._inflated_distribution,
            reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims,
            name="ZeroInflatedRate")
      elif count_batch_ndims > inflated_batch_ndims:
        raise ValueError(
            "count_distribution has %d-D batch_shape, which smaller"
            "than %d-D batch_shape of inflated_distribution" %
            (count_batch_ndims, inflated_batch_ndims))

      # Ensure that all batch and event ndims are consistent.
      if validate_args:
        self._runtime_assertions.append(
            tf.assert_equal(
                self._count_distribution.batch_shape_tensor(),
                self._inflated_distribution.batch_shape_tensor(),
                message=(
                    "dist batch shape must match logits|probs batch shape")))

    # We let the zero-inflated distribution access _graph_parents since its arguably
    # more like a baseclass.
    reparameterization_type = [
        self._count_distribution.reparameterization_type,
        self._inflated_distribution.reparameterization_type
    ]
    if any(i == reparameterization.NOT_REPARAMETERIZED
           for i in reparameterization_type):
      reparameterization_type = reparameterization.NOT_REPARAMETERIZED
    else:
      reparameterization_type = reparameterization.FULLY_REPARAMETERIZED

    super(ZeroInflated,
          self).__init__(dtype=self._count_distribution.dtype,
                         reparameterization_type=reparameterization_type,
                         validate_args=validate_args,
                         allow_nan_stats=allow_nan_stats,
                         parameters=parameters,
                         graph_parents=self._count_distribution._graph_parents +
                         self._inflated_distribution._graph_parents,
                         name=name)
예제 #8
0
            p1 = getattr(d1, name)
            p2 = getattr(d2, name)
            p = getattr(d, name)
            assert np.all(
                np.isclose(p.numpy(),
                           tf.concat((p1, p2), axis=0).numpy()))
    except NotImplementedError:
        pass


shape = (8, 2)
count = np.random.randint(0, 20, size=shape).astype('float32')
probs = np.random.rand(*shape).astype('float32')
logits = np.random.rand(*shape).astype('float32')

assert_consistent_statistics(Bernoulli(probs=probs), Bernoulli(logits=logits))
assert_consistent_statistics(Bernoulli(logits=logits),
                             Bernoulli(logits=logits))
assert_consistent_statistics(
    Independent(Bernoulli(probs=probs), reinterpreted_batch_ndims=1),
    Independent(Bernoulli(logits=logits), reinterpreted_batch_ndims=1))

assert_consistent_statistics(
    NegativeBinomial(total_count=count, logits=logits),
    NegativeBinomial(total_count=count, probs=probs))
assert_consistent_statistics(
    Independent(NegativeBinomial(total_count=count, logits=logits),
                reinterpreted_batch_ndims=1),
    Independent(NegativeBinomial(total_count=count, probs=probs),
                reinterpreted_batch_ndims=1))
assert_consistent_statistics(