def val_step(batch_input_val, batch_latent_val):
        batch_likelihood_val = nn(batch_input_val)
        batch_post_mean_val, batch_log_post_var_val = nn.encoder(
            batch_input_val)

        batch_loss_val_vae =\
                loss_diagonal_weighted_penalized_difference(
                        batch_input_val, batch_likelihood_val,
                        noise_regularization_matrix,
                        1)
        batch_loss_val_kld =\
                loss_kld(
                        batch_post_mean_val, batch_log_post_var_val,
                        prior_mean, prior_cov_inv,
                        1)
        batch_loss_val_posterior =\
                (1-hyperp.penalty_js)/hyperp.penalty_js *\
                tf.reduce_sum(batch_log_post_var_val,axis=1) +\
                loss_diagonal_weighted_penalized_difference(
                        batch_latent_val, batch_post_mean_val,
                        1/tf.math.exp(batch_log_post_var_val/2),
                        (1-hyperp.penalty_js)/hyperp.penalty_js)

        batch_loss_val = -(-batch_loss_val_vae\
                           -batch_loss_val_kld\
                           -batch_loss_val_posterior)

        metrics.mean_loss_val(batch_loss_val)
        metrics.mean_loss_val_posterior(batch_loss_val_posterior)
        metrics.mean_loss_val_vae(batch_loss_val_vae)
        metrics.mean_loss_val_encoder(batch_loss_val_kld)
        def test_step(batch_input_test, batch_latent_test):
            batch_post_mean_test, batch_log_post_var_test = nn.encoder(
                batch_input_test)

            unscaled_replica_batch_loss_test_kld =\
                    loss_kld(
                            batch_post_mean_test, batch_log_post_var_test,
                            prior_mean, prior_cov_inv,
                            1)
            unscaled_replica_batch_loss_test_posterior =\
                    (1-hyperp.penalty_js)/hyperp.penalty_js *\
                    tf.reduce_sum(batch_log_post_var_test,axis=1) +\
                    loss_diagonal_weighted_penalized_difference(
                            batch_latent_test, batch_post_mean_test,
                            1/tf.math.exp(batch_log_post_var_test/2),
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            unscaled_replica_batch_loss_test =\
                    -(-unscaled_replica_batch_loss_test_kld\
                      -unscaled_replica_batch_loss_test_posterior)

            metrics.mean_loss_test(unscaled_replica_batch_loss_test)
            metrics.mean_loss_test_encoder(
                unscaled_replica_batch_loss_test_kld)
            metrics.mean_loss_test_posterior(
                unscaled_replica_batch_loss_test_posterior)

            metrics.mean_relative_error_latent_posterior(
                relative_error(batch_latent_test, batch_post_mean_test))
        def val_step(batch_input_val, batch_latent_val):
            batch_post_mean_val, batch_log_post_var_val = nn.encoder(
                batch_input_val)

            unscaled_replica_batch_loss_val_kld =\
                    loss_kld(
                            batch_post_mean_val, batch_log_post_var_val,
                            prior_mean, prior_cov_inv,
                            1)
            unscaled_replica_batch_loss_val_posterior =\
                    (1-hyperp.penalty_js)/hyperp.penalty_js *\
                    tf.reduce_sum(batch_log_post_var_val,axis=1) +\
                    loss_diagonal_weighted_penalized_difference(
                            batch_latent_val, batch_post_mean_val,
                            1/tf.math.exp(batch_log_post_var_val/2),
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            unscaled_replica_batch_loss_val =\
                    -(-unscaled_replica_batch_loss_val_kld\
                      -unscaled_replica_batch_loss_val_posterior)

            metrics.mean_loss_val(unscaled_replica_batch_loss_val)
            metrics.mean_loss_val_encoder(unscaled_replica_batch_loss_val_kld)
            metrics.mean_loss_val_posterior(
                unscaled_replica_batch_loss_val_posterior)
        def val_step(batch_input_val, batch_latent_val):
            batch_likelihood_val = nn(batch_input_val)
            batch_post_mean_val, batch_log_post_var_val, batch_post_cov_chol_val\
                    = nn.encoder(batch_input_val)

            unscaled_replica_batch_loss_val_vae =\
                    loss_diagonal_weighted_penalized_difference(
                            batch_input_val, batch_likelihood_val,
                            noise_regularization_matrix,
                            1)
            unscaled_replica_batch_loss_val_kld =\
                    loss_kld_full(
                            batch_post_mean_val, batch_log_post_var_val,
                            batch_post_cov_chol_val,
                            prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv,
                            1)
            unscaled_replica_batch_loss_val_posterior =\
                (1-hyperp.penalty_js)/hyperp.penalty_js *\
                2*tf.reduce_sum(batch_log_post_var_val,axis=1) +\
                loss_weighted_post_cov_full_penalized_difference(
                        batch_latent_val, batch_post_mean_val,
                        batch_post_cov_chol_val,
                        (1-hyperp.penalty_js)/hyperp.penalty_js)

            unscaled_replica_batch_loss_val =\
                    -(-unscaled_replica_batch_loss_val_vae\
                      -unscaled_replica_batch_loss_val_kld\
                      -unscaled_replica_batch_loss_val_posterior)

            metrics.mean_loss_val(unscaled_replica_batch_loss_val)
            metrics.mean_loss_val_vae(unscaled_replica_batch_loss_val_vae)
            metrics.mean_loss_val_encoder(unscaled_replica_batch_loss_val_kld)
            metrics.mean_loss_val_posterior(
                unscaled_replica_batch_loss_val_posterior)
    def test_step(batch_input_test, batch_latent_test):
        batch_likelihood_test = nn(batch_input_test)
        batch_post_mean_test, batch_log_post_var_test = nn.encoder(
            batch_input_test)
        batch_posterior_sample_test = nn.iaf_chain_encoder(
            (batch_post_mean_test, batch_log_post_var_test),
            sample_flag=True,
            infer_flag=False)
        batch_input_pred_test = nn.decoder(batch_latent_test)

        batch_loss_test_vae =\
                loss_diagonal_weighted_penalized_difference(
                        batch_input_test, batch_likelihood_test,
                        noise_regularization_matrix,
                        1)
        batch_loss_test_iaf_encoder =\
                nn.iaf_chain_encoder((batch_post_mean_test,
                                      batch_log_post_var_test),
                                      sample_flag = False,
                                      infer_flag = True)
        batch_loss_test_prior =\
                loss_diagonal_weighted_penalized_difference(
                        prior_mean, batch_posterior_sample_test,
                        prior_covariance_cholesky_inverse,
                        1)
        batch_loss_test_posterior =\
                loss_penalized_difference(
                        batch_latent_test, batch_posterior_sample_test,
                        (1-hyperp.penalty_js)/hyperp.penalty_js)

        batch_loss_test = -(-batch_loss_test_vae\
                            -batch_loss_test_iaf_encoder\
                            -batch_loss_test_prior\
                            -batch_loss_test_posterior)

        metrics.mean_loss_test(batch_loss_test)
        metrics.mean_loss_test_vae(batch_loss_test_vae)
        metrics.mean_loss_test_encoder(batch_loss_test_iaf_encoder)
        metrics.mean_loss_test_prior(batch_loss_test_prior)
        metrics.mean_loss_test_posterior(batch_loss_test_posterior)

        metrics.mean_relative_error_input_vae(
            relative_error(batch_input_test, batch_likelihood_test))
        metrics.mean_relative_error_latent_posterior(
            relative_error(batch_latent_test, batch_posterior_sample_test))
        metrics.mean_relative_error_input_decoder(
            relative_error(batch_input_test, batch_input_pred_test))
    def train_step(batch_input_train, batch_latent_train):
        with tf.GradientTape() as tape:
            batch_likelihood_train = nn(batch_input_train)
            batch_post_mean_train, batch_log_post_var_train = nn.encoder(
                batch_input_train)
            batch_posterior_sample_train = nn.iaf_chain_encoder(
                (batch_post_mean_train, batch_log_post_var_train),
                sample_flag=True,
                infer_flag=False)

            batch_loss_train_vae =\
                    loss_diagonal_weighted_penalized_difference(
                            batch_input_train, batch_likelihood_train,
                            noise_regularization_matrix,
                            1)
            batch_loss_train_iaf_encoder =\
                    nn.iaf_chain_encoder((batch_post_mean_train,
                                          batch_log_post_var_train),
                                          sample_flag = False,
                                          infer_flag = True)
            batch_loss_train_prior =\
                    loss_diagonal_weighted_penalized_difference(
                            prior_mean, batch_posterior_sample_train,
                            prior_covariance_cholesky_inverse,
                            1)
            batch_loss_train_posterior =\
                    loss_penalized_difference(
                            batch_latent_train, batch_posterior_sample_train,
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            batch_loss_train = -(-batch_loss_train_vae\
                                 -batch_loss_train_iaf_encoder\
                                 -batch_loss_train_prior\
                                 -batch_loss_train_posterior)
            batch_loss_train_mean = tf.reduce_mean(batch_loss_train, axis=0)

        gradients = tape.gradient(batch_loss_train_mean,
                                  nn.trainable_variables)
        optimizer.apply_gradients(zip(gradients, nn.trainable_variables))
        metrics.mean_loss_train(batch_loss_train)
        metrics.mean_loss_train_vae(batch_loss_train_vae)
        metrics.mean_loss_train_encoder(batch_loss_train_iaf_encoder)
        metrics.mean_loss_train_prior(batch_loss_train_prior)
        metrics.mean_loss_train_posterior(batch_loss_train_posterior)

        return gradients
        def train_step(batch_input_train, batch_latent_train):
            with tf.GradientTape() as tape:
                batch_post_mean_train, batch_log_post_var_train = nn.encoder(
                    batch_input_train)
                batch_input_pred_forward_model_train =\
                        solve_forward_model(
                            nn.reparameterize(batch_post_mean_train, batch_log_post_var_train))

                unscaled_replica_batch_loss_train_vae =\
                        loss_diagonal_weighted_penalized_difference(
                                batch_input_train, batch_input_pred_forward_model_train,
                                noise_regularization_matrix,
                                1)
                unscaled_replica_batch_loss_train_kld =\
                        loss_kld(
                                batch_post_mean_train, batch_log_post_var_train,
                                prior_mean, prior_cov_inv,
                                1)
                unscaled_replica_batch_loss_train_posterior =\
                        (1-hyperp.penalty_js)/hyperp.penalty_js *\
                        tf.reduce_sum(batch_log_post_var_train,axis=1) +\
                        loss_diagonal_weighted_penalized_difference(
                                batch_latent_train, batch_post_mean_train,
                                1/tf.math.exp(batch_log_post_var_train/2),
                                (1-hyperp.penalty_js)/hyperp.penalty_js)

                unscaled_replica_batch_loss_train =\
                        -(-unscaled_replica_batch_loss_train_vae\
                          -unscaled_replica_batch_loss_train_kld\
                          -unscaled_replica_batch_loss_train_posterior)
                scaled_replica_batch_loss_train = tf.reduce_sum(
                    unscaled_replica_batch_loss_train *
                    (1. / hyperp.batch_size))

            gradients = tape.gradient(scaled_replica_batch_loss_train,
                                      nn.trainable_variables)
            optimizer.apply_gradients(zip(gradients, nn.trainable_variables))
            metrics.mean_loss_train_vae(-unscaled_replica_batch_loss_train_vae)
            metrics.mean_loss_train_encoder(
                unscaled_replica_batch_loss_train_kld)
            metrics.mean_loss_train_posterior(
                unscaled_replica_batch_loss_train_posterior)

            return scaled_replica_batch_loss_train
    def val_step(batch_input_val, batch_latent_val):
        batch_likelihood_val = nn(batch_input_val)
        batch_post_mean_val, batch_log_post_var_val = nn.encoder(
            batch_input_val)
        batch_posterior_sample_val = nn.iaf_chain_encoder(
            (batch_post_mean_val, batch_log_post_var_val),
            sample_flag=True,
            infer_flag=False)

        batch_loss_val_vae =\
                loss_diagonal_weighted_penalized_difference(
                        batch_input_val, batch_likelihood_val,
                        noise_regularization_matrix,
                        1)
        batch_loss_val_iaf_encoder =\
                nn.iaf_chain_encoder((batch_post_mean_val,
                                      batch_log_post_var_val),
                                      sample_flag = False,
                                      infer_flag = True)
        batch_loss_val_prior =\
                loss_diagonal_weighted_penalized_difference(
                        prior_mean, batch_posterior_sample_val,
                        prior_covariance_cholesky_inverse,
                        1)
        batch_loss_val_posterior =\
                loss_penalized_difference(
                        batch_latent_val, batch_posterior_sample_val,
                        (1-hyperp.penalty_js)/hyperp.penalty_js)

        batch_loss_val = -(-batch_loss_val_vae\
                           -batch_loss_val_iaf_encoder\
                           -batch_loss_val_prior\
                           -batch_loss_val_posterior)

        metrics.mean_loss_val(batch_loss_val)
        metrics.mean_loss_val_vae(batch_loss_val_vae)
        metrics.mean_loss_val_encoder(batch_loss_val_iaf_encoder)
        metrics.mean_loss_val_prior(batch_loss_val_prior)
        metrics.mean_loss_val_posterior(batch_loss_val_posterior)
    def test_step(batch_input_test, batch_latent_test):
        batch_likelihood_test = nn(batch_input_test)
        batch_post_mean_test, batch_log_post_var_test = nn.encoder(
            batch_input_test)
        batch_input_pred_test = nn.decoder(batch_latent_test)

        batch_loss_test_vae =\
                loss_diagonal_weighted_penalized_difference(
                        batch_input_test, batch_likelihood_test,
                        noise_regularization_matrix,
                        1)
        batch_loss_test_kld =\
                loss_kld(
                        batch_post_mean_test, batch_log_post_var_test,
                        prior_mean, prior_cov_inv,
                        1)
        batch_loss_test_posterior =\
                (1-hyperp.penalty_js)/hyperp.penalty_js *\
                tf.reduce_sum(batch_log_post_var_test,axis=1) +\
                loss_diagonal_weighted_penalized_difference(
                        batch_latent_test, batch_post_mean_test,
                        1/tf.math.exp(batch_log_post_var_test/2),
                        (1-hyperp.penalty_js)/hyperp.penalty_js)

        batch_loss_test = -(-batch_loss_test_vae\
                            -batch_loss_test_kld\
                            -batch_loss_test_posterior)

        metrics.mean_loss_test(batch_loss_test)
        metrics.mean_loss_test_vae(batch_loss_test_vae)
        metrics.mean_loss_test_encoder(batch_loss_test_kld)
        metrics.mean_loss_test_posterior(batch_loss_test_posterior)

        metrics.mean_relative_error_input_vae(
            relative_error(batch_input_test, batch_likelihood_test))
        metrics.mean_relative_error_latent_posterior(
            relative_error(batch_latent_test, batch_post_mean_test))
        metrics.mean_relative_error_input_decoder(
            relative_error(batch_input_test, batch_input_pred_test))
示例#10
0
    def train_step(batch_input_train, batch_latent_train):
        with tf.GradientTape() as tape:
            batch_likelihood_train = nn(batch_input_train)
            batch_post_mean_train, batch_log_post_var_train = nn.encoder(
                batch_input_train)

            batch_loss_train_vae =\
                    loss_diagonal_weighted_penalized_difference(
                            batch_input_train, batch_likelihood_train,
                            noise_regularization_matrix,
                            1)
            batch_loss_train_kld =\
                    loss_kld(
                            batch_post_mean_train, batch_log_post_var_train,
                            prior_mean, prior_cov_inv,
                            1)
            batch_loss_train_posterior =\
                    (1-hyperp.penalty_js)/hyperp.penalty_js *\
                    tf.reduce_sum(batch_log_post_var_train,axis=1) +\
                    loss_diagonal_weighted_penalized_difference(
                            batch_latent_train, batch_post_mean_train,
                            1/tf.math.exp(batch_log_post_var_train/2),
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            batch_loss_train = -(-batch_loss_train_vae\
                                 -batch_loss_train_kld\
                                 -batch_loss_train_posterior)
            batch_loss_train_mean = tf.reduce_mean(batch_loss_train, axis=0)

        gradients = tape.gradient(batch_loss_train_mean,
                                  nn.trainable_variables)
        optimizer.apply_gradients(zip(gradients, nn.trainable_variables))
        metrics.mean_loss_train(batch_loss_train)
        metrics.mean_loss_train_posterior(batch_loss_train_posterior)
        metrics.mean_loss_train_vae(batch_loss_train_vae)
        metrics.mean_loss_train_encoder(batch_loss_train_kld)

        return gradients
示例#11
0
    def train_step(batch_input_train, batch_latent_train):
        with tf.GradientTape() as tape:
            batch_post_mean_train, batch_log_post_std_train, batch_post_cov_chol_train\
                    = nn.encoder(batch_input_train)
            batch_input_pred_forward_model_train =\
                    solve_forward_model(batch_post_mean_train)

            batch_loss_train_vae =\
                    loss_trace_likelihood(batch_post_cov_chol_train,
                            identity_otimes_likelihood_matrix,
                            1) +\
                    loss_diagonal_weighted_penalized_difference(
                            batch_input_train, batch_input_pred_forward_model_train,
                            noise_regularization_matrix,
                            1)
            batch_loss_train_kld =\
                    loss_kld_full(
                            batch_post_mean_train, batch_log_post_std_train,
                            batch_post_cov_chol_train,
                            prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv,
                            1)

            batch_loss_train_posterior =\
                    (1-hyperp.penalty_js)/hyperp.penalty_js *\
                    2*tf.reduce_sum(batch_log_post_std_train,axis=1) +\
                    loss_weighted_post_cov_full_penalized_difference(
                            batch_latent_train, batch_post_mean_train,
                            batch_post_cov_chol_train,
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            batch_loss_train = -(-batch_loss_train_vae\
                                 -batch_loss_train_kld\
                                 -batch_loss_train_posterior)
            batch_loss_train_mean = tf.reduce_mean(batch_loss_train, axis=0)

        gradients = tape.gradient(batch_loss_train_mean,
                                  nn.trainable_variables)
        optimizer.apply_gradients(zip(gradients, nn.trainable_variables))
        metrics.mean_loss_train(batch_loss_train)
        metrics.mean_loss_train_vae(batch_loss_train_vae)
        metrics.mean_loss_train_encoder(batch_loss_train_kld)
        metrics.mean_loss_train_posterior(batch_loss_train_posterior)

        return gradients
        def test_step(batch_input_test, batch_latent_test):
            batch_likelihood_test = nn(batch_input_test)
            batch_post_mean_test, batch_log_post_var_test, batch_post_cov_chol_test\
                    = nn.encoder(batch_input_test)
            batch_input_pred_test = nn.decoder(batch_latent_test)

            unscaled_replica_batch_loss_test_vae =\
                    loss_diagonal_weighted_penalized_difference(
                            batch_input_test, batch_likelihood_test,
                            noise_regularization_matrix,
                            1)
            unscaled_replica_batch_loss_test_kld =\
                    loss_kld_full(
                            batch_post_mean_test, batch_log_post_var_test,
                            batch_post_cov_chol_test,
                            prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv,
                            1)
            unscaled_replica_batch_loss_test_posterior =\
                    (1-hyperp.penalty_js)/hyperp.penalty_js *\
                    2*tf.reduce_sum(batch_log_post_var_test,axis=1) +\
                    loss_weighted_post_cov_full_penalized_difference(
                            batch_latent_test, batch_post_mean_test,
                            batch_post_cov_chol_test,
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            unscaled_replica_batch_loss_test =\
                    -(-unscaled_replica_batch_loss_test_vae\
                      -unscaled_replica_batch_loss_test_kld\
                      -unscaled_replica_batch_loss_test_posterior)

            metrics.mean_loss_test(unscaled_replica_batch_loss_test)
            metrics.mean_loss_test_vae(unscaled_replica_batch_loss_test_vae)
            metrics.mean_loss_test_encoder(
                unscaled_replica_batch_loss_test_kld)
            metrics.mean_loss_test_posterior(
                unscaled_replica_batch_loss_test_posterior)

            metrics.mean_relative_error_input_vae(
                relative_error(batch_input_test, batch_likelihood_test))
            metrics.mean_relative_error_latent_posterior(
                relative_error(batch_latent_test, batch_post_mean_test))
            metrics.mean_relative_error_input_decoder(
                relative_error(batch_input_test, batch_input_pred_test))
        def train_step(batch_input_train, batch_latent_train):
            with tf.GradientTape() as tape:
                batch_likelihood_train = nn(batch_input_train)
                batch_post_mean_train, batch_log_post_var_train, batch_post_cov_chol_train\
                        = nn.encoder(batch_input_train)

                unscaled_replica_batch_loss_train_vae =\
                        loss_diagonal_weighted_penalized_difference(
                                batch_input_train, batch_likelihood_train,
                                noise_regularization_matrix,
                                1)
                unscaled_replica_batch_loss_train_kld =\
                        loss_kld_full(
                                batch_post_mean_train, batch_log_post_var_train,
                                batch_post_cov_chol_train,
                                prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv,
                                1)
                unscaled_replica_batch_loss_train_posterior =\
                        (1-hyperp.penalty_js)/hyperp.penalty_js *\
                        2*tf.reduce_sum(batch_log_post_var_train,axis=1) +\
                        loss_weighted_post_cov_full_penalized_difference(
                                batch_latent_train, batch_post_mean_train,
                                batch_post_cov_chol_train,
                                (1-hyperp.penalty_js)/hyperp.penalty_js)

                unscaled_replica_batch_loss_train =\
                        -(-unscaled_replica_batch_loss_train_vae\
                          -unscaled_replica_batch_loss_train_kld\
                          -unscaled_replica_batch_loss_train_posterior)
                scaled_replica_batch_loss_train =\
                        tf.reduce_sum(unscaled_replica_batch_loss_train * (1./hyperp.batch_size))

            gradients = tape.gradient(scaled_replica_batch_loss_train,
                                      nn.trainable_variables)
            optimizer.apply_gradients(zip(gradients, nn.trainable_variables))
            metrics.mean_loss_train_vae(unscaled_replica_batch_loss_train_vae)
            metrics.mean_loss_train_encoder(
                unscaled_replica_batch_loss_train_kld)
            metrics.mean_loss_train_posterior(
                unscaled_replica_batch_loss_train_posterior)

            return scaled_replica_batch_loss_train