def test_set_core_num(self): _, core_num = get_node_and_core_number() set_core_number(core_num + 1) _, new_core_num = get_node_and_core_number() assert new_core_num == core_num + 1, \ "set_core_num failed, set the core" \ " number to be {} but got {}".format(core_num + 1, new_core_num) set_core_number(core_num)
def read_parquet(file_path, ): """ Read parquet files to SparkXShards of pandas DataFrames. :param file_path: Parquet file path, a list of multiple parquet file paths, or a directory containing parquet files. Local file system, HDFS, and AWS S3 are supported. :return: An instance of SparkXShards. """ sc = init_nncontext() node_num, core_num = get_node_and_core_number() from pyspark.sql import SQLContext sqlContext = SQLContext.getOrCreate(sc) spark = sqlContext.sparkSession df = spark.read.parquet(file_path) if df.rdd.getNumPartitions() < node_num: df = df.repartition(node_num) def to_pandas(columns): def f(iter): import pandas as pd data = list(iter) pd_df = pd.DataFrame(data, columns=columns) return [pd_df] return f pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns)) try: data_shards = SparkXShards(pd_rdd) except Exception as e: print("An error occurred when reading parquet files") raise e return data_shards
def from_ndarrays(tensors, batch_size=-1, batch_per_thread=-1, hard_code_batch_size=False, val_tensors=None, sequential_order=False, shuffle=True): sc = getOrCreateSparkContext() node_num, core_num = get_node_and_core_number() total_core_num = node_num * core_num rdd, tensor_structure = _tensors_to_rdd(tensors, sc, total_core_num) val_rdd = None if val_tensors is not None: val_rdd, _ = _tensors_to_rdd(val_tensors, sc, total_core_num) return TFNdarrayDataset(rdd, tensor_structure, batch_size, batch_per_thread, hard_code_batch_size, val_rdd, sequential_order=sequential_order, shuffle=shuffle)
def __init__(self, rdd, tensor_structure, batch_size, batch_per_thread, hard_code_batch_size=False, val_rdd=None): ''' TFDatasets represents a distributed collection of elements to be feed into Tensorflow graph. TFDatasets can be created using a RDD and each of its records is one or more numpy.ndarray of the same nested structure, representing the tensors to be feed into TensorFlow graph on each iteration. TFDatasets must be used with TFOptimizer or TFPredictor. ''' import tensorflow as tf if batch_size > 0 and batch_per_thread > 0: raise ValueError("bath_size and batch_per_thread should not be set simultaneously") node_num, core_num = get_node_and_core_number() self.total_core_num = node_num * core_num if batch_size > 0: if batch_size % self.total_core_num != 0: raise ValueError("batch_size should be a multiple " + "of total core number, but got batch_size: " + "%s where total core number is %s" % (batch_size, self.total_core_num)) if batch_size <= 0 and batch_per_thread <= 0: batch_per_thread = 1 batch_size = self.total_core_num self.batch_size = batch_size self.batch_per_thread = batch_per_thread self.hard_code_batch_size = hard_code_batch_size self.tensor_structure = tensor_structure if isinstance(self.tensor_structure, list): self.tensor_structure = tuple(tensor_structure) self.val_rdd = val_rdd from tensorflow.python.util import nest if not self.hard_code_batch_size: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[None] + list(t.shape) if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: if self.batch_per_thread > 0: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_per_thread] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_size // self.total_core_num] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) self.rdd = rdd self.input_names = nest.pack_sequence_as( self.tensor_structure, [t.name if t is not None else None for t in nest.flatten(self.tensor_structure)]) self._tensors = None
def __init__(self, hosts=None, processes_per_node=1, env=None): driver_ip = get_node_ip() if hosts is None: # Single node self.hosts = [driver_ip] elif hosts == "all": # All executor nodes in the cluster def get_ip(iter): yield get_node_ip() from bigdl.util.common import get_node_and_core_number from zoo.orca import OrcaContext sc = OrcaContext.get_spark_context() node_num, core_num = get_node_and_core_number() total_cores = node_num * core_num self.hosts = list(set(sc.range(0, total_cores, numSlices=total_cores).barrier() .mapPartitions(get_ip).collect())) else: # User specified hosts, assumed to be non-duplicate assert isinstance(hosts, list) self.hosts = hosts self.master = self.hosts[0] print("Master: ", self.master) self.remote_hosts = [] for host in self.hosts: if host != driver_ip: self.remote_hosts.append(host) print("Remote hosts: ", self.remote_hosts) print("Hosts: ", self.hosts) self.processes_per_node = processes_per_node self.env = env if env else {}
def input_fn(mode): if mode == tf.estimator.ModeKeys.EVAL or mode == tf.estimator.ModeKeys.TRAIN: return TFDataset.from_rdd(rdd, features=features_dict, labels=labels, batch_size=batch_size) else: node_num, core_num = get_node_and_core_number() return TFDataset.from_rdd(rdd, features=features_dict, batch_per_thread=batch_size // (node_num * core_num))
def __init__(self, rdd, names, shapes, types, batch_size=-1, batch_pre_core=-1, hard_code_batch_size=False): if batch_size > 0 and batch_pre_core > 0: raise ValueError( "bath_size and batch_per_core should not be set simultaneously" ) node_num, core_num = get_node_and_core_number() self.total_core_num = node_num * core_num if batch_size > 0: if batch_size % self.total_core_num != 0: raise ValueError("batch_size should be a multiple " "of core_num, but got batch_size: " "%s where core_num is %s" % (batch_size, self.total_core_num)) if batch_size <= 0 and batch_pre_core <= 0: batch_pre_core = 1 batch_size = self.total_core_num self.batch_size = batch_size self.batch_pre_core = batch_pre_core if not hard_code_batch_size: self.tensors = [ tf.placeholder(name=names[i], dtype=types[i], shape=[None] + shapes[i]) for i in range(len(names)) ] else: if batch_pre_core is not None: self.tensors = [ tf.placeholder(name=names[i], dtype=types[i], shape=[batch_pre_core] + shapes[i]) for i in range(len(names)) ] else: self.tensors = [ tf.placeholder(name=names[i], dtype=types[i], shape=[batch_size / self.total_core_num] + shapes[i]) for i in range(len(names)) ] self.rdd = rdd.map(lambda arr: arr[:len(names)]) self.input_names = names for i in range(len(self.tensors)): tf.add_to_collection(self.tensors[i].name, self)
def __init__(self, learningrate=1e-3, learningrate_decay=0.0, beta1=0.9, beta2=0.999, epsilon=1e-8, parallel_num=-1, bigdl_type="float"): if parallel_num == -1: parallel_num = get_node_and_core_number()[1] super(ParallelAdam, self).__init__(None, bigdl_type, learningrate, learningrate_decay, beta1, beta2, epsilon, parallel_num)
def write(path, generator, schema, block_size=1000, write_mode="overwrite", **kwargs): """ Take each record in the generator and write it to a parquet file. **generator** Each record in the generator is a dict, the key is a string and will be the column name of saved parquet record and the value is the data. **schema** schema defines the name, dtype, shape of a column, as well as the feature type of a column. The feature type, defines how to encode and decode the column value. There are three kinds of feature type: 1. Scalar, such as a int or float number, or a string, which can be directly mapped to a parquet type 2. NDarray, which takes a np.ndarray and save it serialized bytes. The corresponding parquet type is BYTE_ARRAY . 3. Image, which takes a string representing a image file in local file system and save the raw file content bytes. The corresponding parquet type is BYTE_ARRAY. :param path: the output path, e.g. file:///output/path, hdfs:///output/path :param generator: generate a dict, whose key is a string and value is one of (a scalar value, ndarray, image file path) :param schema: a dict, whose key is a string, value is one of (schema_field.Scalar, schema_field.NDarray, schema_field.Image) :param kwargs: other args """ sc = init_nncontext() spark = SparkSession(sc) node_num, core_num = get_node_and_core_number() for i, chunk in enumerate(chunks(generator, block_size)): chunk_path = os.path.join(path, f"chunk={i}") rows_rdd = sc.parallelize(chunk, core_num * node_num)\ .map(lambda x: dict_to_row(schema, x)) spark.createDataFrame(rows_rdd).write.mode(write_mode).parquet( chunk_path) metadata_path = os.path.join(path, "_orca_metadata") write_text(metadata_path, encode_schema(schema))
def read_file_spark(context, file_path, file_type, **kwargs): file_url_splits = file_path.split("://") prefix = file_url_splits[0] node_num, core_num = get_node_and_core_number() file_paths = [] if isinstance(file_path, list): [ file_paths.extend(extract_one_path(path, file_type, os.environ)) for path in file_path ] else: file_paths = extract_one_path(file_path, file_type, os.environ) if not file_paths: raise Exception( "The file path is invalid/empty or does not include csv/json files" ) rdd = context.parallelize(file_paths, node_num * core_num) if prefix == "hdfs": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs)) elif prefix == "s3": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs)) else: def loadFile(iterator): for x in iterator: df = read_pd_file(x, file_type, **kwargs) yield df pd_rdd = rdd.mapPartitions(loadFile) data_shards = SparkXShards(pd_rdd) return data_shards
def from_ndarrays(tensors, batch_size=-1, batch_per_thread=-1, hard_code_batch_size=False, val_tensors=None): ''' Create a TFDataset from a nested structure of numpy ndarrays. Each element in the resulting TFDataset has the same structure of the argument tensors and is created by indexing on the first dimension of each ndarray in the tensors argument. This method is equivalent to sc.parallize the tensors and call TFDataset.from_rdd :return: ''' sc = getOrCreateSparkContext() node_num, core_num = get_node_and_core_number() total_core_num = node_num * core_num rdd, tensor_structure = _tensors_to_rdd(tensors, sc, total_core_num) val_rdd = None if val_tensors is not None: val_rdd, _ = _tensors_to_rdd(val_tensors, sc, total_core_num) return TFNdarrayDataset(rdd, tensor_structure, batch_size, batch_per_thread, hard_code_batch_size, val_rdd)
def read_file_spark(file_path, file_type, **kwargs): sc = init_nncontext() node_num, core_num = get_node_and_core_number() backend = OrcaContext.pandas_read_backend if backend == "pandas": file_url_splits = file_path.split("://") prefix = file_url_splits[0] file_paths = [] if isinstance(file_path, list): [ file_paths.extend(extract_one_path(path, os.environ)) for path in file_path ] else: file_paths = extract_one_path(file_path, os.environ) if not file_paths: raise Exception( "The file path is invalid or empty, please check your data") num_files = len(file_paths) total_cores = node_num * core_num num_partitions = num_files if num_files < total_cores else total_cores rdd = sc.parallelize(file_paths, num_partitions) if prefix == "hdfs": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs)) elif prefix == "s3": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs)) else: def loadFile(iterator): for x in iterator: df = read_pd_file(x, file_type, **kwargs) yield df pd_rdd = rdd.mapPartitions(loadFile) else: # Spark backend; spark.read.csv/json accepts a folder path as input assert file_type == "json" or file_type == "csv", \ "Unsupported file type: %s. Only csv and json files are supported for now" % file_type from pyspark.sql import SQLContext sqlContext = SQLContext.getOrCreate(sc) spark = sqlContext.sparkSession # TODO: add S3 confidentials # The following implementation is adapted from # https://github.com/databricks/koalas/blob/master/databricks/koalas/namespace.py # with some modifications. if "mangle_dupe_cols" in kwargs: assert kwargs[ "mangle_dupe_cols"], "mangle_dupe_cols can only be True" kwargs.pop("mangle_dupe_cols") if "parse_dates" in kwargs: assert not kwargs["parse_dates"], "parse_dates can only be False" kwargs.pop("parse_dates") names = kwargs.get("names", None) if "names" in kwargs: kwargs.pop("names") usecols = kwargs.get("usecols", None) if "usecols" in kwargs: kwargs.pop("usecols") dtype = kwargs.get("dtype", None) if "dtype" in kwargs: kwargs.pop("dtype") squeeze = kwargs.get("squeeze", False) if "squeeze" in kwargs: kwargs.pop("squeeze") index_col = kwargs.get("index_col", None) if "index_col" in kwargs: kwargs.pop("index_col") if file_type == "csv": # Handle pandas-compatible keyword arguments kwargs["inferSchema"] = True header = kwargs.get("header", "infer") if isinstance(names, str): kwargs["schema"] = names if header == "infer": header = 0 if names is None else None if header == 0: kwargs["header"] = True elif header is None: kwargs["header"] = False else: raise ValueError("Unknown header argument {}".format(header)) if "quotechar" in kwargs: quotechar = kwargs["quotechar"] kwargs.pop("quotechar") kwargs["quote"] = quotechar if "escapechar" in kwargs: escapechar = kwargs["escapechar"] kwargs.pop("escapechar") kwargs["escape"] = escapechar # sep and comment are the same as pandas if "comment" in kwargs: comment = kwargs["comment"] if not isinstance(comment, str) or len(comment) != 1: raise ValueError( "Only length-1 comment characters supported") df = spark.read.csv(file_path, **kwargs) if header is None: df = df.selectExpr(*[ "`%s` as `%s`" % (field.name, i) for i, field in enumerate(df.schema) ]) else: df = spark.read.json(file_path, **kwargs) # Handle pandas-compatible postprocessing arguments if usecols is not None and not callable(usecols): usecols = list(usecols) renamed = False if isinstance(names, list): if len(set(names)) != len(names): raise ValueError( "Found duplicate names, please check your names input") if usecols is not None: if not callable(usecols): # usecols is list if len(names) != len(usecols) and len(names) != len( df.schema): raise ValueError("Passed names did not match usecols") if len(names) == len(df.schema): df = df.selectExpr(*[ "`%s` as `%s`" % (field.name, name) for field, name in zip(df.schema, names) ]) renamed = True else: if len(names) != len(df.schema): raise ValueError( "The number of names [%s] does not match the number " "of columns [%d]. Try names by a Spark SQL DDL-formatted " "string." % (len(names), len(df.schema))) df = df.selectExpr(*[ "`%s` as `%s`" % (field.name, name) for field, name in zip(df.schema, names) ]) renamed = True index_map = dict([(i, field.name) for i, field in enumerate(df.schema)]) if usecols is not None: if callable(usecols): cols = [ field.name for field in df.schema if usecols(field.name) ] missing = [] elif all(isinstance(col, int) for col in usecols): cols = [ field.name for i, field in enumerate(df.schema) if i in usecols ] missing = [ col for col in usecols if col >= len(df.schema) or df.schema[col].name not in cols ] elif all(isinstance(col, str) for col in usecols): cols = [ field.name for field in df.schema if field.name in usecols ] if isinstance(names, list): missing = [c for c in usecols if c not in names] else: missing = [col for col in usecols if col not in cols] else: raise ValueError( "usecols must only be list-like of all strings, " "all unicode, all integers or a callable.") if len(missing) > 0: raise ValueError( "usecols do not match columns, columns expected but not found: %s" % missing) if len(cols) > 0: df = df.select(cols) if isinstance(names, list): if not renamed: df = df.selectExpr(*[ "`%s` as `%s`" % (col, name) for col, name in zip(cols, names) ]) # update index map after rename for index, col in index_map.items(): if col in cols: index_map[index] = names[cols.index(col)] if df.rdd.getNumPartitions() < node_num: df = df.repartition(node_num) def to_pandas(columns, squeeze=False, index_col=None): def f(iter): import pandas as pd data = list(iter) pd_df = pd.DataFrame(data, columns=columns) if dtype is not None: if isinstance(dtype, dict): for col, type in dtype.items(): if isinstance(col, str): if col not in pd_df.columns: raise ValueError( "column to be set type is not" " in current dataframe") pd_df[col] = pd_df[col].astype(type) elif isinstance(col, int): if index_map[col] not in pd_df.columns: raise ValueError( "column index to be set type is not" " in current dataframe") pd_df[index_map[col]] = pd_df[ index_map[col]].astype(type) else: pd_df = pd_df.astype(dtype) if squeeze and len(pd_df.columns) == 1: pd_df = pd_df.iloc[:, 0] if index_col: pd_df = pd_df.set_index(index_col) return [pd_df] return f pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns, squeeze, index_col)) try: data_shards = SparkXShards(pd_rdd) except Exception as e: alternative_backend = "pandas" if backend == "spark" else "spark" print( "An error occurred when reading files with '%s' backend, you may switch to '%s' " "backend for another try. You can set the backend using " "OrcaContext.pandas_read_backend" % (backend, alternative_backend)) raise e return data_shards
if __name__ == "__main__": parser = OptionParser() parser.add_option("-f", type=str, dest="file_path", help="The file path to be read") (options, args) = parser.parse_args(sys.argv) # Prepare csv files df = pd.read_csv(options.file_path) sc = init_spark_on_local(cores="*") sqlContext = SQLContext(sc) num_nodes, num_cores = get_node_and_core_number() df_spark = sqlContext.createDataFrame(df) df_spark.printSchema() df_spark.repartition(num_cores).write.\ format('json').mode("overwrite").save("/tmp/ray-pandas-example") # init ray context ray_ctx = RayContext(sc=sc, object_store_memory="5g") ray_ctx.init(object_store_memory="5g") # read data data_shard = zoo.xshard.pandas.read_json("/tmp/ray-pandas-example", ray_ctx) # collect data data = data_shard.collect()
def __init__(self, tensor_structure, batch_size, batch_per_thread, hard_code_batch_size=False): ''' TFDataset represents a distributed collection of elements (backed by a RDD) to be feed into Tensorflow graph. :param tensor_structure: a nested structure of TensorMeta objects specifying the name, shape and data type of each element in this TFDataset :param batch_size: the batch size, used for training, should be a multiple of total core num :param batch_per_thread: the batch size for each thread, used for inference or evaluation :param hard_code_batch_size: whether to hard code the batch_size into tensorflow graph, if True, the static size of the first dimension of the resulting tensors is batch_size/total_core_num (training) or batch_per_thread for inference; if False, it is None. ''' if batch_size > 0 and batch_per_thread > 0: raise ValueError("bath_size and batch_per_thread should not be set simultaneously") self.has_batch = True node_num, core_num = get_node_and_core_number() self.total_core_num = node_num * core_num if batch_size > 0: if batch_size % self.total_core_num != 0: raise ValueError("batch_size should be a multiple " + "of total core number, but got batch_size: " + "%s where total core number is %s" % (batch_size, self.total_core_num)) if batch_size <= 0 and batch_per_thread <= 0: batch_per_thread = 1 batch_size = self.total_core_num self.has_batch = False self.batch_size = batch_size self.batch_per_thread = batch_per_thread self.hard_code_batch_size = hard_code_batch_size self.tensor_structure = tensor_structure if not self.hard_code_batch_size: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[None] + list(t.shape) if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: if self.batch_per_thread > 0: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_per_thread] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) else: self.output_shapes = nest.pack_sequence_as( self.tensor_structure, [[self.batch_size // self.total_core_num] + t.shape if t is not None else None for t in nest.flatten(self.tensor_structure)]) self.input_names = nest.pack_sequence_as( self.tensor_structure, [t.name if t is not None else None for t in nest.flatten(self.tensor_structure)]) self._tensors = None
def read_file_spark(file_path, file_type, **kwargs): sc = init_nncontext() file_url_splits = file_path.split("://") prefix = file_url_splits[0] node_num, core_num = get_node_and_core_number() file_paths = [] if isinstance(file_path, list): [ file_paths.extend(extract_one_path(path, file_type, os.environ)) for path in file_path ] else: file_paths = extract_one_path(file_path, file_type, os.environ) if not file_paths: raise Exception( "The file path is invalid/empty or does not include csv/json files" ) if ZooContext.orca_pandas_read_backend == "pandas": num_files = len(file_paths) total_cores = node_num * core_num num_partitions = num_files if num_files < total_cores else total_cores rdd = sc.parallelize(file_paths, num_partitions) if prefix == "hdfs": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs)) elif prefix == "s3": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs)) else: def loadFile(iterator): for x in iterator: df = read_pd_file(x, file_type, **kwargs) yield df pd_rdd = rdd.mapPartitions(loadFile) else: from pyspark.sql import SQLContext sqlContext = SQLContext.getOrCreate(sc) spark = sqlContext.sparkSession # TODO: add S3 confidentials if file_type == "json": df = spark.read.json(file_paths, **kwargs) elif file_type == "csv": df = spark.read.csv(file_paths, **kwargs) else: raise Exception("Unsupported file type") if df.rdd.getNumPartitions() < node_num: df = df.repartition(node_num) def to_pandas(columns): def f(iter): import pandas as pd data = list(iter) yield pd.DataFrame(data, columns=columns) return f pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns)) data_shards = SparkXShards(pd_rdd) return data_shards
def read_file_spark(context, file_path, file_type, **kwargs): file_url_splits = file_path.split("://") prefix = file_url_splits[0] node_num, core_num = get_node_and_core_number() if prefix == "s3": data_paths = list_s3_file(file_url_splits[1], file_type, os.environ) else: data_paths = get_file_list(file_path) rdd = context.parallelize(data_paths, node_num * core_num) if prefix == "hdfs": def loadFile(iterator): import pandas as pd import pyarrow as pa fs = pa.hdfs.connect() for x in iterator: with fs.open(x, 'rb') as f: if file_type == "csv": df = pd.read_csv(f, **kwargs) elif file_type == "json": df = pd.read_json(f, **kwargs) else: raise Exception("Unsupported file type") yield df pd_rdd = rdd.mapPartitions(loadFile) elif prefix == "s3": def loadFile(iterator): access_key_id = os.environ["AWS_ACCESS_KEY_ID"] secret_access_key = os.environ["AWS_SECRET_ACCESS_KEY"] import boto3 import pandas as pd s3_client = boto3.Session( aws_access_key_id=access_key_id, aws_secret_access_key=secret_access_key, ).client('s3', verify=False) for x in iterator: path_parts = x.split("://")[1].split('/') bucket = path_parts.pop(0) key = "/".join(path_parts) obj = s3_client.get_object(Bucket=bucket, Key=key) if file_type == "json": df = pd.read_json(obj['Body'], **kwargs) elif file_type == "csv": df = pd.read_csv(obj['Body'], **kwargs) else: raise Exception("Unsupported file type") yield df pd_rdd = rdd.mapPartitions(loadFile) else: def loadFile(iterator): import pandas as pd for x in iterator: if file_type == "csv": df = pd.read_csv(x, **kwargs) elif file_type == "json": df = pd.read_json(x, **kwargs) else: raise Exception("Unsupported file type") yield df pd_rdd = rdd.mapPartitions(loadFile) data_shards = SparkDataShards(pd_rdd) return data_shards
def __init__(self, rdd, names, shapes, types, batch_size, batch_per_thread, hard_code_batch_size=False, val_rdd=None): import tensorflow as tf ''' TFDatasets represents a distributed collection of elements to be feed into Tensorflow graph. TFDatasets can be created using a RDD and each of its records is a list of numpy.ndarray representing the tensors to be feed into tensorflow graph on each iteration. TFDatasets must be used with TFOptimizer or TFPredictor. :param rdd: a rdd of list of numpy.ndarray each representing a tensor to feed into tensorflow graph on each iteration :param names: the names of the resulting tensors, should be a list of str :param shapes: the shapes of the resulting tensors, should be a list of list of int :param types: the types of the result tensors, should be a list of tf.dtype :param batch_size: the batch size, used for training, should be a multiple of total core num :param batch_per_thread: the batch size for each thread, used for inference :param hard_code_batch_size: whether to hard code the batch_size into tensorflow graph, if True, the static size of the first dimension of the resulting tensors is batch_size/total_core_num (training) or batch_per_thread for inference; if False, it is None. ''' if batch_size > 0 and batch_per_thread > 0: raise ValueError( "bath_size and batch_per_thread should not be set simultaneously" ) node_num, core_num = get_node_and_core_number() self.total_core_num = node_num * core_num if batch_size > 0: if batch_size % self.total_core_num != 0: raise ValueError("batch_size should be a multiple " + "of total core number, but got batch_size: " + "%s where total core number is %s" % (batch_size, self.total_core_num)) if batch_size <= 0 and batch_per_thread <= 0: batch_per_thread = 1 batch_size = self.total_core_num self.batch_size = batch_size self.batch_per_thread = batch_per_thread if not hard_code_batch_size: self.tensors = [ tf.placeholder(name=names[i], dtype=types[i], shape=[None] + shapes[i]) for i in range(len(names)) ] else: if batch_per_thread > 0: self.tensors = [ tf.placeholder(name=names[i], dtype=types[i], shape=[batch_per_thread] + shapes[i]) for i in range(len(names)) ] else: self.tensors = [ tf.placeholder(name=names[i], dtype=types[i], shape=[batch_size / self.total_core_num] + shapes[i]) for i in range(len(names)) ] self.val_rdd = val_rdd self.rdd = rdd.map(lambda arr: arr[:len(names)]) self.input_names = names for i in range(len(self.tensors)): tf.add_to_collection(self.tensors[i].name, self)