def dtc_loss(self, qZ_X, qZ_Xprime=None, training=None): r""" Discriminated total correlation loss Algorithm(2) Minimize the probability of: - `q(z)` misclassified as `D(z)[:, 0]` - `q(z')` misclassified as `D(z')[:, 1]` Arguments: qZ_X : `Tensor` or `Distribution`. Samples of the latents from first batch qZ_Xprime : `Tensor` or `Distribution` (optional). Samples of the latents from second batch, this will be permuted. If not given, then reuse `qZ_X`. Return: scalar - loss value for training the discriminator """ # we don't want the gradient to be propagated to the encoder z = self._to_samples(qZ_X) z = tf.stop_gradient(z) z_logits = self(z, training=training) d_z = -tf.nn.log_softmax(z_logits, axis=-1) # must be negative here # if qZ_Xprime is not None: z = self._to_samples(qZ_Xprime) z = tf.stop_gradient(z) z_perm = permute_dims(z) zperm_logits = self(z_perm, training=training) d_zperm = -tf.nn.log_softmax(zperm_logits, axis=-1) # reduce the negative of d_z, and the positive of d_zperm # this equal to cross_entropy(d_z, zeros) + cross_entropy(d_zperm, ones) loss = 0.5 * (tf.reduce_mean(d_z[..., 0]) + tf.reduce_mean(d_zperm[..., -1])) return loss
def _elbo(self, X, pX_Z, qZ_X, analytic, reverse, sample_shape, mask, training): # don't take KL of qC_X llk, div = super()._elbo(X, pX_Z, qZ_X, analytic, reverse, sample_shape) z_prime = [permute_dims(q) for q in qZ_X] pX_Zprime = self.decode(z_prime, training=training) qZ_Xprime = self.encode(pX_Zprime, training=training) div['mmd'] = self.gamma * maximum_mean_discrepancy( qZ=qZ_Xprime, pZ=qZ_X[0].KL_divergence.prior, q_sample_shape=None, p_sample_shape=100) return llk, div
def dtc_loss(self, qZ_X: Distribution, qZ_Xprime: Optional[Distribution] = None, training: Optional[bool] = None) -> tf.Tensor: r""" Discriminated total correlation loss Algorithm(2) Minimize the probability of: - `q(z)` misclassified as `D(z)[:, 0]` - `q(z')` misclassified as `D(z')[:, 1]` Arguments: qZ_X : `Tensor` or `Distribution`. Samples of the latents from first batch qZ_Xprime : `Tensor` or `Distribution` (optional). Samples of the latents from second batch, this will be permuted. If not given, then reuse `qZ_X`. Return: scalar - loss value for training the discriminator """ # we don't want the gradient to be propagated to the encoder z = self._to_samples(qZ_X, stop_grad=True) z_logits = self._tc_logits(self(z, training=training)) # using log_softmax function give more numerical stabalized results than # logsumexp yourself. d_z = -tf.math.log_sigmoid(z_logits) # must be negative here # for X_prime if qZ_Xprime is not None: z = self._to_samples(qZ_Xprime, stop_grad=True) z_perm = permute_dims(z) zperm_logits = self._tc_logits(self(z_perm, training=training)) d_zperm = -tf.math.log_sigmoid(zperm_logits) # also negative here # reduce the negative of d_z, and the positive of d_zperm # this equal to cross_entropy(d_z, zeros) + cross_entropy(d_zperm, ones) loss = 0.5 * (tf.reduce_mean(d_z) + tf.reduce_mean(zperm_logits + d_zperm)) return loss