Esempio n. 1
0
    def get_train_test_data(self, context_vars: Tuple[str], target_var: str, n_train=10 ** 3, n_test=10 ** 2, seed=None,
                            as_dataframes=False, **kwargs) -> \
            Union[Tuple[np.array, np.array, np.array, np.array], Tuple[pd.DataFrame, pd.DataFrame]]:

        assert all([inp in self.var_names for inp in context_vars])
        assert target_var in self.var_names

        train_seed = seed
        test_seed = 2 * seed if seed is not None else None
        logger.info(f'Sampling {n_train} train observations and {n_test} test observations '
                    f'with seeds: {train_seed} and {test_seed}')
        train = self.sem.sample(n_train, seed=train_seed).numpy()
        test = self.sem.sample(n_test, seed=test_seed).numpy()

        inputs_ind = search_nonsorted(self.var_names, context_vars)
        target_ind = search_nonsorted(self.var_names, target_var)
        if not as_dataframes:
            return train[:, inputs_ind], train[:, target_ind], test[:, inputs_ind], test[:, target_ind]
        else:
            train_df = pd.DataFrame(train[:, inputs_ind], columns=context_vars)
            train_df[target_var] = train[:, target_ind]

            test_df = pd.DataFrame(test[:, inputs_ind], columns=context_vars)
            test_df[target_var] = test[:, target_ind]
            return train_df, test_df
Esempio n. 2
0
 def get_spouses(self, node: str) -> set:
     node_ind = search_nonsorted(self.var_names, [node])[0]
     children = tuple(self.DAG.successors(node_ind))
     spouses = tuple([
         par for child in children
         for par in tuple(self.DAG.predecessors(child)) if par != node_ind
     ])
     return set([self.var_names[node] for node in spouses])
Esempio n. 3
0
 def conditional_distribution(
         self,
         node: str,
         context: Dict[str, Tensor] = None) -> Distribution:
     node_ind = search_nonsorted(self.dag.var_names, [node])
     if context is None or len(context) == 0:  # Unconditional distribution
         return Normal(
             self.joint_mean[node_ind].item(),
             torch.sqrt(self.joint_cov[node_ind, node_ind]).item())
     else:  # Conditional distribution
         context_ind = search_nonsorted(self.dag.var_names,
                                        list(context.keys()))
         cond_dist = GaussianConditionalEstimator()
         cond_dist.fit_mean_cov(self.joint_mean.numpy(),
                                self.joint_cov.numpy(),
                                inp_ind=node_ind,
                                cont_ind=context_ind)
         context_sorted = [context[par_node] for par_node in context.keys()]
         context_sorted = np.stack(
             context_sorted).T if len(context_sorted) > 0 else None
         return cond_dist.conditional_distribution(context_sorted)
Esempio n. 4
0
 def get_children(self, node: str) -> set:
     node_ind = search_nonsorted(self.var_names, [node])[0]
     children = tuple(self.DAG.successors(node_ind))
     return set([self.var_names[node] for node in children])
Esempio n. 5
0
 def get_parents(self, node: str) -> set:
     node_ind = search_nonsorted(self.var_names, [node])[0]
     parents = tuple(self.DAG.predecessors(node_ind))
     return set([self.var_names[node] for node in parents])
Esempio n. 6
0
def main(args: DictConfig):

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)

    # Adding default estimator params
    default_names, _, _, default_values, _, _, _ = \
        inspect.getfullargspec(instantiate(args.estimator, context_size=0).__class__.__init__)
    if default_values is not None:
        args.estimator['defaults'] = {
            n: str(v)
            for (n, v) in zip(
                default_names[len(default_names) -
                              len(default_values):], default_values)
        }
    logger.info(OmegaConf.to_yaml(args, resolve=True))

    # Data-generating DAG
    data_path = hydra.utils.to_absolute_path(
        f'{ROOT_PATH}/{args.data.relative_path}')
    exp_name = args.data.relative_path.split('/')[-1]
    adjacency_matrix = np.load(
        f'{data_path}/DAG{args.data.sample_ind}.npy').astype(int)
    if exp_name == 'sachs_2005':
        var_names = np.load(f'{data_path}/sachs-header.npy')
    else:
        var_names = [f'x{i}' for i in range(len(adjacency_matrix))]
    dag = DirectedAcyclicGraph(adjacency_matrix, var_names)

    # Experiment tracking
    mlflow.set_tracking_uri(args.exp.mlflow_uri)
    mlflow.set_experiment(exp_name)

    # Checking if run exist
    if check_existing_hash(args, exp_name):
        logger.info('Skipping existing run.')
        return
    else:
        logger.info('No runs found - perfoming one.')

    # Loading Train-test data
    data = np.load(f'{data_path}/data{args.data.sample_ind}.npy')
    if args.data.standard_normalize:
        standard_normalizer = StandardScaler()
        data = standard_normalizer.fit_transform(data)
    data_train, data_test = train_test_split(data,
                                             test_size=args.data.test_ratio,
                                             random_state=args.data.split_seed)
    train_df = pd.DataFrame(data_train, columns=dag.var_names)
    test_df = pd.DataFrame(data_test, columns=dag.var_names)

    mlflow.start_run()
    mlflow.log_params(flatten_dict(args))
    mlflow.log_param('data_generator/dag/n', len(var_names))
    mlflow.log_param('data_generator/dag/m', int(adjacency_matrix.sum()))
    mlflow.log_param('data/n_train', len(train_df))
    mlflow.log_param('data/n_test', len(test_df))

    # Saving artifacts
    train_df.to_csv(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/train.csv'),
        index=False)
    test_df.to_csv(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/test.csv'),
        index=False)
    dag.plot_dag()
    plt.savefig(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/dag.png'))
    if len(dag.var_names) <= 20:
        df = pd.concat([train_df, test_df],
                       keys=['train',
                             'test']).reset_index().drop(columns=['level_1'])
        g = sns.pairplot(df, plot_kws={'alpha': 0.25}, hue='level_0')
        g.fig.suptitle(exp_name)
        plt.savefig(
            hydra.utils.to_absolute_path(
                f'{mlflow.get_artifact_uri()}/data.png'))

    metrics = {}

    for var_ind, target_var in enumerate(dag.var_names):

        var_results = {}

        # Considering all the variables for input
        input_vars = [var for var in dag.var_names if var != target_var]
        y_train, X_train = train_df.loc[:,
                                        target_var].values, train_df.loc[:,
                                                                         input_vars].values
        y_test, X_test = test_df.loc[:,
                                     target_var].values, test_df.loc[:,
                                                                     input_vars].values

        # Initialising risks
        risks = {}
        for risk in args.predictors.risks:
            risks[risk] = getattr(importlib.import_module('sklearn.metrics'),
                                  risk)

        # Fitting predictive model
        models = {}
        for pred_model in args.predictors.pred_models:
            logger.info(
                f'Fitting {pred_model._target_} for target = {target_var} and inputs {input_vars}'
            )
            model = instantiate(pred_model)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            models[pred_model._target_] = model
            for risk, risk_func in risks.items():
                var_results[f'test_{risk}_{pred_model._target_}'] = risk_func(
                    y_test, y_pred)

        sampler = instantiate(args.estimator.sampler,
                              X_train=X_train,
                              fit_method=args.estimator.fit_method,
                              fit_params=args.estimator.fit_params)

        # =================== Relative feature importance ===================
        # 1. G = MB(target_var), FoI = input_vars / MB(target_var)
        G_vars_1 = list(dag.get_markov_blanket(target_var))
        fsoi_vars_1 = [
            var for var in input_vars
            if var not in list(dag.get_markov_blanket(target_var))
        ]
        prefix_1 = 'mb'

        # 2. G = input_vars / MB(target_var), FoI = MB(target_var)
        fsoi_vars_2 = list(dag.get_markov_blanket(target_var))
        G_vars_2 = [
            var for var in input_vars
            if var not in list(dag.get_markov_blanket(target_var))
        ]
        prefix_2 = 'non_mb'

        for (G_vars, fsoi_vars, prefix) in zip([G_vars_1, G_vars_2],
                                               [fsoi_vars_1, fsoi_vars_2],
                                               [prefix_1, prefix_2]):
            G = search_nonsorted(input_vars, G_vars)
            fsoi = search_nonsorted(input_vars, fsoi_vars)

            rfi_gof_metrics = {}
            for f, f_var in zip(fsoi, fsoi_vars):
                estimator = sampler.train([f], G)

                # GoF diagnostics
                rfi_gof_results = {}
                if estimator is not None:

                    rfi_gof_results[f'rfi/gof/{prefix}_mean_log_lik'] = \
                        estimator.log_prob(inputs=X_test[:, f], context=X_test[:, G]).mean()

                rfi_gof_metrics = {
                    k: rfi_gof_metrics.get(k, []) +
                    [rfi_gof_results.get(k, np.nan)]
                    for k in set(
                        list(rfi_gof_metrics.keys()) +
                        list(rfi_gof_results.keys()))
                }

            # Feature importance
            if len(fsoi) > 0:
                var_results[f'rfi/{prefix}_cond_size'] = len(G_vars)

                for model_name, model in models.items():
                    for risk, risk_func in risks.items():

                        rfi_explainer = explainer.Explainer(
                            model.predict,
                            fsoi,
                            X_train,
                            sampler=sampler,
                            loss=risk_func,
                            fs_names=input_vars)
                        mb_explanation = rfi_explainer.rfi(
                            X_test, y_test, G, nr_runs=args.exp.rfi.nr_runs)
                        var_results[f'rfi/{prefix}_mean_rfi_{risk}_{model_name}'] = \
                            np.abs(mb_explanation.fi_vals(return_np=True)).mean()

                var_results = {
                    **var_results,
                    **{
                        k: np.nanmean(v) if len(G_vars) > 0 else np.nan
                        for (k, v) in rfi_gof_metrics.items()
                    }
                }

        # TODO  =================== Global SAGE ===================

        mlflow.log_metrics(var_results, step=var_ind)

        metrics = {
            k: metrics.get(k, []) + [var_results.get(k, np.nan)]
            for k in set(list(metrics.keys()) + list(var_results.keys()))
        }

    # Logging mean statistics
    mlflow.log_metrics({k: np.nanmean(v)
                        for (k, v) in metrics.items()},
                       step=len(dag.var_names))
    mlflow.end_run()
Esempio n. 7
0
def main(args: DictConfig):

    exp_name = 'census'

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)

    # Dataset loading
    data_df, y = shap.datasets.adult()
    data_df['Salary'] = y

    # Binning for Capital Gain
    if args.exp.discretize:
        discretized_vars = set(list(args.exp.discretize))
        discretizer = KBinsDiscretizer(n_bins=50,
                                       encode='ordinal',
                                       strategy='uniform')
        data_df.loc[:, discretized_vars] = discretizer.fit_transform(
            data_df.loc[:, discretized_vars].values).astype(int)

    target_var = {'Salary'}
    all_inputs_vars = set(data_df.columns) - target_var
    sensetive_vars = set(list(args.exp.sensetive_vars))
    if args.exp.exclude_sensetive:
        wo_sens_inputs_vars = all_inputs_vars - sensetive_vars
    else:
        wo_sens_inputs_vars = all_inputs_vars
    cat_vars = set(data_df.select_dtypes(exclude=[np.floating]).columns.values)
    cont_vars = all_inputs_vars - cat_vars
    logger.info(
        f'Target var: {target_var}, all_inputs: {all_inputs_vars}, sensetive_vars: {sensetive_vars}, '
        f'cat_vars: {cat_vars}')

    if args.data.standard_normalize:
        standard_normalizer = StandardScaler()
        data_df.loc[:, cont_vars] = standard_normalizer.fit_transform(
            data_df[cont_vars].values)

    train_df, test_df = train_test_split(data_df,
                                         test_size=args.data.test_ratio,
                                         random_state=args.data.split_seed)
    y_train, X_train, X_train_wo_sens = train_df[target_var], train_df[
        all_inputs_vars], train_df[wo_sens_inputs_vars]
    y_test, X_test, X_test_wo_sens = test_df[target_var], test_df[
        all_inputs_vars], test_df[wo_sens_inputs_vars]

    # Adding default estimator params
    default_names, _, _, default_values, _, _, _ = \
        inspect.getfullargspec(instantiate(args.estimator, context_size=0).__class__.__init__)
    if default_values is not None:
        args.estimator['defaults'] = {
            n: str(v)
            for (n, v) in zip(
                default_names[len(default_names) -
                              len(default_values):], default_values)
        }
    logger.info(OmegaConf.to_yaml(args, resolve=True))

    # Experiment tracking
    mlflow.set_tracking_uri(args.exp.mlflow_uri)
    mlflow.set_experiment(exp_name)

    # Checking if run exist
    if check_existing_hash(args, exp_name):
        logger.info('Skipping existing run.')
        return
    else:
        logger.info('No runs found - perfoming one.')

    mlflow.start_run()
    mlflow.log_params(flatten_dict(args))
    mlflow.log_param('data/n_train', len(train_df))
    mlflow.log_param('data/n_test', len(test_df))

    # Saving artifacts
    train_df.to_csv(hydra.utils.to_absolute_path(
        f'{mlflow.get_artifact_uri()}/train_df.csv'),
                    index=False)
    test_df.to_csv(hydra.utils.to_absolute_path(
        f'{mlflow.get_artifact_uri()}/test_df.csv'),
                   index=False)
    # df = pd.concat([train_df, test_df], keys=['train', 'test']).reset_index().drop(columns=['level_1'])
    # g = sns.pairplot(df, plot_kws={'alpha': 0.25}, hue='level_0')
    # g.fig.suptitle(exp_name)
    # plt.savefig(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/data.png'))

    pred_results = {}

    # Initialising risks
    risks = {}
    pred_funcs = {}
    for risk in args.predictors.risks:
        risks[risk['name']] = getattr(
            importlib.import_module('sklearn.metrics'), risk['name'])
        pred_funcs[risk['name']] = risk['method']

    # Fitting predictive model
    models = {}
    models_pred_funcs = {}
    for pred_model in args.predictors.pred_models:
        cat_vars_wo_sens = cat_vars - sensetive_vars
        logger.info(
            f'Fitting {pred_model._target_} for target = {target_var} and inputs {wo_sens_inputs_vars} '
            f'(categorical {cat_vars_wo_sens})')
        model = instantiate(pred_model,
                            categorical_feature=search_nonsorted(
                                list(wo_sens_inputs_vars),
                                list(cat_vars_wo_sens)))
        model.fit(X_train_wo_sens, y_train)
        models[pred_model._target_] = model
        for risk, risk_func in risks.items():
            if pred_funcs[risk] == 'predict_proba':
                models_pred_funcs[risk] = lambda X_test: getattr(
                    model, pred_funcs[risk])(X_test)[:, 1]
            else:
                models_pred_funcs[risk] = lambda X_test: getattr(
                    model, pred_funcs[risk])(X_test)
            y_pred = models_pred_funcs[risk](X_test_wo_sens)
            pred_results[f'test_{risk}_{pred_model._target_}'] = risk_func(
                y_test, y_pred)

    mlflow.log_metrics(pred_results, step=0)

    sampler = instantiate(args.estimator.sampler,
                          X_train=X_train,
                          fit_method=args.estimator.fit_method,
                          fit_params=args.estimator.fit_params,
                          cat_inputs=list(cat_vars))

    wo_sens_inputs_vars_list = sorted(list(wo_sens_inputs_vars))
    mlflow.log_param('features_sequence', str(wo_sens_inputs_vars_list))
    mlflow.log_param('exp/cat_vars', str(sorted(list(cat_vars))))

    for i, fsoi_var in enumerate(wo_sens_inputs_vars_list):
        logger.info(f'Analysing the importance of feature: {fsoi_var}')

        interpret_results = {}
        # fsoi = search_nonsorted(list(all_inputs_vars), [fsoi_var])
        R_j = list(wo_sens_inputs_vars)

        # Permutation feature importance
        G_pfi, name_pfi = [], 'pfi'
        sampler.train([fsoi_var], G_pfi)

        # Conditional feature importance
        G_cfi, name_cfi = list(wo_sens_inputs_vars - {fsoi_var}), 'cfi'
        estimator = sampler.train([fsoi_var], G_cfi)
        test_inputs = X_test[sampler._order_fset([fsoi_var])].to_numpy()
        test_context = X_test[sampler._order_fset(G_cfi)].to_numpy()
        interpret_results['cfi/gof/mean_log_lik'] = estimator.log_prob(
            inputs=test_inputs, context=test_context).mean()

        # Relative feature importance (sensetive ignored vars)
        G_rfi, name_rfi = list(sensetive_vars), 'rfi'
        estimator = sampler.train([fsoi_var], G_rfi)
        test_inputs = X_test[sampler._order_fset([fsoi_var])].to_numpy()
        test_context = X_test[sampler._order_fset(G_rfi)].to_numpy()
        if estimator is not None:
            interpret_results['rfi/gof/mean_log_lik'] = estimator.log_prob(
                inputs=test_inputs, context=test_context).mean()
        else:
            interpret_results['rfi/gof/mean_log_lik'] = np.nan

        for model_name, model in models.items():
            for risk, risk_func in risks.items():
                rfi_explainer = explainer.Explainer(
                    models_pred_funcs[risk], [fsoi_var],
                    X_train,
                    sampler=sampler,
                    loss=risk_func,
                    fs_names=list(all_inputs_vars))
                for G, name in zip([G_pfi, G_cfi, G_rfi],
                                   [name_pfi, name_cfi, name_rfi]):
                    mb_explanation = rfi_explainer.rfi(
                        X_test, y_test, G, R_j, nr_runs=args.exp.rfi.nr_runs)
                    interpret_results[
                        f'{name}/mean_{risk}_{model_name}'] = np.abs(
                            mb_explanation.fi_vals().values).mean()

        mlflow.log_metrics(interpret_results, step=i)

    mlflow.end_run()
Esempio n. 8
0
def main(args: DictConfig):

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)

    # Adding default estimator params
    default_names, _, _, default_values, _, _, _ = \
        inspect.getfullargspec(instantiate(args.estimator, context_size=0).__class__.__init__)
    if default_values is not None:
        args.estimator['defaults'] = {
            n: str(v) for (n, v) in zip(default_names[len(default_names) - len(default_values):], default_values)
        }
        args.estimator['defaults'].pop('cat_context')
    logger.info(OmegaConf.to_yaml(args, resolve=True))

    # Data generator init
    dag = DirectedAcyclicGraph.random_dag(**args.data_generator.dag)
    # if 'interpolation_switch' in args.data_generator.sem:
    #     args.data_generator.sem.interpolation_switch = args.data.n_train + args.data.n_test
    sem = instantiate(args.data_generator.sem, dag=dag)

    # Experiment tracking
    mlflow.set_tracking_uri(args.exp.mlflow_uri)
    mlflow.set_experiment(args.data_generator.sem_type)

    # Checking if run exist
    if check_existing_hash(args, args.data_generator.sem_type):
        logger.info('Skipping existing run.')
        return
    else:
        logger.info('No runs found - perfoming one.')

    mlflow.start_run()
    mlflow.log_params(flatten_dict(args))

    # Generating Train-test dataframes
    train_df = pd.DataFrame(sem.sample(size=args.data.n_train, seed=args.data.train_seed).numpy(), columns=dag.var_names)
    test_df = pd.DataFrame(sem.sample(size=args.data.n_test, seed=args.data.test_seed).numpy(), columns=dag.var_names)

    # Saving artifacts
    train_df.to_csv(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/train.csv'), index=False)
    test_df.to_csv(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/test.csv'), index=False)
    sem.dag.plot_dag()
    plt.savefig(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/dag.png'))
    if len(dag.var_names) <= 20:
        df = pd.concat([train_df, test_df], keys=['train', 'test']).reset_index().drop(columns=['level_1'])
        g = sns.pairplot(df, plot_kws={'alpha': 0.25}, hue='level_0')
        g.fig.suptitle(sem.__class__.__name__)
        plt.savefig(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/data.png'))

    metrics = {}

    for var_ind, target_var in enumerate(dag.var_names):

        var_results = {}

        # Considering all the variables for input
        input_vars = [var for var in dag.var_names if var != target_var]
        y_train, X_train = train_df.loc[:, target_var].values, train_df.loc[:, input_vars].values
        y_test, X_test = test_df.loc[:, target_var].values, test_df.loc[:, input_vars].values

        # Initialising risks
        risks = {}
        for risk in args.predictors.risks:
            risks[risk] = getattr(importlib.import_module('sklearn.metrics'), risk)

        # Fitting predictive model
        models = {}
        for pred_model in args.predictors.pred_models:
            logger.info(f'Fitting {pred_model._target_} for target = {target_var} and inputs {input_vars}')
            model = instantiate(pred_model)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            models[pred_model._target_] = model
            for risk, risk_func in risks.items():
                var_results[f'test_{risk}_{pred_model._target_}'] = risk_func(y_test, y_pred)

        sampler = instantiate(args.estimator.sampler, X_train=X_train,
                              fit_method=args.estimator.fit_method, fit_params=args.estimator.fit_params)

        # =================== Relative feature importance ===================
        # 1. G = MB(target_var), FoI = input_vars / MB(target_var)
        G_vars_1 = list(sem.get_markov_blanket(target_var))
        fsoi_vars_1 = [var for var in input_vars if var not in list(sem.get_markov_blanket(target_var))]
        prefix_1 = 'mb'

        # 2. G = input_vars / MB(target_var), FoI = MB(target_var)
        fsoi_vars_2 = list(sem.get_markov_blanket(target_var))
        G_vars_2 = [var for var in input_vars if var not in list(sem.get_markov_blanket(target_var))]
        prefix_2 = 'non_mb'

        for (G_vars, fsoi_vars, prefix) in zip([G_vars_1, G_vars_2], [fsoi_vars_1, fsoi_vars_2], [prefix_1, prefix_2]):
            G = search_nonsorted(input_vars, G_vars)
            fsoi = search_nonsorted(input_vars, fsoi_vars)

            rfi_gof_metrics = {}
            for f, f_var in zip(fsoi, fsoi_vars):
                estimator = sampler.train([f], G)

                # GoF diagnostics
                rfi_gof_results = {}
                if estimator is not None:

                    rfi_gof_results[f'rfi/gof/{prefix}_mean_log_lik'] = \
                        estimator.log_prob(inputs=X_test[:, f], context=X_test[:, G]).mean()

                    # Advanced conditional GoF metrics
                    if sem.get_markov_blanket(f_var).issubset(set(G_vars)):
                        cond_mode = 'all'
                    if isinstance(sem, LinearGaussianNoiseSEM):
                        cond_mode = 'arbitrary'

                    if sem.get_markov_blanket(f_var).issubset(set(G_vars)) or isinstance(sem, LinearGaussianNoiseSEM):
                        rfi_gof_results[f'rfi/gof/{prefix}_kld'] = \
                            conditional_kl_divergence(estimator, sem, f_var, G_vars, args.exp, cond_mode, test_df)
                        rfi_gof_results[f'rfi/gof/{prefix}_hd'] = \
                            conditional_hellinger_distance(estimator, sem, f_var, G_vars, args.exp, cond_mode, test_df)
                        rfi_gof_results[f'rfi/gof/{prefix}_jsd'] = \
                            conditional_js_divergence(estimator, sem, f_var, G_vars, args.exp, cond_mode, test_df)

                rfi_gof_metrics = {k: rfi_gof_metrics.get(k, []) + [rfi_gof_results.get(k, np.nan)]
                                   for k in set(list(rfi_gof_metrics.keys()) + list(rfi_gof_results.keys()))}

            # Feature importance
            if len(fsoi) > 0:
                var_results[f'rfi/{prefix}_cond_size'] = len(G_vars)

                for model_name, model in models.items():
                    for risk, risk_func in risks.items():

                        rfi_explainer = explainer.Explainer(model.predict, fsoi, X_train, sampler=sampler, loss=risk_func,
                                                            fs_names=input_vars)
                        mb_explanation = rfi_explainer.rfi(X_test, y_test, G, nr_runs=args.exp.rfi.nr_runs)
                        var_results[f'rfi/{prefix}_mean_rfi_{risk}_{model_name}'] = \
                            np.abs(mb_explanation.fi_vals(return_np=True)).mean()

                var_results = {**var_results,
                               **{k: np.nanmean(v) if len(G_vars) > 0 else np.nan for (k, v) in rfi_gof_metrics.items()}}

        # TODO  =================== Global SAGE ===================

        mlflow.log_metrics(var_results, step=var_ind)

        metrics = {k: metrics.get(k, []) + [var_results.get(k, np.nan)]
                   for k in set(list(metrics.keys()) + list(var_results.keys()))}

    # Logging mean statistics
    mlflow.log_metrics({k: np.nanmean(v) for (k, v) in metrics.items()}, step=len(dag.var_names))
    mlflow.end_run()