コード例 #1
0
def predict_time_series(
    day0: pd.Timestamp,
    dep_var: str,
    mr_model: MRBRT,
    dep_trans_out: Callable[[pd.Series], pd.Series],
    diff: bool,
) -> pd.DataFrame:
    data = mr_model.data.to_df()

    pred_data = MRData()
    t = np.arange(0, data['t'].max() + 1)
    pred_data.load_df(pd.DataFrame({'t': t}), col_covs='t')
    pred_data_value = mr_model.predict(pred_data)
    if diff:
        pred_data_value = pred_data_value.cumsum()
    pred_data_value = dep_trans_out(pred_data_value)
    pred_data = pd.DataFrame({
        't': t,
        dep_var: pred_data_value,
    })
    pred_data['date'] = pred_data['t'].apply(
        lambda x: day0 + pd.Timedelta(days=x))
    pred_data = pred_data.set_index('date')[dep_var]

    return pred_data
コード例 #2
0
def test_remove_nan_in_covs(df, covs):
    df.loc[:0, covs] = np.nan
    d = MRData()
    with pytest.warns(Warning):
        d.load_df(df, col_obs='obs', col_obs_se='obs_se', col_covs=covs)

    assert d.num_obs == df.shape[0] - 1
コード例 #3
0
def test_covs(df, covs):
    d = MRData()
    d.load_df(df, col_obs='obs', col_obs_se='obs_se', col_covs=covs)

    num_covs = 0 if covs is None else len(covs)
    num_covs += 1
    assert d.num_covs == num_covs
コード例 #4
0
def data(df):
    df['study_id'] = np.array([0, 0, 1, 1, 2])
    d = MRData()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=[f'cov{i}' for i in range(3)],
              col_study_id='study_id')
    return d
コード例 #5
0
def test_obs(df, obs, obs_se):
    d = MRData()
    d.load_df(df,
              col_obs=obs,
              col_obs_se=obs_se,
              col_covs=['cov0', 'cov1', 'cov2'])
    assert d.obs.size == df.shape[0]
    assert d.obs_se.size == df.shape[0]
    if obs is None:
        assert all(np.isnan(d.obs))
コード例 #6
0
 def predict(
     self,
     data: MRData = None,
     slope_quantile: Dict[str, float] = None,
     ref_cov: Tuple[str, Any] = None,
 ):
     if data is None:
         data = self.data1
     data._sort_by_data_id()
     pred1 = self.model1.predict(data, slope_quantile=slope_quantile, ref_cov=ref_cov)
     return self.model2.predict(data) + pred1
コード例 #7
0
ファイル: test_data.py プロジェクト: jiaweih/MRTool
def test_load_xr(xarray):
    d = MRData()
    d.load_xr(xarray,
              var_obs='y',
              var_obs_se='y_se',
              var_covs=['sdi'],
              coord_study_id='location_id')

    assert np.allclose(np.sort(d.obs), np.sort(xarray['y'].data.flatten()))
    assert np.allclose(np.sort(d.obs_se), np.sort(xarray['y_se'].data.flatten()))
    assert np.allclose(np.sort(d.covs['sdi']), np.sort(xarray['sdi'].data.flatten()))
    assert np.allclose(np.sort(d.studies), np.sort(xarray.coords['location_id']))
コード例 #8
0
def test_is_empty(df):
    d = MRData()
    assert d.is_empty()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=['cov0', 'cov1', 'cov2'])
    assert not d.is_empty()
    d.reset()
    assert d.is_empty()
コード例 #9
0
def model_intercept(data: pd.DataFrame,
                    dep_var: str,
                    prediction: pd.Series,
                    weight_data: pd.DataFrame = None,
                    dep_var_se: str = None,
                    dep_trans_in: Callable[[pd.Series],
                                           pd.Series] = lambda x: x,
                    dep_se_trans_in: Callable[[pd.Series],
                                              pd.Series] = lambda x: x,
                    dep_trans_out: Callable[[pd.Series],
                                            pd.Series] = lambda x: x,
                    verbose: bool = True):
    data = data.copy()
    data[dep_var] = dep_trans_in(data[dep_var])
    prediction = dep_trans_in(prediction)
    data = reshape_data_long(data, dep_var)
    if weight_data is not None:
        weight_data = reshape_data_long(weight_data, dep_var_se)
        if (data['date'] != weight_data['date']).any():
            raise ValueError(
                'Dates in `data` and `weight_data` not identical.')
        data['se'] = dep_se_trans_in(weight_data[dep_var_se])
    else:
        data['se'] = 1.
    data = data.set_index('date').sort_index()
    data[dep_var] = data[dep_var] - prediction
    data = data.reset_index().dropna()
    data['intercept'] = 1

    mr_data = MRData()
    mr_data.load_df(
        data,
        col_obs=dep_var,
        col_obs_se='se',
        col_covs=['intercept'],
        col_study_id='date',
    )
    intercept_model = LinearCovModel(
        'intercept',
        use_re=False,
    )
    mr_model = MRBRT(mr_data, [intercept_model])
    mr_model.fit_model()

    intercept = mr_model.beta_soln

    prediction += intercept
    prediction = dep_trans_out(prediction)

    return prediction
コード例 #10
0
ファイル: test_covmodel.py プロジェクト: vishalbelsare/MRTool
def mrdata(seed=123):
    np.random.seed(seed)
    data = pd.DataFrame({
        'obs': np.random.randn(10),
        'obs_se': np.full(10, 0.1),
        'cov0': np.ones(10),
        'cov1': np.random.randn(10),
        'study_id': np.random.choice(range(3), 10)
    })
    mrdata = MRData()
    mrdata.load_df(data,
                   col_obs='obs',
                   col_obs_se='obs_se',
                   col_covs=['cov0', 'cov1'],
                   col_study_id='study_id')
    return mrdata
コード例 #11
0
 def get_pred_data(self) -> MRData:
     zero_cov = np.zeros(self.num_points)
     other_covs = {
         cov_name: zero_cov
         for cov_name in self.model.data.covs
         if cov_name not in [self.cont_cov_name, 'intercept']
     }
     covs = {self.cont_cov_name: self.pred_cont, **other_covs}
     return MRData(covs=covs)
コード例 #12
0
ファイル: continuous.py プロジェクト: rmbarber/mrtool
 def get_signal(self,
                alt_cov: List[np.ndarray],
                ref_cov: List[np.ndarray]) -> np.ndarray:
     covs = {}
     for i, cov_name in enumerate(self.alt_cov_names):
         covs[cov_name] = alt_cov[i]
     for i, cov_name in enumerate(self.ref_cov_names):
         covs[cov_name] = ref_cov[i]
     data = MRData(covs=covs)
     return self.signal_model.predict(data)
コード例 #13
0
def test_has_covs(df):
    d = MRData()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=['cov0', 'cov1', 'cov2'])
    assert d.has_covs(['cov0'])
    assert d.has_covs(['cov0', 'cov1'])
    assert not d.has_covs(['cov3'])
コード例 #14
0
def test_normalize_covs(df, covs):
    d = MRData()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=['cov0', 'cov1', 'cov2'])

    d.normalize_covs(covs)
    assert d.is_cov_normalized(covs)
コード例 #15
0
def create_mr_data(model_data: pd.DataFrame,
                   dep_var: str,
                   dep_var_se: str,
                   fe_vars: List[str],
                   group_var: str,
                   pred: bool = False,
                   **kwargs):
    if pred:
        mr_data = MRData(
            covs={fe_var: model_data[fe_var].values
                  for fe_var in fe_vars},
            study_id=model_data[group_var].values)
    else:
        mr_data = MRData(
            obs=model_data[dep_var].values,
            obs_se=model_data[dep_var_se].values,
            covs={fe_var: model_data[fe_var].values
                  for fe_var in fe_vars},
            study_id=model_data[group_var].values)

    return mr_data
コード例 #16
0
def test_study_id(df, study_id):
    if study_id is not None:
        df['study_id'] = study_id
        col_study_id = 'study_id'
    else:
        col_study_id = None
    d = MRData()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=['cov0', 'cov1', 'cov2'],
              col_study_id=col_study_id)

    if col_study_id is None:
        assert np.all(d.study_id == 'Unknown')
        assert d.num_studies == 1
        assert d.studies[0] == 'Unknown'
    else:
        assert np.allclose(d.study_id, np.array([0, 0, 1, 1, 2]))
        assert d.num_studies == 3
        assert np.allclose(d.studies, np.array([0, 1, 2]))
        assert np.allclose(d.study_sizes, np.array([2, 2, 1]))
コード例 #17
0
def test_assert_has_covs(df):
    d = MRData()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=['cov0', 'cov1', 'cov2'])
    with pytest.raises(ValueError):
        d._assert_has_covs('cov3')
コード例 #18
0
    def __init__(self, t, y,
                 spline_options=None):
        """Constructor of the SplineFit

        Args:
            t (np.ndarray): Independent variable.
            y (np.ndarray): Dependent variable.
            spline_options (dict | None, optional):
                Dictionary of spline prior options.
        """
        self.t = t
        self.y = y
        self.spline_options = {} if spline_options is None else spline_options

        # create mrbrt object
        df = pd.DataFrame({
            'y': self.y,
            'y_se': 1.0/np.exp(self.y),
            't': self.t,
            'study_id': 1,
        })

        data = MRData(
            df=df,
            col_obs='y',
            col_obs_se='y_se',
            col_covs=['t'],
            col_study_id='study_id',
            add_intercept=True
        )

        intercept = LinearCovModel(
            alt_cov='intercept',
            use_re=True,
            prior_gamma_uniform=np.array([0.0, 0.0]),
            name='intercept'
        )

        time = LinearCovModel(
            alt_cov='t',
            use_re=False,
            use_spline=True,
            **self.spline_options,
            name='time'
        )

        self.mr_model = MRBRT(data, cov_models=[intercept, time])
        self.spline = time.create_spline(data)
        self.spline_coef = None
コード例 #19
0
 def create_design_mat_from_xarray(self, covs: List[xr.DataArray]) -> np.ndarray:
     var_coord = "variable"
     for cov in covs:
         if "year_id" in cov.coords:
             year_id = cov.year_id
             year_id.name = "year_id_"
             covs.append(year_id)
             break
     da = xr.merge(covs).to_array()
     data = MRData(covs={
         cov.strip("_"): da.values[i].ravel()
         for i, cov in enumerate(da.coords[var_coord].values)
     })
     del da.coords[var_coord]
     return self.create_design_mat(data), da[0].coords, da[0].dims, da[0].shape
コード例 #20
0
ファイル: test_data.py プロジェクト: jiaweih/MRTool
def test_get_covs(df):
    d = MRData()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=['cov0', 'cov1', 'cov2'])
    for cov_name in ['cov0', 'cov1', 'cov2']:
        assert np.allclose(d.get_covs(cov_name), df[cov_name].to_numpy()[:, None])

    cov_mat = d.get_covs(['cov0', 'cov1', 'cov2'])
    assert np.allclose(cov_mat, df[['cov0', 'cov1', 'cov2']].to_numpy())
コード例 #21
0
ファイル: continuous.py プロジェクト: rmbarber/mrtool
 def get_pred_data(self, num_points: int = 100) -> MRData:
     if num_points == -1:
         alt_cov = self.ref_exposures
     else:
         alt_cov = self.get_pred_exposures(num_points=num_points)
     ref_cov = np.repeat(self.exposure_lend, alt_cov.size)
     zero_cov = np.zeros(alt_cov.size)
     signal = self.get_signal(
         alt_cov=[alt_cov for _ in self.alt_cov_names],
         ref_cov=[ref_cov for _ in self.ref_cov_names]
     )
     other_covs = {
         cov_name: zero_cov
         for cov_name in self.final_model.data.covs
         if cov_name not in ('signal', 'linear')
     }
     if not self.j_shaped:
         covs = {'signal': signal, **other_covs}
     else:
         covs = {'signal': signal, 'linear': alt_cov - ref_cov, **other_covs}
     return MRData(covs=covs)
コード例 #22
0
def test_data_id(df, study_id):
    if study_id is not None:
        df['study_id'] = study_id
        col_study_id = 'study_id'
    else:
        col_study_id = None

    d = MRData()
    d.load_df(df,
              col_obs='obs',
              col_obs_se='obs_se',
              col_covs=['cov0', 'cov1', 'cov2'],
              col_study_id=col_study_id)

    d._sort_by_data_id()
    assert np.allclose(d.obs, df['obs'])
    assert np.allclose(d.obs_se, df['obs_se'])
    for i in range(3):
        assert np.allclose(d.covs[f'cov{i}'], df[f'cov{i}'])
コード例 #23
0
    def __init__(self,
                 data: pd.DataFrame,
                 dep_var: str,
                 spline_var: str,
                 indep_vars: List[str],
                 n_i_knots: int,
                 ensemble_knots: np.array = None,
                 spline_options: Dict = dict(),
                 observed_var: str = None,
                 pseudo_se_multiplier: float = 1.,
                 se_default: float = 1.,
                 log: bool = True,
                 verbose: bool = True):
        # set up model data
        if verbose:
            logger.debug('Setting up model data.')
        data = data.copy()
        if observed_var:
            if not data[observed_var].dtype == 'bool':
                raise ValueError(
                    f'Observed variable ({observed_var}) is not boolean.')
            data.loc[~data[observed_var], 'obs_se'] *= pseudo_se_multiplier
        else:
            observed_var = 'observed'
            data[observed_var] = True

        # create mrbrt object
        data['study_id'] = 1
        if verbose:
            logger.debug('Building MRData.')
        mr_data = MRData(df=data,
                         col_obs=dep_var,
                         col_obs_se='obs_se',
                         col_covs=indep_vars + [spline_var],
                         col_study_id='study_id')
        self.data = data

        # cov models
        if verbose:
            logger.debug('Making covariate models.')
        cov_models = []
        if 'intercept' in indep_vars:
            if log:
                prior_beta_uniform = {
                    'prior_beta_uniform': np.array([-np.inf, 0.])
                }
            else:
                prior_beta_uniform = {
                    'prior_beta_uniform': np.array([0., np.inf])
                }
            cov_models += [
                LinearCovModel(alt_cov='intercept',
                               use_re=True,
                               prior_gamma_uniform=np.array([0., 0.]),
                               name='intercept',
                               **prior_beta_uniform)
            ]
        if any([i not in ['intercept']
                for i in indep_vars]):  # , 'Model testing rate'
            bad_vars = [i for i in indep_vars
                        if i not in ['intercept']]  # , 'Model testing rate'
            raise ValueError(
                f"Unsupported independent variable(s) entered: {'; '.join(bad_vars)}"
            )

        # get random knot placement
        if verbose:
            logger.debug('Getting random knot placement.')
        if 'spline_knots' in list(spline_options.keys()):
            raise ValueError(
                'Using random spline, do not manually specify knots.')
        if ensemble_knots is None:
            ensemble_knots = self.get_ensemble_knots(n_i_knots,
                                                     data[spline_var].values,
                                                     data[observed_var].values,
                                                     spline_options)

        # spline cov model
        if verbose:
            logger.debug('Setting up spline covariate model.')
        spline_model = LinearCovModel(alt_cov=spline_var,
                                      use_re=False,
                                      use_spline=True,
                                      **spline_options,
                                      prior_spline_num_constraint_points=100,
                                      spline_knots=ensemble_knots[0],
                                      name=spline_var)

        # var names
        self.indep_vars = [i for i in indep_vars if i != 'intercept']
        self.spline_var = spline_var

        # model
        if verbose:
            logger.debug('Building MRBeRT model.')
        self.mr_model = MRBeRT(mr_data,
                               ensemble_cov_model=spline_model,
                               ensemble_knots=ensemble_knots,
                               cov_models=cov_models)
        self.submodel_fits = None
        self.coef_dicts = None
コード例 #24
0
    def __init__(self,
                 t,
                 y,
                 spline_options=None,
                 se_power=1.0,
                 space='ln daily',
                 max_iter=50):
        """Constructor of the SplineFit

        Args:
            t (np.ndarray): Independent variable.
            y (np.ndarray): Dependent variable.
            spline_options (dict | None, optional):
                Dictionary of spline prior options.
            se_power (float):
                A number between 0 and 1 that scale the standard error.
            space (str):
                Which space is the spline fitting, assume y is daily cases.
            max_iter (int):
                Maximum number of iteration.
        """
        self.space = space
        assert self.space in ['daily', 'ln daily', 'cumul', 'ln cumul'], "spline_space must be one of 'daily'," \
                                                                         " 'ln daily', 'cumul', 'ln cumul' space."
        if self.space == 'ln daily':
            self.t = t[y > 0.0]
            self.y = np.log(y[y > 0.0])
        elif self.space == 'daily':
            self.t = t
            self.y = y
        elif self.space == 'ln cumul':
            y = np.cumsum(y)
            self.t = t[y > 0.0]
            self.y = np.log(y[y > 0.0])
        else:
            self.t = t
            self.y = np.cumsum(y)
        self.spline_options = {} if spline_options is None else spline_options
        self.se_power = se_power

        assert 0 <= self.se_power <= 1, "spline se_power has to be between 0 and 1."
        if self.se_power == 0:
            y_se = np.ones(self.t.size)
        else:
            y_se = 1.0 / np.exp(self.y)**self.se_power
        # create mrbrt object
        df = pd.DataFrame({
            'y': self.y,
            'y_se': y_se,
            't': self.t,
            'study_id': 1,
        })

        data = MRData(df=df,
                      col_obs='y',
                      col_obs_se='y_se',
                      col_covs=['t'],
                      col_study_id='study_id',
                      add_intercept=True)

        intercept = LinearCovModel(alt_cov='intercept',
                                   use_re=True,
                                   prior_gamma_uniform=np.array([0.0, 0.0]),
                                   name='intercept')

        time = LinearCovModel(alt_cov='t',
                              use_re=False,
                              use_spline=True,
                              **self.spline_options,
                              name='time')

        self.mr_model = MRBRT(data, cov_models=[intercept, time])
        self.spline = time.create_spline(data)
        self.spline_coef = None
        self.max_iter = max_iter
コード例 #25
0
def estimate_time_series(
    data: pd.DataFrame,
    spline_options: Dict,
    n_knots: int,
    dep_var: str,
    dep_trans_in: Callable[[pd.Series], pd.Series] = lambda x: x,
    weight_data: pd.DataFrame = None,
    dep_var_se: str = None,
    dep_se_trans_in: Callable[[pd.Series], pd.Series] = lambda x: x,
    diff: bool = False,
    num_submodels: int = 25,
    single_random_knot: bool = False,
    min_interval_days: int = 7,
    dep_trans_out: Callable[[pd.Series], pd.Series] = lambda x: x,
    split_l_interval: bool = False,
    split_r_interval: bool = False,
    verbose: bool = False,
) -> Tuple[pd.DataFrame, pd.Series, MRBeRT]:
    if verbose: logger.info('Formatting data.')
    data = data.copy()
    data[dep_var] = dep_trans_in(data[dep_var])
    if diff:
        if verbose:
            logger.info(
                'For diff model, drop day1 (i.e., if day0 is > 0, day0->day1 diff would be hugely negative).'
            )
        data[dep_var] = data[dep_var].diff()
        data[dep_var] = data[dep_var][data[dep_var].diff().notnull()]
    if data[[dep_var]].shape[1] > 1:
        reshape = True
        data = reshape_data_long(data, dep_var)
        if weight_data is not None:
            weight_data = reshape_data_long(weight_data, dep_var_se)
    else:
        reshape = False
    if weight_data is not None:
        if (data['date'] != weight_data['date']).any():
            raise ValueError(
                'Dates in `data` and `weight_data` not identical.')
        data['se'] = dep_se_trans_in(weight_data[dep_var_se])
    else:
        data['se'] = 1.
    data = data.rename(columns={dep_var: 'y'})
    day0 = data['date'].min()
    keep_vars = ['date', 'y', 'se']
    data = data.loc[:, keep_vars]
    start_len = len(data)
    data = data.dropna()
    end_len = len(data)
    if start_len != end_len and not reshape:
        if verbose: logger.debug('NAs in data')
    data['t'] = (data['date'] - day0).dt.days

    col_args = {
        'col_obs': 'y',
        'col_obs_se': 'se',
        'col_covs': ['t'],
        #'col_study_id':'date',
    }
    if verbose: logger.info('Getting base knots.')
    min_interval = min_interval_days / data['t'].max()
    if num_submodels == 1 and single_random_knot:
        spline_knots = get_ensemble_knots(n_knots, min_interval, 1)[0]
    else:
        spline_knots = np.linspace(0., 1., n_knots)

    if split_l_interval or split_r_interval:
        if num_submodels > 1:
            raise ValueError(
                'Would need to set up functionality to split segments for ensemble.'
            )
        if split_l_interval:
            n_knots += 1
            spline_knots = np.insert(spline_knots, 0, spline_knots[:2].mean())
        if split_r_interval:
            n_knots += 1
            spline_knots = np.insert(spline_knots, -1,
                                     spline_knots[-2:].mean())

    if verbose: logger.info('Creating model data.')
    mr_data = MRData()
    mr_data.load_df(data, **col_args)
    spline_model = LinearCovModel('t',
                                  use_re=False,
                                  use_spline=True,
                                  use_spline_intercept=True,
                                  spline_knots=spline_knots,
                                  **spline_options)
    if num_submodels > 1:
        if verbose: logger.info('Sampling knots.')
        ensemble_knots = get_ensemble_knots(n_knots, min_interval,
                                            num_submodels)

        if verbose: logger.info('Initializing model.')
        mr_model = MRBeRT(mr_data, spline_model, ensemble_knots)
    else:
        if verbose: logger.info('Initializing model.')
        mr_model = MRBRT(mr_data, [spline_model])

    if verbose: logger.info('Fitting model.')
    mr_model.fit_model()

    if num_submodels > 1:
        if verbose: logger.info('Scoring submodels.')
        mr_model.score_model()

    data = data.set_index('date')[['y', 'se']]

    if verbose: logger.info('Making prediction.')
    smooth_data = predict_time_series(
        day0=day0,
        dep_var=dep_var,
        mr_model=mr_model,
        dep_trans_out=dep_trans_out,
        diff=diff,
    )

    return data, smooth_data, mr_model
コード例 #26
0
def test_assert_not_empty():
    d = MRData()
    with pytest.raises(ValueError):
        d._assert_not_empty()
コード例 #27
0
ファイル: plots.py プロジェクト: rmbarber/mrtool
def plot_risk_function(mrbrt, pair, beta_samples, gamma_samples, alt_cov_names=None, 
    ref_cov_names=None, continuous_variables=[], plot_note=None, plots_dir=None, 
    write_file=False):
    """Plot predicted relative risk.
    Args:
        mrbrt (mrtool.MRBRT):
            MRBeRT object.
        pair (str):
            risk_outcome pair. eg. 'redmeat_colorectal'
        beta_samples (np.ndarray):
            Beta samples generated using `sample_soln` function in MRBRT
        gamma_samples (np.ndarray):
            Gamma samples generated using `sample_soln` function in MRBRT
        alt_cov_names (List[str], optional):
            Name of the alternative exposures, if `None` use `['b_0', 'b_1']`.
            Default to `None`.
        ref_cov_names (List[str], optional):
            Name of the reference exposures, if `None` use `['a_0', 'a_1']`.
            Default to `None`.
        continuous_variables (list):
            List of continuous covariate names.
        plot_note (str):
            The notes intended to be written on the title.
        plots_dir (str):
            Directory where to save the plot.
        write_file (bool):
            Specify `True` if the plot is expected to be saved on disk.
            If True, `plots_dir` should be specified too.
    """
    data_df = mrbrt.data.to_df()
    sub = mrbrt.sub_models[0]
    knots = sub.get_cov_model(mrbrt.ensemble_cov_model_name).spline.knots
    min_cov = knots[0]
    max_cov = knots[-1]
    dose_grid = np.linspace(min_cov, max_cov)
    col_covs = sub.cov_names
    pred_df = pd.DataFrame(dict(zip(col_covs, np.zeros(len(col_covs)))), 
        index=np.arange(len(dose_grid)))

    alt_cov_names = ['b_0', 'b_1'] if alt_cov_names is None else alt_cov_names
    ref_cov_names = ['a_0', 'a_1'] if ref_cov_names is None else ref_cov_names
    pred_df['intercept'] = 1
    pred_df[alt_cov_names[0]] = dose_grid
    pred_df[alt_cov_names[1]] = dose_grid
    pred_df[ref_cov_names[0]] = knots[0]
    pred_df[ref_cov_names[1]] = knots[0]
    
    # if it's continuous variables, take median 
    for var in continuous_variables:
        pred_df[var] = np.median(data_df[var])

    pred_data = MRData()
    pred_data.load_df(pred_df, col_covs=col_covs)

    y_draws = mrbrt.create_draws(pred_data, beta_samples, gamma_samples, random_study=True)
    y_draws_fe = mrbrt.create_draws(pred_data, beta_samples, gamma_samples, random_study=False)

    num_samples = y_draws_fe.shape[1]
    sort_index = np.argsort(y_draws_fe[-1])
    trimmed_draws = y_draws_fe[:, sort_index[int(num_samples*0.01): -int(num_samples*0.01)]]
    patch_index = np.random.choice(trimmed_draws.shape[1], 
        y_draws_fe.shape[1] - trimmed_draws.shape[1], replace=True)
    y_draws_fe = np.hstack((trimmed_draws, trimmed_draws[:, patch_index]))
    
    y_mean_fe = np.mean(y_draws_fe, axis=1)
    y_lower_fe = np.percentile(y_draws_fe, 2.5, axis=1)
    y_upper_fe = np.percentile(y_draws_fe, 97.5, axis=1)
    
    plt.rcParams['axes.edgecolor'] = '0.15'
    plt.rcParams['axes.linewidth'] = 0.5

    plt.plot(dose_grid, np.exp(y_lower_fe), c='gray')
    plt.plot(dose_grid, np.exp(y_upper_fe), c='gray')
    plt.plot(dose_grid, np.exp(y_mean_fe), c='red')
    plt.ylim([np.exp(y_lower_fe).min() - np.exp(y_mean_fe).ptp()*0.1,
              np.exp(y_upper_fe).max() + np.exp(y_mean_fe).ptp()*0.1])
    plt.ylabel('RR', fontsize=10)
    plt.xlabel("Exposure", fontsize=10)
    
    if plot_note is not None:
        plt.title(plot_note)

    # save plot    
    if write_file:
        assert plots_dir is not None, "plots_dir is not specified!"
        outfile = os.path.join(plots_dir, f'{pair}_risk_function.pdf')
        plt.savefig(outfile, bbox_inches='tight')
        print(f"Risk function plot saved at {outfile}")
    else:
        plt.show()
    plt.close()