def train_model_energy_gradient(i=0, out_dir=None, mode='training'):
    """Train an energy plus gradient model. Uses precomputed feature and model representation.

    Args:
        i (int, optional): Model index. The default is 0.
        out_dir (str, optional): Directory for fit output. The default is None.
        mode (str, optional): Fit-mode to take from hyperparameters. The default is 'training'.

    Raises:
        ValueError: Wrong input shape.

    Returns:
        error_val (list): Validation error for (energy,gradient).

    """
    i = int(i)
    # Load everything from folder
    training_config = load_json_file(
        os.path.join(out_dir, mode + "_config.json"))
    model_config = load_json_file(os.path.join(out_dir, "model_config.json"))
    i_train = np.load(os.path.join(out_dir, "train_index.npy"))
    i_val = np.load(os.path.join(out_dir, "test_index.npy"))
    scaler_config = load_json_file(os.path.join(out_dir, "scaler_config.json"))

    # Info from Config
    num_atoms = int(model_config["config"]["atoms"])
    unit_label_energy = training_config['unit_energy']
    unit_label_grad = training_config['unit_gradient']
    energies_only = model_config["config"]['energy_only']
    epo = training_config['epo']
    batch_size = training_config['batch_size']
    epostep = training_config['epostep']
    initialize_weights = training_config['initialize_weights']
    learning_rate = training_config['learning_rate']
    loss_weights = training_config['loss_weights']
    use_callbacks = list(training_config["callbacks"])

    # Load data.
    data_dir = os.path.dirname(out_dir)
    xyz = read_xyz_file(os.path.join(data_dir, "geometries.xyz"))
    x = np.array([x[1] for x in xyz])
    if x.shape[1] != num_atoms:
        raise ValueError(
            f"Mismatch Shape between {x.shape} model and data {num_atoms}")
    y1 = np.array(load_json_file(os.path.join(data_dir, "energies.json")))
    y2 = np.array(load_json_file(os.path.join(data_dir, "forces.json")))
    print("INFO: Shape of y", y1.shape, y2.shape)
    y = [y1, y2]

    # Fit stats dir
    dir_save = os.path.join(out_dir, "fit_stats")
    os.makedirs(dir_save, exist_ok=True)

    # cbks, Learning rate schedule
    cbks = []
    for x in use_callbacks:
        if isinstance(x, dict):
            # tf.keras.utils.get_registered_object()
            cb = tf.keras.utils.deserialize_keras_object(x)
            cbks.append(cb)

    # Index train test split
    print("Info: Train-Test split at Train:", len(i_train), "Test", len(i_val),
          "Total", len(x))

    # Make all Model
    assert model_config[
        "class_name"] == "EnergyGradientModel", "Training script only for EnergyGradientModel"
    out_model = EnergyGradientModel(**model_config["config"])
    out_model.precomputed_features = True
    out_model.output_as_dict = True
    out_model.energy_only = energies_only

    # Look for loading weights
    npeps = np.finfo(float).eps
    if not initialize_weights:
        out_model.load_weights(os.path.join(out_dir, "model_weights.h5"))
        print("Info: Load old weights at:",
              os.path.join(out_dir, "model_weights.h5"))
        print("Info: Transferring weights...")
    else:
        print("Info: Making new initialized weights.")

    # Scale x,y
    scaler = EnergyGradientStandardScaler(**scaler_config["config"])
    scaler.fit(x[i_train], [y[0][i_train], y[1][i_train]])
    x_rescale, y_rescale = scaler.transform(x, y)
    y1, y2 = y_rescale

    # Model + Model precompute layer +feat
    feat_x, feat_grad = out_model.precompute_feature_in_chunks(
        x_rescale, batch_size=batch_size)

    # Train Test split
    xtrain = [feat_x[i_train], feat_grad[i_train]]
    ytrain = [y1[i_train], y2[i_train]]
    xval = [feat_x[i_val], feat_grad[i_val]]
    yval = [y1[i_val], y2[i_val]]

    # Setting constant feature normalization
    optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
    lr_metric = get_lr_metric(optimizer)
    mae_energy = ScaledMeanAbsoluteError(scaling_shape=scaler.energy_std.shape)
    mae_force = ScaledMeanAbsoluteError(
        scaling_shape=scaler.gradient_std.shape)
    mae_energy.set_scale(scaler.energy_std)
    mae_force.set_scale(scaler.gradient_std)
    if energies_only:
        train_loss = {'energy': 'mean_squared_error', 'force': ZeroEmptyLoss()}
    else:
        train_loss = {
            'energy': 'mean_squared_error',
            'force': 'mean_squared_error'
        }
    out_model.compile(optimizer=optimizer,
                      loss=train_loss,
                      loss_weights=loss_weights,
                      metrics={
                          'energy': [mae_energy, lr_metric, r2_metric],
                          'force': [mae_force, lr_metric, r2_metric]
                      })

    scaler.print_params_info()

    print("")
    print("Start fit.")
    out_model.summary()
    hist = out_model.fit(x=xtrain,
                         y={
                             'energy': ytrain[0],
                             'force': ytrain[1]
                         },
                         epochs=epo,
                         batch_size=batch_size,
                         callbacks=cbks,
                         validation_freq=epostep,
                         validation_data=(xval, {
                             'energy': yval[0],
                             'force': yval[1]
                         }),
                         verbose=2)
    print("End fit.")
    print("")
    out_model.energy_only = False

    outname = os.path.join(dir_save, "history.json")
    outhist = {
        a: np.array(b, dtype=np.float64).tolist()
        for a, b in hist.history.items()
    }
    with open(outname, 'w') as f:
        json.dump(outhist, f)

    print("Info: Saving auto-scaler to file...")
    scaler.save_weights(os.path.join(out_dir, "scaler_weights.npy"))

    # Plot and Save
    yval_plot = [y[0][i_val], y[1][i_val]]
    ytrain_plot = [y[0][i_train], y[1][i_train]]
    # Convert back scaler
    pval = out_model.predict(xval)
    ptrain = out_model.predict(xtrain)
    _, pval = scaler.inverse_transform(y=[pval['energy'], pval['force']])
    _, ptrain = scaler.inverse_transform(y=[ptrain['energy'], ptrain['force']])

    print("Info: Predicted Energy shape:", ptrain[0].shape)
    print("Info: Predicted Gradient shape:", ptrain[1].shape)
    print("Info: Plot fit stats...")

    # Plot
    plot_loss_curves([
        hist.history['energy_mean_absolute_error'],
        hist.history['force_mean_absolute_error']
    ], [
        hist.history['val_energy_mean_absolute_error'],
        hist.history['val_force_mean_absolute_error']
    ],
                     label_curves=["energy", "force"],
                     val_step=epostep,
                     save_plot_to_file=True,
                     dir_save=dir_save,
                     filename='fit' + str(i),
                     filetypeout='.png',
                     unit_loss=unit_label_energy,
                     loss_name="MAE",
                     plot_title="Energy")

    plot_learning_curve(hist.history['energy_lr'],
                        filename='fit' + str(i),
                        dir_save=dir_save)

    plot_scatter_prediction(pval[0],
                            yval_plot[0],
                            save_plot_to_file=True,
                            dir_save=dir_save,
                            filename='fit' + str(i) + "_energy",
                            filetypeout='.png',
                            unit_actual=unit_label_energy,
                            unit_predicted=unit_label_energy,
                            plot_title="Prediction Energy")

    plot_scatter_prediction(pval[1],
                            yval_plot[1],
                            save_plot_to_file=True,
                            dir_save=dir_save,
                            filename='fit' + str(i) + "_grad",
                            filetypeout='.png',
                            unit_actual=unit_label_grad,
                            unit_predicted=unit_label_grad,
                            plot_title="Prediction Gradient")

    plot_error_vec_mean(
        [pval[1], ptrain[1]], [yval_plot[1], ytrain_plot[1]],
        label_curves=["Validation gradients", "Training Gradients"],
        unit_predicted=unit_label_grad,
        filename='fit' + str(i) + "_grad",
        dir_save=dir_save,
        save_plot_to_file=True,
        filetypeout='.png',
        x_label='Gradients xyz * #atoms * #states ',
        plot_title="Gradient mean error")

    plot_error_vec_max([pval[1], ptrain[1]], [yval_plot[1], ytrain_plot[1]],
                       label_curves=["Validation", "Training"],
                       unit_predicted=unit_label_grad,
                       filename='fit' + str(i) + "_grad",
                       dir_save=dir_save,
                       save_plot_to_file=True,
                       filetypeout='.png',
                       x_label='Gradients xyz * #atoms * #states ',
                       plot_title="Gradient max error")

    pval = out_model.predict(xval)
    ptrain = out_model.predict(xtrain)
    _, pval = scaler.inverse_transform(y=[pval['energy'], pval['force']])
    _, ptrain = scaler.inverse_transform(y=[ptrain['energy'], ptrain['force']])
    out_model.precomputed_features = False
    out_model.output_as_dict = False
    ptrain2 = out_model.predict(x_rescale[i_train])
    _, ptrain2 = scaler.inverse_transform(y=[ptrain2[0], ptrain2[1]])
    print("Info: Max error precomputed and full gradient computation:")
    print("Energy", np.max(np.abs(ptrain[0] - ptrain2[0])))
    print("Gradient", np.max(np.abs(ptrain[1] - ptrain2[1])))
    error_val = [
        np.mean(np.abs(pval[0] - y[0][i_val])),
        np.mean(np.abs(pval[1] - y[1][i_val]))
    ]
    error_train = [
        np.mean(np.abs(ptrain[0] - y[0][i_train])),
        np.mean(np.abs(ptrain[1] - y[1][i_train]))
    ]
    print("error_val:", error_val)
    print("error_train:", error_train)
    error_dict = {
        "train": [error_train[0].tolist(), error_train[1].tolist()],
        "valid": [error_val[0].tolist(), error_val[1].tolist()]
    }
    save_json_file(error_dict, os.path.join(out_dir, "fit_error.json"))

    print("Info: Saving model to file...")
    out_model.precomputed_features = False
    out_model.save_weights(os.path.join(out_dir, "model_weights.h5"))
    out_model.save(os.path.join(out_dir, "model_tf"))

    return error_val
def train_model_energy(i=0, out_dir=None, mode='training'):
    r"""Train an energy model. Uses precomputed feature. Always require scaler.

    Args:
        i (int, optional): Model index. The default is 0.
        out_dir (str, optional): Directory for this training. The default is None.
        mode (str, optional): Fit-mode to take from hyper-parameters. The default is 'training'.

    Raises:
        ValueError: Wrong input shape.

    Returns:
        error_val (list): Validation error for (energy,gradient).
    """
    i = int(i)
    np_eps = np.finfo(float).eps

    # Load everything from folder
    training_config = load_json_file(
        os.path.join(out_dir, mode + "_config.json"))
    model_config = load_json_file(os.path.join(out_dir, "model_config.json"))
    i_train = np.load(os.path.join(out_dir, "train_index.npy"))
    i_val = np.load(os.path.join(out_dir, "test_index.npy"))
    scaler_config = load_json_file(os.path.join(out_dir, "scaler_config.json"))

    # training parameters
    unit_label_energy = training_config['unit_energy']
    epo = training_config['epo']
    batch_size = training_config['batch_size']
    epostep = training_config['epostep']
    initialize_weights = training_config['initialize_weights']
    learning_rate = training_config['learning_rate']
    use_callbacks = training_config['callbacks']
    range_dist = model_config["config"]["schnet_kwargs"]["gauss_args"][
        "distance"]

    # Load data.
    data_dir = os.path.dirname(out_dir)
    xyz = read_xyz_file(os.path.join(data_dir, "geometries.xyz"))
    coords = [np.array(x[1]) for x in xyz]
    atoms = [np.array([global_proton_dict[at] for at in x[0]]) for x in xyz]
    range_indices = [
        define_adjacency_from_distance(coordinates_to_distancematrix(x),
                                       max_distance=range_dist)[1]
        for x in coords
    ]
    y = load_json_file(os.path.join(data_dir, "energies.json"))
    y = np.array(y)

    # Fit stats dir
    dir_save = os.path.join(out_dir, "fit_stats")
    os.makedirs(dir_save, exist_ok=True)

    # cbks,Learning rate schedule
    cbks = []
    for x in use_callbacks:
        if isinstance(x, dict):
            # tf.keras.utils.get_registered_object()
            cb = tf.keras.utils.deserialize_keras_object(x)
            cbks.append(cb)

    # Make Model
    # Only works for Energy model here
    assert model_config[
        "class_name"] == "SchnetEnergy", "Training script only for SchnetEnergy"
    out_model = SchnetEnergy(**model_config["config"])

    # Look for loading weights
    if not initialize_weights:
        out_model.load_weights(os.path.join(out_dir, "model_weights.h5"))
        print("Info: Load old weights at:",
              os.path.join(out_dir, "model_weights.h5"))
    else:
        print("Info: Making new initialized weights.")

    # Recalculate standardization
    scaler = EnergyStandardScaler(**scaler_config["config"])
    scaler.fit(x=None, y=y[i_train])
    _, y1 = scaler.transform(x=None, y=y)

    # Train Test split
    xtrain = [
        ragged_tensor_from_nested_numpy([atoms[i] for i in i_train]),
        ragged_tensor_from_nested_numpy([coords[i] for i in i_train]),
        ragged_tensor_from_nested_numpy([range_indices[i] for i in i_train])
    ]
    xval = [
        ragged_tensor_from_nested_numpy([atoms[i] for i in i_val]),
        ragged_tensor_from_nested_numpy([coords[i] for i in i_val]),
        ragged_tensor_from_nested_numpy([range_indices[i] for i in i_val])
    ]
    ytrain = y1[i_train]
    yval = y1[i_val]

    # Compile model
    # This is only for metric to without std.
    scaled_metric = ScaledMeanAbsoluteError(
        scaling_shape=scaler.energy_std.shape)
    scaled_metric.set_scale(scaler.energy_std)
    optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
    lr_metric = get_lr_metric(optimizer)
    out_model.compile(optimizer=optimizer,
                      loss='mean_squared_error',
                      metrics=[scaled_metric, lr_metric, r2_metric])

    scaler.print_params_info()

    out_model.summary()
    print("")
    print("Start fit.")
    hist = out_model.fit(x=xtrain,
                         y=ytrain,
                         epochs=epo,
                         batch_size=batch_size,
                         callbacks=cbks,
                         validation_freq=epostep,
                         validation_data=(xval, yval),
                         verbose=2)
    print("End fit.")
    print("")

    outname = os.path.join(dir_save, "history.json")
    outhist = {
        a: np.array(b, dtype=np.float64).tolist()
        for a, b in hist.history.items()
    }
    with open(outname, 'w') as f:
        json.dump(outhist, f)

    print("Info: Saving auto-scaler to file...")
    scaler.save_weights(os.path.join(out_dir, "scaler_weights.npy"))

    # Plot and Save
    yval_plot = y[i_val]
    ytrain_plot = y[i_train]
    # Convert back scaler
    pval = out_model.predict(xval)
    ptrain = out_model.predict(xtrain)
    _, pval = scaler.inverse_transform(y=pval)
    _, ptrain = scaler.inverse_transform(y=ptrain)

    print("Info: Predicted Energy shape:", ptrain.shape)
    print("Info: Predicted Gradient shape:", ptrain.shape)
    print("Info: Plot fit stats...")

    # Plot
    plot_loss_curves(hist.history['mean_absolute_error'],
                     hist.history['val_mean_absolute_error'],
                     val_step=epostep,
                     save_plot_to_file=True,
                     dir_save=dir_save,
                     filename='fit' + str(i),
                     filetypeout='.png',
                     unit_loss=unit_label_energy,
                     loss_name="MAE",
                     plot_title="Energy")

    plot_scatter_prediction(pval,
                            yval_plot,
                            save_plot_to_file=True,
                            dir_save=dir_save,
                            filename='fit' + str(i),
                            filetypeout='.png',
                            unit_actual=unit_label_energy,
                            unit_predicted=unit_label_energy,
                            plot_title="Prediction")

    plot_learning_curve(hist.history['lr'],
                        filename='fit' + str(i),
                        dir_save=dir_save)

    # Safe fitting Error MAE
    pval = out_model.predict(xval)
    ptrain = out_model.predict(xtrain)
    _, pval = scaler.inverse_transform(y=pval)
    _, ptrain = scaler.inverse_transform(y=ptrain)

    error_val = np.mean(np.abs(pval - y[i_val]))
    error_train = np.mean(np.abs(ptrain - y[i_train]))
    print("error_val:", error_val)
    print("error_train:", error_train)
    error_dict = {"train": error_train.tolist(), "valid": error_val.tolist()}
    save_json_file(error_dict, os.path.join(out_dir, "fit_error.json"))

    print("Info: Saving model to file...")
    out_model.save_weights(os.path.join(out_dir, "model_weights.h5"))
    out_model.save(os.path.join(out_dir, "model_tf"))

    return error_val
def train_model_nac(i=0, out_dir=None, mode='training'):
    """
    Train NAC model. Uses precomputed feature and model representation.

    Args:
        i (int, optional): Model index. The default is 0.
        out_dir (str, optional): Direcotry for fit output. The default is None.
        mode (str, optional): Fitmode to take from hyperparameters. The default is 'training'.

    Raises:
        ValueError: Wrong input shape.

    Returns:
        error_val (list): Validation error for NAC.
    """
    i = int(i)
    # Load everything from folder
    training_config = load_json_file(os.path.join(out_dir, mode + "_config.json"))
    model_config = load_json_file(os.path.join(out_dir, "model_config.json"))
    i_train = np.load(os.path.join(out_dir, "train_index.npy"))
    i_val = np.load(os.path.join(out_dir, "test_index.npy"))
    scaler_config = load_json_file(os.path.join(out_dir, "scaler_config.json"))

    # Model
    num_outstates = int(model_config["config"]['states'])
    num_atoms = int(model_config["config"]['atoms'])
    unit_label_nac = training_config['unit_nac']
    phase_less_loss = training_config['phase_less_loss']
    epo = training_config['epo']
    batch_size = training_config['batch_size']
    epostep = training_config['epostep']
    pre_epo = training_config['pre_epo']
    initialize_weights = training_config['initialize_weights']
    learning_rate = training_config['learning_rate']
    use_callbacks = list(training_config["callbacks"])

    # Data Check here:
    data_dir = os.path.dirname(out_dir)
    xyz = read_xyz_file(os.path.join(data_dir, "geometries.xyz"))
    x = np.array([x[1] for x in xyz])
    if x.shape[1] != num_atoms:
        raise ValueError(f"Mismatch Shape between {x.shape} model and data {num_atoms}")
    y_in = np.load(os.path.join(data_dir, "couplings.npy"))
    print("INFO: Shape of y", y_in.shape)

    # Set stat dir
    dir_save = os.path.join(out_dir, "fit_stats")
    os.makedirs(dir_save, exist_ok=True)

    # cbks,Learning rate schedule
    cbks = []
    for x in use_callbacks:
        if isinstance(x, dict):
            # tf.keras.utils.get_registered_object()
            cb = tf.keras.utils.deserialize_keras_object(x)
            cbks.append(cb)

    # Make all Models
    assert model_config["class_name"] == "NACModel2", "Training script only for NACModel2"
    out_model = NACModel2(**model_config["config"])
    out_model.precomputed_features = True

    npeps = np.finfo(float).eps
    if not initialize_weights:
        out_model.load_weights(os.path.join(out_dir, "model_weights.h5"))
        print("Info: Load old weights at:", os.path.join(out_dir, "model_weights.h5"))
        print("Info: Transferring weights...")
    else:
        print("Info: Making new initialized weights..")

    scaler = NACStandardScaler(**scaler_config["config"])
    scaler.fit(x[i_train], y_in[i_train])
    x_rescale, y = scaler.transform(x=x, y=y_in)

    # Calculate features
    feat_x, feat_grad = out_model.precompute_feature_in_chunks(x_rescale, batch_size=batch_size)

    xtrain = [feat_x[i_train], feat_grad[i_train]]
    ytrain = y[i_train]
    xval = [feat_x[i_val], feat_grad[i_val]]
    yval = y[i_val]

    # Set Scaling
    scaled_metric = ScaledMeanAbsoluteError(scaling_shape=scaler.nac_std.shape)
    scaled_metric.set_scale(scaler.nac_std)
    scaler.print_params_info()
    print("")

    # Compile model
    optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
    lr_metric = get_lr_metric(optimizer)
    out_model.compile(loss='mean_squared_error',
                      optimizer=optimizer,
                      metrics=[scaled_metric, lr_metric, r2_metric])

    # Pre -fit
    print("")
    print("Start fit.")
    if pre_epo > 0:
        print("Start Pre-fit without phaseless-loss.")
        print("Used loss:", out_model.loss)
        out_model.summary()
        out_model.fit(x=xtrain, y=ytrain, epochs=pre_epo, batch_size=batch_size, validation_freq=epostep,
                      validation_data=(xval, yval), verbose=2)
        print("End fit.")
        print("")

    print("Start fit.")
    if phase_less_loss:
        print("Recompiling with phaseless loss.")
        out_model.compile(
            loss=NACphaselessLoss(number_state=num_outstates, shape_nac=(num_atoms, 3), name='phaseless_loss'),
            optimizer=optimizer,
            metrics=[scaled_metric, lr_metric, r2_metric])
        print("Used loss:", out_model.loss)

    out_model.summary()
    hist = out_model.fit(x=xtrain, y=ytrain, epochs=epo, batch_size=batch_size, callbacks=cbks, validation_freq=epostep,
                         validation_data=(xval, yval), verbose=2)
    print("End fit.")
    print("")

    print("Info: Saving history...")
    outname = os.path.join(dir_save, "history.json")
    outhist = {a: np.array(b, dtype=np.float64).tolist() for a, b in hist.history.items()}
    with open(outname, 'w') as f:
        json.dump(outhist, f)

    print("Info: Saving auto-scaler to file...")
    scaler.save_weights(os.path.join(out_dir, "scaler_weights.npy"))

    # Plot stats
    yval_plot = y_in[i_val]
    ytrain_plot = y_in[i_train]
    # Revert standard but keep unit conversion
    pval = out_model.predict(xval)
    ptrain = out_model.predict(xtrain)
    _, pval = scaler.inverse_transform(y=pval)
    _, ptrain = scaler.inverse_transform(y=ptrain)

    print("Info: Predicted NAC shape:", ptrain.shape)
    print("Info: Plot fit stats...")

    plot_loss_curves(hist.history['mean_absolute_error'],
                     hist.history['val_mean_absolute_error'],
                     label_curves="NAC",
                     val_step=epostep, save_plot_to_file=True, dir_save=dir_save,
                     filename='fit' + str(i) + "_nac", filetypeout='.png', unit_loss=unit_label_nac,
                     loss_name="MAE",
                     plot_title="NAC")

    plot_learning_curve(hist.history['lr'], filename='fit' + str(i), dir_save=dir_save)

    plot_scatter_prediction(pval, yval_plot, save_plot_to_file=True, dir_save=dir_save,
                            filename='fit' + str(i) + "_nac",
                            filetypeout='.png', unit_actual=unit_label_nac, unit_predicted=unit_label_nac,
                            plot_title="Prediction NAC")

    plot_error_vec_mean([pval, ptrain], [yval_plot, ytrain_plot],
                        label_curves=["Validation NAC", "Training NAC"], unit_predicted=unit_label_nac,
                        filename='fit' + str(i) + "_nac", dir_save=dir_save, save_plot_to_file=True,
                        filetypeout='.png', x_label='NACs xyz * #atoms * #states ',
                        plot_title="NAC mean error")

    plot_error_vec_max([pval, ptrain], [yval_plot, ytrain_plot],
                       label_curves=["Validation", "Training"],
                       unit_predicted=unit_label_nac, filename='fit' + str(i) + "_nc",
                       dir_save=dir_save, save_plot_to_file=True, filetypeout='.png',
                       x_label='NACs xyz * #atoms * #states ', plot_title="NAC max error")
    # error out
    error_val = None

    print("Info: saving fitting error...")
    # Safe fitting Error MAE
    pval = out_model.predict(xval)
    ptrain = out_model.predict(xtrain)
    _, pval = scaler.inverse_transform(y=pval)
    _, ptrain = scaler.inverse_transform(y=ptrain)
    out_model.precomputed_features = False
    ptrain2 = out_model.predict(x_rescale[i_train])
    ptrain2 = ptrain2 * scaler.nac_std + scaler.nac_mean
    print("Info: MAE between precomputed and full keras model:")
    print("NAC", np.mean(np.abs(ptrain - ptrain2)))
    error_val = np.mean(np.abs(pval - y_in[i_val]))
    error_train = np.mean(np.abs(ptrain - y_in[i_train]))
    print("error_val:", error_val)
    print("error_train:", error_train)
    error_dict = {"train": error_train.tolist(), "valid": error_val.tolist()}
    save_json_file(error_dict, os.path.join(out_dir, "fit_error.json"))

    # Save Weights
    print("Info: Saving weights...")
    out_model.precomputed_features = False
    out_model.save_weights(os.path.join(out_dir, "model_weights.h5"))
    out_model.save(os.path.join(out_dir, "model_tf"))

    return error_val