Exemplo n.º 1
0
 def _log_prob(self, value: tf.Tensor):
     """ expect `value` is output from ELU function """
     params = self._params
     transformed_value, value = _switch_domain(
         value,
         inputs_domain=self.inputs_domain,
         low=self.low,
         high=self.high)
     ## prepare the parameters
     if self.n_channels == 1:
         component_logits, locs, scales = params
     else:
         channel_tensors = tf.split(transformed_value,
                                    self.n_channels,
                                    axis=-1)
         # If there is more than one channel, we create a linear autoregressive
         # dependency among the location parameters of the channels of a single
         # pixel (the scale parameters within a pixel are independent). For a pixel
         # with R/G/B channels, the `r`, `g`, and `b` saturation values are
         # distributed as:
         #
         # r ~ Logistic(loc_r, scale_r)
         # g ~ Logistic(coef_rg * r + loc_g, scale_g)
         # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
         component_logits, locs, scales, coeffs = params
         loc_tensors = tf.split(locs, self.n_channels, axis=-1)
         coef_tensors = tf.split(coeffs, self.n_coeffs, axis=-1)
         coef_count = 0
         for i in range(self.n_channels):
             channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :]
             for j in range(i):
                 loc_tensors[
                     i] += channel_tensors[j] * coef_tensors[coef_count]
                 coef_count += 1
         locs = tf.concat(loc_tensors, axis=-1)
     ## create the distrubtion
     mixture_distribution = Categorical(logits=component_logits)
     # Convert distribution parameters for pixel values in
     # `[self._low, self._high]` for use with `QuantizedDistribution`
     locs = self.low + 0.5 * (self.high - self.low) * (locs + 1.)
     scales = scales * 0.5 * (self.high - self.low)
     logistic_dist = QuantizedDistribution(
         distribution=TransformedDistribution(
             distribution=Logistic(loc=locs, scale=scales),
             bijector=Shift(shift=tf.cast(-0.5, self.dtype))),
         low=self.low,
         high=self.high,
     )
     dist = MixtureSameFamily(mixture_distribution=mixture_distribution,
                              components_distribution=Independent(
                                  logistic_dist,
                                  reinterpreted_batch_ndims=1))
     dist = Independent(dist, reinterpreted_batch_ndims=2)
     return dist.log_prob(value)
Exemplo n.º 2
0
 def _mean(self, **kwargs):
     params = self._params
     if self.n_channels == 1:
         component_logits, locs, scales = params
     else:
         # r ~ Logistic(loc_r, scale_r)
         # g ~ Logistic(coef_rg * r + loc_g, scale_g)
         # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
         component_logits, locs, scales, coeffs = params
         loc_tensors = tf.split(locs, self.n_channels, axis=-1)
         coef_tensors = tf.split(coeffs, self.n_coeffs, axis=-1)
         coef_count = 0
         for i in range(self.n_channels):
             for j in range(i):
                 loc_tensors[i] += loc_tensors[j] * coef_tensors[coef_count]
                 coef_count += 1
         locs = tf.concat(loc_tensors, axis=-1)
     ## create the distrubtion
     mixture_distribution = Categorical(logits=component_logits)
     # Convert distribution parameters for pixel values in
     # `[self._low, self._high]` for use with `QuantizedDistribution`
     locs = self.low + 0.5 * (self.high - self.low) * (locs + 1.)
     scales = scales * 0.5 * (self.high - self.low)
     logistic_dist = TransformedDistribution(
         distribution=Logistic(loc=locs, scale=scales),
         bijector=Shift(shift=tf.cast(-0.5, self.dtype)))
     dist = MixtureSameFamily(mixture_distribution=mixture_distribution,
                              components_distribution=Independent(
                                  logistic_dist,
                                  reinterpreted_batch_ndims=1))
     mean = Independent(dist, reinterpreted_batch_ndims=2).mean()
     ## normalize the data back to the input domain
     return _pixels_to(mean, self.inputs_domain, self.low, self.high)
Exemplo n.º 3
0
def MixtureQLogistic(
        locs: tf.Tensor,
        scales: tf.Tensor,
        logits: Optional[tf.Tensor] = None,
        probs: Optional[tf.Tensor] = None,
        batch_ndims: int = 0,
        low: int = 0,
        bits: int = 8,
        name: str = 'MixtureQuantizedLogistic') -> MixtureSameFamily:
    """ Mixture of quantized logistic distribution

  Parameters
  ----------
  locs : tf.Tensor
      locs of all logistics components, shape `[batch_size, n_components, event_size]`
  scales : tf.Tensor
      scales of all logistics components, shape `[batch_size, n_components, event_size]`
  logits, probs : tf.Tensor
      probability for the mixture Categorical distribution, shape `[batch_size, n_components]`
  low : int, optional
      minimum quantized value, by default 0
  bits : int, optional
      number of bits for quantization, the maximum will be `2^bits - 1`, by default 8
  name : str, optional
      distribution name, by default 'MixtureQuantizedLogistic'

  Returns
  -------
  MixtureSameFamily
      the mixture of quantized logistic distribution

  Example
  -------
  ```
  d = MixtureQLogistic(np.ones((12, 3, 8)).astype('float32'),
                       np.ones((12, 3, 8)).astype('float32'),
                       logits=np.random.rand(12, 3).astype('float32'),
                       batch_ndims=1)
  ```

  Reference
  ---------
  Salimans, T., Karpathy, A., Chen, X., Kingma, D.P., 2017.
    PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture
    Likelihood and Other Modifications. arXiv:1701.05517 [cs, stat].

  """
    cats = Categorical(probs=probs, logits=logits)
    dists = Logistic(loc=locs, scale=scales)
    dists = TransformedDistribution(
        distribution=dists, bijector=Shift(shift=tf.cast(-0.5, dists.dtype)))
    dists = QuantizedDistribution(dists, low=low, high=2**bits - 1.)
    dists = Independent(dists, reinterpreted_batch_ndims=batch_ndims)
    dists = MixtureSameFamily(mixture_distribution=cats,
                              components_distribution=dists,
                              name=name)
    return dists
  def test_pad_mixture_dimensions_mixture_same_family(self):
    gm = MixtureSameFamily(
        mixture_distribution=Categorical(probs=[0.3, 0.7]),
        components_distribution=MultivariateNormalDiag(
            loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5]))

    x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
    x_pad = distribution_util.pad_mixture_dimensions(
        x, gm, gm.mixture_distribution, tensorshape_util.rank(gm.event_shape))
    x_out, x_pad_out = self.evaluate([x, x_pad])

    self.assertAllEqual(x_pad_out.shape, [2, 2, 1])
    self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
Exemplo n.º 5
0
    def set_prior(self,
                  loc=0.,
                  log_scale=np.log(np.expm1(1)),
                  mixture_logits=None):
        r""" Set the prior for mixture density network

    loc : Scalar or Tensor with shape `[n_components, event_size]`
    log_scale : Scalar or Tensor with shape
      `[n_components, event_size]` for 'none' and 'diag' component, and
      `[n_components, event_size*(event_size +1)//2]` for 'full' component.
    mixture_logits : Scalar or Tensor with shape `[n_components]`
    """
        event_size = self.event_size
        if self.covariance == 'diag':
            scale_shape = [self.n_components, event_size]
            fn = lambda l, s: MultivariateNormalDiag(
                loc=l, scale_diag=tf.nn.softplus(s))
        elif self.covariance == 'none':
            scale_shape = [self.n_components, event_size]
            fn = lambda l, s: Independent(
                Normal(loc=l, scale=tf.math.softplus(s)), 1)
        elif self.covariance == 'full':
            scale_shape = [
                self.n_components, event_size * (event_size + 1) // 2
            ]
            fn = lambda l, s: MultivariateNormalTriL(
                loc=l,
                scale_tril=FillScaleTriL(diag_shift=1e-5)(tf.math.softplus(s)))
        #
        if isinstance(log_scale, Number) or tf.rank(log_scale) == 0:
            loc = tf.fill([self.n_components, self.event_size], loc)
        #
        if isinstance(log_scale, Number) or tf.rank(log_scale) == 0:
            log_scale = tf.fill(scale_shape, log_scale)
        #
        if mixture_logits is None:
            p = 1. / self.n_components
            mixture_logits = np.log(p / (1. - p))
        if isinstance(mixture_logits, Number) or tf.rank(mixture_logits) == 0:
            mixture_logits = tf.fill([self.n_components], mixture_logits)
        #
        loc = tf.cast(loc, self.dtype)
        log_scale = tf.cast(log_scale, self.dtype)
        mixture_logits = tf.cast(mixture_logits, self.dtype)
        self._prior = MixtureSameFamily(
            components_distribution=fn(loc, log_scale),
            mixture_distribution=Categorical(logits=mixture_logits),
            name="prior")
        return self
Exemplo n.º 6
0
def model_gmmprior(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  latent_size = np.prod(nets['latents'].event_shape)
  n_components = 100
  loc = tf.compat.v1.get_variable(name="loc", shape=[n_components, latent_size])
  raw_scale_diag = tf.compat.v1.get_variable(
    name="raw_scale_diag", shape=[n_components, latent_size])
  mixture_logits = tf.compat.v1.get_variable(
    name="mixture_logits", shape=[n_components])
  nets['latents'].prior = MixtureSameFamily(
    components_distribution=MultivariateNormalDiag(
      loc=loc,
      scale_diag=tf.nn.softplus(raw_scale_diag) + tf.math.exp(-7.)),
    mixture_distribution=Categorical(logits=mixture_logits),
    name="prior")
  return VariationalAutoencoder(**nets, name='GMMPrior')
Exemplo n.º 7
0
def model_fullcovgmm(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  latent_size = int(np.prod(nets['latents'].event_shape))
  n_components = 100
  loc = tf.compat.v1.get_variable(name="loc", shape=[n_components, latent_size])
  raw_scale_diag = tf.compat.v1.get_variable(
    name="raw_scale_diag", shape=[n_components, latent_size])
  mixture_logits = tf.compat.v1.get_variable(
    name="mixture_logits", shape=[n_components])
  nets['latents'] = RVconf(
    event_shape=latent_size,
    projection=True,
    posterior='mvntril',
    prior=MixtureSameFamily(
      components_distribution=MultivariateNormalDiag(
        loc=loc,
        scale_diag=tf.nn.softplus(raw_scale_diag) + tf.math.exp(-7.)),
      mixture_distribution=Categorical(logits=mixture_logits),
      name="prior"),
    name='latents').create_posterior()
  return VariationalAutoencoder(**nets, name='FullCov')