def _create_placeholders(self): import tensorflow as tf if not self.hard_code_batch_size: tensors = nest.pack_sequence_as(self.tensor_structure, [ tf.placeholder( name=t.name, dtype=t.dtype, shape=[None] + list(t.shape)) for t in nest.flatten(self.tensor_structure) ]) else: if self.batch_per_thread > 0: tensors = nest.pack_sequence_as(self.tensor_structure, [ tf.placeholder( name=t.name, dtype=t.dtype, shape=[self.batch_per_thread] + list(t.shape)) for t in nest.flatten(self.tensor_structure) ]) else: tensors = nest.pack_sequence_as(self.tensor_structure, [ tf.placeholder( name=t.name, dtype=t.dtype, shape=[self.batch_size // self.total_core_num] + list(t.shape)) for t in nest.flatten(self.tensor_structure) ]) for tensor in nest.flatten(tensors): tf.get_default_graph().clear_collection(tensor.name) tf.add_to_collection(tensor.name, self) return tensors
def _tensors_to_rdd(tensors, sc, splits): import tensorflow as tf if isinstance(tensors, np.ndarray): tensors = (tensors,) if isinstance(tensors, list): for i in range(len(tensors)): if tensors[i].dtype == np.dtype("float64"): tensors[i] = np.float32(tensors[i]) data_list = _splits(tensors) rdd = sc.parallelize(data_list, splits) tensor_structure = [TensorMeta(tf.as_dtype(t.dtype), shape=t.shape[1:], name="input_%s" % i) for i, t in enumerate(tensors)] else: flattened = nest.flatten(tensors) for i in range(len(flattened)): if flattened[i].dtype == np.dtype("float64"): flattened[i] = np.float32(flattened[i]) data_list = _splits(flattened) rdd = sc.parallelize(data_list, splits) rdd = rdd.map(lambda x: nest.pack_sequence_as(tensors, x)) tensor_structure = nest.pack_sequence_as(tensors, [TensorMeta(tf.as_dtype(t.dtype), shape=t.shape[1:], name="input_%s" % i) for i, t in enumerate(flattened)]) return rdd, tensor_structure
def __init__(self, rdd, tensor_structure, batch_size, batch_per_thread, hard_code_batch_size=False, val_rdd=None): ''' TFDatasets represents a distributed collection of elements to be feed into Tensorflow graph. TFDatasets can be created using a RDD and each of its records is one or more numpy.ndarray of the same nested structure, representing the tensors to be feed into TensorFlow graph on each iteration. TFDatasets must be used with TFOptimizer or TFPredictor. ''' if batch_size > 0 and batch_per_thread > 0: raise ValueError("bath_size and batch_per_thread should not be set simultaneously") self.has_batch = True node_num, core_num = get_node_and_core_number() self.total_core_num = node_num * core_num if batch_size > 0: if batch_size % self.total_core_num != 0: raise ValueError("batch_size should be a multiple " + "of total core number, but got batch_size: " + "%s where total core number is %s" % (batch_size, self.total_core_num)) if batch_size <= 0 and batch_per_thread <= 0: batch_per_thread = 1 batch_size = self.total_core_num self.has_batch = False self.batch_size = batch_size self.batch_per_thread = batch_per_thread self.hard_code_batch_size = hard_code_batch_size self.tensor_structure = tensor_structure self.val_rdd = val_rdd if not self.hard_code_batch_size: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[None] + list(t.shape) if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: if self.batch_per_thread > 0: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_per_thread] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_size // self.total_core_num] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) self.rdd = rdd self.input_names = nest.pack_sequence_as( self.tensor_structure, [t.name if t is not None else None for t in nest.flatten(self.tensor_structure)]) self._tensors = None
def __init__(self, file_path, parse_fn, batch_size, batch_per_thread, hard_code_batch_size=False, validation_file_path=None): import tensorflow as tf g = tf.Graph() with g.as_default(): serialized_example = tf.placeholder(dtype=tf.string, shape=[]) results = parse_fn(serialized_example) flattened = nest.flatten(results) output_names = [tf.cast(t, dtype=tf.float32).name for t in flattened] serialized_graph = bytearray(g.as_graph_def().SerializeToString()) sc = getOrCreateSparkContext() train_rdd = callBigDlFunc("float", "createRDDFromTFRecords", file_path, sc, serialized_graph, serialized_example.name, output_names) validation_rdd = None if validation_file_path is not None: validation_rdd = callBigDlFunc("float", "createRDDFromTFRecords", validation_file_path, sc, serialized_graph, serialized_example.name, output_names) tensor_structure = nest.pack_sequence_as(results, [TensorMeta(tf.as_dtype(t.dtype), shape=t.shape, name="data_%s" % i) for i, t in enumerate(nest.flatten(results))]) super(TFRecordDataset, self).__init__(tensor_structure, batch_size, batch_per_thread, hard_code_batch_size) self.train_rdd = train_rdd self.validation_rdd = validation_rdd
def partition(data, num_shards=None): """ Partition local in memory data and form a SparkXShards :param data: np.ndarray, a tuple, list, dict of np.ndarray, or a nested structure made of tuple, list, dict with ndarray as the leaf value :param num_shards: the number of shards that the data will be partitioned into :return: a SparkXShards """ sc = init_nncontext() node_num, core_num = get_node_and_core_number() shard_num = node_num * core_num if num_shards is None else num_shards import numpy as np type_err_msg = """ The types supported in zoo.orca.data.XShards.partition are 1. np.ndarray 2. a tuple, list, dict of np.ndarray 3. nested structure made of tuple, list, dict with ndarray as the leaf value But got data of type {} """.format(type(data)) supported_types = {list, tuple, dict} if isinstance(data, np.ndarray): if data.shape[0] < shard_num: raise ValueError( "The length of data {} is smaller than the total number " "of shards {}. Please adjust the num_shards option to be " "at most {}.".format(data.shape[0], shard_num, data.shape[0])) arrays = np.array_split(data, shard_num) rdd = sc.parallelize(arrays) else: assert type(data) in supported_types, type_err_msg flattened = nest.flatten(data) data_length = len(flattened[0]) data_to_be_shard = [] if data_length < shard_num: raise ValueError( "The length of data {} is smaller than the total number " "of shards {}. Please adjust the num_shards option to be " "at most {}.".format(data_length, shard_num, data_length)) for i in range(shard_num): data_to_be_shard.append([]) for x in flattened: assert len(x) == data_length, \ "the ndarrays in data must all have the same size in first dimension, " \ "got first ndarray of size {} and another {}".format(data_length, len(x)) x_parts = np.array_split(x, shard_num) for idx, x_part in enumerate(x_parts): data_to_be_shard[idx].append(x_part) data_to_be_shard = [ nest.pack_sequence_as(data, shard) for shard in data_to_be_shard ] rdd = sc.parallelize(data_to_be_shard) data_shards = SparkXShards(rdd) return data_shards
def _tensors_to_rdd(tensors, sc, splits): import tensorflow as tf if isinstance(tensors, list): data_list = _splits(tensors) rdd = sc.parallelize(data_list, splits) tensor_structure = [TensorMeta(tf.as_dtype(t.dtype), shape=t.shape[1:], name="input_%s" % i) for i, t in enumerate(tensors)] else: flattened = nest.flatten(tensors) data_list = _splits(flattened) rdd = sc.parallelize(data_list, splits) rdd = rdd.map(lambda x: nest.pack_sequence_as(tensors, x)) tensor_structure = nest.pack_sequence_as(tensors, [TensorMeta(tf.as_dtype(t.dtype), shape=t.shape[1:], name="input_%s" % i) for i, t in enumerate(flattened)]) return rdd, tensor_structure
def to_dataset(iter): data_list = list(iter) import tensorflow as tf if not data_list: return [] datasets = [create_dataset_fn(data) for data in data_list] from functools import reduce dataset = reduce(lambda x, y: x.concatenate(y), datasets) dataset = dataset.batch(batch_per_shard, drop_remainder) iterator = dataset.make_initializable_iterator() train_next_ops = nest.flatten(iterator.get_next()) output_types = [ t.as_datatype_enum for t in nest.flatten(dataset.output_types) ] init_op_name = iterator.initializer.name table_init_op = tf.tables_initializer().name output_names = [op.name for op in train_next_ops] graph = train_next_ops[0].graph flatten_shapes = nest.flatten(dataset.output_shapes) flatten_shapes = [shape[1:] for shape in flatten_shapes] flatten_tensor_structure = [ TensorMeta(dtype=output_types[i], shape=list(flatten_shapes[i]), name="zoo_input_{}".format(i)) for i in range(len(flatten_shapes)) ] structure = dataset.output_types if isinstance(structure, tf.DType): structure = (structure, ) tensor_structure = nest.pack_sequence_as(structure, flatten_tensor_structure) meta_info = { "init_op_name": init_op_name, "table_init_op": table_init_op, "output_names": output_names, "output_types": output_types, "tensor_structure": tensor_structure } return [(bytearray(graph.as_graph_def().SerializeToString()), meta_info)]
def partition(data): """ Partition local in memory data and form a SparkXShards :param data: np.ndarray, a tuple, list, dict of np.ndarray, or a nested structure made of tuple, list, dict with ndarray as the leaf value :return: a SparkXShards """ sc = init_nncontext() node_num, core_num = get_node_and_core_number() total_core_num = node_num * core_num import numpy as np type_err_msg = """ The types supported in zoo.orca.data.XShards.partition are 1. np.ndarray 2. a tuple, list, dict of np.ndarray 3. nested structure made of tuple, list, dict with ndarray as the leaf value But got data of type {} """.format(type(data)) supported_types = {list, tuple, dict} if isinstance(data, np.ndarray): arrays = np.array_split(data, total_core_num) rdd = sc.parallelize(arrays) else: assert type(data) in supported_types, type_err_msg flattened = nest.flatten(data) data_length = len(flattened[0]) data_to_be_shard = [] for i in range(total_core_num): data_to_be_shard.append([]) for x in flattened: assert len(x) == data_length, \ "the ndarrays in data must all have the same size in first dimension, " \ "got first ndarray of size {} and another {}".format(data_length, len(x)) x_parts = np.array_split(x, total_core_num) for idx, x_part in enumerate(x_parts): data_to_be_shard[idx].append(x_part) data_to_be_shard = [ nest.pack_sequence_as(data, shard) for shard in data_to_be_shard ] rdd = sc.parallelize(data_to_be_shard) data_shards = SparkXShards(rdd) return data_shards
def __init__(self, tensor_structure, batch_size, batch_per_thread, hard_code_batch_size=False): ''' TFDataset represents a distributed collection of elements (backed by a RDD) to be feed into Tensorflow graph. :param tensor_structure: a nested structure of TensorMeta objects specifying the name, shape and data type of each element in this TFDataset :param batch_size: the batch size, used for training, should be a multiple of total core num :param batch_per_thread: the batch size for each thread, used for inference or evaluation :param hard_code_batch_size: whether to hard code the batch_size into tensorflow graph, if True, the static size of the first dimension of the resulting tensors is batch_size/total_core_num (training) or batch_per_thread for inference; if False, it is None. ''' if batch_size > 0 and batch_per_thread > 0: raise ValueError("bath_size and batch_per_thread should not be set simultaneously") self.has_batch = True node_num, core_num = get_node_and_core_number() self.total_core_num = node_num * core_num if batch_size > 0: if batch_size % self.total_core_num != 0: raise ValueError("batch_size should be a multiple " + "of total core number, but got batch_size: " + "%s where total core number is %s" % (batch_size, self.total_core_num)) if batch_size <= 0 and batch_per_thread <= 0: batch_per_thread = 1 batch_size = self.total_core_num self.has_batch = False self.batch_size = batch_size self.batch_per_thread = batch_per_thread self.hard_code_batch_size = hard_code_batch_size self.tensor_structure = tensor_structure if not self.hard_code_batch_size: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[None] + list(t.shape) if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: if self.batch_per_thread > 0: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_per_thread] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_size // self.total_core_num] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) self.input_names = nest.pack_sequence_as( self.tensor_structure, [t.name if t is not None else None for t in nest.flatten(self.tensor_structure)]) self._tensors = None
def predict(self, data, **kwargs): def predict_transform(dict_data, batch_size): assert isinstance(dict_data, dict), "each shard should be an dict" assert "x" in dict_data, "key x should in each shard" feature_data = dict_data["x"] if isinstance(feature_data, np.ndarray): assert feature_data.shape[1] <= batch_size, \ "The batch size of input data (the second dim) should be less than the model " \ "batch size, otherwise some inputs will be ignored." elif isinstance(feature_data, list): for elem in feature_data: assert isinstance(elem, np.ndarray), "Each element in the x list should be " \ "a ndarray, but get " + \ elem.__class__.__name__ assert elem.shape[1] <= batch_size, "The batch size of each input data (the " \ "second dim) should be less than the " \ "model batch size, otherwise some inputs " \ "will be ignored." else: raise ValueError( "x in each shard should be a ndarray or a list of ndarray." ) return dict_data["x"] sc = init_nncontext() if isinstance(data, SparkXShards): assert sc is not None, "You should pass sc(spark context) if data is a XShards." from zoo.orca.learn.utils import convert_predict_to_xshard data = data.transform_shard(predict_transform, self.batch_size) result_rdd = self.model.distributed_predict(data.rdd, sc) return convert_predict_to_xshard(result_rdd) elif isinstance(data, (np.ndarray, list)): total_core_num = self.core_num * self.node_num if isinstance(data, np.ndarray): assert data.shape[1] <= self.batch_size, "The batch size of input data (the " \ "second dim) should be less than the " \ "model batch size, otherwise some " \ "inputs will be ignored." split_num = min(total_core_num, data.shape[0]) arrays = np.array_split(data, split_num) data_rdd = sc.parallelize(arrays, numSlices=split_num) elif isinstance(data, list): flattened = nest.flatten(data) data_length = len(flattened[0]) data_to_be_rdd = [] split_num = min(total_core_num, flattened[0].shape[0]) for i in range(split_num): data_to_be_rdd.append([]) for x in flattened: assert isinstance(x, np.ndarray), "the data in the data list should be " \ "ndarrays, but get " + \ x.__class__.__name__ assert len(x) == data_length, \ "the ndarrays in data must all have the same size in first dimension" \ ", got first ndarray of size {} and another {}".format(data_length, len(x)) assert x.shape[1] <= self.batch_size, "The batch size of each input data (" \ "the second dim) should be less than " \ "the model batch size, otherwise some " \ "inputs will be ignored." x_parts = np.array_split(x, split_num) for idx, x_part in enumerate(x_parts): data_to_be_rdd[idx].append(x_part) data_to_be_rdd = [ nest.pack_sequence_as(data, shard) for shard in data_to_be_rdd ] data_rdd = sc.parallelize(data_to_be_rdd, numSlices=split_num) result_rdd = self.model.distributed_predict(data_rdd, sc) result_arr_list = result_rdd.collect() result_arr = np.concatenate(result_arr_list, axis=0) return result_arr else: raise ValueError( "Only XShards, a numpy array and a list of numpy arrays are supported " "as input data, but get " + data.__class__.__name__)
def predict(self, data, feature_cols=None): """ Predict input data :param data: data to be predicted. XShards, Spark DataFrame, numpy array and list of numpy arrays are supported. If data is XShards, each partition is a dictionary of {'x': feature}, where feature(label) is a numpy array or a list of numpy arrays. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame. Default: None. :return: predicted result. If the input data is XShards, the predict result is a XShards, each partition of the XShards is a dictionary of {'prediction': result}, where the result is a numpy array or a list of numpy arrays. If the input data is numpy arrays or list of numpy arrays, the predict result is a numpy array or a list of numpy arrays. """ from pyspark.sql import DataFrame def predict_transform(dict_data, batch_size): assert isinstance(dict_data, dict), "each shard should be an dict" assert "x" in dict_data, "key x should in each shard" feature_data = dict_data["x"] if isinstance(feature_data, np.ndarray): assert feature_data.shape[0] <= batch_size, \ "The batch size of input data (the second dim) should be less than the model " \ "batch size, otherwise some inputs will be ignored." elif isinstance(feature_data, list): for elem in feature_data: assert isinstance(elem, np.ndarray), "Each element in the x list should be " \ "a ndarray, but get " + \ elem.__class__.__name__ assert elem.shape[0] <= batch_size, "The batch size of each input data (the " \ "second dim) should be less than the " \ "model batch size, otherwise some inputs " \ "will be ignored." else: raise ValueError( "x in each shard should be a ndarray or a list of ndarray." ) return feature_data sc = init_nncontext() if isinstance(data, DataFrame): from zoo.orca.learn.utils import dataframe_to_xshards, convert_predict_rdd_to_dataframe xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict") transformed_data = xshards.transform_shard(predict_transform, self.batch_size) result_rdd = self.model.distributed_predict( transformed_data.rdd, sc) def delete_useless_result(data): shard, y = data data_length = len(shard["x"]) return y[:data_length] result_rdd = xshards.rdd.zip(result_rdd).map(delete_useless_result) return convert_predict_rdd_to_dataframe( data, result_rdd.flatMap(lambda data: data)) elif isinstance(data, SparkXShards): transformed_data = data.transform_shard(predict_transform, self.batch_size) result_rdd = self.model.distributed_predict( transformed_data.rdd, sc) def update_shard(data): shard, y = data data_length = len(shard["x"]) shard["prediction"] = y[:data_length] return shard return SparkXShards(data.rdd.zip(result_rdd).map(update_shard)) elif isinstance(data, (np.ndarray, list)): if isinstance(data, np.ndarray): split_num = math.ceil(len(data) / self.batch_size) arrays = np.array_split(data, split_num) data_length_list = list(map(lambda arr: len(arr), arrays)) data_rdd = sc.parallelize(arrays, numSlices=split_num) elif isinstance(data, list): flattened = nest.flatten(data) data_length = len(flattened[0]) data_to_be_rdd = [] split_num = math.ceil(flattened[0].shape[0] / self.batch_size) for i in range(split_num): data_to_be_rdd.append([]) for x in flattened: assert isinstance(x, np.ndarray), "the data in the data list should be " \ "ndarrays, but get " + \ x.__class__.__name__ assert len(x) == data_length, \ "the ndarrays in data must all have the same size in first dimension" \ ", got first ndarray of size {} and another {}".format(data_length, len(x)) x_parts = np.array_split(x, split_num) for idx, x_part in enumerate(x_parts): data_to_be_rdd[idx].append(x_part) data_length_list = list( map(lambda arr: len(arr), x_part)) data_to_be_rdd = [ nest.pack_sequence_as(data, shard) for shard in data_to_be_rdd ] data_rdd = sc.parallelize(data_to_be_rdd, numSlices=split_num) result_rdd = self.model.distributed_predict(data_rdd, sc) result_arr_list = result_rdd.collect() for i in range(0, len(result_arr_list)): result_arr_list[i] = result_arr_list[i][:data_length_list[i]] result_arr = np.concatenate(result_arr_list, axis=0) return result_arr else: raise ValueError( "Only XShards, Spark DataFrame, a numpy array and a list of numpy arr" "ays are supported as input data, but get " + data.__class__.__name__)
def predict(self, data, feature_cols=None, batch_size=4): """ Predict input data :param batch_size: Int. Set batch Size, default is 4. :param data: data to be predicted. XShards, Spark DataFrame, numpy array and list of numpy arrays are supported. If data is XShards, each partition is a dictionary of {'x': feature}, where feature(label) is a numpy array or a list of numpy arrays. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame. Default: None. :return: predicted result. If the input data is XShards, the predict result is a XShards, each partition of the XShards is a dictionary of {'prediction': result}, where the result is a numpy array or a list of numpy arrays. If the input data is numpy arrays or list of numpy arrays, the predict result is a numpy array or a list of numpy arrays. """ sc = init_nncontext() model_bytes_broadcast = sc.broadcast(self.model_bytes) weight_bytes_broadcast = sc.broadcast(self.weight_bytes) def partition_inference(partition): model_bytes = model_bytes_broadcast.value weight_bytes = weight_bytes_broadcast.value partition = list(partition) data_num = len(partition) ie = IECore() config = {'CPU_THREADS_NUM': str(self.core_num)} ie.set_config(config, 'CPU') net = ie.read_network(model=model_bytes, weights=weight_bytes, init_from_buffer=True) net.batch_size = batch_size local_model = ie.load_network(network=net, device_name="CPU", num_requests=data_num) inputs = list(iter(local_model.requests[0].input_blobs)) outputs = list(iter(local_model.requests[0].output_blobs)) assert len( outputs) != 0, "The number of model outputs should not be 0." def add_elem(d): d_len = len(d) if d_len < batch_size: rep_time = [1] * (d_len - 1) rep_time.append(batch_size - d_len + 1) return np.repeat(d, rep_time, axis=0), d_len else: return d, d_len results = [] for idx, batch_data in enumerate(partition): infer_request = local_model.requests[idx] input_dict = dict() elem_num = 0 if isinstance(batch_data, list): for i, input in enumerate(inputs): input_dict[input], elem_num = add_elem(batch_data[i]) else: input_dict[inputs[0]], elem_num = add_elem(batch_data) infer_request.infer(input_dict) if len(outputs) == 1: results.append(infer_request.output_blobs[ outputs[0]].buffer[:elem_num]) else: results.append( list( map( lambda output: infer_request.output_blobs[ output].buffer[:elem_num], outputs))) return results def predict_transform(dict_data, batch_size): assert isinstance(dict_data, dict), "each shard should be an dict" assert "x" in dict_data, "key x should in each shard" feature_data = dict_data["x"] if isinstance(feature_data, np.ndarray): assert feature_data.shape[0] <= batch_size, \ "The batch size of input data (the second dim) should be less than the model " \ "batch size, otherwise some inputs will be ignored." elif isinstance(feature_data, list): for elem in feature_data: assert isinstance(elem, np.ndarray), "Each element in the x list should be " \ "a ndarray, but get " + \ elem.__class__.__name__ assert elem.shape[0] <= batch_size, "The batch size of each input data (the " \ "second dim) should be less than the " \ "model batch size, otherwise some inputs " \ "will be ignored." else: raise ValueError( "x in each shard should be a ndarray or a list of ndarray." ) return feature_data if isinstance(data, DataFrame): from zoo.orca.learn.utils import dataframe_to_xshards, convert_predict_rdd_to_dataframe xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict") transformed_data = xshards.transform_shard(predict_transform, batch_size) result_rdd = transformed_data.rdd.mapPartitions( lambda iter: partition_inference(iter)) return convert_predict_rdd_to_dataframe( data, result_rdd.flatMap(lambda data: data)) elif isinstance(data, SparkXShards): transformed_data = data.transform_shard(predict_transform, batch_size) result_rdd = transformed_data.rdd.mapPartitions( lambda iter: partition_inference(iter)) def update_result_shard(data): shard, y = data shard["prediction"] = y return shard return SparkXShards( data.rdd.zip(result_rdd).map(update_result_shard)) elif isinstance(data, (np.ndarray, list)): if isinstance(data, np.ndarray): split_num = math.ceil(len(data) / batch_size) arrays = np.array_split(data, split_num) num_slices = min(split_num, self.node_num) data_rdd = sc.parallelize(arrays, numSlices=num_slices) elif isinstance(data, list): flattened = nest.flatten(data) data_length = len(flattened[0]) data_to_be_rdd = [] split_num = math.ceil(flattened[0].shape[0] / batch_size) num_slices = min(split_num, self.node_num) for i in range(split_num): data_to_be_rdd.append([]) for x in flattened: assert isinstance(x, np.ndarray), "the data in the data list should be " \ "ndarrays, but get " + \ x.__class__.__name__ assert len(x) == data_length, \ "the ndarrays in data must all have the same size in first dimension" \ ", got first ndarray of size {} and another {}".format(data_length, len(x)) x_parts = np.array_split(x, split_num) for idx, x_part in enumerate(x_parts): data_to_be_rdd[idx].append(x_part) data_to_be_rdd = [ nest.pack_sequence_as(data, shard) for shard in data_to_be_rdd ] data_rdd = sc.parallelize(data_to_be_rdd, numSlices=num_slices) print("Partition number: ", data_rdd.getNumPartitions()) result_rdd = data_rdd.mapPartitions( lambda iter: partition_inference(iter)) result_arr_list = result_rdd.collect() result_arr = None if isinstance(result_arr_list[0], list): result_arr = [ np.concatenate([r[i] for r in result_arr_list], axis=0) for i in range(len(result_arr_list[0])) ] elif isinstance(result_arr_list[0], np.ndarray): result_arr = np.concatenate(result_arr_list, axis=0) return result_arr else: raise ValueError( "Only XShards, Spark DataFrame, a numpy array and a list of numpy arr" "ays are supported as input data, but get " + data.__class__.__name__)