예제 #1
0
    def test_input_w_just_2_points_raises_exception(self):
        data = pd.DataFrame(np.random.randn(2, 2), columns=['x1', 'x2'])
        causal_impact = CausalImpact(data, [0, 0], [1, 1], {})

        with pytest.raises(ValueError) as excinfo:
            causal_impact._format_input(causal_impact.params["data"],
                                        causal_impact.params["pre_period"],
                                        causal_impact.params["post_period"],
                                        causal_impact.params["model_args"],
                                        None, None,
                                        causal_impact.params["alpha"])
        assert str(excinfo.value) == 'data must have at least 3 time points'
예제 #2
0
    def test_input_covariates_w_nan_value_raises(self):
        data = np.array([[1, 1, 2], [1, 2, 3], [1, 3, 4], [1, np.nan, 5],
                         [1, 6, 7]])
        data = pd.DataFrame(data, columns=['y', 'x1', 'x2'])
        causal_impact = CausalImpact(data, [0, 3], [3, 4], {})

        with pytest.raises(ValueError) as excinfo:
            causal_impact._format_input(causal_impact.params["data"],
                                        causal_impact.params["pre_period"],
                                        causal_impact.params["post_period"],
                                        causal_impact.params["model_args"],
                                        None, None,
                                        causal_impact.params["alpha"])
        assert str(excinfo.value) == 'covariates must not contain null values'
예제 #3
0
    def test_input_w_time_column(self):
        data = pd.DataFrame(np.random.randn(100, 2), columns=['x1', 'x2'])
        data['time'] = pd.date_range(start='2018-01-01', periods=100)
        data = data[['time', 'x1', 'x2']]
        pre_period = ['2018-01-01', '2018-02-10']
        post_period = ['2018-02-11', '2018-4-10']

        causal_impact = CausalImpact(data, pre_period, post_period, {})

        data = data.set_index('time')
        pre_period = [pd.to_datetime(e) for e in pre_period]
        post_period = [pd.to_datetime(e) for e in post_period]

        expected = {
            "data": data,
            "pre_period": pre_period,
            "post_period": post_period,
            "model_args": causal_impact.params['model_args'],
            "ucm_model": None,
            "post_period_response": None,
            "alpha": causal_impact.params['alpha']
        }
        result = causal_impact._format_input(
            causal_impact.params["data"], causal_impact.params["pre_period"],
            causal_impact.params["post_period"],
            causal_impact.params["model_args"], None, None,
            causal_impact.params["alpha"])

        result_data = result["data"]
        expected_data = expected["data"]
        assert_frame_equal(result_data, expected_data)

        result_model_args = result["model_args"]
        expected_model_args = expected["model_args"]
        assert result_model_args == expected_model_args

        result_other = {
            key: result[key]
            for key in result if key not in {"model_args", "data"}
        }

        expected_other = {
            key: expected[key]
            for key in expected if key not in {"model_args", "data"}
        }
        assert result_other == expected_other
예제 #4
0
    def test_float_index_pre_period_contains_int(self):
        data = np.random.randn(200, 3)
        data = pd.DataFrame(data, columns=['y', 'x1', 'x2'])
        data = data.set_index(np.array([float(i) for i in range(200)]))
        causal_impact = CausalImpact(data, [0, 3], [3, 4], {})

        expected = {
            "data": causal_impact.params['data'],
            "pre_period": causal_impact.params['pre_period'],
            "post_period": causal_impact.params['post_period'],
            "model_args": causal_impact.params['model_args'],
            "ucm_model": None,
            "post_period_response": None,
            "alpha": causal_impact.params['alpha']
        }
        result = causal_impact._format_input(
            causal_impact.params["data"], causal_impact.params["pre_period"],
            causal_impact.params["post_period"],
            causal_impact.params["model_args"], None, None,
            causal_impact.params["alpha"])

        result_data = result["data"]
        expected_data = expected["data"]
        assert_frame_equal(result_data, expected_data)

        result_model_args = result["model_args"]
        expected_model_args = expected["model_args"]
        assert result_model_args == expected_model_args

        result_other = {
            key: result[key]
            for key in result if key not in {"model_args", "data"}
        }
        expected_other = {
            key: expected[key]
            for key in expected if key not in {"model_args", "data"}
        }
        assert result_other == expected_other
예제 #5
0
    def test_pre_period_in_conflict_w_post_period(self):
        data = pd.DataFrame(np.random.randn(20, 2), columns=['x1', 'x2'])
        causal_impact = CausalImpact(data, [0, 10], [9, 20], {})

        with pytest.raises(ValueError) as excinfo:
            causal_impact._format_input(causal_impact.params["data"],
                                        causal_impact.params["pre_period"],
                                        causal_impact.params["post_period"],
                                        causal_impact.params["model_args"],
                                        None, None,
                                        causal_impact.params["alpha"])
        assert str(excinfo.value) == (
            'post period must start at least 1 observation after the end of '
            'the pre_period')

        causal_impact = CausalImpact(data, [0, 10], [11, 9], {})
        with pytest.raises(ValueError) as excinfo:
            causal_impact._format_input(causal_impact.params["data"],
                                        causal_impact.params["pre_period"],
                                        causal_impact.params["post_period"],
                                        causal_impact.params["model_args"],
                                        None, None,
                                        causal_impact.params["alpha"])
        assert str(excinfo.value) == (
            'post_period[1] must not be earlier than post_period[0]')

        causal_impact = CausalImpact(data, [0, 10], [11, 9], {})
        with pytest.raises(ValueError) as excinfo:
            causal_impact._format_input(causal_impact.params["data"],
                                        causal_impact.params["pre_period"],
                                        causal_impact.params["post_period"],
                                        causal_impact.params["model_args"],
                                        None, None,
                                        causal_impact.params["alpha"])
        assert str(excinfo.value) == (
            'post_period[1] must not be earlier than post_period[0]')