Example #1
0
    def test_transform_multi_class(self):
        # set dim as 2, to mock a multi class model.
        model = create_xor_model(output_dim=2)

        with spark_session('test_transform_multi_class') as spark:
            df = create_xor_data(spark)
            metadata = util._get_metadata(df)

            torch_model = hvd_spark.TorchModel(history=None,
                                               model=model,
                                               input_shapes=[[2]],
                                               feature_columns=['features'],
                                               label_columns=['y'],
                                               _metadata=metadata)
            out_df = torch_model.transform(df)

            # in multi class model, model output is a vector but label is number.
            expected_types = {
                'x1': IntegerType,
                'x2': IntegerType,
                'features': VectorUDT,
                'weight': FloatType,
                'y': FloatType,
                'y__output': VectorUDT
            }

            for field in out_df.schema.fields:
                assert type(field.dataType) == expected_types[field.name]
Example #2
0
    def test_transform_multi_class(self):
        model = create_xor_model(output_dim=2)

        with spark_session('test_transform_multi_class') as spark:
            df = create_xor_data(spark)
            metadata = util._get_metadata(df)

            torch_model = hvd_spark.TorchModel(history=None,
                                               model=model,
                                               input_shapes=[[2]],
                                               feature_columns=['features'],
                                               label_columns=['y'],
                                               _metadata=metadata)
            out_df = torch_model.transform(df)

            expected_types = {
                'x1': LongType,
                'x2': LongType,
                'features': VectorUDT,
                'weight': DoubleType,
                'y': DoubleType,
                'y__output': VectorUDT
            }

            for field in out_df.schema.fields:
                assert type(field.dataType) == expected_types[field.name]
Example #3
0
    def test_get_metadata(self):
        expected_metadata = \
            {
                'float': {
                    'spark_data_type': FloatType,
                    'is_sparse_vector_only': False,
                    'intermediate_format': constants.NOCHANGE,
                    'max_size': 1,
                    'shape': 1
                },
                'dense': {
                    'spark_data_type': DenseVector,
                    'is_sparse_vector_only': False,
                    'intermediate_format': constants.ARRAY,
                    'max_size': 2,
                    'shape': 2
                },
                'sparse': {
                    'spark_data_type': SparseVector,
                    'is_sparse_vector_only': True,
                    'intermediate_format': constants.CUSTOM_SPARSE,
                    'max_size': 1,
                    'shape': 2
                },
                'mixed': {
                    'spark_data_type': DenseVector,
                    'is_sparse_vector_only': False,
                    'intermediate_format': constants.ARRAY,
                    'max_size': 2,
                    'shape': 2
                },
            }

        with spark_session('test_get_metadata') as spark:
            data = [[
                1.0,
                DenseVector([1.0, 1.0]),
                SparseVector(2, {0: 1.0}),
                DenseVector([1.0, 1.0])
            ],
                    [
                        1.0,
                        DenseVector([1.0, 1.0]),
                        SparseVector(2, {1: 1.0}),
                        SparseVector(2, {1: 1.0})
                    ]]
            schema = StructType([
                StructField('float', FloatType()),
                StructField('dense', VectorUDT()),
                StructField('sparse', VectorUDT()),
                StructField('mixed', VectorUDT())
            ])
            df = create_test_data_from_schema(spark, data, schema)

            metadata = util._get_metadata(df)
            self.assertDictEqual(metadata, expected_metadata)
Example #4
0
    def test_check_shape_compatibility(self):
        feature_columns = ['x1', 'x2', 'features']
        label_columns = ['y1', 'y_embedding']

        schema = StructType([
            StructField('x1', DoubleType()),
            StructField('x2', IntegerType()),
            StructField('features', VectorUDT()),
            StructField('y1', FloatType()),
            StructField('y_embedding', VectorUDT())
        ])
        data = [[
            1.0, 1,
            DenseVector([1.0] * 12), 1.0,
            DenseVector([1.0] * 12)
        ]] * 10

        with spark_session('test_df_cache') as spark:
            df = create_test_data_from_schema(spark, data, schema)
            metadata = util._get_metadata(df)

            input_shapes = [[1], [1], [-1, 3, 4]]
            output_shapes = [[1], [-1, 3, 4]]
            util.check_shape_compatibility(metadata, feature_columns,
                                           label_columns, input_shapes,
                                           output_shapes)

            input_shapes = [[1], [1], [3, 2, 2]]
            output_shapes = [[1, 1], [-1, 2, 3, 2]]
            util.check_shape_compatibility(metadata, feature_columns,
                                           label_columns, input_shapes,
                                           output_shapes)

            bad_input_shapes = [[1], [1], [-1, 3, 5]]
            with pytest.raises(ValueError):
                util.check_shape_compatibility(metadata, feature_columns,
                                               label_columns, bad_input_shapes,
                                               output_shapes)

            bad_input_shapes = [[2], [1], [-1, 3, 4]]
            with pytest.raises(ValueError):
                util.check_shape_compatibility(metadata, feature_columns,
                                               label_columns, bad_input_shapes,
                                               output_shapes)

            bad_output_shapes = [[7], [-1, 3, 4]]
            with pytest.raises(ValueError):
                util.check_shape_compatibility(metadata, feature_columns,
                                               label_columns, input_shapes,
                                               bad_output_shapes)
Example #5
0
    def test_prepare_data(self):
        with spark_session('test_prepare_data') as spark:
            df = create_xor_data(spark)

            train_rows = df.count()
            schema_cols = ['features', 'y']
            metadata = util._get_metadata(df)
            assert metadata['features']['intermediate_format'] == constants.ARRAY

            to_petastorm = util.to_petastorm_fn(schema_cols, metadata)
            modified_df = df.rdd.map(to_petastorm).toDF()
            data = modified_df.collect()

            prepare_data = remote._prepare_data_fn(metadata)
            features = torch.tensor([data[i].features for i in range(train_rows)])
            features_prepared = prepare_data('features', features)
            assert np.array_equal(features_prepared, features)