def test_step(batch_input_test, batch_latent_test): batch_post_mean_test, batch_log_post_std_test, batch_post_cov_chol_test\ = nn.encoder(batch_input_test) unscaled_replica_batch_loss_test_kld =\ loss_kld_full( batch_post_mean_test, batch_log_post_std_test, batch_post_cov_chol_test, prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv, 1) unscaled_replica_batch_loss_test_posterior =\ (1-hyperp.penalty_js)/hyperp.penalty_js *\ 2*tf.reduce_sum(batch_log_post_std_test,axis=1) +\ loss_weighted_post_cov_full_penalized_difference( batch_latent_test, batch_post_mean_test, batch_post_cov_chol_test, (1-hyperp.penalty_js)/hyperp.penalty_js) unscaled_replica_batch_loss_test =\ -(-unscaled_replica_batch_loss_test_kld\ -unscaled_replica_batch_loss_test_posterior) metrics.mean_loss_test(unscaled_replica_batch_loss_test) metrics.mean_loss_test_encoder( unscaled_replica_batch_loss_test_kld) metrics.mean_loss_test_posterior( unscaled_replica_batch_loss_test_posterior) metrics.mean_relative_error_latent_posterior( relative_error(batch_latent_test, batch_post_mean_test))
def val_step(batch_input_val, batch_latent_val): batch_post_mean_val, batch_log_post_std_val, batch_post_cov_chol_val\ = nn.encoder(batch_input_val) unscaled_replica_batch_loss_val_kld =\ loss_kld_full( batch_post_mean_val, batch_log_post_std_val, batch_post_cov_chol_val, prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv, 1) unscaled_replica_batch_loss_val_posterior =\ (1-hyperp.penalty_js)/hyperp.penalty_js *\ 2*tf.reduce_sum(batch_log_post_std_val,axis=1) +\ loss_weighted_post_cov_full_penalized_difference( batch_latent_val, batch_post_mean_val, batch_post_cov_chol_val, (1-hyperp.penalty_js)/hyperp.penalty_js) unscaled_replica_batch_loss_val =\ -(-unscaled_replica_batch_loss_val_kld\ -unscaled_replica_batch_loss_val_posterior) metrics.mean_loss_val(unscaled_replica_batch_loss_val) metrics.mean_loss_val_encoder(unscaled_replica_batch_loss_val_kld) metrics.mean_loss_val_posterior( unscaled_replica_batch_loss_val_posterior)
def val_step(batch_input_val, batch_latent_val): batch_likelihood_val = nn(batch_input_val) batch_post_mean_val, batch_log_post_std_val, batch_post_cov_chol_val\ = nn.encoder(batch_input_val) batch_loss_val_vae =\ loss_diagonal_weighted_penalized_difference( batch_input_val, batch_likelihood_val, noise_regularization_matrix, 1) batch_loss_val_kld =\ loss_kld_full( batch_post_mean_val, batch_log_post_std_val, batch_post_cov_chol_val, prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv, 1) batch_loss_val_posterior =\ (1-hyperp.penalty_js)/hyperp.penalty_js *\ 2*tf.reduce_sum(batch_log_post_std_val,axis=1) +\ loss_weighted_post_cov_full_penalized_difference( batch_latent_val, batch_post_mean_val, batch_post_cov_chol_val, (1-hyperp.penalty_js)/hyperp.penalty_js) batch_loss_val = -(-batch_loss_val_vae\ -batch_loss_val_kld\ -batch_loss_val_posterior) metrics.mean_loss_val(batch_loss_val) metrics.mean_loss_val_posterior(batch_loss_val_posterior) metrics.mean_loss_val_vae(batch_loss_val_vae) metrics.mean_loss_val_encoder(batch_loss_val_kld)
def train_step(batch_input_train, batch_latent_train): with tf.GradientTape() as tape: batch_post_mean_train, batch_log_post_std_train, batch_post_cov_chol_train\ = nn.encoder(batch_input_train) batch_input_pred_forward_model_train =\ solve_forward_model(batch_post_mean_train) unscaled_replica_batch_loss_train_vae =\ loss_trace_likelihood(batch_post_cov_chol_train, identity_otimes_likelihood_matrix, 1) +\ loss_diagonal_weighted_penalized_difference( batch_input_train, batch_input_pred_forward_model_train, noise_regularization_matrix, 1) unscaled_replica_batch_loss_train_kld =\ loss_kld_full( batch_post_mean_train, batch_log_post_std_train, batch_post_cov_chol_train, prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv, 1) unscaled_replica_batch_loss_train_posterior =\ (1-hyperp.penalty_js)/hyperp.penalty_js *\ 2*tf.reduce_sum(batch_log_post_std_train,axis=1) +\ loss_weighted_post_cov_full_penalized_difference( batch_latent_train, batch_post_mean_train, batch_post_cov_chol_train, (1-hyperp.penalty_js)/hyperp.penalty_js) unscaled_replica_batch_loss_train =\ -(-unscaled_replica_batch_loss_train_vae\ -unscaled_replica_batch_loss_train_kld\ -unscaled_replica_batch_loss_train_posterior) scaled_replica_batch_loss_train = tf.reduce_sum( unscaled_replica_batch_loss_train * (1. / hyperp.batch_size)) gradients = tape.gradient(scaled_replica_batch_loss_train, nn.trainable_variables) optimizer.apply_gradients(zip(gradients, nn.trainable_variables)) metrics.mean_loss_train_vae(-unscaled_replica_batch_loss_train_vae) metrics.mean_loss_train_encoder( unscaled_replica_batch_loss_train_kld) metrics.mean_loss_train_posterior( unscaled_replica_batch_loss_train_posterior) return scaled_replica_batch_loss_train
def test_step(batch_input_test, batch_latent_test): batch_likelihood_test = nn(batch_input_test) batch_post_mean_test, batch_log_post_var_test, batch_post_cov_chol_test\ = nn.encoder(batch_input_test) batch_input_pred_test = nn.decoder(batch_latent_test) unscaled_replica_batch_loss_test_vae =\ loss_diagonal_weighted_penalized_difference( batch_input_test, batch_likelihood_test, noise_regularization_matrix, 1) unscaled_replica_batch_loss_test_kld =\ loss_kld_full( batch_post_mean_test, batch_log_post_var_test, batch_post_cov_chol_test, prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv, 1) unscaled_replica_batch_loss_test_posterior =\ (1-hyperp.penalty_js)/hyperp.penalty_js *\ 2*tf.reduce_sum(batch_log_post_var_test,axis=1) +\ loss_weighted_post_cov_full_penalized_difference( batch_latent_test, batch_post_mean_test, batch_post_cov_chol_test, (1-hyperp.penalty_js)/hyperp.penalty_js) unscaled_replica_batch_loss_test =\ -(-unscaled_replica_batch_loss_test_vae\ -unscaled_replica_batch_loss_test_kld\ -unscaled_replica_batch_loss_test_posterior) metrics.mean_loss_test(unscaled_replica_batch_loss_test) metrics.mean_loss_test_vae(unscaled_replica_batch_loss_test_vae) metrics.mean_loss_test_encoder( unscaled_replica_batch_loss_test_kld) metrics.mean_loss_test_posterior( unscaled_replica_batch_loss_test_posterior) metrics.mean_relative_error_input_vae( relative_error(batch_input_test, batch_likelihood_test)) metrics.mean_relative_error_latent_posterior( relative_error(batch_latent_test, batch_post_mean_test)) metrics.mean_relative_error_input_decoder( relative_error(batch_input_test, batch_input_pred_test))
def train_step(batch_input_train, batch_latent_train): with tf.GradientTape() as tape: batch_likelihood_train = nn(batch_input_train) batch_post_mean_train, batch_log_post_std_train, batch_post_cov_chol_train\ = nn.encoder(batch_input_train) batch_loss_train_vae =\ loss_diagonal_weighted_penalized_difference( batch_input_train, batch_likelihood_train, noise_regularization_matrix, 1) batch_loss_train_kld =\ loss_kld_full( batch_post_mean_train, batch_log_post_std_train, batch_post_cov_chol_train, prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv, 1) batch_loss_train_posterior =\ (1-hyperp.penalty_js)/hyperp.penalty_js *\ 2*tf.reduce_sum(batch_log_post_std_train,axis=1) +\ loss_weighted_post_cov_full_penalized_difference( batch_latent_train, batch_post_mean_train, batch_post_cov_chol_train, (1-hyperp.penalty_js)/hyperp.penalty_js) batch_loss_train = -(-batch_loss_train_vae\ -batch_loss_train_kld\ -batch_loss_train_posterior) batch_loss_train_mean = tf.reduce_mean(batch_loss_train, axis=0) gradients = tape.gradient(batch_loss_train_mean, nn.trainable_variables) optimizer.apply_gradients(zip(gradients, nn.trainable_variables)) metrics.mean_loss_train(batch_loss_train) metrics.mean_loss_train_posterior(batch_loss_train_posterior) metrics.mean_loss_train_vae(batch_loss_train_vae) metrics.mean_loss_train_encoder(batch_loss_train_kld) return gradients