def _build(self, unused_input=None): # pylint:disable=unused-variable,possibly-unused-variable # Reason: # All endpoints are stored as attribute at the end of `_build`. # Pylint cannot infer this case so it emits false alarm of # unused-variable if we do not disable this warning. # pylint:disable=invalid-name # Reason: # Following variables have their name consider to be invalid by pylint so # we disable the warning. # - Variable that is class # - Variable that in its name has A or B indicating their belonging of # one side of data. # --------------------------------------------------------------------- # ## Extract parameters from config # --------------------------------------------------------------------- config = self.config lr = config.get('lr', 3e-4) n_latent_shared = config['n_latent_shared'] if 'n_latent' in config: n_latent_A = n_latent_B = config['n_latent'] else: n_latent_A = config['vae_A']['n_latent'] n_latent_B = config['vae_B']['n_latent'] # --------------------------------------------------------------------- # ## VAE containing Modules with parameters # --------------------------------------------------------------------- vae_A = VAE(config['vae_A'], name='vae_A') vae_A() vae_B = VAE(config['vae_B'], name='vae_B') vae_B() vae_lr = tf.constant(lr) vae_vars = vae_A.vae_vars + vae_B.vae_vars vae_loss = vae_A.vae_loss + vae_B.vae_loss train_vae = tf.train.AdamOptimizer(learning_rate=vae_lr).minimize( vae_loss, var_list=vae_vars) vae_saver = tf.train.Saver(vae_vars, max_to_keep=100) # --------------------------------------------------------------------- # ## Computation Flow # --------------------------------------------------------------------- # Tensor Endpoints x_A = vae_A.x x_B = vae_B.x q_z_sample_A = vae_A.q_z_sample q_z_sample_B = vae_B.q_z_sample mu_A, sigma_A = vae_A.mu, vae_A.sigma mu_B, sigma_B = vae_B.mu, vae_B.sigma x_prime_A = vae_A.x_prime x_prime_B = vae_B.x_prime x_from_prior_A = vae_A.x_from_prior x_from_prior_B = vae_B.x_from_prior x_A_to_B = vae_B.decoder(q_z_sample_A) x_B_to_A = vae_A.decoder(q_z_sample_B) x_A_to_B_direct = vae_B.decoder(mu_A) x_B_to_A_direct = vae_A.decoder(mu_B) z_hat = tf.placeholder(tf.float32, shape=(None, n_latent_shared)) x_joint_A = vae_A.decoder(z_hat) x_joint_B = vae_B.decoder(z_hat) vae_loss_A = vae_A.vae_loss vae_loss_B = vae_B.vae_loss x_align_A = tf.placeholder(tf.float32, shape=(None, n_latent_A)) x_align_B = tf.placeholder(tf.float32, shape=(None, n_latent_B)) mu_align_A, sigma_align_A = vae_A.encoder(x_align_A) mu_align_B, sigma_align_B = vae_B.encoder(x_align_B) q_z_align_A = ds.Normal(loc=mu_align_A, scale=sigma_align_A) q_z_align_B = ds.Normal(loc=mu_align_B, scale=sigma_align_B) # VI in joint space mu_align, sigma_align = nn.product_two_guassian_pdfs( mu_align_A, sigma_align_A, mu_align_B, sigma_align_B) q_z_align = ds.Normal(loc=mu_align, scale=sigma_align) p_z_align = ds.Normal(loc=0., scale=1.) # - KL KL_qp_align = ds.kl_divergence(q_z_align, p_z_align) KL_align = tf.reduce_sum(KL_qp_align, axis=-1) mean_KL_align = tf.reduce_mean(KL_align) prior_loss_align = mean_KL_align prior_loss_align_beta = config.get('prior_loss_align_beta', 0.0) scaled_prior_loss_align = prior_loss_align * prior_loss_align_beta # - Reconstruction (from joint Gussian) q_z_sample_align = q_z_align.sample() x_prime_A_align = vae_A.decoder(q_z_sample_align) x_prime_B_align = vae_B.decoder(q_z_sample_align) mean_recons_A_align = tf.reduce_mean(tf.square(x_prime_A_align - x_align_A)) mean_recons_B_align = tf.reduce_mean(tf.square(x_prime_B_align - x_align_B)) mean_recons_A_align_beta = config.get('mean_recons_A_align_beta', 0.0) scaled_mean_recons_A_align = mean_recons_A_align * mean_recons_A_align_beta mean_recons_B_align_beta = config.get('mean_recons_B_align_beta', 0.0) scaled_mean_recons_B_align = mean_recons_B_align * mean_recons_B_align_beta scaled_mean_recons_align = ( scaled_mean_recons_A_align + scaled_mean_recons_B_align) # - Reconstruction (from transfer) q_z_align_A_sample = q_z_align_A.sample() q_z_align_B_sample = q_z_align_B.sample() x_A_to_B_align = vae_B.decoder(q_z_align_A_sample) x_B_to_A_align = vae_A.decoder(q_z_align_B_sample) mean_recons_A_to_B_align = tf.reduce_mean( tf.square(x_A_to_B_align - x_align_B)) mean_recons_B_to_A_align = tf.reduce_mean( tf.square(x_B_to_A_align - x_align_A)) mean_recons_A_to_B_align_beta = config.get('mean_recons_A_to_B_align_beta', 0.0) scaled_mean_recons_A_to_B_align = ( mean_recons_A_to_B_align * mean_recons_A_to_B_align_beta) mean_recons_B_to_A_align_beta = config.get('mean_recons_B_to_A_align_beta', 0.0) scaled_mean_recons_B_to_A_align = ( mean_recons_B_to_A_align * mean_recons_B_to_A_align_beta) scaled_mean_recons_cross_A_B_align = ( scaled_mean_recons_A_to_B_align + scaled_mean_recons_B_to_A_align) # Full loss full_loss = (vae_loss_A + vae_loss_B + scaled_mean_recons_align + scaled_mean_recons_cross_A_B_align) # train op full_lr = tf.constant(lr) train_full = tf.train.AdamOptimizer(learning_rate=full_lr).minimize( full_loss, var_list=vae_vars) # Add all endpoints as object attributes for k, v in iteritems(locals()): self.__dict__[k] = v
def _build(self, unused_input=None): # pylint:disable=unused-variable # Reason: # All endpoints are stored as attribute at the end of `_build`. # Pylint cannot infer this case so it emits false alarm of # unused-variable if we do not disable this warning. # pylint:disable=invalid-name # Reason: # Following variables have their name consider to be invalid by pylint so # we disable the warning. # - Variable that is class # - Variable that in its name has A or B indictating their belonging of # one side of data. # --------------------------------------------------------------------- # ## Extract parameters from config # --------------------------------------------------------------------- config = self.config lr = config.get('lr', 3e-4) n_latent_shared = config['n_latent_shared'] if 'n_latent' in config: n_latent_A = n_latent_B = config['n_latent'] else: n_latent_A = config['vae_A']['n_latent'] n_latent_B = config['vae_B']['n_latent'] # --------------------------------------------------------------------- # ## VAE containing Modules with parameters # --------------------------------------------------------------------- vae_A = VAE(config['vae_A'], name='vae_A') vae_A() vae_B = VAE(config['vae_B'], name='vae_B') vae_B() vae_lr = tf.constant(lr) vae_vars = vae_A.vae_vars + vae_B.vae_vars vae_loss = vae_A.vae_loss + vae_B.vae_loss train_vae = tf.train.AdamOptimizer(learning_rate=vae_lr).minimize( vae_loss, var_list=vae_vars) vae_saver = tf.train.Saver(vae_vars, max_to_keep=100) # --------------------------------------------------------------------- # ## Computation Flow # --------------------------------------------------------------------- # Tensor Endpoints x_A = vae_A.x x_B = vae_B.x q_z_sample_A = vae_A.q_z_sample q_z_sample_B = vae_B.q_z_sample mu_A, sigma_A = vae_A.mu, vae_A.sigma mu_B, sigma_B = vae_B.mu, vae_B.sigma x_prime_A = vae_A.x_prime x_prime_B = vae_B.x_prime x_from_prior_A = vae_A.x_from_prior x_from_prior_B = vae_B.x_from_prior x_A_to_B = vae_B.decoder(q_z_sample_A) x_B_to_A = vae_A.decoder(q_z_sample_B) x_A_to_B_direct = vae_B.decoder(mu_A) x_B_to_A_direct = vae_A.decoder(mu_B) z_hat = tf.placeholder(tf.float32, shape=(None, n_latent_shared)) x_joint_A = vae_A.decoder(z_hat) x_joint_B = vae_B.decoder(z_hat) vae_loss_A = vae_A.vae_loss vae_loss_B = vae_B.vae_loss x_align_A = tf.placeholder(tf.float32, shape=(None, n_latent_A)) x_align_B = tf.placeholder(tf.float32, shape=(None, n_latent_B)) mu_align_A, sigma_align_A = vae_A.encoder(x_align_A) mu_align_B, sigma_align_B = vae_B.encoder(x_align_B) q_z_align_A = ds.Normal(loc=mu_align_A, scale=sigma_align_A) q_z_align_B = ds.Normal(loc=mu_align_B, scale=sigma_align_B) # VI in joint space mu_align, sigma_align = nn.product_two_guassian_pdfs( mu_align_A, sigma_align_A, mu_align_B, sigma_align_B) q_z_align = ds.Normal(loc=mu_align, scale=sigma_align) p_z_align = ds.Normal(loc=0., scale=1.) # - KL KL_qp_align = ds.kl_divergence(q_z_align, p_z_align) KL_align = tf.reduce_sum(KL_qp_align, axis=-1) mean_KL_align = tf.reduce_mean(KL_align) prior_loss_align = mean_KL_align prior_loss_align_beta = config.get('prior_loss_align_beta', 0.0) scaled_prior_loss_align = prior_loss_align * prior_loss_align_beta # - Reconstruction (from joint Gussian) q_z_sample_align = q_z_align.sample() x_prime_A_align = vae_A.decoder(q_z_sample_align) x_prime_B_align = vae_B.decoder(q_z_sample_align) mean_recons_A_align = tf.reduce_mean(tf.square(x_prime_A_align - x_align_A)) mean_recons_B_align = tf.reduce_mean(tf.square(x_prime_B_align - x_align_B)) mean_recons_A_align_beta = config.get('mean_recons_A_align_beta', 0.0) scaled_mean_recons_A_align = mean_recons_A_align * mean_recons_A_align_beta mean_recons_B_align_beta = config.get('mean_recons_B_align_beta', 0.0) scaled_mean_recons_B_align = mean_recons_B_align * mean_recons_B_align_beta scaled_mean_recons_align = ( scaled_mean_recons_A_align + scaled_mean_recons_B_align) # - Reconstruction (from transfer) q_z_align_A_sample = q_z_align_A.sample() q_z_align_B_sample = q_z_align_B.sample() x_A_to_B_align = vae_B.decoder(q_z_align_A_sample) x_B_to_A_align = vae_A.decoder(q_z_align_B_sample) mean_recons_A_to_B_align = tf.reduce_mean( tf.square(x_A_to_B_align - x_align_B)) mean_recons_B_to_A_align = tf.reduce_mean( tf.square(x_B_to_A_align - x_align_A)) mean_recons_A_to_B_align_beta = config.get('mean_recons_A_to_B_align_beta', 0.0) scaled_mean_recons_A_to_B_align = ( mean_recons_A_to_B_align * mean_recons_A_to_B_align_beta) mean_recons_B_to_A_align_beta = config.get('mean_recons_B_to_A_align_beta', 0.0) scaled_mean_recons_B_to_A_align = ( mean_recons_B_to_A_align * mean_recons_B_to_A_align_beta) scaled_mean_recons_cross_A_B_align = ( scaled_mean_recons_A_to_B_align + scaled_mean_recons_B_to_A_align) # Full loss full_loss = (vae_loss_A + vae_loss_B + scaled_mean_recons_align + scaled_mean_recons_cross_A_B_align) # train op full_lr = tf.constant(lr) train_full = tf.train.AdamOptimizer(learning_rate=full_lr).minimize( full_loss, var_list=vae_vars) # Add all endpoints as object attributes for k, v in iteritems(locals()): self.__dict__[k] = v