def optimize_distributed(dist_strategy, hyperp, options, filepaths, nn,
                         optimizer, input_and_latent_train,
                         input_and_latent_val, input_and_latent_test,
                         input_dimensions, latent_dimension, num_batches_train,
                         noise_regularization_matrix, measurement_matrix,
                         prior_mean, prior_cov_inv, forward_matrix,
                         solve_forward_model):

    #=== Kronecker Product of Identity and Prior Covariance Inverse ===#
    identity_otimes_prior_cov_inv =\
            tf.linalg.LinearOperatorKronecker(
                    [tf.linalg.LinearOperatorFullMatrix(tf.eye(latent_dimension)),
                    tf.linalg.LinearOperatorFullMatrix(prior_cov_inv)])

    #=== Kronecker Product of Identity and Likelihood Matrix ===#
    if measurement_matrix.shape == (1, 1):
        likelihood_matrix = tf.linalg.matmul(
            tf.transpose(forward_matrix),
            noise_regularization_matrix * forward_matrix)
    else:
        likelihood_matrix = tf.linalg.matmul(
            tf.transpose(tf.linalg.matmul(measurement_matrix, forward_matrix)),
            tf.linalg.matmul(
                tf.linalg.diag(tf.squeeze(noise_regularization_matrix)),
                tf.linalg.matmul(measurement_matrix, forward_matrix)))
    identity_otimes_likelihood_matrix =\
            tf.linalg.LinearOperatorKronecker(
                    [tf.linalg.LinearOperatorFullMatrix(tf.eye(latent_dimension)),
                    tf.linalg.LinearOperatorFullMatrix(likelihood_matrix)])

    #=== Check Number of Parallel Computations and Set Global Batch Size ===#
    print('Number of Replicas in Sync: %d' %
          (dist_strategy.num_replicas_in_sync))

    #=== Distribute Data ===#
    dist_input_and_latent_train =\
            dist_strategy.experimental_distribute_dataset(input_and_latent_train)
    dist_input_and_latent_val = dist_strategy.experimental_distribute_dataset(
        input_and_latent_val)
    dist_input_and_latent_test = dist_strategy.experimental_distribute_dataset(
        input_and_latent_test)

    #=== Metrics ===#
    metrics = Metrics(dist_strategy)

    #=== Creating Directory for Trained Neural Network ===#
    if not os.path.exists(filepaths.directory_trained_nn):
        os.makedirs(filepaths.directory_trained_nn)

    #=== Tensorboard ===# "tensorboard --logdir=tensorboard"
    if os.path.exists(filepaths.directory_tensorboard):
        shutil.rmtree(filepaths.directory_tensorboard)
    summary_writer = tf.summary.create_file_writer(
        filepaths.directory_tensorboard)

    #=== Display Neural Network Architecture ===#
    with dist_strategy.scope():
        nn.build((hyperp.batch_size, input_dimensions))
        nn.summary()

###############################################################################
#                   Training, Validation and Testing Step                     #
###############################################################################
    with dist_strategy.scope():
        #=== Training Step ===#
        def train_step(batch_input_train, batch_latent_train):
            with tf.GradientTape() as tape:
                batch_post_mean_train, batch_log_post_std_train, batch_post_cov_chol_train\
                        = nn.encoder(batch_input_train)
                batch_input_pred_forward_model_train =\
                        solve_forward_model(batch_post_mean_train)

                unscaled_replica_batch_loss_train_vae =\
                        loss_trace_likelihood(batch_post_cov_chol_train,
                                identity_otimes_likelihood_matrix,
                                1) +\
                        loss_diagonal_weighted_penalized_difference(
                                batch_input_train, batch_input_pred_forward_model_train,
                                noise_regularization_matrix,
                                1)
                unscaled_replica_batch_loss_train_kld =\
                        loss_kld_full(
                                batch_post_mean_train, batch_log_post_std_train,
                                batch_post_cov_chol_train,
                                prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv,
                                1)
                unscaled_replica_batch_loss_train_posterior =\
                        (1-hyperp.penalty_js)/hyperp.penalty_js *\
                        2*tf.reduce_sum(batch_log_post_std_train,axis=1) +\
                        loss_weighted_post_cov_full_penalized_difference(
                                batch_latent_train, batch_post_mean_train,
                                batch_post_cov_chol_train,
                                (1-hyperp.penalty_js)/hyperp.penalty_js)

                unscaled_replica_batch_loss_train =\
                        -(-unscaled_replica_batch_loss_train_vae\
                          -unscaled_replica_batch_loss_train_kld\
                          -unscaled_replica_batch_loss_train_posterior)
                scaled_replica_batch_loss_train = tf.reduce_sum(
                    unscaled_replica_batch_loss_train *
                    (1. / hyperp.batch_size))

            gradients = tape.gradient(scaled_replica_batch_loss_train,
                                      nn.trainable_variables)
            optimizer.apply_gradients(zip(gradients, nn.trainable_variables))
            metrics.mean_loss_train_vae(-unscaled_replica_batch_loss_train_vae)
            metrics.mean_loss_train_encoder(
                unscaled_replica_batch_loss_train_kld)
            metrics.mean_loss_train_posterior(
                unscaled_replica_batch_loss_train_posterior)

            return scaled_replica_batch_loss_train

        @tf.function
        def dist_train_step(batch_input_train, batch_latent_train):
            per_replica_losses = dist_strategy.experimental_run_v2(
                train_step, args=(batch_input_train, batch_latent_train))
            return dist_strategy.reduce(tf.distribute.ReduceOp.SUM,
                                        per_replica_losses,
                                        axis=None)

        #=== Validation Step ===#
        def val_step(batch_input_val, batch_latent_val):
            batch_post_mean_val, batch_log_post_std_val, batch_post_cov_chol_val\
                    = nn.encoder(batch_input_val)

            unscaled_replica_batch_loss_val_kld =\
                    loss_kld_full(
                            batch_post_mean_val, batch_log_post_std_val,
                            batch_post_cov_chol_val,
                            prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv,
                            1)
            unscaled_replica_batch_loss_val_posterior =\
                    (1-hyperp.penalty_js)/hyperp.penalty_js *\
                    2*tf.reduce_sum(batch_log_post_std_val,axis=1) +\
                    loss_weighted_post_cov_full_penalized_difference(
                            batch_latent_val, batch_post_mean_val,
                            batch_post_cov_chol_val,
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            unscaled_replica_batch_loss_val =\
                    -(-unscaled_replica_batch_loss_val_kld\
                      -unscaled_replica_batch_loss_val_posterior)

            metrics.mean_loss_val(unscaled_replica_batch_loss_val)
            metrics.mean_loss_val_encoder(unscaled_replica_batch_loss_val_kld)
            metrics.mean_loss_val_posterior(
                unscaled_replica_batch_loss_val_posterior)

        @tf.function
        def dist_val_step(batch_input_val, batch_latent_val):
            return dist_strategy.experimental_run_v2(
                val_step, (batch_input_val, batch_latent_val))

        #=== Test Step ===#
        def test_step(batch_input_test, batch_latent_test):
            batch_post_mean_test, batch_log_post_std_test, batch_post_cov_chol_test\
                    = nn.encoder(batch_input_test)

            unscaled_replica_batch_loss_test_kld =\
                    loss_kld_full(
                            batch_post_mean_test, batch_log_post_std_test,
                            batch_post_cov_chol_test,
                            prior_mean, prior_cov_inv, identity_otimes_prior_cov_inv,
                            1)
            unscaled_replica_batch_loss_test_posterior =\
                    (1-hyperp.penalty_js)/hyperp.penalty_js *\
                    2*tf.reduce_sum(batch_log_post_std_test,axis=1) +\
                    loss_weighted_post_cov_full_penalized_difference(
                            batch_latent_test, batch_post_mean_test,
                            batch_post_cov_chol_test,
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            unscaled_replica_batch_loss_test =\
                    -(-unscaled_replica_batch_loss_test_kld\
                      -unscaled_replica_batch_loss_test_posterior)

            metrics.mean_loss_test(unscaled_replica_batch_loss_test)
            metrics.mean_loss_test_encoder(
                unscaled_replica_batch_loss_test_kld)
            metrics.mean_loss_test_posterior(
                unscaled_replica_batch_loss_test_posterior)

            metrics.mean_relative_error_latent_posterior(
                relative_error(batch_latent_test, batch_post_mean_test))

        @tf.function
        def dist_test_step(batch_input_test, batch_latent_test):
            return dist_strategy.experimental_run_v2(
                test_step, (batch_input_test, batch_latent_test))

###############################################################################
#                             Train Neural Network                            #
###############################################################################

    print('Beginning Training')
    for epoch in range(hyperp.num_epochs):
        print('================================')
        print('            Epoch %d            ' % (epoch))
        print('================================')
        print('Project: ' + filepaths.case_name + '\n' + 'nn: ' +
              filepaths.nn_name + '\n')
        print('GPUs: ' + options.dist_which_gpus + '\n')
        print('Optimizing %d batches of size %d:' %
              (num_batches_train, hyperp.batch_size))
        start_time_epoch = time.time()
        batch_counter = 0
        total_loss_train = 0
        for batch_input_train, batch_latent_train in dist_input_and_latent_train:
            start_time_batch = time.time()
            #=== Compute Train Step ===#
            batch_loss_train = dist_train_step(batch_input_train,
                                               batch_latent_train)
            total_loss_train += batch_loss_train
            elapsed_time_batch = time.time() - start_time_batch
            if batch_counter == 0:
                print('Time per Batch: %.4f' % (elapsed_time_batch))
            batch_counter += 1
        metrics.mean_loss_train = total_loss_train / batch_counter

        #=== Computing Validation Metrics ===#
        for batch_input_val, batch_latent_val in dist_input_and_latent_val:
            dist_val_step(batch_input_val, batch_latent_val)

        #=== Computing Test Metrics ===#
        for batch_input_test, batch_latent_test in dist_input_and_latent_test:
            dist_test_step(batch_input_test, batch_latent_test)

        #=== Tensorboard Tracking Training Metrics, Weights and Gradients ===#
        metrics.update_tensorboard(summary_writer, epoch)

        #=== Update Storage Arrays ===#
        metrics.update_storage_arrays()

        #=== Display Epoch Iteration Information ===#
        elapsed_time_epoch = time.time() - start_time_epoch
        print('Time per Epoch: %.4f\n' % (elapsed_time_epoch))
        print('Train Loss: Full: %.3e, VAE: %.3e, KLD: %.3e, Posterior: %.3e'\
                %(metrics.mean_loss_train,
                  metrics.mean_loss_train_vae.result(),
                  metrics.mean_loss_train_encoder.result(),
                  metrics.mean_loss_train_posterior.result()))
        print('Val Loss: Full: %.3e, KLD: %.3e, Posterior: %.3e'\
                %(metrics.mean_loss_val.result(),
                  metrics.mean_loss_val_encoder.result(),
                  metrics.mean_loss_val_posterior.result()))
        print('Test Loss: Full: %.3e, KLD: %.3e, Posterior: %.3e'\
                %(metrics.mean_loss_test.result(),
                  metrics.mean_loss_test_encoder.result(),
                  metrics.mean_loss_val_posterior.result()))
        print('Rel Errors: Posterior Mean: %.3e\n'\
                %(metrics.mean_relative_error_latent_posterior.result()))
        start_time_epoch = time.time()

        #=== Resetting Metrics ===#
        metrics.reset_metrics()

        #=== Save Current Model and Metrics ===#
        if epoch % 5 == 0:
            nn.save_weights(filepaths.trained_nn)
            metrics.save_metrics(filepaths)
            dump_attrdict_as_yaml(hyperp, filepaths.directory_trained_nn,
                                  'hyperp')
            dump_attrdict_as_yaml(options, filepaths.directory_trained_nn,
                                  'options')
            print('Current Model and Metrics Saved')

    #=== Save Final Model ===#
    nn.save_weights(filepaths.trained_nn)
    metrics.save_metrics(filepaths)
    dump_attrdict_as_yaml(hyperp, filepaths.directory_trained_nn, 'hyperp')
    dump_attrdict_as_yaml(options, filepaths.directory_trained_nn, 'options')
    print('Final Model and Metrics Saved')
Пример #2
0
def optimize(hyperp, options, filepaths, nn, optimizer, input_and_latent_train,
             input_and_latent_val, input_and_latent_test, input_dimensions,
             latent_dimension, num_batches_train, noise_regularization_matrix,
             prior_mean, prior_covariance_cholesky_inverse):

    #=== Define Metrics ===#
    metrics = Metrics()

    #=== Creating Directory for Trained Neural Network ===#
    if not os.path.exists(filepaths.directory_trained_nn):
        os.makedirs(filepaths.directory_trained_nn)

    #=== Tensorboard ===# "tensorboard --logdir=tensorboard"
    if os.path.exists(filepaths.directory_tensorboard):
        shutil.rmtree(filepaths.directory_tensorboard)
    summary_writer = tf.summary.create_file_writer(
        filepaths.directory_tensorboard)

    #=== Display Neural Network Architecture ===#
    nn.build((hyperp.batch_size, input_dimensions))
    nn.summary()

    ###############################################################################
    #                   Training, Validation and Testing Step                     #
    ###############################################################################
    #=== Train Step ===#
    @tf.function
    def train_step(batch_input_train, batch_latent_train):
        with tf.GradientTape() as tape:
            batch_likelihood_train = nn(batch_input_train)
            batch_post_mean_train, batch_log_post_var_train = nn.encoder(
                batch_input_train)
            batch_posterior_sample_train = nn.iaf_chain_encoder(
                (batch_post_mean_train, batch_log_post_var_train),
                sample_flag=True,
                infer_flag=False)

            batch_loss_train_vae =\
                    loss_diagonal_weighted_penalized_difference(
                            batch_input_train, batch_likelihood_train,
                            noise_regularization_matrix,
                            1)
            batch_loss_train_iaf_encoder =\
                    nn.iaf_chain_encoder((batch_post_mean_train,
                                          batch_log_post_var_train),
                                          sample_flag = False,
                                          infer_flag = True)
            batch_loss_train_prior =\
                    loss_diagonal_weighted_penalized_difference(
                            prior_mean, batch_posterior_sample_train,
                            prior_covariance_cholesky_inverse,
                            1)
            batch_loss_train_posterior =\
                    loss_penalized_difference(
                            batch_latent_train, batch_posterior_sample_train,
                            (1-hyperp.penalty_js)/hyperp.penalty_js)

            batch_loss_train = -(-batch_loss_train_vae\
                                 -batch_loss_train_iaf_encoder\
                                 -batch_loss_train_prior\
                                 -batch_loss_train_posterior)
            batch_loss_train_mean = tf.reduce_mean(batch_loss_train, axis=0)

        gradients = tape.gradient(batch_loss_train_mean,
                                  nn.trainable_variables)
        optimizer.apply_gradients(zip(gradients, nn.trainable_variables))
        metrics.mean_loss_train(batch_loss_train)
        metrics.mean_loss_train_vae(batch_loss_train_vae)
        metrics.mean_loss_train_encoder(batch_loss_train_iaf_encoder)
        metrics.mean_loss_train_prior(batch_loss_train_prior)
        metrics.mean_loss_train_posterior(batch_loss_train_posterior)

        return gradients

    #=== Validation Step ===#
    @tf.function
    def val_step(batch_input_val, batch_latent_val):
        batch_likelihood_val = nn(batch_input_val)
        batch_post_mean_val, batch_log_post_var_val = nn.encoder(
            batch_input_val)
        batch_posterior_sample_val = nn.iaf_chain_encoder(
            (batch_post_mean_val, batch_log_post_var_val),
            sample_flag=True,
            infer_flag=False)

        batch_loss_val_vae =\
                loss_diagonal_weighted_penalized_difference(
                        batch_input_val, batch_likelihood_val,
                        noise_regularization_matrix,
                        1)
        batch_loss_val_iaf_encoder =\
                nn.iaf_chain_encoder((batch_post_mean_val,
                                      batch_log_post_var_val),
                                      sample_flag = False,
                                      infer_flag = True)
        batch_loss_val_prior =\
                loss_diagonal_weighted_penalized_difference(
                        prior_mean, batch_posterior_sample_val,
                        prior_covariance_cholesky_inverse,
                        1)
        batch_loss_val_posterior =\
                loss_penalized_difference(
                        batch_latent_val, batch_posterior_sample_val,
                        (1-hyperp.penalty_js)/hyperp.penalty_js)

        batch_loss_val = -(-batch_loss_val_vae\
                           -batch_loss_val_iaf_encoder\
                           -batch_loss_val_prior\
                           -batch_loss_val_posterior)

        metrics.mean_loss_val(batch_loss_val)
        metrics.mean_loss_val_vae(batch_loss_val_vae)
        metrics.mean_loss_val_encoder(batch_loss_val_iaf_encoder)
        metrics.mean_loss_val_prior(batch_loss_val_prior)
        metrics.mean_loss_val_posterior(batch_loss_val_posterior)

    #=== Test Step ===#
    @tf.function
    def test_step(batch_input_test, batch_latent_test):
        batch_likelihood_test = nn(batch_input_test)
        batch_post_mean_test, batch_log_post_var_test = nn.encoder(
            batch_input_test)
        batch_posterior_sample_test = nn.iaf_chain_encoder(
            (batch_post_mean_test, batch_log_post_var_test),
            sample_flag=True,
            infer_flag=False)
        batch_input_pred_test = nn.decoder(batch_latent_test)

        batch_loss_test_vae =\
                loss_diagonal_weighted_penalized_difference(
                        batch_input_test, batch_likelihood_test,
                        noise_regularization_matrix,
                        1)
        batch_loss_test_iaf_encoder =\
                nn.iaf_chain_encoder((batch_post_mean_test,
                                      batch_log_post_var_test),
                                      sample_flag = False,
                                      infer_flag = True)
        batch_loss_test_prior =\
                loss_diagonal_weighted_penalized_difference(
                        prior_mean, batch_posterior_sample_test,
                        prior_covariance_cholesky_inverse,
                        1)
        batch_loss_test_posterior =\
                loss_penalized_difference(
                        batch_latent_test, batch_posterior_sample_test,
                        (1-hyperp.penalty_js)/hyperp.penalty_js)

        batch_loss_test = -(-batch_loss_test_vae\
                            -batch_loss_test_iaf_encoder\
                            -batch_loss_test_prior\
                            -batch_loss_test_posterior)

        metrics.mean_loss_test(batch_loss_test)
        metrics.mean_loss_test_vae(batch_loss_test_vae)
        metrics.mean_loss_test_encoder(batch_loss_test_iaf_encoder)
        metrics.mean_loss_test_prior(batch_loss_test_prior)
        metrics.mean_loss_test_posterior(batch_loss_test_posterior)

        metrics.mean_relative_error_input_vae(
            relative_error(batch_input_test, batch_likelihood_test))
        metrics.mean_relative_error_latent_posterior(
            relative_error(batch_latent_test, batch_posterior_sample_test))
        metrics.mean_relative_error_input_decoder(
            relative_error(batch_input_test, batch_input_pred_test))

###############################################################################
#                             Train Neural Network                            #
###############################################################################

    print('Beginning Training')
    for epoch in range(hyperp.num_epochs):
        print('================================')
        print('            Epoch %d            ' % (epoch))
        print('================================')
        print('Project: ' + filepaths.case_name + '\n' + 'nn: ' +
              filepaths.nn_name + '\n')
        print('GPU: ' + options.which_gpu + '\n')
        print('Optimizing %d batches of size %d:' %
              (num_batches_train, hyperp.batch_size))
        start_time_epoch = time.time()
        for batch_num, (
                batch_input_train,
                batch_latent_train) in input_and_latent_train.enumerate():
            start_time_batch = time.time()
            #=== Computing Train Step ===#
            gradients = train_step(batch_input_train, batch_latent_train)
            elapsed_time_batch = time.time() - start_time_batch
            if batch_num == 0:
                print('Time per Batch: %.4f' % (elapsed_time_batch))

        #=== Computing Relative Errors Validation ===#
        for batch_input_val, batch_latent_val in input_and_latent_val:
            val_step(batch_input_val, batch_latent_val)

        #=== Computing Relative Errors Test ===#
        for batch_input_test, batch_latent_test in input_and_latent_test:
            test_step(batch_input_test, batch_latent_test)

        #=== Update Current Relative Gradient Norm ===#
        with summary_writer.as_default():
            for w in nn.weights:
                tf.summary.histogram(w.name, w, step=epoch)
            l2_norm = lambda t: tf.sqrt(tf.reduce_sum(tf.pow(t, 2)))
            sum_gradient_norms = 0.0
            for gradient, variable in zip(gradients, nn.trainable_variables):
                tf.summary.histogram("gradients_norm/" + variable.name,
                                     l2_norm(gradient),
                                     step=epoch)
                sum_gradient_norms += l2_norm(gradient)
                if epoch == 0:
                    initial_sum_gradient_norms = sum_gradient_norms
        metrics.relative_gradient_norm = sum_gradient_norms / initial_sum_gradient_norms

        #=== Track Training Metrics, Weights and Gradients ===#
        metrics.update_tensorboard(summary_writer, epoch)

        #=== Update Storage Arrays ===#
        metrics.update_storage_arrays()

        #=== Display Epoch Iteration Information ===#
        elapsed_time_epoch = time.time() - start_time_epoch
        print('Time per Epoch: %.4f\n' % (elapsed_time_epoch))
        print('Train Loss: Full: %.3e, VAE: %.3e, IAF_post: %.3e, prior: %.3e'\
                %(metrics.mean_loss_train.result(),
                  metrics.mean_loss_train_vae.result(),
                  metrics.mean_loss_train_encoder.result(),
                  metrics.mean_loss_train_prior.result()))
        print('Val Loss: Full: %.3e, VAE: %.3e, IAF_post: %.3e, prior: %.3e'\
                %(metrics.mean_loss_val.result(),
                  metrics.mean_loss_val_vae.result(),
                  metrics.mean_loss_val_encoder.result(),
                  metrics.mean_loss_val_prior.result()))
        print('Test Loss: Full: %.3e, VAE: %.3e, IAF_post: %.3e, prior: %.3e'\
                %(metrics.mean_loss_test.result(),
                  metrics.mean_loss_test_vae.result(),
                  metrics.mean_loss_test_encoder.result(),
                  metrics.mean_loss_test_prior.result()))
        print('Rel Errors: VAE: %.3e, Post Draw: %.3e, Decoder: %.3e\n'\
                %(metrics.mean_relative_error_input_vae.result(),
                  metrics.mean_relative_error_latent_posterior.result(),
                  metrics.mean_relative_error_input_decoder.result()))
        print('Relative Gradient Norm: %.4f\n' %
              (metrics.relative_gradient_norm))
        start_time_epoch = time.time()

        #=== Resetting Metrics ===#
        metrics.reset_metrics()

        #=== Saving Current Model and Metrics ===#
        if epoch % 100 == 0:
            nn.save_weights(filepaths.trained_nn)
            metrics.save_metrics(filepaths)
            dump_attrdict_as_yaml(hyperp, filepaths.directory_trained_nn,
                                  'hyperp')
            dump_attrdict_as_yaml(options, filepaths.directory_trained_nn,
                                  'options')
            print('Current Model and Metrics Saved')

        #=== Gradient Norm Termination Condition ===#
        if metrics.relative_gradient_norm < 1e-6:
            print('Gradient norm tolerance reached, breaking training loop')
            break

    #=== Save Final Model ===#
    nn.save_weights(filepaths.trained_nn)
    metrics.save_metrics(filepaths)
    dump_attrdict_as_yaml(hyperp, filepaths.directory_trained_nn, 'hyperp')
    dump_attrdict_as_yaml(options, filepaths.directory_trained_nn, 'options')
    print('Final Model and Metrics Saved')