def test_dataframe_to_xshards(self): rdd = self.sc.range(0, 100) df = rdd.map(lambda x: ([float(x)] * 50, [int(np.random.randint(0, 2, size=()))])).toDF( ["feature", "label"]) num_partitions = df.rdd.getNumPartitions() # test shard_size = None shards = _dataframe_to_xshards(df, feature_cols=["feature"], label_cols=["label"]) num_shards = shards.rdd.count() assert num_shards == num_partitions from zoo.orca import OrcaContext OrcaContext._shard_size = 1 shards = _dataframe_to_xshards(df, feature_cols=["feature"], label_cols=["label"]) num_shards = shards.rdd.count() assert num_shards == df.rdd.count()
def test_convert_predict_xshards_to_dataframe_multi_output(self): rdd = self.sc.range(0, 100) df = rdd.map(lambda x: ([float(x)] * 50, [int(np.random.randint(0, 2, size=()))])).toDF( ["feature", "label"]) pred_shards = _dataframe_to_xshards( df, feature_cols=["feature"]).transform_shard( lambda x: {"prediction": [x["x"][:, :25], x["x"][:, 25:]]}) result_df = convert_predict_xshards_to_dataframe(df, pred_shards) expr = "sum(cast(feature <> flatten(prediction) as int)) as error" assert result_df.selectExpr(expr).first()["error"] == 0