def main(args):
    """Main function of RL-LIM for synthetic data experiments.

  Args:
    args: data_name, train_no, probe_no, test_no,
          seed, hyperparameters, network parameters
  """

    # Problem specification
    problem = args.problem

    # The ratio between training and probe datasets
    train_rate = args.train_rate
    probe_rate = args.probe_rate

    dict_rate = {'train': train_rate, 'probe': probe_rate}

    # Random seed
    seed = args.seed

    # Network parameters
    parameters = dict()
    parameters['hidden_dim'] = args.hidden_dim
    parameters['iterations'] = args.iterations
    parameters['num_layers'] = args.num_layers
    parameters['batch_size'] = args.batch_size
    parameters['batch_size_inner'] = args.batch_size_inner
    parameters['lambda'] = args.hyper_lambda

    # Checkpoint file name
    checkpoint_file_name = args.checkpoint_file_name

    # Number of sample explanations
    n_exp = args.n_exp

    # Loads data
    data_loading.load_facebook_data(dict_rate, seed)

    print('Finished data loading.')

    # Preprocesses data
    # Normalization methods: either 'minmax' or 'standard'
    normalization = args.normalization

    # Extracts features and labels & normalizes features
    x_train, y_train, x_probe, _, x_test, y_test, col_names = \
    data_loading.preprocess_data(normalization,
                                 'train.csv', 'probe.csv', 'test.csv')

    print('Finished data preprocess.')

    # Trains black-box model
    # Initializes black-box model
    if problem == 'regression':
        bb_model = lightgbm.LGBMRegressor()
    elif problem == 'classification':
        bb_model = lightgbm.LGBMClassifier()

    # Trains black-box model
    bb_model = bb_model.fit(x_train, y_train)

    print('Finished black-box model training.')

    # Constructs auxiliary datasets
    if problem == 'regression':
        y_train_hat = bb_model.predict(x_train)
        y_probe_hat = bb_model.predict(x_probe)
    elif problem == 'classification':
        y_train_hat = bb_model.predict_proba(x_train)[:, 1]
        y_probe_hat = bb_model.predict_proba(x_probe)[:, 1]

    print('Finished auxiliary dataset construction.')

    # Trains interpretable baseline
    # Defines baseline
    baseline = linear_model.Ridge(alpha=1)

    # Trains baseline model
    baseline.fit(x_train, y_train_hat)

    print('Finished interpretable baseline training.')

    # Trains instance-wise weight estimator
    # Defines locally interpretable model
    interp_model = linear_model.Ridge(alpha=1)

    # Initializes RL-LIM
    rllim_class = rllim.Rllim(x_train, y_train_hat, x_probe, y_probe_hat,
                              parameters, interp_model, baseline,
                              checkpoint_file_name)

    # Trains RL-LIM
    rllim_class.rllim_train()

    print('Finished instance-wise weight estimator training.')

    # Interpretable inference
    # Trains locally interpretable models and output
    # instance-wise explanations (test_coef) and
    # interpretable predictions (test_y_fit)
    test_y_fit, test_coef = rllim_class.rllim_interpreter(
        x_train, y_train_hat, x_test, interp_model)

    print('Finished instance-wise predictions and local explanations.')

    # Overall performance
    mae = rllim_metrics.overall_performance_metrics(y_test,
                                                    test_y_fit,
                                                    metric='mae')
    print('Overall performance of RL-LIM in terms of MAE: ' +
          str(np.round(mae, 4)))

    # Black-box model predictions
    y_test_hat = bb_model.predict(x_test)

    # Fidelity in terms of MAE
    mae = rllim_metrics.fidelity_metrics(y_test_hat, test_y_fit, metric='mae')
    print('Fidelity of RL-LIM in terms of MAE: ' + str(np.round(mae, 4)))

    # Fidelity in terms of R2 Score
    r2 = rllim_metrics.fidelity_metrics(y_test_hat, test_y_fit, metric='r2')
    print('Fidelity of RL-LIM in terms of R2 Score: ' + str(np.round(r2, 4)))

    # Instance-wise explanations
    # Local explanations of n_exp samples
    local_explanations = test_coef[:n_exp, :]

    final_col_names = np.concatenate((np.asarray(['intercept']), col_names),
                                     axis=0)
    pd.DataFrame(data=local_explanations,
                 index=range(n_exp),
                 columns=final_col_names)
def main(args):
    """Main function of RL-LIM for synthetic data experiments.

  Args:
    args: data_name, train_no, probe_no, test_no,
          seed, hyperparameters, network parameters
  """

    # Inputs
    data_name = args.data_name

    # The number of training, probe and testing samples
    train_no = args.train_no
    probe_no = args.probe_no
    test_no = args.test_no
    dim_no = args.dim_no

    dict_no = {
        'train': train_no,
        'probe': probe_no,
        'test': test_no,
        'dim': dim_no
    }

    # Random seed
    seed = args.seed

    # Network parameters
    parameters = dict()
    parameters['hidden_dim'] = args.hidden_dim
    parameters['iterations'] = args.iterations
    parameters['num_layers'] = args.num_layers
    parameters['batch_size'] = args.batch_size
    parameters['batch_size_inner'] = args.batch_size_inner
    parameters['lambda'] = args.hyper_lambda

    # Checkpoint file name
    checkpoint_file_name = args.checkpoint_file_name

    # Loads data
    x_train, y_train, x_probe, y_probe, x_test, y_test, c_test = \
        data_loading.load_synthetic_data(data_name, dict_no, seed)

    print('Finish ' + str(data_name) + ' data loading')

    # Trains interpretable baseline
    # Defins baseline
    baseline = linear_model.Ridge(alpha=1)

    # Trains interpretable baseline model
    baseline.fit(x_train, y_train)

    print('Finished interpretable baseline training.')

    # Trains instance-wise weight estimator
    # Defines locally interpretable model
    interp_model = linear_model.Ridge(alpha=1)

    # Initializes RL-LIM
    rllim_class = rllim.Rllim(x_train, y_train, x_probe, y_probe, parameters,
                              interp_model, baseline, checkpoint_file_name)

    # Trains RL-LIM
    rllim_class.rllim_train()

    print('Finished instance-wise weight estimator training.')

    # Interpretable inference
    # Trains locally interpretable models and output
    # instance-wise explanations (test_coef)
    # and interpretable predictions (test_y_fit)
    test_y_fit, test_coef = \
        rllim_class.rllim_interpreter(x_train, y_train, x_test, interp_model)

    print('Finished interpretable predictions and local explanations.')

    # Fidelity
    mae = rllim_metrics.fidelity_metrics(y_test, test_y_fit, metric='mae')
    print('fidelity of RL-LIM in terms of MAE: ' + str(np.round(mae, 4)))

    # Absolute Weight Differences (AWD) between ground truth local dynamics and
    # estimated local dynamics by RL-LIM
    awd = rllim_metrics.awd_metric(c_test, test_coef)
    print('AWD of RL-LIM: ' + str(np.round(awd, 4)))

    # Fidelity plot
    rllim_metrics.plot_result(x_test,
                              data_name,
                              y_test,
                              test_y_fit,
                              c_test,
                              test_coef,
                              metric='mae',
                              criteria='Fidelity')

    # AWD plot
    rllim_metrics.plot_result(x_test,
                              data_name,
                              y_test,
                              test_y_fit,
                              c_test,
                              test_coef,
                              metric='mae',
                              criteria='AWD')