def training(hyperp, options, filepaths, data_dict, prior_dict):

    #=== GPU Settings ===#
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if options.distributed_training == 0:
        os.environ["CUDA_VISIBLE_DEVICES"] = options.which_gpu
    if options.distributed_training == 1:
        os.environ["CUDA_VISIBLE_DEVICES"] = options.dist_which_gpus
        gpus = tf.config.experimental.list_physical_devices('GPU')

    #=== Construct Validation Set and Batches ===#
    input_and_latent_train, input_and_latent_val, input_and_latent_test,\
    num_batches_train, num_batches_val, num_batches_test\
    = form_train_val_test_tf_batches(
            data_dict["state_obs_train"], data_dict["parameter_train"],
            data_dict["state_obs_test"], data_dict["parameter_test"],
            hyperp.batch_size, options.random_seed)

    #=== Data and Latent Dimensions of Autoencoder ===#
    input_dimensions = data_dict["obs_dimensions"]
    latent_dimensions = options.parameter_dimensions

    #=== Load Forward Operator  ===#
    forward_operator = load_forward_operator_tf(options, filepaths)

    #=== Construct Forward Model ===#
    forward_model = SolveForward1D(options, filepaths, forward_operator,
                                   data_dict["obs_indices"])
    if options.discrete_polynomial == True:
        forward_model_solve = forward_model.discrete_polynomial
    if options.discrete_exponential == True:
        forward_model_solve = forward_model.discrete_exponential

    #=== Neural Network Regularizers ===#
    kernel_initializer = tf.keras.initializers.RandomNormal(mean=0.0,
                                                            stddev=0.05)
    bias_initializer = 'zeros'

    #=== Non-distributed Training ===#
    if options.distributed_training == 0:
        #=== Neural Network ===#
        nn = VAE(hyperp, options, input_dimensions, latent_dimensions,
                 kernel_initializer, bias_initializer,
                 positivity_constraint_log_exp)

        #=== Optimizer ===#
        optimizer = tf.keras.optimizers.Adam()

        #=== Training ===#
        optimize(hyperp, options, filepaths, nn, optimizer,
                 input_and_latent_train, input_and_latent_val,
                 input_and_latent_test, input_dimensions, latent_dimensions,
                 num_batches_train, data_dict["noise_regularization_matrix"],
                 prior_dict["prior_mean"],
                 prior_dict["prior_covariance_inverse"], forward_model_solve)

    #=== Distributed Training ===#
    if options.distributed_training == 1:
        dist_strategy = tf.distribute.MirroredStrategy()
        with dist_strategy.scope():
            #=== Neural Network ===#
            nn = VAE(hyperp, options, input_dimensions, latent_dimensions,
                     kernel_initializer, bias_initializer,
                     positivity_constraint_log_exp)

            #=== Optimizer ===#
            optimizer = tf.keras.optimizers.Adam()

        #=== Training ===#
        optimize_distributed(
            dist_strategy, hyperp, options, filepaths, nn, optimizer,
            input_and_latent_train, input_and_latent_val,
            input_and_latent_test, input_dimensions, latent_dimensions,
            num_batches_train, data_dict["noise_regularization_matrix"],
            prior_dict["prior_mean"], prior_dict["prior_covariance_inverse"],
            forward_model_solve)
def predict_and_plot(hyperp, options, filepaths):

    #=== Mesh Properties ===#
    options.mesh_point_1 = [-1, -1]
    options.mesh_point_2 = [1, 1]

    # options.nx = 15
    # options.ny = 15

    options.nx = 50
    options.ny = 50

    options.num_obs_points = 10
    options.order_fe_space = 1
    options.order_meta_space = 1
    options.num_nodes = (options.nx + 1) * (options.ny + 1)

    #=== Construct Mesh ===#
    fe_space, meta_space,\
    nodes, dof_fe, dof_meta = construct_mesh(options)

    #=== Load Observation Indices ===#
    if options.obs_type == 'full':
        obs_dimensions = options.parameter_dimensions
    if options.obs_type == 'obs':
        obs_dimensions = options.num_obs_points
        print('Loading Boundary Indices')
        df_obs_indices = pd.read_csv(filepaths.project.obs_indices + '.csv')
        obs_indices = df_obs_indices.to_numpy()

    #=== Data and Latent Dimensions of Autoencoder ===#
    input_dimensions = obs_dimensions
    latent_dimensions = options.parameter_dimensions

    #=== Prepare Data ===#
    data = DataHandler(hyperp, options, filepaths, obs_indices,
                       options.parameter_dimensions, obs_dimensions,
                       options.parameter_dimensions)
    data.load_data_test()
    if options.add_noise == True:
        data.add_noise_qoi_test()
    parameter_test = data.poi_test
    state_obs_test = data.qoi_test

    #=== Load Trained Neural Network ===#
    nn = VAE(hyperp, options, input_dimensions, latent_dimensions, None, None,
             positivity_constraint_log_exp)
    nn.load_weights(filepaths.trained_nn)

    #=== Selecting Samples ===#
    # sample_number = 1
    sample_number = 128
    parameter_test_sample = np.expand_dims(parameter_test[sample_number, :], 0)
    state_obs_test_sample = np.expand_dims(state_obs_test[sample_number, :], 0)

    #=== Saving Specific Sample ===#
    df_poi_specific = pd.DataFrame(
        {'poi_specific': parameter_test_sample.flatten()})
    df_poi_specific.to_csv(filepaths.poi_specific + '.csv', index=False)
    df_qoi_specific = pd.DataFrame(
        {'qoi_specific': state_obs_test_sample.flatten()})
    df_qoi_specific.to_csv(filepaths.qoi_specific + '.csv', index=False)

    ##=== Predictions ===#
    num_draws = 20
    #for draw in range(0,num_draws):
    #    start_time_nn = time.time()
    #    posterior_mean_pred, posterior_cov_pred = nn.encoder(state_obs_test_sample)
    #    elapsed_time_nn = time.time() - start_time_nn
    #    print('Time taken for neural network inference: %.4f' %(elapsed_time_nn))
    #    posterior_pred_draw = nn.reparameterize(posterior_mean_pred, posterior_cov_pred)

    #    posterior_mean_pred = posterior_mean_pred.numpy().flatten()
    #    posterior_cov_pred = posterior_cov_pred.numpy().flatten()
    #    posterior_pred_draw = posterior_pred_draw.numpy().flatten()

    #    if options.model_aware == 1:
    #        state_obs_pred_draw = nn.decoder(np.expand_dims(posterior_pred_draw, 0))
    #        state_obs_pred_draw = state_obs_pred_draw.numpy().flatten()

    #    #=== Plotting Prediction ===#
    #    print('================================')
    #    print('      Plotting Predictions      ')
    #    print('================================')

    #    #=== Plot FEM Functions ===#
    #    # cross_section_y = 0.5
    #    cross_section_y = 0.0
    #    plot_parameter_min = 0
    #    plot_parameter_max = 6
    #    plot_variance_min = 0
    #    plot_variance_max = 1.3
    #    filename_extension = '_%d.png'%(sample_number)
    #    filename_extension_draw = '_%d_%d.png'%(sample_number,draw)
    #    plot_fem_function_fenics_2d(meta_space, parameter_test_sample,
    #                                cross_section_y,
    #                                '',
    #                                filepaths.figure_parameter_test + filename_extension,
    #                                (5,5), (plot_parameter_min,plot_parameter_max),
    #                                True)
    #    plot_fem_function_fenics_2d(meta_space, posterior_mean_pred,
    #                                cross_section_y,
    #                                '',
    #                                filepaths.figure_posterior_mean + filename_extension,
    #                                (5,5), (plot_parameter_min,plot_parameter_max),
    #                                False)
    #    plot_fem_function_fenics_2d(meta_space, posterior_pred_draw,
    #                                cross_section_y,
    #                                '',
    #                                filepaths.figure_parameter_pred + filename_extension_draw,
    #                                (5,5), (plot_parameter_min,plot_parameter_max),
    #                                True)
    #    if options.obs_type == 'full':
    #        plot_fem_function_fenics_2d(meta_space, state_obs_test_sample,
    #                                    cross_section_y,
    #                                    'True State',
    #                                    filepaths.figure_state_test + filename_extension,
    #                                    (5,5))
    #        plot_fem_function_fenics_2d(meta_space, state_obs_pred_draw,
    #                                    cross_section_y,
    #                                    'State Prediction',
    #                                    filepaths.figure_state_pred + filename_extension,
    #                                    (5,5))

    #    #=== Plot Cross-Section with Error Bounds ===#
    #    plot_cross_section(meta_space,
    #                    parameter_test_sample, posterior_mean_pred, posterior_cov_pred,
    #                    (-1,1), cross_section_y,
    #                    '',
    #                    filepaths.figure_parameter_cross_section + filename_extension,
    #                    (plot_parameter_min,plot_parameter_max))
    #    plot_cross_section(meta_space,
    #                    parameter_test_sample, posterior_pred_draw, posterior_cov_pred,
    #                    (-1,1), cross_section_y,
    #                    '',
    #                    filepaths.figure_parameter_cross_section + filename_extension_draw,
    #                    (plot_parameter_min,plot_parameter_max))

    #    #=== Plot Variation ===#
    #    plot_fem_function_fenics_2d(meta_space, np.exp(posterior_cov_pred),
    #                                cross_section_y,
    #                                '',
    #                                filepaths.figure_posterior_covariance + filename_extension,
    #                                (5,5), (plot_variance_min,plot_variance_max),
    #                                False)

    #    print('Predictions plotted')

    #=== Make Movie ===#
    sample_number = 128
    make_movie(filepaths.figure_parameter_pred + '_%d' % (sample_number),
               filepaths.directory_movie, 'parameter_pred', 2, 0, num_draws)

    make_movie(
        filepaths.figure_parameter_cross_section + '_%d' % (sample_number),
        filepaths.directory_movie, 'parameter_cross_section', 2, 0, num_draws)

    combine_movies(filepaths.directory_movie + '/parameter_pred',
                   filepaths.directory_movie + '/parameter_cross_section',
                   filepaths.directory_movie,
                   'parameter_pred_and_parameter_cross_section')
def predict_and_plot(hyperp, options, filepaths):

    #=== Load Observation Indices ===#
    if options.obs_type == 'full':
        obs_dimensions = options.parameter_dimensions
    if options.obs_type == 'obs':
        obs_dimensions = options.num_obs_points
        print('Loading Boundary Indices')
        df_obs_indices = pd.read_csv(filepaths.project.obs_indices + '.csv')
        obs_indices = df_obs_indices.to_numpy()

    #=== Data and Latent Dimensions of Autoencoder ===#
    input_dimensions = obs_dimensions
    latent_dimensions = options.parameter_dimensions

    #=== Prepare Data ===#
    data = DataHandler(hyperp, options, filepaths,
                       options.parameter_dimensions, obs_dimensions)
    data.load_data_test()
    if options.add_noise == 1:
        data.add_noise_qoi_test()
    parameter_test = data.poi_test
    state_obs_test = data.qoi_test

    #=== Load Trained Neural Network ===#
    nn = VAE(hyperp, options,
             input_dimensions, latent_dimensions,
             None, None,
             positivity_constraint_log_exp)
    nn.load_weights(filepaths.trained_nn)

    #=== Selecting Samples ===#
    sample_number = 105
    parameter_test_sample = np.expand_dims(parameter_test[sample_number,:], 0)
    state_obs_test_sample = np.expand_dims(state_obs_test[sample_number,:], 0)

    #=== Predictions ===#
    parameter_pred_sample, _ = nn.encoder(state_obs_test_sample)
    state_obs_pred_sample = nn.decoder(parameter_test_sample)
    parameter_pred_sample = parameter_pred_sample.numpy().flatten()
    state_obs_pred_sample = state_obs_pred_sample.numpy().flatten()

    #=== Plotting Prediction ===#
    print('================================')
    print('      Plotting Predictions      ')
    print('================================')
    #=== Load Mesh ===#
    nodes, elements, _, _, _, _, _, _ = load_mesh(filepaths.project)

    #=== Plot FEM Functions ===#
    plot_fem_function(filepaths.figure_parameter_test,
                     'True Parameter', 7.0,
                      nodes, elements,
                      parameter_test_sample)
    plot_fem_function(filepaths.figure_parameter_pred,
                      'Parameter Prediction', 7.0,
                      nodes, elements,
                      parameter_pred_sample)
    if options.obs_type == 'full':
        plot_fem_function(filepaths.figure_state_test,
                          'True State', 2.6,
                          nodes, elements,
                          state_obs_test_sample)
        plot_fem_function(filepaths.figure_state_pred,
                          'State Prediction', 2.6,
                          nodes, elements,
                          state_obs_pred_sample)

    print('Predictions plotted')
def predict_and_plot(hyperp, options, filepaths):

    #=== Mesh Properties ===#
    options.hole_single_circle = False
    options.hole_two_rectangles = True
    options.discretization_domain = 17
    options.domain_length = 1
    options.domain_width = 1
    options.rect_1_point_1 = [0.25, 0.15]
    options.rect_1_point_2 = [0.5, 0.4]
    options.rect_2_point_1 = [0.6, 0.6]
    options.rect_2_point_2 = [0.75, 0.85]

    #=== Construct Mesh ===#
    Vh, nodes, dof = construct_mesh(options)

    #=== Load Observation Indices ===#
    obs_dimensions = options.num_obs_points * options.num_time_steps
    print('Loading Boundary Indices')
    df_obs_indices = pd.read_csv(filepaths.project.obs_indices + '.csv')
    obs_indices = df_obs_indices.to_numpy()

    #=== Data and Latent Dimensions of Autoencoder ===#
    input_dimensions = obs_dimensions
    latent_dimensions = options.parameter_dimensions

    #=== Prepare Data ===#
    data = DataHandler(hyperp, options, filepaths,
                       options.parameter_dimensions, obs_dimensions)
    # data.load_data_specific()
    # if options.add_noise == 1:
    #     data.add_noise_qoi_specific()
    # parameter_test = data.poi_specific
    # state_obs_test = data.qoi_specific

    data.load_data_test()
    if options.add_noise == True:
        data.add_noise_qoi_test()
    parameter_test = data.poi_test
    state_obs_test = data.qoi_test

    #=== Load Trained Neural Network ===#
    nn = VAE(hyperp, options, input_dimensions, latent_dimensions, None, None,
             positivity_constraint_log_exp)
    nn.load_weights(filepaths.trained_nn)

    #=== Selecting Samples ===#
    sample_number = 15
    parameter_test_sample = np.expand_dims(parameter_test[sample_number, :], 0)
    state_obs_test_sample = np.expand_dims(state_obs_test[sample_number, :], 0)

    #=== Predictions ===#
    posterior_mean_pred, posterior_cov_pred = nn.encoder(state_obs_test_sample)
    posterior_pred_draw = nn.reparameterize(posterior_mean_pred,
                                            posterior_cov_pred)

    posterior_mean_pred = posterior_mean_pred.numpy().flatten()
    posterior_cov_pred = posterior_cov_pred.numpy().flatten()
    posterior_pred_draw = posterior_pred_draw.numpy().flatten()

    if options.model_aware == 1:
        state_obs_pred_draw = nn.decoder(np.expand_dims(
            posterior_pred_draw, 0))
        state_obs_pred_draw = state_obs_pred_draw.numpy().flatten()

    #=== Plotting Prediction ===#
    print('================================')
    print('      Plotting Predictions      ')
    print('================================')

    #=== Plot FEM Functions ===#
    cross_section_y = 0.8
    filename_extension = '_%d.png' % (sample_number)
    plot_fem_function_fenics_2d(
        Vh, parameter_test_sample, cross_section_y, '',
        filepaths.figure_parameter_test + filename_extension, (5, 5), (0, 5),
        False)
    plot_fem_function_fenics_2d(
        Vh, posterior_mean_pred, cross_section_y, '',
        filepaths.figure_posterior_mean + filename_extension, (5, 5), (0, 5),
        True)
    plot_fem_function_fenics_2d(
        Vh, posterior_pred_draw, cross_section_y, '',
        filepaths.figure_parameter_pred + filename_extension, (5, 5), (0, 5),
        True)
    if options.obs_type == 'full':
        plot_fem_function_fenics_2d(
            Vh, state_obs_test_sample, cross_section_y, 'True State',
            filepaths.figure_state_test + filename_extension, (5, 5))
        plot_fem_function_fenics_2d(
            Vh, state_obs_pred_draw, cross_section_y, 'State Prediction',
            filepaths.figure_state_pred + filename_extension, (5, 5))

    #=== Plot Cross-Section with Error Bounds ===#
    plot_cross_section(
        Vh, parameter_test_sample, posterior_mean_pred, posterior_cov_pred,
        (0, 1), cross_section_y, '',
        filepaths.figure_parameter_cross_section + filename_extension, (0, 5))

    print('Predictions plotted')