Beispiel #1
0
 def elbo_components(self, inputs, training=None, mask=None):
     ## unsupervised ELBO
     X, y, mask = prepare_ssl_inputs(inputs,
                                     mask=mask,
                                     n_unsupervised_inputs=1)
     if mask is not None:
         mask = tf.reshape(mask, (-1, ))
     llk, kl = super(AnnealingVAE, self).elbo_components(X[0],
                                                         mask=mask,
                                                         training=training)
     P, Q = self.last_outputs
     px_z = P[:-1]
     py_z = P[-1]
     Q = as_tuple(Q)  # q(z|x)
     ## supervised loss
     llk[f"llk_{self.labels.name}"] = _get_llk_y(py_z, y, mask, self.alpha)
     ## MI objective
     self.labels = self.labels_q
     mi_y, mi_z = self._mi_loss(Q, py_z, training=training, mask=mask)
     self.labels = self.labels_p
     ## maximizing the MI
     llk[f'mi_{self.labels.name}'] = mi_y
     for z, mi in zip(as_tuple(self.latents), mi_z):
         llk[f'mi_{z.name}'] = mi
     return llk, kl
Beispiel #2
0
 def elbo_components(self, inputs, training=None, mask=None):
     ## unsupervised ELBO
     X, y, mask = prepare_ssl_inputs(inputs,
                                     mask=mask,
                                     n_unsupervised_inputs=1)
     if mask is not None:
         mask = tf.reshape(mask, (-1, ))
     llk, kl = super(AnnealingVAE,
                     self).elbo_components(X[0],
                                           mask=mask,
                                           training=training)
     P, Q = self.last_outputs
     kl[f'kl_{self.latents_y.name}'] = self.beta * Q[-1].KL_divergence(
         analytic=self.analytic,
         free_bits=self.free_bits,
         reverse=self.reverse)
     py_z = P[-1]
     ## supervised loss
     llk[f"llk_{self.labels.name}"] = _get_llk_y(
         py_z, y, mask, self.alpha)
     ## MI objective
     mi_y, mi_z = self._mi_loss(Q,
                                py_z,
                                training=training,
                                mask=mask,
                                which_latents_sampling=[1])
     ## maximizing the MI
     llk[f'mi_{self.labels.name}'] = mi_y
     for z, mi in zip(as_tuple(self.latents), mi_z):
         llk[f'mi_{z.name}'] = mi
     return llk, kl
Beispiel #3
0
 def encode(self, inputs, training=None, mask=None, **kwargs):
     X, y, mask = prepare_ssl_inputs(inputs,
                                     mask=mask,
                                     n_unsupervised_inputs=1)
     # don't condition on the labels, only accept inputs
     X = X[0]
     qz_x = super().encode(X, training=training, mask=None, **kwargs)
     return qz_x
Beispiel #4
0
 def encode(self, inputs, training=None, mask=None, **kwargs):
   X, y, mask = prepare_ssl_inputs(inputs, mask=mask, n_unsupervised_inputs=1)
   X = X[0]  # only accept single inputs now
   # encode normally
   h_x = self.encoder(X, training=training, mask=mask)
   qz_x = self.latents(h_x, training=training, mask=mask)
   qzc_x = self.denotations(h_x, training=training, mask=mask)
   # prepare the label embedding
   z_c = tf.convert_to_tensor(qzc_x)
   qy_zx = self.classify(z_c, training=training)
   return (qz_x, qzc_x, qy_zx)
Beispiel #5
0
 def encode(self, inputs, training=None, mask=None, **kwargs):
   X, y, mask = prepare_ssl_inputs(inputs, mask=mask, n_unsupervised_inputs=1)
   X = X[0]  # only accept single inputs now
   # prepare the label embedding
   qy_x = self.classify(X, training=training)
   h_y = self.y_to_qz(qy_x, training=training)
   # encode normally
   h_x = self.encoder(X, training=training, mask=mask)
   h_x = self.flatten(h_x)
   h_x = self.x_to_qz(h_x, training=training)
   # combine into q(z|xy)
   h_xy = self.concat([h_x, h_y])
   # conditional embedding y
   h_xy = self.xy_to_qz_net(h_xy, training=training, mask=mask)
   qz_xy = self.latents(h_xy, training=training, mask=mask)
   qz_xy.qy_x = qy_x
   return qz_xy
Beispiel #6
0
def _prepare_elbo(self, inputs, training=None, mask=None):
  X, y, mask = prepare_ssl_inputs(inputs, mask=mask, n_unsupervised_inputs=1)
  X_u, X_l, y_l = split_ssl_inputs(X, y, mask)
  # for simplication only 1 inputs and 1 labels are supported
  X_u, X_l = X_u[0], X_l[0]
  if len(y_l) > 0:
    y_l = y_l[0]
  else:
    y_l = None
  # marginalize the unsupervised data
  if self.marginalize:
    X_u, y_u = marginalize_categorical_labels(
      X=X_u,
      n_classes=self.n_classes,
      dtype=self.dtype,
    )
  else:
    y_u = None
  return X_u, y_u, X_l, y_l
Beispiel #7
0
 def elbo_components(self, inputs, training=None, mask=None):
     X, y, mask = prepare_ssl_inputs(inputs,
                                     mask=mask,
                                     n_unsupervised_inputs=1)
     X = X[0]
     mask = tf.reshape(mask, (-1, ))
     X_u = tf.boolean_mask(X, tf.logical_not(mask), axis=0)
     X_l = tf.boolean_mask(X, mask, axis=0)
     y_l = tf.boolean_mask(y[0], mask, axis=0)
     ## supervised
     llk_l, kl_l = super(AnnealingVAE,
                         self).elbo_components(X_l, training=training)
     P_l, Q_l = self.last_outputs
     ## unsupervised
     llk_u, kl_u = super(AnnealingVAE,
                         self).elbo_components(X_u, training=training)
     P_u, Q_u = self.last_outputs
     ## merge the losses
     llk = {}
     for k, v in llk_l.items():
         llk[k] = tf.concat([v, llk_u[k]], axis=0)
     kl = {}
     for k, v in kl_l.items():
         kl[k] = tf.concat([v, kl_u[k]], axis=0)
     ## supervised loss
     py_z = P_l[-1]
     llk[f"llk_{self.labels.name}"] = tf.reduce_mean(self.alpha *
                                                     py_z.log_prob(y_l))
     ## minimizing D(q(y|z_u)||p(y|z_l)) objective
     # calculate the pair-wise distance between q(y|z) and p(y|z)
     qy_z = P_u[-1]
     y = tf.convert_to_tensor(qy_z)
     tf.assert_equal(
         tf.shape(X_u), tf.shape(X_l),
         'Require number of labeled examples equal unlabeled examples')
     kl[f'kl_{self.labels.name}'] = self.alpha * tf.reduce_mean(
         qy_z.log_prob(y) - py_z.log_prob(y))
     # llk_q = tf.expand_dims(qy_z.log_prob(y), axis=-1)
     # llk_p = py_z.log_prob(tf.expand_dims(y, axis=-2))
     # mi_y = tf.reduce_mean(llk_q - llk_p)
     # kl[f'kl_{self.labels.name}'] = self.mi_coef * mi_y
     ## return
     return llk, kl
Beispiel #8
0
 def encode(self, inputs, training=None, mask=None, **kwargs):
   X, y, mask = prepare_ssl_inputs(inputs, mask=mask, n_unsupervised_inputs=1)
   X = X[0]  # only accept single inputs now
   # prepare the label embedding
   qy_x = self.classify(X, training=training)
   h_y = self.y_to_qz(qy_x, training=training)
   # encode normally
   h_x = self.encoder(X, training=training, mask=mask)
   h_x = bk.flatten(h_x, n_outdim=2)
   h_x = self.x_to_qz(h_x, training=training)
   # combine into q(z|xy)
   h_xy = h_x + h_y
   if self.batchnorm:
     h_xy = self.qz_xy_norm(h_xy, training=training)
   if 0.0 < self.dropout < 1.0:
     h_xy = self.qz_xy_drop(h_xy, training=training)
   # conditional embedding y
   h_xy = self.xy_to_qz_net(h_xy, training=training, mask=mask)
   qz_xy = self.latents(h_xy, training=training, mask=mask)
   return (qz_xy, qy_x)
Beispiel #9
0
 def encode(self, inputs, training=None, mask=None, **kwargs):
     X, y, mask = prepare_ssl_inputs(inputs,
                                     mask=mask,
                                     n_unsupervised_inputs=1)
     # don't condition on the labels, only accept inputs
     X = X[0]
     h_e = super().encode(X,
                          training=training,
                          mask=None,
                          only_encoding=True,
                          **kwargs)
     qz_x = self.latents(h_e,
                         training=training,
                         mask=None,
                         sample_shape=self.sample_shape)
     if self.encoder_y is not None:
         # tied encoder
         h_y = h_e if len(self.encoder_y.layers) == 1 else X
         qzy_x = self.encoder_y(h_y, training=training, mask=None)
         return as_tuple(qz_x) + (qzy_x, )
     return qz_x
Beispiel #10
0
 def encode(self, inputs, training=None, mask=None, **kwargs):
     X, y, mask = prepare_ssl_inputs(inputs,
                                     mask=mask,
                                     n_unsupervised_inputs=1)
     return super().encode(X[0], training=training, mask=None, **kwargs)