def _check_metadata_compatibility(self, metadata): input_shapes, output_shapes = self.get_model_shapes() util.check_shape_compatibility(metadata, self.getFeatureCols(), self.getLabelCols(), input_shapes=input_shapes, output_shapes=output_shapes)
def test_check_shape_compatibility(self): feature_columns = ['x1', 'x2', 'features'] label_columns = ['y1', 'y_embedding'] schema = StructType([ StructField('x1', DoubleType()), StructField('x2', IntegerType()), StructField('features', VectorUDT()), StructField('y1', FloatType()), StructField('y_embedding', VectorUDT()) ]) data = [[ 1.0, 1, DenseVector([1.0] * 12), 1.0, DenseVector([1.0] * 12) ]] * 10 with spark_session('test_df_cache') as spark: df = create_test_data_from_schema(spark, data, schema) metadata = util._get_metadata(df) input_shapes = [[1], [1], [-1, 3, 4]] output_shapes = [[1], [-1, 3, 4]] util.check_shape_compatibility(metadata, feature_columns, label_columns, input_shapes, output_shapes) input_shapes = [[1], [1], [3, 2, 2]] output_shapes = [[1, 1], [-1, 2, 3, 2]] util.check_shape_compatibility(metadata, feature_columns, label_columns, input_shapes, output_shapes) bad_input_shapes = [[1], [1], [-1, 3, 5]] with pytest.raises(ValueError): util.check_shape_compatibility(metadata, feature_columns, label_columns, bad_input_shapes, output_shapes) bad_input_shapes = [[2], [1], [-1, 3, 4]] with pytest.raises(ValueError): util.check_shape_compatibility(metadata, feature_columns, label_columns, bad_input_shapes, output_shapes) bad_output_shapes = [[7], [-1, 3, 4]] with pytest.raises(ValueError): util.check_shape_compatibility(metadata, feature_columns, label_columns, input_shapes, bad_output_shapes)
def _check_metadata_compatibility(self, metadata): util.check_shape_compatibility(metadata, self.getFeatureCols(), self.getLabelCols(), input_shapes=self.getInputShapes(), label_shapes=self.getLabelShapes())