Пример #1
0
    def test_binarizer(self):
        b0 = Binarizer()
        self.assertListEqual(b0.params, [
            b0.inputCol, b0.inputCols, b0.outputCol, b0.outputCols,
            b0.threshold, b0.thresholds
        ])
        self.assertTrue(all([~b0.isSet(p) for p in b0.params]))
        self.assertTrue(b0.hasDefault(b0.threshold))
        self.assertEqual(b0.getThreshold(), 0.0)
        b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0)
        self.assertTrue(not all([b0.isSet(p) for p in b0.params]))
        self.assertEqual(b0.getThreshold(), 1.0)
        self.assertEqual(b0.getInputCol(), "input")
        self.assertEqual(b0.getOutputCol(), "output")

        b0c = b0.copy({b0.threshold: 2.0})
        self.assertEqual(b0c.uid, b0.uid)
        self.assertListEqual(b0c.params, b0.params)
        self.assertEqual(b0c.getThreshold(), 2.0)

        b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output")
        self.assertNotEqual(b1.uid, b0.uid)
        self.assertEqual(b1.getThreshold(), 2.0)
        self.assertEqual(b1.getInputCol(), "input")
        self.assertEqual(b1.getOutputCol(), "output")
Пример #2
0
    def test_binarizer(self):
        b0 = Binarizer()
        self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold])
        self.assertTrue(all([~b0.isSet(p) for p in b0.params]))
        self.assertTrue(b0.hasDefault(b0.threshold))
        self.assertEqual(b0.getThreshold(), 0.0)
        b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0)
        self.assertTrue(all([b0.isSet(p) for p in b0.params]))
        self.assertEqual(b0.getThreshold(), 1.0)
        self.assertEqual(b0.getInputCol(), "input")
        self.assertEqual(b0.getOutputCol(), "output")

        b0c = b0.copy({b0.threshold: 2.0})
        self.assertEqual(b0c.uid, b0.uid)
        self.assertListEqual(b0c.params, b0.params)
        self.assertEqual(b0c.getThreshold(), 2.0)

        b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output")
        self.assertNotEqual(b1.uid, b0.uid)
        self.assertEqual(b1.getThreshold(), 2.0)
        self.assertEqual(b1.getInputCol(), "input")
        self.assertEqual(b1.getOutputCol(), "output")
binarizer = Binarizer(threshold=0.97,
                      inputCol="review_scores_rating",
                      outputCol="high_rating")
transformedBinnedDF = binarizer.transform(airbnbDF)

display(transformedBinnedDF)

# COMMAND ----------

# TEST - Run this cell to test your solution
from pyspark.ml.feature import Binarizer

dbTest("ML1-P-05-01-01", True, type(binarizer) == type(Binarizer()))
dbTest("ML1-P-05-01-02", True,
       binarizer.getInputCol() == 'review_scores_rating')
dbTest("ML1-P-05-01-03", True, binarizer.getOutputCol() == 'high_rating')
dbTest("ML1-P-05-01-04", True, "high_rating" in transformedBinnedDF.columns)

print("Tests passed!")

# COMMAND ----------

# MAGIC %md-sandbox
# MAGIC ### Step 2: Regular Expressions on Strings
# MAGIC
# MAGIC Clean the column `price` by creating two new columns:<br><br>
# MAGIC
# MAGIC 1. `price`: a new column that contains a cleaned version of price.  This can be done using the regular expression replacement of `"[\$,]"` with `""`.  Cast the column as a decimal.
# MAGIC 2. `raw_price`: the collumn `price` in its current form
# MAGIC
# MAGIC <img alt="Hint" title="Hint" style="vertical-align: text-bottom; position: relative; height:1.75em; top:0.3em" src="https://files.training.databricks.com/static/images/icon-light-bulb.svg"/>&nbsp;**Hint:** See the <a href="http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=regexp_replace#pyspark.sql.functions.regexp_replace" target="_blank">`regex_replace` Docs</a> for more details.
Пример #4
0
# pyspark.ml.feature module

#
from pyspark.ml.feature import Binarizer
df = sparksession.createDataFrame([(0.5,)], ["values"])
df.collect()
binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features")
df2 = binarizer.transform(df)
df2.dtypes
df.collect()
df2.collect()
binarizer.getOutputCol()

rawData.take(1)
binarizer2 = Binarizer(threshold=0.5, inputCol="srv_diff_host_rate", outputCol="features")
binarizer2.transform(rawData)

binarizer.explainParam('inputCol')
binarizer.inputCol
binarizer.params

rawData.select(['count']).show()


rawData.dtypes
from pyspark.ml.feature import StringIndexer
stringIndexer = StringIndexer(inputCol="y_label", outputCol='indexed_y_label')
model = stringIndexer.fit(rawData)
td = model.transform(rawData)
td.dtypes