def test_default_read_write(self): temp_path = tempfile.mkdtemp() lr = LogisticRegression() lr.setMaxIter(50) lr.setThreshold(0.75) writer = DefaultParamsWriter(lr) savePath = temp_path + "/lr" writer.save(savePath) reader = DefaultParamsReadable.read() lr2 = reader.load(savePath) self.assertEqual(lr.uid, lr2.uid) self.assertEqual(lr.extractParamMap(), lr2.extractParamMap()) # test overwrite lr.setThreshold(0.8) writer.overwrite().save(savePath) reader = DefaultParamsReadable.read() lr3 = reader.load(savePath) self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
def test_default_read_write_default_params(self): lr = LogisticRegression() self.assertFalse(lr.isSet(lr.getParam("threshold"))) lr.setMaxIter(50) lr.setThreshold(0.75) # `threshold` is set by user, default param `predictionCol` is not set by user. self.assertTrue(lr.isSet(lr.getParam("threshold"))) self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) writer = DefaultParamsWriter(lr) metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) self.assertTrue("defaultParamMap" in metadata) reader = DefaultParamsReadable.read() metadataStr = json.dumps(metadata, separators=[",", ":"]) loadedMetadata = reader._parseMetaData( metadataStr, ) reader.getAndSetParams(lr, loadedMetadata) self.assertTrue(lr.isSet(lr.getParam("threshold"))) self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) # manually create metadata without `defaultParamMap` section. del metadata["defaultParamMap"] metadataStr = json.dumps(metadata, separators=[",", ":"]) loadedMetadata = reader._parseMetaData( metadataStr, ) with self.assertRaisesRegex(AssertionError, "`defaultParamMap` section not found"): reader.getAndSetParams(lr, loadedMetadata) # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. metadata["sparkVersion"] = "2.3.0" metadataStr = json.dumps(metadata, separators=[",", ":"]) loadedMetadata = reader._parseMetaData( metadataStr, ) reader.getAndSetParams(lr, loadedMetadata)