コード例 #1
0
    def test__evaluate_pipeline_test_split_none(self, evaluate_signal_mock):
        test_split = None
        detrend = False

        signals = [self.signal]

        score = self.set_score(1, ANY, ANY)
        evaluate_signal_mock.return_value = score

        returned = benchmark._evaluate_pipeline(self.pipeline, self.name,
                                                self.dataset, signals,
                                                self.hyper, self.metrics,
                                                self.distributed, test_split,
                                                detrend)

        expected_return = [
            self.set_score(1, ANY, True),
            self.set_score(1, ANY, False)
        ]
        assert returned == expected_return

        expected_calls = [
            call(self.pipeline, self.name, self.dataset, self.signal,
                 self.hyper, self.metrics, True, detrend),
            call(self.pipeline, self.name, self.dataset, self.signal,
                 self.hyper, self.metrics, False, detrend)
        ]
        assert evaluate_signal_mock.call_args_list == expected_calls
コード例 #2
0
    def test__evaluate_pipeline_no_test_split(self, evaluate_signal_mock):
        test_split = False
        detrend = False

        signals = [self.signal]

        score = self.set_score(1, ANY, test_split)
        evaluate_signal_mock.return_value = score

        expected_return = [score]
        returned = benchmark._evaluate_pipeline(self.pipeline, self.name,
                                                self.dataset, signals,
                                                self.hyper, self.metrics,
                                                self.distributed, test_split,
                                                detrend)

        assert returned == expected_return

        evaluate_signal_mock.assert_called_once_with(self.pipeline, self.name,
                                                     self.dataset, self.signal,
                                                     self.hyper, self.metrics,
                                                     test_split, detrend)