Ejemplo n.º 1
0
 def test_error_no_fit(self, data):
     sgf = SurvivalGFormula(data,
                            idvar='id',
                            exposure='art',
                            outcome='d',
                            time='t')
     with pytest.raises(ValueError):
         sgf.plot(treatment='all')
Ejemplo n.º 2
0
def causal_check():
    # Check IPTW plots
    data = load_sample_data(False)
    data[['cd4_rs1', 'cd4_rs2']] = spline(data,
                                          'cd40',
                                          n_knots=3,
                                          term=2,
                                          restricted=True)
    data[['age_rs1', 'age_rs2']] = spline(data,
                                          'age0',
                                          n_knots=3,
                                          term=2,
                                          restricted=True)
    ipt = IPTW(data, treatment='art', stabilized=True)
    ipt.regression_models(
        'male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0')
    ipt.fit()
    ipt.plot_love()
    plt.tight_layout()
    plt.show()
    ipt.plot_kde()
    plt.show()
    ipt.plot_kde(measure='logit')
    plt.show()
    ipt.plot_boxplot()
    plt.show()
    ipt.plot_boxplot(measure='logit')
    plt.show()

    # Check SurvivalGFormula plots
    df = load_sample_data(False).drop(columns=['cd4_wk45'])
    df['t'] = np.round(df['t']).astype(int)
    df = pd.DataFrame(np.repeat(df.values, df['t'], axis=0),
                      columns=df.columns)
    df['t'] = df.groupby('id')['t'].cumcount() + 1
    df.loc[((df['dead'] == 1) & (df['id'] != df['id'].shift(-1))), 'd'] = 1
    df['d'] = df['d'].fillna(0)
    df['t_sq'] = df['t']**2
    df['t_cu'] = df['t']**3
    sgf = SurvivalGFormula(df,
                           idvar='id',
                           exposure='art',
                           outcome='d',
                           time='t')
    sgf.outcome_model(
        model='art + male + age0 + cd40 + dvl0 + t + t_sq + t_cu')
    sgf.fit(treatment='all')
    sgf.plot()
    plt.show()
    sgf.plot(c='r', linewidth=3, alpha=0.8)
    plt.show()
Ejemplo n.º 3
0
 def test_error_continuous_a(self, data):
     with pytest.raises(ValueError):
         SurvivalGFormula(data,
                          idvar='id',
                          exposure='cd40',
                          outcome='d',
                          time='t')
Ejemplo n.º 4
0
    def test_treat_custom(self, data):
        sgf = SurvivalGFormula(data, idvar='id', exposure='art', outcome='d', time='t')
        sgf.outcome_model(model='art + male + age0 + cd40 + dvl0 + t + t_sq + t_cu', print_results=False)
        sgf.fit(treatment="((g['age0']>=25) & (g['male']==0))")

        npt.assert_allclose(sgf.marginal_outcome.iloc[0], 0.015090, atol=1e-5)
        npt.assert_allclose(sgf.marginal_outcome.iloc[-1], 0.137336, atol=1e-5)
        npt.assert_allclose(sgf.marginal_outcome.iloc[9], 0.088886, atol=1e-5)
Ejemplo n.º 5
0
    def test_treat_natural(self, data):
        sgf = SurvivalGFormula(data, idvar='id', exposure='art', outcome='d', time='t')
        sgf.outcome_model(model='art + male + age0 + cd40 + dvl0 + t + t_sq + t_cu', print_results=False)
        sgf.fit(treatment='natural')

        npt.assert_allclose(sgf.marginal_outcome.iloc[0], 0.015030, atol=1e-5)
        npt.assert_allclose(sgf.marginal_outcome.iloc[-1], 0.135694, atol=1e-5)
        npt.assert_allclose(sgf.marginal_outcome.iloc[9], 0.088382, atol=1e-5)
Ejemplo n.º 6
0
def causal_check():
    data = load_sample_data(False).drop(columns=['cd4_wk45'])
    data[['cd4_rs1', 'cd4_rs2']] = spline(data,
                                          'cd40',
                                          n_knots=3,
                                          term=2,
                                          restricted=True)
    data[['age_rs1', 'age_rs2']] = spline(data,
                                          'age0',
                                          n_knots=3,
                                          term=2,
                                          restricted=True)

    # Check TimeFixedGFormula diagnostics
    g = TimeFixedGFormula(data, exposure='art', outcome='dead')
    g.outcome_model(
        model=
        'art + male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0'
    )
    g.run_diagnostics(decimal=3)

    # Check IPTW plots
    ipt = IPTW(data, treatment='art', outcome='dead')
    ipt.treatment_model(
        'male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0',
        stabilized=True)
    ipt.marginal_structural_model('art')
    ipt.fit()
    ipt.plot_love()
    plt.tight_layout()
    plt.show()
    ipt.plot_kde()
    plt.show()
    ipt.plot_kde(measure='logit')
    plt.show()
    ipt.plot_boxplot()
    plt.show()
    ipt.plot_boxplot(measure='logit')
    plt.show()
    ipt.run_diagnostics()

    # Check AIPTW Diagnostics
    aipw = AIPTW(data, exposure='art', outcome='dead')
    aipw.exposure_model(
        'male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0')
    aipw.outcome_model(
        'art + male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0'
    )
    aipw.fit()
    aipw.run_diagnostics()
    aipw.plot_kde(to_plot='exposure')
    plt.show()
    aipw.plot_kde(to_plot='outcome')
    plt.show()
    aipw.plot_love()
    plt.show()

    # Check TMLE diagnostics
    tmle = TMLE(data, exposure='art', outcome='dead')
    tmle.exposure_model(
        'male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0')
    tmle.outcome_model(
        'art + male + age0 + age_rs1 + age_rs2 + cd40 + cd4_rs1 + cd4_rs2 + dvl0'
    )
    tmle.fit()
    tmle.run_diagnostics()
    tmle.plot_kde(to_plot='exposure')
    plt.show()
    tmle.plot_kde(to_plot='outcome')
    plt.show()
    tmle.plot_love()
    plt.show()

    # Check SurvivalGFormula plots
    df = load_sample_data(False).drop(columns=['cd4_wk45'])
    df['t'] = np.round(df['t']).astype(int)
    df = pd.DataFrame(np.repeat(df.values, df['t'], axis=0),
                      columns=df.columns)
    df['t'] = df.groupby('id')['t'].cumcount() + 1
    df.loc[((df['dead'] == 1) & (df['id'] != df['id'].shift(-1))), 'd'] = 1
    df['d'] = df['d'].fillna(0)
    df['t_sq'] = df['t']**2
    df['t_cu'] = df['t']**3
    sgf = SurvivalGFormula(df,
                           idvar='id',
                           exposure='art',
                           outcome='d',
                           time='t')
    sgf.outcome_model(
        model='art + male + age0 + cd40 + dvl0 + t + t_sq + t_cu')
    sgf.fit(treatment='all')
    sgf.plot()
    plt.show()
    sgf.plot(c='r', linewidth=3, alpha=0.8)
    plt.show()
Ejemplo n.º 7
0
                                         term=2,
                                         restricted=True)
df[['cd4_rs1', 'cd4_rs2']] = spline(df,
                                    'cd40',
                                    n_knots=3,
                                    term=2,
                                    restricted=True)
df[['age_rs1', 'age_rs2']] = spline(df,
                                    'age0',
                                    n_knots=3,
                                    term=2,
                                    restricted=True)

sgf = SurvivalGFormula(df.drop(columns=['dead']),
                       idvar='id',
                       exposure='art',
                       outcome='d',
                       time='t')
sgf.outcome_model(model='art + male + age0 + age_rs1 + age_rs2 + cd40 + '
                  'cd4_rs1 + cd4_rs2 + dvl0 + t + t_rs1 + t_rs2 + t_rs3',
                  print_results=False)

sgf.fit(treatment='all')
sgf.plot(c='b')
sgf.fit(treatment='none')
sgf.plot(c='r')
plt.ylabel('Probability of death')
plt.tight_layout()
plt.savefig("../images/survival_gf_cif.png", format='png', dpi=300)
plt.close()