Ejemplo n.º 1
0
    def test_default_read_write(self):
        temp_path = tempfile.mkdtemp()

        lr = LogisticRegression()
        lr.setMaxIter(50)
        lr.setThreshold(.75)
        writer = DefaultParamsWriter(lr)

        savePath = temp_path + "/lr"
        writer.save(savePath)

        reader = DefaultParamsReadable.read()
        lr2 = reader.load(savePath)

        self.assertEqual(lr.uid, lr2.uid)
        self.assertEqual(lr.extractParamMap(), lr2.extractParamMap())

        # test overwrite
        lr.setThreshold(.8)
        writer.overwrite().save(savePath)

        reader = DefaultParamsReadable.read()
        lr3 = reader.load(savePath)

        self.assertEqual(lr.uid, lr3.uid)
        self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
Ejemplo n.º 2
0
 def saveImpl(instance, stages, sc, path):
     """
     Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
     - save metadata to path/metadata
     - save stages to stages/IDX_UID
     """
     stageUids = [stage.uid for stage in stages]
     jsonParams = {'stageUids': stageUids, 'language': 'Python'}
     DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
     stagesDir = os.path.join(path, "stages")
     for index, stage in enumerate(stages):
         stage.write().save(PipelineSharedReadWrite
                            .getStagePath(stage.uid, index, len(stages), stagesDir))
Ejemplo n.º 3
0
 def saveImpl(self, path):
     params = self.instance.extractParamMap()
     jsonParams = {}
     for p in params:
         if isinstance(params[p], pd.DataFrame):
             jsonParams[p.name] = params[p].to_json()
         elif isinstance(params[p], dt.datetime):
             jsonParams[p.name] = str(params[p])
         else:
             jsonParams[p.name] = params[p]
     DefaultParamsWriter.saveMetadata(self.instance,
                                      path,
                                      self.sc,
                                      paramMap=jsonParams)
Ejemplo n.º 4
0
 def saveImpl(
     instance: Union[Pipeline, PipelineModel],
     stages: List["PipelineStage"],
     sc: SparkContext,
     path: str,
 ) -> None:
     """
     Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
     - save metadata to path/metadata
     - save stages to stages/IDX_UID
     """
     stageUids = [stage.uid for stage in stages]
     jsonParams = {"stageUids": stageUids, "language": "Python"}
     DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
     stagesDir = os.path.join(path, "stages")
     for index, stage in enumerate(stages):
         cast(MLWritable, stage).write().save(
             PipelineSharedReadWrite.getStagePath(stage.uid, index, len(stages), stagesDir)
         )
Ejemplo n.º 5
0
    def test_default_read_write_default_params(self):
        lr = LogisticRegression()
        self.assertFalse(lr.isSet(lr.getParam("threshold")))

        lr.setMaxIter(50)
        lr.setThreshold(0.75)

        # `threshold` is set by user, default param `predictionCol` is not set by user.
        self.assertTrue(lr.isSet(lr.getParam("threshold")))
        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))

        writer = DefaultParamsWriter(lr)
        metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
        self.assertTrue("defaultParamMap" in metadata)

        reader = DefaultParamsReadable.read()
        metadataStr = json.dumps(metadata, separators=[",", ":"])
        loadedMetadata = reader._parseMetaData(
            metadataStr,
        )
        reader.getAndSetParams(lr, loadedMetadata)

        self.assertTrue(lr.isSet(lr.getParam("threshold")))
        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))

        # manually create metadata without `defaultParamMap` section.
        del metadata["defaultParamMap"]
        metadataStr = json.dumps(metadata, separators=[",", ":"])
        loadedMetadata = reader._parseMetaData(
            metadataStr,
        )
        with self.assertRaisesRegex(AssertionError, "`defaultParamMap` section not found"):
            reader.getAndSetParams(lr, loadedMetadata)

        # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
        metadata["sparkVersion"] = "2.3.0"
        metadataStr = json.dumps(metadata, separators=[",", ":"])
        loadedMetadata = reader._parseMetaData(
            metadataStr,
        )
        reader.getAndSetParams(lr, loadedMetadata)
Ejemplo n.º 6
0
    def test_default_read_write(self):
        temp_path = tempfile.mkdtemp()

        lr = LogisticRegression()
        lr.setMaxIter(50)
        lr.setThreshold(0.75)
        writer = DefaultParamsWriter(lr)

        savePath = temp_path + "/lr"
        writer.save(savePath)

        reader = DefaultParamsReadable.read()
        lr2 = reader.load(savePath)

        self.assertEqual(lr.uid, lr2.uid)
        self.assertEqual(lr.extractParamMap(), lr2.extractParamMap())

        # test overwrite
        lr.setThreshold(0.8)
        writer.overwrite().save(savePath)

        reader = DefaultParamsReadable.read()
        lr3 = reader.load(savePath)

        self.assertEqual(lr.uid, lr3.uid)
        self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
Ejemplo n.º 7
0
    def test_default_read_write_default_params(self):
        lr = LogisticRegression()
        self.assertFalse(lr.isSet(lr.getParam("threshold")))

        lr.setMaxIter(50)
        lr.setThreshold(.75)

        # `threshold` is set by user, default param `predictionCol` is not set by user.
        self.assertTrue(lr.isSet(lr.getParam("threshold")))
        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))

        writer = DefaultParamsWriter(lr)
        metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
        self.assertTrue("defaultParamMap" in metadata)

        reader = DefaultParamsReadable.read()
        metadataStr = json.dumps(metadata, separators=[',',  ':'])
        loadedMetadata = reader._parseMetaData(metadataStr, )
        reader.getAndSetParams(lr, loadedMetadata)

        self.assertTrue(lr.isSet(lr.getParam("threshold")))
        self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
        self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))

        # manually create metadata without `defaultParamMap` section.
        del metadata['defaultParamMap']
        metadataStr = json.dumps(metadata, separators=[',',  ':'])
        loadedMetadata = reader._parseMetaData(metadataStr, )
        with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
            reader.getAndSetParams(lr, loadedMetadata)

        # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
        metadata['sparkVersion'] = '2.3.0'
        metadataStr = json.dumps(metadata, separators=[',',  ':'])
        loadedMetadata = reader._parseMetaData(metadataStr, )
        reader.getAndSetParams(lr, loadedMetadata)
Ejemplo n.º 8
0
    def saveMetadata(instance, path, sc, logger, extraMetadata=None):
        """
        Save the metadata of an xgboost.spark._SparkXGBEstimator or
        xgboost.spark._SparkXGBModel.
        """
        instance._validate_params()
        skipParams = ["callbacks", "xgb_model"]
        jsonParams = {}
        for p, v in instance._paramMap.items():  # pylint: disable=protected-access
            if p.name not in skipParams:
                jsonParams[p.name] = v

        extraMetadata = extraMetadata or {}
        callbacks = instance.getOrDefault(instance.callbacks)
        if callbacks is not None:
            logger.warning(
                "The callbacks parameter is saved using cloudpickle and it "
                "is not a fully self-contained format. It may fail to load "
                "with different versions of dependencies.")
            serialized_callbacks = base64.encodebytes(
                cloudpickle.dumps(callbacks)).decode("ascii")
            extraMetadata["serialized_callbacks"] = serialized_callbacks
        init_booster = instance.getOrDefault(instance.xgb_model)
        if init_booster is not None:
            extraMetadata["init_booster"] = _INIT_BOOSTER_SAVE_PATH
        DefaultParamsWriter.saveMetadata(instance,
                                         path,
                                         sc,
                                         extraMetadata=extraMetadata,
                                         paramMap=jsonParams)
        if init_booster is not None:
            ser_init_booster = serialize_booster(init_booster)
            save_path = os.path.join(path, _INIT_BOOSTER_SAVE_PATH)
            _get_spark_session().createDataFrame(
                [(ser_init_booster, )],
                ["init_booster"]).write.parquet(save_path)