コード例 #1
0
    def test_grid_gbm_in_spark_pipeline(self):
        prostate_frame = self._spark.read.csv(
            "file://" +
            unit_test_utils.locate("smalldata/prostate/prostate.csv"),
            header=True,
            inferSchema=True)

        algo = H2OGridSearch(labelCol="AGE",
                             hyperParameters={"_seed": [1, 2, 3]},
                             ratio=0.8,
                             algo=H2OGBM(),
                             strategy="RandomDiscrete",
                             maxModels=3,
                             maxRuntimeSecs=60,
                             selectBestModelBy="RMSE")

        pipeline = Pipeline(stages=[algo])
        pipeline.write().overwrite().save(
            "file://" + os.path.abspath("build/grid_gbm_pipeline"))
        loaded_pipeline = Pipeline.load(
            "file://" + os.path.abspath("build/grid_gbm_pipeline"))
        model = loaded_pipeline.fit(prostate_frame)

        model.write().overwrite().save(
            "file://" + os.path.abspath("build/grid_gbm_pipeline_model"))
        loaded_model = PipelineModel.load(
            "file://" + os.path.abspath("build/grid_gbm_pipeline_model"))

        loaded_model.transform(prostate_frame).count()
コード例 #2
0
def createGridForProblemSpecificTesting(algorithm):
    algorithm.setLabelCol("CAPSULE")
    algorithm.setSplitRatio(0.8)
    hyperParameters = {"seed": [1, 2], "ntrees": [3, 5, 10]}
    return H2OGridSearch(hyperParameters=hyperParameters,
                         seed=42,
                         algo=algorithm)
コード例 #3
0
def testGridSearchGetAlgoIsAbleToReturnAlgorithmOfVariousTypes():
    grid = H2OGridSearch(algo=H2ODRFClassifier())
    assert grid.getAlgo().__class__.__name__ == "H2ODRFClassifier"

    grid.setAlgo(H2OKMeans())
    assert grid.getAlgo().__class__.__name__ == "H2OKMeans"

    grid.setAlgo(H2OGBMRegressor())
    assert grid.getAlgo().__class__.__name__ == "H2OGBMRegressor"
コード例 #4
0
def testGetAlgoViaSetter():
    # SW-2276, 3rd call of getAlgo failed
    grid = H2OGridSearch(hyperParameters={"seed": [1, 2, 3]},
                         strategy="RandomDiscrete",
                         maxModels=3,
                         maxRuntimeSecs=60,
                         selectBestModelBy="RMSE")
    grid.setAlgo(H2OGBM().setNtrees(100).setLabelCol("AGE").setSplitRatio(0.8))
    grid.getAlgo()
    grid.getAlgo()
    assert grid.getAlgo().getNtrees() == 100
コード例 #5
0
def testGetGridModels(prostateDataset):
    grid = H2OGridSearch(hyperParameters={"seed": [1, 2, 3]},
                         algo=H2OGBM(splitRatio=0.8, labelCol="AGE"),
                         strategy="RandomDiscrete",
                         maxModels=3,
                         maxRuntimeSecs=60,
                         selectBestModelBy="RMSE")

    grid.fit(prostateDataset)
    models = grid.getGridModels()
    assert len(models) == 3
コード例 #6
0
def testGetGridModelsNoParams(prostateDataset):
    grid = H2OGridSearch(algo=H2OGBM(labelCol="AGE", splitRatio=0.8),
                         strategy="RandomDiscrete",
                         maxModels=3,
                         maxRuntimeSecs=60,
                         selectBestModelBy="RMSE")

    grid.fit(prostateDataset)
    params = grid.getGridModelsParams()
    assert params.count() == 1
    assert params.columns == ['MOJO Model ID']
    params.collect()  # try materializing
コード例 #7
0
def testGetGridModelsParams(prostateDataset):
    grid = H2OGridSearch(hyperParameters={"seed": [1, 2, 3]},
                         algo=H2OGBM(splitRatio=0.8, labelCol="AGE"),
                         strategy="RandomDiscrete",
                         maxModels=3,
                         maxRuntimeSecs=60,
                         selectBestModelBy="RMSE")

    grid.fit(prostateDataset)
    params = grid.getGridModelsParams()
    assert params.count() == 3
    assert params.columns == ['MOJO Model ID', 'seed']
    params.collect()  # try materializing
コード例 #8
0
def testGetAlgoViaConstructor():
    # SW-2276, 3rd call of getAlgo failed
    grid = H2OGridSearch(hyperParameters={"seed": [1, 2, 3]},
                         algo=H2OGBM(labelCol="AGE",
                                     ntrees=100,
                                     splitRatio=0.8),
                         strategy="RandomDiscrete",
                         maxModels=3,
                         maxRuntimeSecs=60,
                         selectBestModelBy="RMSE")
    grid.getAlgo()
    grid.getAlgo()
    assert grid.getAlgo().getNtrees() == 100
コード例 #9
0
def testGetGridModelsMetrics(prostateDataset):
    grid = H2OGridSearch(hyperParameters={"seed": [1, 2, 3]},
                         algo=H2OGBM(labelCol="AGE", splitRatio=0.8),
                         strategy="RandomDiscrete",
                         maxModels=3,
                         maxRuntimeSecs=60,
                         selectBestModelBy="RMSE")

    grid.fit(prostateDataset)
    metrics = grid.getGridModelsMetrics()
    assert metrics.count() == 3
    assert metrics.columns == [
        'MOJO Model ID', 'MSE', 'MeanResidualDeviance', 'R2', 'RMSE'
    ]
    metrics.collect()  # try materializing
コード例 #10
0
def gridSearchTester(algo, prostateDataset):
    grid = H2OGridSearch(hyperParameters={"seed": [1, 2, 3]},
                         algo=algo.setSplitRatio(0.8),
                         strategy="RandomDiscrete",
                         maxModels=3,
                         maxRuntimeSecs=60,
                         selectBestModelBy="RMSE")

    pipeline = Pipeline(stages=[grid])
    pipeline.write().overwrite().save("file://" +
                                      os.path.abspath("build/grid_pipeline"))
    loadedPipeline = Pipeline.load("file://" +
                                   os.path.abspath("build/grid_pipeline"))
    model = loadedPipeline.fit(prostateDataset)

    model.write().overwrite().save(
        "file://" + os.path.abspath("build/grid_pipeline_model"))
    loadedModel = PipelineModel.load(
        "file://" + os.path.abspath("build/grid_pipeline_model"))

    loadedModel.transform(prostateDataset).count()