コード例 #1
0
    def setUp(self) -> None:
        np.random.seed(RANDOM_SEED)
        self.X, self.a, self.t, self.y = load_nhefs_survival()

        # Init various multiple models
        self.estimators = {
            'observed_non_parametric':
            MarginalSurvival(),
            'observed_parametric':
            MarginalSurvival(survival_model=LogisticRegression(max_iter=2000)),
            'ipw_non_parametric':
            WeightedSurvival(weight_model=IPW(
                LogisticRegression(max_iter=4000), use_stabilized=True),
                             survival_model=None),
            'ipw_parametric':
            WeightedSurvival(weight_model=IPW(
                LogisticRegression(max_iter=4000), use_stabilized=True),
                             survival_model=LogisticRegression(max_iter=4000)),
            'ipw_parametric_pipeline':
            WeightedSurvival(weight_model=IPW(
                LogisticRegression(max_iter=4000), use_stabilized=True),
                             survival_model=Pipeline([
                                 ('transform', PolynomialFeatures(degree=2)),
                                 ('LR', LogisticRegression(max_iter=1000, C=2))
                             ])),
            'standardization_non_stratified':
            StandardizedSurvival(
                survival_model=LogisticRegression(max_iter=4000),
                stratify=False),
            'standardization_stratified':
            StandardizedSurvival(
                survival_model=LogisticRegression(max_iter=4000),
                stratify=True),
        }
コード例 #2
0
    def test_weighted_kaplan_meier_curves(self):
        weighted_survival = WeightedSurvival(weight_model=IPW(
            LogisticRegression(max_iter=10000, C=10), use_stabilized=True),
                                             survival_model=None)
        weighted_survival.fit(self.X, self.a)
        curves_causallib = weighted_survival.estimate_population_outcome(
            self.X, self.a, self.t, self.y)

        weighted_survival_lifelines_km = WeightedSurvival(
            weight_model=IPW(LogisticRegression(max_iter=10000, C=10),
                             use_stabilized=True),
            survival_model=lifelines.KaplanMeierFitter())
        weighted_survival_lifelines_km.fit(self.X, self.a)
        curves_causallib_lifelines = weighted_survival_lifelines_km.estimate_population_outcome(
            self.X, self.a, self.t, self.y)

        np.testing.assert_array_almost_equal(curves_causallib,
                                             curves_causallib_lifelines,
                                             decimal=8)
コード例 #3
0
    def test_ipw_null_effect(self):
        test_data = TEST_DATA_TTE_DRUG_EFFECTS['B']
        model = WeightedSurvival
        params = {
            'weight_model':
            IPW(LogisticRegression(max_iter=10000), use_stabilized=True)
        }
        adjusted_diff = fit_synthetic_data(model_cls=model,
                                           params=params,
                                           test_data=test_data)

        self.assertAlmostEqual(adjusted_diff,
                               TEST_DATA_DRUG_EFFECTS_B_ORACLE_DIFF,
                               delta=TEST_DATA_DRUG_EFFECTS_DELTA)
コード例 #4
0
    def test_unnamed_input(self):
        test_data = TEST_DATA_TTE_DRUG_EFFECTS['A']
        X = test_data[['x_0', 'x_1']]
        a = test_data['a']
        y = test_data['y']
        t = test_data['t']

        X.index.name = None
        a.name = None
        y.name = None
        t.name = None
        ipw = WeightedSurvival(weight_model=IPW(LogisticRegression()),
                               survival_model=LogisticRegression())
        ipw.fit(X, a, t, y)
        _ = ipw.estimate_population_outcome(X=X, a=a, t=t, y=y)
コード例 #5
0
    def test_weighted_standardization_non_stratified(self):
        test_data = TEST_DATA_TTE_DRUG_EFFECTS['A']
        model = WeightedStandardizedSurvival
        params = {
            'weight_model':
            IPW(LogisticRegression(max_iter=2000), use_stabilized=True),
            'survival_model':
            LogisticRegression(max_iter=4000, C=5),
            'stratify':
            False
        }
        adjusted_diff = fit_synthetic_data(model_cls=model,
                                           params=params,
                                           test_data=test_data)

        self.assertAlmostEqual(adjusted_diff,
                               TEST_DATA_DRUG_EFFECTS_A_ORACLE_DIFF,
                               delta=TEST_DATA_DRUG_EFFECTS_DELTA)
コード例 #6
0
    def test_fit_kwargs(self):
        ipw = IPW(learner=LogisticRegression(max_iter=1000))
        weighted_standardized_survival = WeightedStandardizedSurvival(
            survival_model=lifelines.CoxPHFitter(), weight_model=ipw)

        # Without fit_kwargs - should raise StatisticalWarning with a suggestion to pass robust=True in fit
        with self.assertWarns(lifelines.exceptions.StatisticalWarning):
            weighted_standardized_survival.fit(self.X, self.a, self.t, self.y)

        # With fit_kwargs - should not raise StatisticalWarning (might raise other warnings, though)
        with self.assertRaises(
                AssertionError
        ):  # negation workaround since there's no assertNotWarns
            with self.assertWarns(lifelines.exceptions.StatisticalWarning):
                weighted_standardized_survival.fit(self.X,
                                                   self.a,
                                                   self.t,
                                                   self.y,
                                                   fit_kwargs={'robust': True})