def test_transform_multi_class(self): # set dim as 2, to mock a multi class model. model = create_xor_model(output_dim=2) with spark_session('test_transform_multi_class') as spark: df = create_xor_data(spark) metadata = util._get_metadata(df) torch_model = hvd_spark.TorchModel(history=None, model=model, input_shapes=[[2]], feature_columns=['features'], label_columns=['y'], _metadata=metadata) out_df = torch_model.transform(df) # in multi class model, model output is a vector but label is number. expected_types = { 'x1': IntegerType, 'x2': IntegerType, 'features': VectorUDT, 'weight': FloatType, 'y': FloatType, 'y__output': VectorUDT } for field in out_df.schema.fields: assert type(field.dataType) == expected_types[field.name]
def test_transform_multi_class(self): model = create_xor_model(output_dim=2) with spark_session('test_transform_multi_class') as spark: df = create_xor_data(spark) metadata = util._get_metadata(df) torch_model = hvd_spark.TorchModel(history=None, model=model, input_shapes=[[2]], feature_columns=['features'], label_columns=['y'], _metadata=metadata) out_df = torch_model.transform(df) expected_types = { 'x1': LongType, 'x2': LongType, 'features': VectorUDT, 'weight': DoubleType, 'y': DoubleType, 'y__output': VectorUDT } for field in out_df.schema.fields: assert type(field.dataType) == expected_types[field.name]