def _build(self, unused_input=None):
    # pylint:disable=unused-variable,possibly-unused-variable
    # Reason:
    #   All endpoints are stored as attribute at the end of `_build`.
    #   Pylint cannot infer this case so it emits false alarm of
    #   unused-variable if we do not disable this warning.

    # pylint:disable=invalid-name
    # Reason:
    #   Following variables have their name consider to be invalid by pylint so
    #   we disable the warning.
    #   - Variable that is class
    #   - Variable that in its name has A or B indicating their belonging of
    #     one side of data.

    # ---------------------------------------------------------------------
    # ## Extract parameters from config
    # ---------------------------------------------------------------------

    config = self.config
    lr = config.get('lr', 3e-4)
    n_latent_shared = config['n_latent_shared']

    if 'n_latent' in config:
      n_latent_A = n_latent_B = config['n_latent']
    else:
      n_latent_A = config['vae_A']['n_latent']
      n_latent_B = config['vae_B']['n_latent']

    # ---------------------------------------------------------------------
    # ## VAE containing Modules with parameters
    # ---------------------------------------------------------------------
    vae_A = VAE(config['vae_A'], name='vae_A')
    vae_A()
    vae_B = VAE(config['vae_B'], name='vae_B')
    vae_B()

    vae_lr = tf.constant(lr)
    vae_vars = vae_A.vae_vars + vae_B.vae_vars
    vae_loss = vae_A.vae_loss + vae_B.vae_loss
    train_vae = tf.train.AdamOptimizer(learning_rate=vae_lr).minimize(
        vae_loss, var_list=vae_vars)
    vae_saver = tf.train.Saver(vae_vars, max_to_keep=100)

    # ---------------------------------------------------------------------
    # ## Computation Flow
    # ---------------------------------------------------------------------

    # Tensor Endpoints
    x_A = vae_A.x
    x_B = vae_B.x
    q_z_sample_A = vae_A.q_z_sample
    q_z_sample_B = vae_B.q_z_sample
    mu_A, sigma_A = vae_A.mu, vae_A.sigma
    mu_B, sigma_B = vae_B.mu, vae_B.sigma
    x_prime_A = vae_A.x_prime
    x_prime_B = vae_B.x_prime
    x_from_prior_A = vae_A.x_from_prior
    x_from_prior_B = vae_B.x_from_prior
    x_A_to_B = vae_B.decoder(q_z_sample_A)
    x_B_to_A = vae_A.decoder(q_z_sample_B)
    x_A_to_B_direct = vae_B.decoder(mu_A)
    x_B_to_A_direct = vae_A.decoder(mu_B)
    z_hat = tf.placeholder(tf.float32, shape=(None, n_latent_shared))
    x_joint_A = vae_A.decoder(z_hat)
    x_joint_B = vae_B.decoder(z_hat)

    vae_loss_A = vae_A.vae_loss
    vae_loss_B = vae_B.vae_loss

    x_align_A = tf.placeholder(tf.float32, shape=(None, n_latent_A))
    x_align_B = tf.placeholder(tf.float32, shape=(None, n_latent_B))
    mu_align_A, sigma_align_A = vae_A.encoder(x_align_A)
    mu_align_B, sigma_align_B = vae_B.encoder(x_align_B)
    q_z_align_A = ds.Normal(loc=mu_align_A, scale=sigma_align_A)
    q_z_align_B = ds.Normal(loc=mu_align_B, scale=sigma_align_B)

    # VI in joint space

    mu_align, sigma_align = nn.product_two_guassian_pdfs(
        mu_align_A, sigma_align_A, mu_align_B, sigma_align_B)
    q_z_align = ds.Normal(loc=mu_align, scale=sigma_align)
    p_z_align = ds.Normal(loc=0., scale=1.)

    # - KL
    KL_qp_align = ds.kl_divergence(q_z_align, p_z_align)
    KL_align = tf.reduce_sum(KL_qp_align, axis=-1)
    mean_KL_align = tf.reduce_mean(KL_align)
    prior_loss_align = mean_KL_align
    prior_loss_align_beta = config.get('prior_loss_align_beta', 0.0)
    scaled_prior_loss_align = prior_loss_align * prior_loss_align_beta

    # - Reconstruction (from joint Gussian)
    q_z_sample_align = q_z_align.sample()
    x_prime_A_align = vae_A.decoder(q_z_sample_align)
    x_prime_B_align = vae_B.decoder(q_z_sample_align)

    mean_recons_A_align = tf.reduce_mean(tf.square(x_prime_A_align - x_align_A))
    mean_recons_B_align = tf.reduce_mean(tf.square(x_prime_B_align - x_align_B))
    mean_recons_A_align_beta = config.get('mean_recons_A_align_beta', 0.0)
    scaled_mean_recons_A_align = mean_recons_A_align * mean_recons_A_align_beta
    mean_recons_B_align_beta = config.get('mean_recons_B_align_beta', 0.0)
    scaled_mean_recons_B_align = mean_recons_B_align * mean_recons_B_align_beta
    scaled_mean_recons_align = (
        scaled_mean_recons_A_align + scaled_mean_recons_B_align)

    # - Reconstruction (from transfer)
    q_z_align_A_sample = q_z_align_A.sample()
    q_z_align_B_sample = q_z_align_B.sample()
    x_A_to_B_align = vae_B.decoder(q_z_align_A_sample)
    x_B_to_A_align = vae_A.decoder(q_z_align_B_sample)
    mean_recons_A_to_B_align = tf.reduce_mean(
        tf.square(x_A_to_B_align - x_align_B))
    mean_recons_B_to_A_align = tf.reduce_mean(
        tf.square(x_B_to_A_align - x_align_A))
    mean_recons_A_to_B_align_beta = config.get('mean_recons_A_to_B_align_beta',
                                               0.0)
    scaled_mean_recons_A_to_B_align = (
        mean_recons_A_to_B_align * mean_recons_A_to_B_align_beta)
    mean_recons_B_to_A_align_beta = config.get('mean_recons_B_to_A_align_beta',
                                               0.0)
    scaled_mean_recons_B_to_A_align = (
        mean_recons_B_to_A_align * mean_recons_B_to_A_align_beta)
    scaled_mean_recons_cross_A_B_align = (
        scaled_mean_recons_A_to_B_align + scaled_mean_recons_B_to_A_align)

    # Full loss
    full_loss = (vae_loss_A + vae_loss_B + scaled_mean_recons_align +
                 scaled_mean_recons_cross_A_B_align)

    # train op
    full_lr = tf.constant(lr)
    train_full = tf.train.AdamOptimizer(learning_rate=full_lr).minimize(
        full_loss, var_list=vae_vars)

    # Add all endpoints as object attributes
    for k, v in iteritems(locals()):
      self.__dict__[k] = v
Example #2
0
  def _build(self, unused_input=None):
    # pylint:disable=unused-variable
    # Reason:
    #   All endpoints are stored as attribute at the end of `_build`.
    #   Pylint cannot infer this case so it emits false alarm of
    #   unused-variable if we do not disable this warning.

    # pylint:disable=invalid-name
    # Reason:
    #   Following variables have their name consider to be invalid by pylint so
    #   we disable the warning.
    #   - Variable that is class
    #   - Variable that in its name has A or B indictating their belonging of
    #     one side of data.

    # ---------------------------------------------------------------------
    # ## Extract parameters from config
    # ---------------------------------------------------------------------

    config = self.config
    lr = config.get('lr', 3e-4)
    n_latent_shared = config['n_latent_shared']

    if 'n_latent' in config:
      n_latent_A = n_latent_B = config['n_latent']
    else:
      n_latent_A = config['vae_A']['n_latent']
      n_latent_B = config['vae_B']['n_latent']

    # ---------------------------------------------------------------------
    # ## VAE containing Modules with parameters
    # ---------------------------------------------------------------------
    vae_A = VAE(config['vae_A'], name='vae_A')
    vae_A()
    vae_B = VAE(config['vae_B'], name='vae_B')
    vae_B()

    vae_lr = tf.constant(lr)
    vae_vars = vae_A.vae_vars + vae_B.vae_vars
    vae_loss = vae_A.vae_loss + vae_B.vae_loss
    train_vae = tf.train.AdamOptimizer(learning_rate=vae_lr).minimize(
        vae_loss, var_list=vae_vars)
    vae_saver = tf.train.Saver(vae_vars, max_to_keep=100)

    # ---------------------------------------------------------------------
    # ## Computation Flow
    # ---------------------------------------------------------------------

    # Tensor Endpoints
    x_A = vae_A.x
    x_B = vae_B.x
    q_z_sample_A = vae_A.q_z_sample
    q_z_sample_B = vae_B.q_z_sample
    mu_A, sigma_A = vae_A.mu, vae_A.sigma
    mu_B, sigma_B = vae_B.mu, vae_B.sigma
    x_prime_A = vae_A.x_prime
    x_prime_B = vae_B.x_prime
    x_from_prior_A = vae_A.x_from_prior
    x_from_prior_B = vae_B.x_from_prior
    x_A_to_B = vae_B.decoder(q_z_sample_A)
    x_B_to_A = vae_A.decoder(q_z_sample_B)
    x_A_to_B_direct = vae_B.decoder(mu_A)
    x_B_to_A_direct = vae_A.decoder(mu_B)
    z_hat = tf.placeholder(tf.float32, shape=(None, n_latent_shared))
    x_joint_A = vae_A.decoder(z_hat)
    x_joint_B = vae_B.decoder(z_hat)

    vae_loss_A = vae_A.vae_loss
    vae_loss_B = vae_B.vae_loss

    x_align_A = tf.placeholder(tf.float32, shape=(None, n_latent_A))
    x_align_B = tf.placeholder(tf.float32, shape=(None, n_latent_B))
    mu_align_A, sigma_align_A = vae_A.encoder(x_align_A)
    mu_align_B, sigma_align_B = vae_B.encoder(x_align_B)
    q_z_align_A = ds.Normal(loc=mu_align_A, scale=sigma_align_A)
    q_z_align_B = ds.Normal(loc=mu_align_B, scale=sigma_align_B)

    # VI in joint space

    mu_align, sigma_align = nn.product_two_guassian_pdfs(
        mu_align_A, sigma_align_A, mu_align_B, sigma_align_B)
    q_z_align = ds.Normal(loc=mu_align, scale=sigma_align)
    p_z_align = ds.Normal(loc=0., scale=1.)

    # - KL
    KL_qp_align = ds.kl_divergence(q_z_align, p_z_align)
    KL_align = tf.reduce_sum(KL_qp_align, axis=-1)
    mean_KL_align = tf.reduce_mean(KL_align)
    prior_loss_align = mean_KL_align
    prior_loss_align_beta = config.get('prior_loss_align_beta', 0.0)
    scaled_prior_loss_align = prior_loss_align * prior_loss_align_beta

    # - Reconstruction (from joint Gussian)
    q_z_sample_align = q_z_align.sample()
    x_prime_A_align = vae_A.decoder(q_z_sample_align)
    x_prime_B_align = vae_B.decoder(q_z_sample_align)

    mean_recons_A_align = tf.reduce_mean(tf.square(x_prime_A_align - x_align_A))
    mean_recons_B_align = tf.reduce_mean(tf.square(x_prime_B_align - x_align_B))
    mean_recons_A_align_beta = config.get('mean_recons_A_align_beta', 0.0)
    scaled_mean_recons_A_align = mean_recons_A_align * mean_recons_A_align_beta
    mean_recons_B_align_beta = config.get('mean_recons_B_align_beta', 0.0)
    scaled_mean_recons_B_align = mean_recons_B_align * mean_recons_B_align_beta
    scaled_mean_recons_align = (
        scaled_mean_recons_A_align + scaled_mean_recons_B_align)

    # - Reconstruction (from transfer)
    q_z_align_A_sample = q_z_align_A.sample()
    q_z_align_B_sample = q_z_align_B.sample()
    x_A_to_B_align = vae_B.decoder(q_z_align_A_sample)
    x_B_to_A_align = vae_A.decoder(q_z_align_B_sample)
    mean_recons_A_to_B_align = tf.reduce_mean(
        tf.square(x_A_to_B_align - x_align_B))
    mean_recons_B_to_A_align = tf.reduce_mean(
        tf.square(x_B_to_A_align - x_align_A))
    mean_recons_A_to_B_align_beta = config.get('mean_recons_A_to_B_align_beta',
                                               0.0)
    scaled_mean_recons_A_to_B_align = (
        mean_recons_A_to_B_align * mean_recons_A_to_B_align_beta)
    mean_recons_B_to_A_align_beta = config.get('mean_recons_B_to_A_align_beta',
                                               0.0)
    scaled_mean_recons_B_to_A_align = (
        mean_recons_B_to_A_align * mean_recons_B_to_A_align_beta)
    scaled_mean_recons_cross_A_B_align = (
        scaled_mean_recons_A_to_B_align + scaled_mean_recons_B_to_A_align)

    # Full loss
    full_loss = (vae_loss_A + vae_loss_B + scaled_mean_recons_align +
                 scaled_mean_recons_cross_A_B_align)

    # train op
    full_lr = tf.constant(lr)
    train_full = tf.train.AdamOptimizer(learning_rate=full_lr).minimize(
        full_loss, var_list=vae_vars)

    # Add all endpoints as object attributes
    for k, v in iteritems(locals()):
      self.__dict__[k] = v