Ejemplo n.º 1
0
 def _check_metadata_compatibility(self, metadata):
     input_shapes, output_shapes = self.get_model_shapes()
     util.check_shape_compatibility(metadata,
                                    self.getFeatureCols(),
                                    self.getLabelCols(),
                                    input_shapes=input_shapes,
                                    output_shapes=output_shapes)
Ejemplo n.º 2
0
    def test_check_shape_compatibility(self):
        feature_columns = ['x1', 'x2', 'features']
        label_columns = ['y1', 'y_embedding']

        schema = StructType([
            StructField('x1', DoubleType()),
            StructField('x2', IntegerType()),
            StructField('features', VectorUDT()),
            StructField('y1', FloatType()),
            StructField('y_embedding', VectorUDT())
        ])
        data = [[
            1.0, 1,
            DenseVector([1.0] * 12), 1.0,
            DenseVector([1.0] * 12)
        ]] * 10

        with spark_session('test_df_cache') as spark:
            df = create_test_data_from_schema(spark, data, schema)
            metadata = util._get_metadata(df)

            input_shapes = [[1], [1], [-1, 3, 4]]
            output_shapes = [[1], [-1, 3, 4]]
            util.check_shape_compatibility(metadata, feature_columns,
                                           label_columns, input_shapes,
                                           output_shapes)

            input_shapes = [[1], [1], [3, 2, 2]]
            output_shapes = [[1, 1], [-1, 2, 3, 2]]
            util.check_shape_compatibility(metadata, feature_columns,
                                           label_columns, input_shapes,
                                           output_shapes)

            bad_input_shapes = [[1], [1], [-1, 3, 5]]
            with pytest.raises(ValueError):
                util.check_shape_compatibility(metadata, feature_columns,
                                               label_columns, bad_input_shapes,
                                               output_shapes)

            bad_input_shapes = [[2], [1], [-1, 3, 4]]
            with pytest.raises(ValueError):
                util.check_shape_compatibility(metadata, feature_columns,
                                               label_columns, bad_input_shapes,
                                               output_shapes)

            bad_output_shapes = [[7], [-1, 3, 4]]
            with pytest.raises(ValueError):
                util.check_shape_compatibility(metadata, feature_columns,
                                               label_columns, input_shapes,
                                               bad_output_shapes)
Ejemplo n.º 3
0
 def _check_metadata_compatibility(self, metadata):
     util.check_shape_compatibility(metadata,
                                    self.getFeatureCols(),
                                    self.getLabelCols(),
                                    input_shapes=self.getInputShapes(),
                                    label_shapes=self.getLabelShapes())