示例#1
0
 def test_match_r_ltmle3(self):
     df = load_longitudinal_data()
     icgf = IterativeCondGFormula(df, exposures=['A1', 'A2', 'A3'], outcomes=['Y1', 'Y2', 'Y3'])
     icgf.outcome_model(models=['A1 + L1', 'A2 + A1 + L2', 'A3 + A2 + L3'], print_results=False)
     icgf.fit(treatments=[1, 1, 1])
     npt.assert_allclose(icgf.marginal_outcome, 0.4334696, rtol=1e-5)
     icgf.fit(treatments=[0, 0, 0])
     npt.assert_allclose(icgf.marginal_outcome, 0.6282985, rtol=1e-5)
示例#2
0
    def test_match_r_custom_treatment(self):
        df = load_longitudinal_data()
        icgf = IterativeCondGFormula(df, exposures=['A1', 'A2', 'A3'], outcomes=['Y1', 'Y2', 'Y3'])
        icgf.outcome_model(models=['A1 + L1', 'A2 + L2', 'A3 + L3'], print_results=False)
        icgf.fit(treatments=[1, 0, 1])
        npt.assert_allclose(icgf.marginal_outcome, 0.4916937, rtol=1e-5)

        icgf.fit(treatments=[0, 1, 0])
        npt.assert_allclose(icgf.marginal_outcome, 0.5634683, rtol=1e-5)
示例#3
0
 def test_treatment_dimension_error2(self):
     df = load_longitudinal_data()
     icgf = IterativeCondGFormula(df,
                                  exposures=['A1', 'A2', 'A3'],
                                  outcomes=['Y1', 'Y2', 'Y3'])
     icgf.outcome_model(models=['A1 + L1', 'A2 + L2', 'A3 + L3'],
                        print_results=False)
     with pytest.raises(ValueError):
         icgf.fit(treatments=[[1, 1, 1], [0, 0, 0]])
示例#4
0
    def test_iterative_for_single_t(self, sim_t_fixed_data):
        # Estimating sequential regression for single t
        gt = IterativeCondGFormula(sim_t_fixed_data, exposures=['A'], outcomes=['Y'])
        gt.outcome_model(['A + W1_sq + W2 + W3'], print_results=False)
        gt.fit(treatments=[1])

        # Estimating with TimeFixedGFormula
        gf = TimeFixedGFormula(sim_t_fixed_data, exposure='A', outcome='Y')
        gf.outcome_model(model='A + W1_sq + W2 + W3', print_results=False)
        gf.fit(treatment='all')

        # Expected behavior; same results between the estimation methods
        npt.assert_allclose(gf.marginal_outcome, gt.marginal_outcome)