Пример #1
0
def main(args: DictConfig):

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)

    # Adding default estimator params
    default_names, _, _, default_values, _, _, _ = \
        inspect.getfullargspec(instantiate(args.estimator, context_size=0).__class__.__init__)
    if default_values is not None:
        args.estimator['defaults'] = {
            n: str(v)
            for (n, v) in zip(
                default_names[len(default_names) -
                              len(default_values):], default_values)
        }
    logger.info(OmegaConf.to_yaml(args, resolve=True))

    # Data-generating DAG
    data_path = hydra.utils.to_absolute_path(
        f'{ROOT_PATH}/{args.data.relative_path}')
    exp_name = args.data.relative_path.split('/')[-1]
    adjacency_matrix = np.load(
        f'{data_path}/DAG{args.data.sample_ind}.npy').astype(int)
    if exp_name == 'sachs_2005':
        var_names = np.load(f'{data_path}/sachs-header.npy')
    else:
        var_names = [f'x{i}' for i in range(len(adjacency_matrix))]
    dag = DirectedAcyclicGraph(adjacency_matrix, var_names)

    # Experiment tracking
    mlflow.set_tracking_uri(args.exp.mlflow_uri)
    mlflow.set_experiment(exp_name)

    # Checking if run exist
    if check_existing_hash(args, exp_name):
        logger.info('Skipping existing run.')
        return
    else:
        logger.info('No runs found - perfoming one.')

    # Loading Train-test data
    data = np.load(f'{data_path}/data{args.data.sample_ind}.npy')
    if args.data.standard_normalize:
        standard_normalizer = StandardScaler()
        data = standard_normalizer.fit_transform(data)
    data_train, data_test = train_test_split(data,
                                             test_size=args.data.test_ratio,
                                             random_state=args.data.split_seed)
    train_df = pd.DataFrame(data_train, columns=dag.var_names)
    test_df = pd.DataFrame(data_test, columns=dag.var_names)

    mlflow.start_run()
    mlflow.log_params(flatten_dict(args))
    mlflow.log_param('data_generator/dag/n', len(var_names))
    mlflow.log_param('data_generator/dag/m', int(adjacency_matrix.sum()))
    mlflow.log_param('data/n_train', len(train_df))
    mlflow.log_param('data/n_test', len(test_df))

    # Saving artifacts
    train_df.to_csv(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/train.csv'),
        index=False)
    test_df.to_csv(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/test.csv'),
        index=False)
    dag.plot_dag()
    plt.savefig(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/dag.png'))
    if len(dag.var_names) <= 20:
        df = pd.concat([train_df, test_df],
                       keys=['train',
                             'test']).reset_index().drop(columns=['level_1'])
        g = sns.pairplot(df, plot_kws={'alpha': 0.25}, hue='level_0')
        g.fig.suptitle(exp_name)
        plt.savefig(
            hydra.utils.to_absolute_path(
                f'{mlflow.get_artifact_uri()}/data.png'))

    metrics = {}

    for var_ind, target_var in enumerate(dag.var_names):

        var_results = {}

        # Considering all the variables for input
        input_vars = [var for var in dag.var_names if var != target_var]
        y_train, X_train = train_df.loc[:,
                                        target_var].values, train_df.loc[:,
                                                                         input_vars].values
        y_test, X_test = test_df.loc[:,
                                     target_var].values, test_df.loc[:,
                                                                     input_vars].values

        # Initialising risks
        risks = {}
        for risk in args.predictors.risks:
            risks[risk] = getattr(importlib.import_module('sklearn.metrics'),
                                  risk)

        # Fitting predictive model
        models = {}
        for pred_model in args.predictors.pred_models:
            logger.info(
                f'Fitting {pred_model._target_} for target = {target_var} and inputs {input_vars}'
            )
            model = instantiate(pred_model)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            models[pred_model._target_] = model
            for risk, risk_func in risks.items():
                var_results[f'test_{risk}_{pred_model._target_}'] = risk_func(
                    y_test, y_pred)

        sampler = instantiate(args.estimator.sampler,
                              X_train=X_train,
                              fit_method=args.estimator.fit_method,
                              fit_params=args.estimator.fit_params)

        # =================== Relative feature importance ===================
        # 1. G = MB(target_var), FoI = input_vars / MB(target_var)
        G_vars_1 = list(dag.get_markov_blanket(target_var))
        fsoi_vars_1 = [
            var for var in input_vars
            if var not in list(dag.get_markov_blanket(target_var))
        ]
        prefix_1 = 'mb'

        # 2. G = input_vars / MB(target_var), FoI = MB(target_var)
        fsoi_vars_2 = list(dag.get_markov_blanket(target_var))
        G_vars_2 = [
            var for var in input_vars
            if var not in list(dag.get_markov_blanket(target_var))
        ]
        prefix_2 = 'non_mb'

        for (G_vars, fsoi_vars, prefix) in zip([G_vars_1, G_vars_2],
                                               [fsoi_vars_1, fsoi_vars_2],
                                               [prefix_1, prefix_2]):
            G = search_nonsorted(input_vars, G_vars)
            fsoi = search_nonsorted(input_vars, fsoi_vars)

            rfi_gof_metrics = {}
            for f, f_var in zip(fsoi, fsoi_vars):
                estimator = sampler.train([f], G)

                # GoF diagnostics
                rfi_gof_results = {}
                if estimator is not None:

                    rfi_gof_results[f'rfi/gof/{prefix}_mean_log_lik'] = \
                        estimator.log_prob(inputs=X_test[:, f], context=X_test[:, G]).mean()

                rfi_gof_metrics = {
                    k: rfi_gof_metrics.get(k, []) +
                    [rfi_gof_results.get(k, np.nan)]
                    for k in set(
                        list(rfi_gof_metrics.keys()) +
                        list(rfi_gof_results.keys()))
                }

            # Feature importance
            if len(fsoi) > 0:
                var_results[f'rfi/{prefix}_cond_size'] = len(G_vars)

                for model_name, model in models.items():
                    for risk, risk_func in risks.items():

                        rfi_explainer = explainer.Explainer(
                            model.predict,
                            fsoi,
                            X_train,
                            sampler=sampler,
                            loss=risk_func,
                            fs_names=input_vars)
                        mb_explanation = rfi_explainer.rfi(
                            X_test, y_test, G, nr_runs=args.exp.rfi.nr_runs)
                        var_results[f'rfi/{prefix}_mean_rfi_{risk}_{model_name}'] = \
                            np.abs(mb_explanation.fi_vals(return_np=True)).mean()

                var_results = {
                    **var_results,
                    **{
                        k: np.nanmean(v) if len(G_vars) > 0 else np.nan
                        for (k, v) in rfi_gof_metrics.items()
                    }
                }

        # TODO  =================== Global SAGE ===================

        mlflow.log_metrics(var_results, step=var_ind)

        metrics = {
            k: metrics.get(k, []) + [var_results.get(k, np.nan)]
            for k in set(list(metrics.keys()) + list(var_results.keys()))
        }

    # Logging mean statistics
    mlflow.log_metrics({k: np.nanmean(v)
                        for (k, v) in metrics.items()},
                       step=len(dag.var_names))
    mlflow.end_run()
Пример #2
0
def main(args: DictConfig):

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)
    args.exp.pop('rfi')

    # Adding default estimator params
    default_names, _, _, default_values, _, _, _ = \
        inspect.getfullargspec(instantiate(args.estimator, context_size=0).__class__.__init__)
    if default_values is not None:
        args.estimator['defaults'] = {
            n: str(v)
            for (n, v) in zip(
                default_names[len(default_names) -
                              len(default_values):], default_values)
        }
    logger.info(OmegaConf.to_yaml(args, resolve=True))

    # Data-generating DAG
    data_path = hydra.utils.to_absolute_path(
        f'{ROOT_PATH}/{args.data.relative_path}')
    exp_name = args.data.relative_path.split('/')[-1]
    adjacency_matrix = np.load(
        f'{data_path}/DAG{args.data.sample_ind}.npy').astype(int)
    if exp_name == 'sachs_2005':
        var_names = np.load(f'{data_path}/sachs-header.npy')
    else:
        var_names = [f'x{i}' for i in range(len(adjacency_matrix))]
    dag = DirectedAcyclicGraph(adjacency_matrix, var_names)

    # Experiment tracking
    exp_name = f'sage/{exp_name}'
    mlflow.set_tracking_uri(args.exp.mlflow_uri)
    mlflow.set_experiment(exp_name)

    # Checking if run exist
    if check_existing_hash(args, exp_name):
        logger.info('Skipping existing run.')
        return
    else:
        logger.info('No runs found - perfoming one.')

    # Loading Train-test data
    data = np.load(f'{data_path}/data{args.data.sample_ind}.npy')
    if args.data.standard_normalize:
        if 'normalise_params' in args.data:
            standard_normalizer = StandardScaler(**args.data.normalise_params)
        else:
            standard_normalizer = StandardScaler()
        data = standard_normalizer.fit_transform(data)
    data_train, data_test = train_test_split(data,
                                             test_size=args.data.test_ratio,
                                             random_state=args.data.split_seed)
    train_df = pd.DataFrame(data_train, columns=dag.var_names)
    test_df = pd.DataFrame(data_test, columns=dag.var_names)

    mlflow.start_run()
    mlflow.log_params(flatten_dict(args))
    mlflow.log_param('data_generator/dag/n', len(var_names))
    mlflow.log_param('data_generator/dag/m', int(adjacency_matrix.sum()))
    mlflow.log_param('data/n_train', len(train_df))
    mlflow.log_param('data/n_test', len(test_df))

    # Saving artifacts
    train_df.to_csv(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/train.csv'),
        index=False)
    test_df.to_csv(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/test.csv'),
        index=False)
    dag.plot_dag()
    plt.savefig(
        hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/dag.png'))

    mlflow.log_param('features_sequence', str(list(dag.var_names)))

    for var_ind, target_var in enumerate(dag.var_names):

        var_results = {}

        # Considering all the variables for input
        input_vars = [var for var in dag.var_names if var != target_var]
        y_train, X_train = train_df.loc[:,
                                        target_var], train_df.loc[:,
                                                                  input_vars]
        y_test, X_test = test_df.loc[:, target_var], test_df.loc[:, input_vars]

        # Initialising risks
        risks = {}
        for risk in args.predictors.risks:
            risks[risk] = getattr(importlib.import_module('sklearn.metrics'),
                                  risk)

        # Fitting predictive model
        models = {}
        for pred_model in args.predictors.pred_models:
            logger.info(
                f'Fitting {pred_model._target_} for target = {target_var} and inputs {input_vars}'
            )
            model = instantiate(pred_model)
            model.fit(X_train.values, y_train.values)
            y_pred = model.predict(X_test.values)
            models[pred_model._target_] = model
            for risk, risk_func in risks.items():
                var_results[f'test_{risk}_{pred_model._target_}'] = risk_func(
                    y_test.values, y_pred)

        # =================== Global SAGE ===================
        logger.info(f'Analysing the importance of features: {input_vars}')

        sampler = instantiate(args.estimator.sampler,
                              X_train=X_train,
                              fit_method=args.estimator.fit_method,
                              fit_params=args.estimator.fit_params)

        log_lik = []
        sage_explainer = explainer.Explainer(None,
                                             input_vars,
                                             X_train,
                                             sampler=sampler,
                                             loss=None)
        # Generating the same orderings across all the models and losses
        np.random.seed(args.exp.sage.orderings_seed)
        fixed_orderings = [
            np.random.permutation(input_vars)
            for _ in range(args.exp.sage.nr_orderings)
        ]

        for model_name, model in models.items():
            for risk, risk_func in risks.items():
                sage_explainer.model = model.predict
                explanation, test_log_lik = sage_explainer.sage(
                    X_test,
                    y_test,
                    loss=risk_func,
                    fixed_orderings=fixed_orderings,
                    nr_runs=args.exp.sage.nr_runs,
                    return_test_log_lik=True,
                    nr_resample_marginalize=args.exp.sage.
                    nr_resample_marginalize)
                log_lik.extend(test_log_lik)
                fi = explanation.fi_vals().mean()

                for fsoi, input_var in enumerate(input_vars):
                    var_results[
                        f'sage/mean_{risk}_{model_name}_{input_var}'] = fi[
                            input_var]

        var_results['sage/mean_log_lik'] = np.mean(log_lik)
        var_results['sage/num_fitted_estimators'] = len(log_lik)

        mlflow.log_metrics(var_results, step=var_ind)

    mlflow.end_run()
Пример #3
0
def main(args: DictConfig):

    exp_name = 'census'

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)

    # Dataset loading
    data_df, y = shap.datasets.adult()
    data_df['Salary'] = y

    # Binning for Capital Gain
    if args.exp.discretize:
        discretized_vars = set(list(args.exp.discretize))
        discretizer = KBinsDiscretizer(n_bins=50,
                                       encode='ordinal',
                                       strategy='uniform')
        data_df.loc[:, discretized_vars] = discretizer.fit_transform(
            data_df.loc[:, discretized_vars].values).astype(int)

    target_var = {'Salary'}
    all_inputs_vars = set(data_df.columns) - target_var
    sensetive_vars = set(list(args.exp.sensetive_vars))
    if args.exp.exclude_sensetive:
        wo_sens_inputs_vars = all_inputs_vars - sensetive_vars
    else:
        wo_sens_inputs_vars = all_inputs_vars
    cat_vars = set(data_df.select_dtypes(exclude=[np.floating]).columns.values)
    cont_vars = all_inputs_vars - cat_vars
    logger.info(
        f'Target var: {target_var}, all_inputs: {all_inputs_vars}, sensetive_vars: {sensetive_vars}, '
        f'cat_vars: {cat_vars}')

    if args.data.standard_normalize:
        standard_normalizer = StandardScaler()
        data_df.loc[:, cont_vars] = standard_normalizer.fit_transform(
            data_df[cont_vars].values)

    train_df, test_df = train_test_split(data_df,
                                         test_size=args.data.test_ratio,
                                         random_state=args.data.split_seed)
    y_train, X_train, X_train_wo_sens = train_df[target_var], train_df[
        all_inputs_vars], train_df[wo_sens_inputs_vars]
    y_test, X_test, X_test_wo_sens = test_df[target_var], test_df[
        all_inputs_vars], test_df[wo_sens_inputs_vars]

    # Adding default estimator params
    default_names, _, _, default_values, _, _, _ = \
        inspect.getfullargspec(instantiate(args.estimator, context_size=0).__class__.__init__)
    if default_values is not None:
        args.estimator['defaults'] = {
            n: str(v)
            for (n, v) in zip(
                default_names[len(default_names) -
                              len(default_values):], default_values)
        }
    logger.info(OmegaConf.to_yaml(args, resolve=True))

    # Experiment tracking
    mlflow.set_tracking_uri(args.exp.mlflow_uri)
    mlflow.set_experiment(exp_name)

    # Checking if run exist
    if check_existing_hash(args, exp_name):
        logger.info('Skipping existing run.')
        return
    else:
        logger.info('No runs found - perfoming one.')

    mlflow.start_run()
    mlflow.log_params(flatten_dict(args))
    mlflow.log_param('data/n_train', len(train_df))
    mlflow.log_param('data/n_test', len(test_df))

    # Saving artifacts
    train_df.to_csv(hydra.utils.to_absolute_path(
        f'{mlflow.get_artifact_uri()}/train_df.csv'),
                    index=False)
    test_df.to_csv(hydra.utils.to_absolute_path(
        f'{mlflow.get_artifact_uri()}/test_df.csv'),
                   index=False)
    # df = pd.concat([train_df, test_df], keys=['train', 'test']).reset_index().drop(columns=['level_1'])
    # g = sns.pairplot(df, plot_kws={'alpha': 0.25}, hue='level_0')
    # g.fig.suptitle(exp_name)
    # plt.savefig(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/data.png'))

    pred_results = {}

    # Initialising risks
    risks = {}
    pred_funcs = {}
    for risk in args.predictors.risks:
        risks[risk['name']] = getattr(
            importlib.import_module('sklearn.metrics'), risk['name'])
        pred_funcs[risk['name']] = risk['method']

    # Fitting predictive model
    models = {}
    models_pred_funcs = {}
    for pred_model in args.predictors.pred_models:
        cat_vars_wo_sens = cat_vars - sensetive_vars
        logger.info(
            f'Fitting {pred_model._target_} for target = {target_var} and inputs {wo_sens_inputs_vars} '
            f'(categorical {cat_vars_wo_sens})')
        model = instantiate(pred_model,
                            categorical_feature=search_nonsorted(
                                list(wo_sens_inputs_vars),
                                list(cat_vars_wo_sens)))
        model.fit(X_train_wo_sens, y_train)
        models[pred_model._target_] = model
        for risk, risk_func in risks.items():
            if pred_funcs[risk] == 'predict_proba':
                models_pred_funcs[risk] = lambda X_test: getattr(
                    model, pred_funcs[risk])(X_test)[:, 1]
            else:
                models_pred_funcs[risk] = lambda X_test: getattr(
                    model, pred_funcs[risk])(X_test)
            y_pred = models_pred_funcs[risk](X_test_wo_sens)
            pred_results[f'test_{risk}_{pred_model._target_}'] = risk_func(
                y_test, y_pred)

    mlflow.log_metrics(pred_results, step=0)

    sampler = instantiate(args.estimator.sampler,
                          X_train=X_train,
                          fit_method=args.estimator.fit_method,
                          fit_params=args.estimator.fit_params,
                          cat_inputs=list(cat_vars))

    wo_sens_inputs_vars_list = sorted(list(wo_sens_inputs_vars))
    mlflow.log_param('features_sequence', str(wo_sens_inputs_vars_list))
    mlflow.log_param('exp/cat_vars', str(sorted(list(cat_vars))))

    for i, fsoi_var in enumerate(wo_sens_inputs_vars_list):
        logger.info(f'Analysing the importance of feature: {fsoi_var}')

        interpret_results = {}
        # fsoi = search_nonsorted(list(all_inputs_vars), [fsoi_var])
        R_j = list(wo_sens_inputs_vars)

        # Permutation feature importance
        G_pfi, name_pfi = [], 'pfi'
        sampler.train([fsoi_var], G_pfi)

        # Conditional feature importance
        G_cfi, name_cfi = list(wo_sens_inputs_vars - {fsoi_var}), 'cfi'
        estimator = sampler.train([fsoi_var], G_cfi)
        test_inputs = X_test[sampler._order_fset([fsoi_var])].to_numpy()
        test_context = X_test[sampler._order_fset(G_cfi)].to_numpy()
        interpret_results['cfi/gof/mean_log_lik'] = estimator.log_prob(
            inputs=test_inputs, context=test_context).mean()

        # Relative feature importance (sensetive ignored vars)
        G_rfi, name_rfi = list(sensetive_vars), 'rfi'
        estimator = sampler.train([fsoi_var], G_rfi)
        test_inputs = X_test[sampler._order_fset([fsoi_var])].to_numpy()
        test_context = X_test[sampler._order_fset(G_rfi)].to_numpy()
        if estimator is not None:
            interpret_results['rfi/gof/mean_log_lik'] = estimator.log_prob(
                inputs=test_inputs, context=test_context).mean()
        else:
            interpret_results['rfi/gof/mean_log_lik'] = np.nan

        for model_name, model in models.items():
            for risk, risk_func in risks.items():
                rfi_explainer = explainer.Explainer(
                    models_pred_funcs[risk], [fsoi_var],
                    X_train,
                    sampler=sampler,
                    loss=risk_func,
                    fs_names=list(all_inputs_vars))
                for G, name in zip([G_pfi, G_cfi, G_rfi],
                                   [name_pfi, name_cfi, name_rfi]):
                    mb_explanation = rfi_explainer.rfi(
                        X_test, y_test, G, R_j, nr_runs=args.exp.rfi.nr_runs)
                    interpret_results[
                        f'{name}/mean_{risk}_{model_name}'] = np.abs(
                            mb_explanation.fi_vals().values).mean()

        mlflow.log_metrics(interpret_results, step=i)

    mlflow.end_run()
Пример #4
0
logging.info('Linear Model')
logging.info(input_var_names)
logging.info(model.coef_)
logging.debug('This is a debugging message.')

# Relative feature importance
G = np.array([1])
fsoi = np.array([0, 1, 2, 3], dtype=np.int16)

samplers_classes = [CNFSampler, GaussianSampler]

for sampler_class in samplers_classes:

    sampler = sampler_class(X_train)
    sampler.train(fsoi, G)  # Fitting sampler

    rfi_explainer = explainer.Explainer(model.predict,
                                        fsoi,
                                        X_train,
                                        sampler=sampler,
                                        loss=mean_squared_error,
                                        fs_names=input_var_names)

    explanation = rfi_explainer.rfi(X_test, y_test, G)
    explanation.barplot()

    plt.title(
        f'{sampler.__class__.__name__}. G = {G}, N = {len(X_train) + len(X_test)}'
    )
    plt.show()
Пример #5
0
def main(args: DictConfig):

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)

    # Adding default estimator params
    default_names, _, _, default_values, _, _, _ = \
        inspect.getfullargspec(instantiate(args.estimator, context_size=0).__class__.__init__)
    if default_values is not None:
        args.estimator['defaults'] = {
            n: str(v) for (n, v) in zip(default_names[len(default_names) - len(default_values):], default_values)
        }
        args.estimator['defaults'].pop('cat_context')
    logger.info(OmegaConf.to_yaml(args, resolve=True))

    # Data generator init
    dag = DirectedAcyclicGraph.random_dag(**args.data_generator.dag)
    # if 'interpolation_switch' in args.data_generator.sem:
    #     args.data_generator.sem.interpolation_switch = args.data.n_train + args.data.n_test
    sem = instantiate(args.data_generator.sem, dag=dag)

    # Experiment tracking
    mlflow.set_tracking_uri(args.exp.mlflow_uri)
    mlflow.set_experiment(args.data_generator.sem_type)

    # Checking if run exist
    if check_existing_hash(args, args.data_generator.sem_type):
        logger.info('Skipping existing run.')
        return
    else:
        logger.info('No runs found - perfoming one.')

    mlflow.start_run()
    mlflow.log_params(flatten_dict(args))

    # Generating Train-test dataframes
    train_df = pd.DataFrame(sem.sample(size=args.data.n_train, seed=args.data.train_seed).numpy(), columns=dag.var_names)
    test_df = pd.DataFrame(sem.sample(size=args.data.n_test, seed=args.data.test_seed).numpy(), columns=dag.var_names)

    # Saving artifacts
    train_df.to_csv(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/train.csv'), index=False)
    test_df.to_csv(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/test.csv'), index=False)
    sem.dag.plot_dag()
    plt.savefig(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/dag.png'))
    if len(dag.var_names) <= 20:
        df = pd.concat([train_df, test_df], keys=['train', 'test']).reset_index().drop(columns=['level_1'])
        g = sns.pairplot(df, plot_kws={'alpha': 0.25}, hue='level_0')
        g.fig.suptitle(sem.__class__.__name__)
        plt.savefig(hydra.utils.to_absolute_path(f'{mlflow.get_artifact_uri()}/data.png'))

    metrics = {}

    for var_ind, target_var in enumerate(dag.var_names):

        var_results = {}

        # Considering all the variables for input
        input_vars = [var for var in dag.var_names if var != target_var]
        y_train, X_train = train_df.loc[:, target_var], train_df.loc[:, input_vars]
        y_test, X_test = test_df.loc[:, target_var], test_df.loc[:, input_vars]

        # Initialising risks
        risks = {}
        for risk in args.predictors.risks:
            risks[risk] = getattr(importlib.import_module('sklearn.metrics'), risk)

        # Fitting predictive model
        models = {}
        for pred_model in args.predictors.pred_models:
            logger.info(f'Fitting {pred_model._target_} for target = {target_var} and inputs {input_vars}')
            model = instantiate(pred_model)
            model.fit(X_train.values, y_train.values)
            y_pred = model.predict(X_test.values)
            models[pred_model._target_] = model
            for risk, risk_func in risks.items():
                var_results[f'test_{risk}_{pred_model._target_}'] = risk_func(y_test.values, y_pred)

        sampler = instantiate(args.estimator.sampler, X_train=X_train,
                              fit_method=args.estimator.fit_method, fit_params=args.estimator.fit_params)

        # =================== Relative feature importance ===================
        # 1. G = MB(target_var), FoI = input_vars / MB(target_var)
        G_vars_1 = list(sem.get_markov_blanket(target_var))
        fsoi_vars_1 = [var for var in input_vars if var not in list(sem.get_markov_blanket(target_var))]
        prefix_1 = 'mb'

        # 2. G = input_vars / MB(target_var), FoI = MB(target_var)
        fsoi_vars_2 = list(sem.get_markov_blanket(target_var))
        G_vars_2 = [var for var in input_vars if var not in list(sem.get_markov_blanket(target_var))]
        prefix_2 = 'non_mb'

        for (G_vars, fsoi_vars, prefix) in zip([G_vars_1, G_vars_2], [fsoi_vars_1, fsoi_vars_2], [prefix_1, prefix_2]):
            G = G_vars
            fsoi = fsoi_vars

            rfi_gof_metrics = {}
            for f, f_var in zip(fsoi, fsoi_vars):
                estimator = sampler.train([f], G)

                # GoF diagnostics
                rfi_gof_results = {}
                if estimator is not None:
                    test_inputs = X_test[sampler._order_fset([f])].to_numpy()
                    test_context = X_test[sampler._order_fset(G)].to_numpy()

                    rfi_gof_results[f'rfi/gof/{prefix}_mean_log_lik'] = \
                        estimator.log_prob(inputs=test_inputs, context=test_context).mean()

                    # Advanced conditional GoF metrics
                    if sem.get_markov_blanket(f_var).issubset(set(G_vars)):
                        cond_mode = 'all'
                    if isinstance(sem, LinearGaussianNoiseSEM):
                        cond_mode = 'arbitrary'

                    if sem.get_markov_blanket(f_var).issubset(set(G_vars)) or isinstance(sem, LinearGaussianNoiseSEM):
                        rfi_gof_results[f'rfi/gof/{prefix}_kld'] = \
                            conditional_kl_divergence(estimator, sem, f_var, G_vars, args.exp, cond_mode, test_df)
                        rfi_gof_results[f'rfi/gof/{prefix}_hd'] = \
                            conditional_hellinger_distance(estimator, sem, f_var, G_vars, args.exp, cond_mode, test_df)
                        rfi_gof_results[f'rfi/gof/{prefix}_jsd'] = \
                            conditional_js_divergence(estimator, sem, f_var, G_vars, args.exp, cond_mode, test_df)

                rfi_gof_metrics = {k: rfi_gof_metrics.get(k, []) + [rfi_gof_results.get(k, np.nan)]
                                   for k in set(list(rfi_gof_metrics.keys()) + list(rfi_gof_results.keys()))}

            # Feature importance
            if len(fsoi) > 0:
                var_results[f'rfi/{prefix}_cond_size'] = len(G_vars)

                for model_name, model in models.items():
                    for risk, risk_func in risks.items():

                        rfi_explainer = explainer.Explainer(model.predict, fsoi, X_train, sampler=sampler, loss=risk_func,
                                                            fs_names=input_vars)
                        mb_explanation = rfi_explainer.rfi(X_test, y_test, G, nr_runs=args.exp.rfi.nr_runs)
                        var_results[f'rfi/{prefix}_mean_rfi_{risk}_{model_name}'] = \
                            np.abs(mb_explanation.fi_vals().values).mean()

                var_results = {**var_results,
                               **{k: np.nanmean(v) if len(G_vars) > 0 else np.nan for (k, v) in rfi_gof_metrics.items()}}

        mlflow.log_metrics(var_results, step=var_ind)

        metrics = {k: metrics.get(k, []) + [var_results.get(k, np.nan)]
                   for k in set(list(metrics.keys()) + list(var_results.keys()))}

    # Logging mean statistics
    mlflow.log_metrics({k: np.nanmean(v) for (k, v) in metrics.items()}, step=len(dag.var_names))
    mlflow.end_run()