コード例 #1
0
ファイル: main.py プロジェクト: pyronear/pyro-risks
def _train_pipeline(model: str, destination: str, ignore_prints: bool,
                    ignore_html: bool) -> None:
    click.echo(f"Train and save pipeline in {destination}")
    X, y = load_dataset()
    train_pipeline(
        X=X,
        y=y,
        model=model,
        destination=destination,
        ignore_prints=ignore_prints,
        ignore_html=ignore_html,
    )
コード例 #2
0
ファイル: test_main.py プロジェクト: dataJSA/pyro-risks
    def test_evaluate_pipeline(self):
        runner = CliRunner()
        pattern = "/*.joblib"
        X, y = load_dataset()

        dummy_pipeline = Pipeline(
            [("dummy_classifier", DummyClassifier(strategy="constant", constant=0))]
        )

        with tempfile.TemporaryDirectory() as destination:
            threshold = destination + "/DUMMY_threshold.json"
            train_pipeline(
                X=X,
                y=y,
                model="DUMMY",
                pipeline=dummy_pipeline,
                destination=destination,
                ignore_prints=True,
                ignore_html=True,
            )
            pipeline_path = glob.glob(destination + pattern)
            runner.invoke(
                main,
                [
                    "evaluate",
                    "--pipeline",
                    pipeline_path[0],
                    "--threshold",
                    threshold,
                    "--prefix",
                    "DUMMY",
                    "--destination",
                    destination,
                ],
            )
            files = glob.glob(destination + "/*")
            self.assertTrue(any([".png" in file for file in files]))
            self.assertTrue(any([".json" in file for file in files]))
            self.assertTrue(any([".csv" in file for file in files]))
コード例 #3
0
ファイル: test_train.py プロジェクト: pyronear/pyro-risks
    def test_train_pipeline(self):
        X, y = load_dataset()
        pattern = "/*.joblib"

        dummy_pipeline = Pipeline([("dummy_classifier",
                                    DummyClassifier(strategy="constant",
                                                    constant=0))])
        with tempfile.TemporaryDirectory() as destination:
            train_pipeline(
                X=X,
                y=y,
                model="XGBOOST",
                destination=destination,
                ignore_prints=True,
                ignore_html=True,
            )
            train_pipeline(
                X=X,
                y=y,
                model="RF",
                destination=destination,
                ignore_prints=True,
                ignore_html=True,
            )
            train_pipeline(
                X=X,
                y=y,
                model="DUMMY",
                pipeline=dummy_pipeline,
                destination=destination,
                ignore_prints=True,
                ignore_html=True,
            )
            files = glob.glob(destination + pattern)
            self.assertTrue(any(["RF" in file for file in files]))
            self.assertTrue(any(["XGBOOST" in file for file in files]))
            self.assertTrue(any(["DUMMY" in file for file in files]))