Пример #1
0
    def dtc_loss(self, qZ_X, qZ_Xprime=None, training=None):
        r""" Discriminated total correlation loss Algorithm(2)

    Minimize the probability of:
     - `q(z)` misclassified as `D(z)[:, 0]`
     - `q(z')` misclassified as `D(z')[:, 1]`

    Arguments:
      qZ_X : `Tensor` or `Distribution`.
        Samples of the latents from first batch
      qZ_Xprime : `Tensor` or `Distribution` (optional).
        Samples of the latents from second batch, this will be permuted.
        If not given, then reuse `qZ_X`.

    Return:
      scalar - loss value for training the discriminator
    """
        # we don't want the gradient to be propagated to the encoder
        z = self._to_samples(qZ_X)
        z = tf.stop_gradient(z)
        z_logits = self(z, training=training)
        d_z = -tf.nn.log_softmax(z_logits, axis=-1)  # must be negative here
        #
        if qZ_Xprime is not None:
            z = self._to_samples(qZ_Xprime)
            z = tf.stop_gradient(z)
        z_perm = permute_dims(z)
        zperm_logits = self(z_perm, training=training)
        d_zperm = -tf.nn.log_softmax(zperm_logits, axis=-1)
        # reduce the negative of d_z, and the positive of d_zperm
        # this equal to cross_entropy(d_z, zeros) + cross_entropy(d_zperm, ones)
        loss = 0.5 * (tf.reduce_mean(d_z[..., 0]) +
                      tf.reduce_mean(d_zperm[..., -1]))
        return loss
Пример #2
0
 def _elbo(self, X, pX_Z, qZ_X, analytic, reverse, sample_shape, mask,
           training):
   # don't take KL of qC_X
   llk, div = super()._elbo(X, pX_Z, qZ_X, analytic, reverse, sample_shape)
   z_prime = [permute_dims(q) for q in qZ_X]
   pX_Zprime = self.decode(z_prime, training=training)
   qZ_Xprime = self.encode(pX_Zprime, training=training)
   div['mmd'] = self.gamma * maximum_mean_discrepancy(
       qZ=qZ_Xprime,
       pZ=qZ_X[0].KL_divergence.prior,
       q_sample_shape=None,
       p_sample_shape=100)
   return llk, div
Пример #3
0
    def dtc_loss(self,
                 qZ_X: Distribution,
                 qZ_Xprime: Optional[Distribution] = None,
                 training: Optional[bool] = None) -> tf.Tensor:
        r""" Discriminated total correlation loss Algorithm(2)

    Minimize the probability of:
     - `q(z)` misclassified as `D(z)[:, 0]`
     - `q(z')` misclassified as `D(z')[:, 1]`

    Arguments:
      qZ_X : `Tensor` or `Distribution`.
        Samples of the latents from first batch
      qZ_Xprime : `Tensor` or `Distribution` (optional).
        Samples of the latents from second batch, this will be permuted.
        If not given, then reuse `qZ_X`.

    Return:
      scalar - loss value for training the discriminator
    """
        # we don't want the gradient to be propagated to the encoder
        z = self._to_samples(qZ_X, stop_grad=True)
        z_logits = self._tc_logits(self(z, training=training))
        # using log_softmax function give more numerical stabalized results than
        # logsumexp yourself.
        d_z = -tf.math.log_sigmoid(z_logits)  # must be negative here
        # for X_prime
        if qZ_Xprime is not None:
            z = self._to_samples(qZ_Xprime, stop_grad=True)
        z_perm = permute_dims(z)
        zperm_logits = self._tc_logits(self(z_perm, training=training))
        d_zperm = -tf.math.log_sigmoid(zperm_logits)  # also negative here
        # reduce the negative of d_z, and the positive of d_zperm
        # this equal to cross_entropy(d_z, zeros) + cross_entropy(d_zperm, ones)
        loss = 0.5 * (tf.reduce_mean(d_z) +
                      tf.reduce_mean(zperm_logits + d_zperm))
        return loss