def _get_estimator(self, model): """Create a :py:obj:`KerasImageFileEstimator` from an existing Keras model""" _random_filename_suffix = str(uuid.uuid4()) model_filename = os.path.join(self.temp_dir, 'model-{}.h5'.format(_random_filename_suffix)) model.save(model_filename) estm = KerasImageFileEstimator(inputCol=self.input_col, outputCol=self.output_col, labelCol=self.one_hot_col, imageLoader=_load_image_from_uri, kerasOptimizer='adam', kerasLoss='categorical_crossentropy', modelFile=model_filename) return estm
.appName('ImageFeatureSelector') \ .config('spark.executor.memory', '2G') \ .config('spark.executor.cores', '2') \ .config('spark.driver.memory', '3G') \ .config('spark.driver.cores', '1') \ .getOrCreate() train_df = spark.createDataFrame(load_train_data(imagenet_path)) pre_trained_model = InceptionV3(weights="imagenet") pre_trained_model.save('/tmp/model-full.h5') estimator = KerasImageFileEstimator(inputCol="uri", outputCol="prediction", labelCol="one_hot_label", imageLoader=load_image_from_uri, kerasOptimizer='adam', kerasLoss='categorical_crossentropy', modelFile='/tmp/model-full-tmp.h5' # local file path for model ) param_grid = (ParamGridBuilder().addGrid(estimator.kerasFitParams, [{"batch_size": 32, "verbose": 0}, {"batch_size": 64, "verbose": 0}]).build()) binary_evaluator = BinaryClassificationEvaluator(rawPredictionCol="prediction", labelCol="label") cv = CrossValidator(estimator=estimator, estimatorParamMaps=param_grid, evaluator=binary_evaluator, numFolds=2) cv_model = cv.fit(train_df) print(cv_model)
def test_validate_params(self): """Test that `KerasImageFileEstimator._validateParams` method works as expected""" kifest = KerasImageFileEstimator() # should raise an error to define required parameters # assuming at least one param without default value self.assertRaisesRegexp(ValueError, 'defined', kifest._validateParams, {}) kifest.setParams(imageLoader=_load_image_from_uri, inputCol='c1', labelCol='c2') kifest.setParams(modelFile='/path/to/file.ext') # should raise an error to define or tune parameters # assuming at least one tunable param without default value self.assertRaisesRegexp(ValueError, 'tuned', kifest._validateParams, {}) kifest.setParams(kerasOptimizer='adam', kerasLoss='mse', kerasFitParams={}) kifest.setParams(outputCol='c3', outputMode='vector') # should raise an error to not override self.assertRaisesRegexp(ValueError, 'not tuned', kifest._validateParams, {kifest.imageLoader: None}) # should pass test on supplying all parameters self.assertTrue(kifest._validateParams({}))