def test_dataframe_with_empty_partition(self): from bigdl.orca import OrcaContext sc = OrcaContext.get_spark_context() rdd = sc.range(0, 10) rdd_with_empty = rdd.repartition(4).\ mapPartitionsWithIndex(lambda idx, part: [] if idx == 0 else part) from pyspark.sql import SparkSession spark = SparkSession(sc) from pyspark.ml.linalg import DenseVector df = rdd_with_empty.map(lambda x: (DenseVector(np.random.randn(1,).astype(np.float)), int(np.random.randint(0, 1, size=()))))\ .toDF(["feature", "label"]) config = {"lr": 0.8} trainer = Estimator.from_keras(model_creator=model_creator, verbose=True, config=config, workers_per_node=2) trainer.fit(df, epochs=1, batch_size=4, steps_per_epoch=25, feature_cols=["feature"], label_cols=["label"]) trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"], label_cols=["label"]) trainer.predict(df, feature_cols=["feature"]).collect()
def __init__(self, hosts=None, processes_per_node=1, env=None): driver_ip = get_node_ip() if hosts is None: # Single node self.hosts = [driver_ip] elif hosts == "all": # All executor nodes in the cluster def get_ip(iter): yield get_node_ip() from bigdl.dllib.utils.common import get_node_and_core_number from bigdl.orca import OrcaContext sc = OrcaContext.get_spark_context() node_num, core_num = get_node_and_core_number() total_cores = node_num * core_num self.hosts = list( set( sc.range(0, total_cores, numSlices=total_cores).barrier().mapPartitions( get_ip).collect())) else: # User specified hosts, assumed to be non-duplicate assert isinstance(hosts, list) self.hosts = hosts self.master = self.hosts[0] print("Master: ", self.master) self.remote_hosts = [] for host in self.hosts: if host != driver_ip: self.remote_hosts.append(host) print("Remote hosts: ", self.remote_hosts) print("Hosts: ", self.hosts) self.processes_per_node = processes_per_node self.env = env if env else {}
def test_dataframe(self): sc = OrcaContext.get_spark_context() rdd = sc.range(0, 100) spark = OrcaContext.get_spark_session() from pyspark.ml.linalg import DenseVector df = rdd.map(lambda x: (DenseVector(np.random.randn(1, ).astype(np.float)), int(np.random.randint(0, 2, size=())))).toDF( ["feature", "label"]) config = {"lr": 0.2} try: temp_dir = tempfile.mkdtemp() trainer = Estimator.from_keras(model_creator=model_creator, verbose=True, config=config, workers_per_node=2, backend="spark", model_dir=temp_dir) res = trainer.fit(df, epochs=5, batch_size=4, steps_per_epoch=25, feature_cols=["feature"], label_cols=["label"], validation_data=df, validation_steps=1) print("start saving") trainer.save_weights(os.path.join(temp_dir, "cifar10_keras.h5")) trainer.load_weights(os.path.join(temp_dir, "cifar10_keras.h5")) trainer.save(os.path.join(temp_dir, "a.model")) trainer.load(os.path.join(temp_dir, "a.model")) res = trainer.evaluate(df, batch_size=4, num_steps=25, feature_cols=["feature"], label_cols=["label"]) print("validation result: ", res) res = trainer.predict(df, feature_cols=["feature"]).collect() finally: shutil.rmtree(temp_dir)
def test_string_input(self): def model_creator(config): import tensorflow as tf vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization( max_tokens=10, output_mode='int', output_sequence_length=4) model = tf.keras.models.Sequential() model.add(tf.keras.Input(shape=(1, ), dtype=tf.string)) model.add(vectorize_layer) return model from bigdl.orca import OrcaContext from pyspark.sql.types import StructType, StructField, StringType spark = OrcaContext.get_spark_session() schema = StructType([StructField("input", StringType(), True)]) input_data = [["foo qux bar"], ["qux baz"]] input_df = spark.createDataFrame(input_data, schema) estimator = Estimator.from_keras(model_creator=model_creator) output_df = estimator.predict(input_df, batch_size=1, feature_cols=["input"]) output = output_df.collect() print(output)
def read_parquet(file_path, columns=None, schema=None, **options): """ Read parquet files to SparkXShards of pandas DataFrames. :param file_path: Parquet file path, a list of multiple parquet file paths, or a directory containing parquet files. Local file system, HDFS, and AWS S3 are supported. :param columns: list of column name, default=None. If not None, only these columns will be read from the file. :param schema: pyspark.sql.types.StructType for the input schema or a DDL-formatted string (For example col0 INT, col1 DOUBLE). :param options: other options for reading parquet. :return: An instance of SparkXShards. """ sc = init_nncontext() spark = OrcaContext.get_spark_session() # df = spark.read.parquet(file_path) df = spark.read.load(file_path, "parquet", schema=schema, **options) if columns: df = df.select(*columns) def to_pandas(columns): def f(iter): import pandas as pd data = list(iter) pd_df = pd.DataFrame(data, columns=columns) return [pd_df] return f pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns)) try: data_shards = SparkXShards(pd_rdd) except Exception as e: print("An error occurred when reading parquet files") raise e return data_shards
def read_df(esConfig, esResource, schema=None): """ Read the data from elastic search into DataFrame. :param esConfig: Dictionary which represents configuration for elastic search(eg. ip, port etc). :param esResource: resource file in elastic search. :param schema: Optional. Defines the schema of Spark dataframe. If each column in Es is single value, don't need set schema. :return: Spark DataFrame. Each row represents a document in ES. """ sc = init_nncontext() spark = OrcaContext.get_spark_session() reader = spark.read.format("org.elasticsearch.spark.sql") for key in esConfig: reader.option(key, esConfig[key]) if schema: reader.schema(schema) df = reader.load(esResource) return df
def test_read_parquet(self): file_path = os.path.join(self.resource_path, "orca/data/csv") sc = init_nncontext() from pyspark.sql.functions import col spark = OrcaContext.get_spark_session() df = spark.read.csv(file_path, header=True) df = df.withColumn('sale_price', col('sale_price').cast('int')) temp = tempfile.mkdtemp() df.write.parquet(os.path.join(temp, "test_parquet")) data_shard2 = bigdl.orca.data.pandas.read_parquet( os.path.join(temp, "test_parquet")) assert data_shard2.num_partitions() == 2, "number of shard should be 2" data = data_shard2.collect() df = data[0] assert "location" in df.columns data_shard2 = bigdl.orca.data.pandas.read_parquet( os.path.join(temp, "test_parquet"), columns=['ID', 'sale_price']) data = data_shard2.collect() df = data[0] assert len(df.columns) == 2 from pyspark.sql.types import StructType, StructField, IntegerType, StringType schema = StructType([ StructField("ID", StringType(), True), StructField("sale_price", IntegerType(), True), StructField("location", StringType(), True) ]) data_shard3 = bigdl.orca.data.pandas.read_parquet( os.path.join(temp, "test_parquet"), columns=['ID', 'sale_price'], schema=schema) data = data_shard3.collect() df = data[0] assert str(df['sale_price'].dtype) == 'int64' shutil.rmtree(temp)
def __init__(self, model_creator, config=None, compile_args_creator=None, verbose=False, workers_per_node=1, model_dir=None): self.model_creator = model_creator self.compile_args_creator = compile_args_creator self.config = {} if config is None else config self.verbose = verbose sc = OrcaContext.get_spark_context() num_node, num_core = get_node_and_core_number() self.num_workers = num_node * workers_per_node self.total_cores = num_node * num_core # over partition to cover tasks all over the cluster self.workerRDD = sc.parallelize(list(range(self.total_cores * 4)), self.total_cores * 4).repartition( self.num_workers) if not "inter_op_parallelism" in self.config: self.config["inter_op_parallelism"] = 1 if not "intra_op_parallelism" in self.config: self.config["intra_op_parallelism"] = num_core // workers_per_node self.model_weights = None self.epoch = 0 if "batch_size" in self.config: raise Exception( "Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate function of the estimator instead.") self.model_dir = model_dir
elif cluster_mode == "yarn": init_orca_context(cluster_mode="yarn-client", num_nodes=2, cores=2, driver_memory="6g") # run on Hadoop YARN cluster elif cluster_mode == "k8s": init_orca_context(cluster_mode="k8s", master=master, container_image=image_name_k8s, num_nodes=1, memory="128g", cores=4) # run in local mode print("INFO 1 cluster_mode_init_success!") # Read in the dataset, and do a little preprocessing new_rating_files="/ppml/trusted-big-data-ml/work/data/ml-1m/ratings_new.dat.2" if not os.path.exists(new_rating_files): print("INFO ERROR ratings_new.dat does not exist") exit(1) # read csv spark = OrcaContext.get_spark_session() df = spark.read.csv(new_rating_files, sep=':', header=True, inferSchema=True).toDF( "user", "item", "label", "timestamp") user_set = df.select('user').collect() item_set = df.select('item').collect() #min_user_id = min(user_set)[0] max_user_id = max(user_set)[0] #min_item_id = min(item_set)[0] max_item_id = max(item_set)[0] #print(min_user_id, max_user_id, min_item_id, max_item_id) # update label starting from 0 df = df.withColumn('label', df.label-1) # split to train/test dataset train_data, test_data = df.randomSplit([0.8, 0.2], 100)
driver_memory=args.driver_memory, driver_cores=args.driver_cores, num_executors=args.slave_num, extra_executor_memory_for_ray=args.extra_executor_memory_for_ray, object_store_memory=args.object_store_memory) else: sc = init_orca_context(cluster_mode="yarn-cluster", cores=args.executor_cores, memory=args.executor_memory, init_ray_on_spark=True, driver_memory=args.driver_memory, driver_cores=args.driver_cores, num_executors=args.slave_num, extra_executor_memory_for_ray=args.extra_executor_memory_for_ray, object_store_memory=args.object_store_memory) ray_ctx = OrcaContext.get_ray_context() elif cluster_mode == "local": sc = init_orca_context(cores=args.driver_cores) ray_ctx = OrcaContext.get_ray_context() elif cluster_mode == "spark-submit": sc = init_orca_context(cluster_mode=cluster_mode) ray_ctx = OrcaContext.get_ray_context() else: print("init_orca_context failed. cluster_mode should be one of 'local', 'yarn' and 'spark-submit' but got " + cluster_mode) # Simple environment with 4 independent cartpole entities register_env("multi_cartpole", lambda _: MultiAgentCartPole({"num_agents": 4})) single_env = gym.make("CartPole-v0") obs_space = single_env.observation_space act_space = single_env.action_space
def fit(self, data, epochs=1, batch_size=32, verbose=1, callbacks=None, validation_data=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1, data_config=None, feature_cols=None, label_cols=None, model_dir=None): """ Train this tensorflow model with train data. :param data: train data. It can be XShards, Spark DataFrame or creator function which returns Iter or DataLoader. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of numpy arrays. :param epochs: Number of epochs to train the model. Default: 1. :param batch_size: Batch size used for training. Default: 32. :param verbose: Prints output of one model if true. :param callbacks: List of Keras compatible callbacks to apply during training. :param validation_data: validation data. Validation data type should be the same as train data. :param class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. :return: """ import numpy as np sc = OrcaContext.get_spark_context() init_params = dict(model_creator=self.model_creator, compile_args_creator=self.compile_args_creator, config=self.config, verbose=self.verbose, size=self.num_workers, mode="fit", cluster_info=self._get_cluster_info(sc), model_dir=self.model_dir, epoch=self.epoch) params = dict(epochs=epochs, batch_size=batch_size, verbose=verbose, callbacks=callbacks, class_weight=class_weight, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq, data_config=data_config) # dataframe change to xshard, num_partition >= num_workers data, validation_data = maybe_dataframe_to_xshards( data, validation_data, feature_cols, label_cols, mode="fit", num_workers=self.num_workers, accept_str_col=True) if isinstance(data, SparkXShards): # set train/validation data if validation_data is None: def transform_func(iter, init_param, param): partition_data = list(iter) param["data_creator"] = make_data_creator(partition_data) return SparkRunner(**init_param).step(**param) res = data.rdd.repartition(self.num_workers).barrier() \ .mapPartitions( lambda iter: transform_func(iter, init_params, params)).collect() else: def transform_func(iter, init_param, param): data_tuple_list = list(iter) data_list = [x[0] for x in data_tuple_list] valid_list = [x[1] for x in data_tuple_list] param["data_creator"] = make_data_creator(data_list) param["validation_data_creator"] = make_data_creator( valid_list) return SparkRunner(**init_param).step(**param) res = data.zip(validation_data).rdd.repartition(self.num_workers).barrier() \ .mapPartitions( lambda iter: transform_func(iter, init_params, params)).collect() else: params["data_creator"] = data params["validation_data_creator"] = validation_data def transform_func(iter, init_param, param): return SparkRunner(**init_param).step(**param) res = self.workerRDD.barrier().mapPartitions( lambda iter: transform_func(iter, init_params, params )).collect() if self.model_dir: try: temp_dir = tempfile.mkdtemp() get_remote_file_to_local(os.path.join(self.model_dir, "states.pkl"), os.path.join(temp_dir, "states.pkl"), over_write=True) import pickle with open(os.path.join(temp_dir, "states.pkl"), 'rb') as f: states = pickle.load(f) self.model_weights = states['weights'] self.epoch = states["epoch"] finally: shutil.rmtree(temp_dir) return res[0]
def predict(self, data, batch_size=None, verbose=1, steps=None, callbacks=None, data_config=None, feature_cols=None): """ Predict the input data :param data: predict input data. It can be XShards or Spark DataFrame. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature}, where feature is a numpy array or a tuple of numpy arrays. :param batch_size: Batch size used for inference. Default: None. :param verbose: Prints output of one model if true. :param steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of None. :param callbacks: List of Keras compatible callbacks to apply during prediction. :param data_config: An optional dictionary that can be passed to data creator function. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :return: """ logger.info("Starting predict step.") sc = OrcaContext.get_spark_context() if self.model_weights: weights = sc.broadcast(self.model_weights) else: weights = None init_params = dict(model_creator=self.model_creator, compile_args_creator=self.compile_args_creator, config=self.config, verbose=self.verbose, size=self.num_workers, model_weights=weights, mode="predict", cluster_info=None) params = dict(verbose=verbose, batch_size=batch_size, steps=steps, callbacks=callbacks, data_config=data_config) if isinstance(data, DataFrame): data = data.repartition(self.num_workers) xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict", accept_str_col=True) def transform_func(iter, init_param, param): partition_data = list(iter) # res = combine_in_partition(partition_data) param["data_creator"] = make_data_creator(partition_data) return SparkRunner(**init_param).predict(**param) pred_shards = SparkXShards(xshards.rdd.repartition(self.num_workers) \ .mapPartitions( lambda iter: transform_func(iter, init_params, params))) result = convert_predict_xshards_to_dataframe(data, pred_shards) else: raise ValueError( "Only xshards or Spark DataFrame is supported for predict") return result
def evaluate(self, data, batch_size=32, num_steps=None, verbose=1, sample_weight=None, callbacks=None, data_config=None, feature_cols=None, label_cols=None): """ Evaluates the model on the validation data set. :param data: evaluate data. It can be XShards, Spark DataFrame or creator function which returns Iter or DataLoader. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of numpy arrays. :param validation_data: validation data. Validation data type should be the same as train data. :param batch_size: Batch size used for evaluation. Default: 32. :param verbose: Prints output of one model if true. :param callbacks: List of Keras compatible callbacks to apply during evaluation. :param class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. :return: validation result """ import numpy as np sc = OrcaContext.get_spark_context() logger.info("Starting validation step.") if self.model_weights: weights = sc.broadcast(self.model_weights) else: weights = None init_params = dict(model_creator=self.model_creator, compile_args_creator=self.compile_args_creator, config=self.config, verbose=self.verbose, size=self.num_workers, model_weights=weights, mode="evaluate", cluster_info=self._get_cluster_info(sc)) params = dict( batch_size=batch_size, verbose=verbose, sample_weight=sample_weight, steps=num_steps, callbacks=callbacks, data_config=data_config, ) # dataframe change to xshard, num_partition >= num_workers data, _ = maybe_dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=label_cols, mode="evaluate", num_workers=self.num_workers, accept_str_col=True) if isinstance(data, SparkXShards): # set train/validation data def transform_func(iter, init_param, param): partition_data = list(iter) param["data_creator"] = make_data_creator(partition_data) return SparkRunner(**init_param).validate(**param) res = data.rdd.repartition(self.num_workers).barrier() \ .mapPartitions(lambda iter: transform_func(iter, init_params, params)).collect() else: params["data_creator"] = data def transform_func(iter, init_param, param): return SparkRunner(**init_param).validate(**param) res = self.workerRDD.barrier().mapPartitions( lambda iter: transform_func(iter, init_params, params )).collect() return res[0]
def read_file_spark(file_path, file_type, **kwargs): sc = init_nncontext() node_num, core_num = get_node_and_core_number() backend = OrcaContext.pandas_read_backend if backend == "pandas": file_url_splits = file_path.split("://") prefix = file_url_splits[0] file_paths = [] if isinstance(file_path, list): [ file_paths.extend(extract_one_path(path, os.environ)) for path in file_path ] else: file_paths = extract_one_path(file_path, os.environ) if not file_paths: raise Exception( "The file path is invalid or empty, please check your data") num_files = len(file_paths) total_cores = node_num * core_num num_partitions = num_files if num_files < total_cores else total_cores rdd = sc.parallelize(file_paths, num_partitions) if prefix == "hdfs": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs)) elif prefix == "s3": pd_rdd = rdd.mapPartitions( lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs)) else: def loadFile(iterator): dfs = [] for x in iterator: df = read_pd_file(x, file_type, **kwargs) dfs.append(df) import pandas as pd return [pd.concat(dfs)] pd_rdd = rdd.mapPartitions(loadFile) else: # Spark backend; spark.read.csv/json accepts a folder path as input assert file_type == "json" or file_type == "csv", \ "Unsupported file type: %s. Only csv and json files are supported for now" % file_type spark = OrcaContext.get_spark_session() # TODO: add S3 confidentials # The following implementation is adapted from # https://github.com/databricks/koalas/blob/master/databricks/koalas/namespace.py # with some modifications. if "mangle_dupe_cols" in kwargs: assert kwargs[ "mangle_dupe_cols"], "mangle_dupe_cols can only be True" kwargs.pop("mangle_dupe_cols") if "parse_dates" in kwargs: assert not kwargs["parse_dates"], "parse_dates can only be False" kwargs.pop("parse_dates") names = kwargs.get("names", None) if "names" in kwargs: kwargs.pop("names") usecols = kwargs.get("usecols", None) if "usecols" in kwargs: kwargs.pop("usecols") dtype = kwargs.get("dtype", None) if "dtype" in kwargs: kwargs.pop("dtype") squeeze = kwargs.get("squeeze", False) if "squeeze" in kwargs: kwargs.pop("squeeze") index_col = kwargs.get("index_col", None) if "index_col" in kwargs: kwargs.pop("index_col") if file_type == "csv": # Handle pandas-compatible keyword arguments kwargs["inferSchema"] = True header = kwargs.get("header", "infer") if isinstance(names, str): kwargs["schema"] = names if header == "infer": header = 0 if names is None else None if header == 0: kwargs["header"] = True elif header is None: kwargs["header"] = False else: raise ValueError("Unknown header argument {}".format(header)) if "quotechar" in kwargs: quotechar = kwargs["quotechar"] kwargs.pop("quotechar") kwargs["quote"] = quotechar if "escapechar" in kwargs: escapechar = kwargs["escapechar"] kwargs.pop("escapechar") kwargs["escape"] = escapechar # sep and comment are the same as pandas if "comment" in kwargs: comment = kwargs["comment"] if not isinstance(comment, str) or len(comment) != 1: raise ValueError( "Only length-1 comment characters supported") df = spark.read.csv(file_path, **kwargs) if header is None: df = df.selectExpr(*[ "`%s` as `%s`" % (field.name, i) for i, field in enumerate(df.schema) ]) else: df = spark.read.json(file_path, **kwargs) # Handle pandas-compatible postprocessing arguments if usecols is not None and not callable(usecols): usecols = list(usecols) renamed = False if isinstance(names, list): if len(set(names)) != len(names): raise ValueError( "Found duplicate names, please check your names input") if usecols is not None: if not callable(usecols): # usecols is list if len(names) != len(usecols) and len(names) != len( df.schema): raise ValueError("Passed names did not match usecols") if len(names) == len(df.schema): df = df.selectExpr(*[ "`%s` as `%s`" % (field.name, name) for field, name in zip(df.schema, names) ]) renamed = True else: if len(names) != len(df.schema): raise ValueError( "The number of names [%s] does not match the number " "of columns [%d]. Try names by a Spark SQL DDL-formatted " "string." % (len(names), len(df.schema))) df = df.selectExpr(*[ "`%s` as `%s`" % (field.name, name) for field, name in zip(df.schema, names) ]) renamed = True index_map = dict([(i, field.name) for i, field in enumerate(df.schema)]) if usecols is not None: if callable(usecols): cols = [ field.name for field in df.schema if usecols(field.name) ] missing = [] elif all(isinstance(col, int) for col in usecols): cols = [ field.name for i, field in enumerate(df.schema) if i in usecols ] missing = [ col for col in usecols if col >= len(df.schema) or df.schema[col].name not in cols ] elif all(isinstance(col, str) for col in usecols): cols = [ field.name for field in df.schema if field.name in usecols ] if isinstance(names, list): missing = [c for c in usecols if c not in names] else: missing = [col for col in usecols if col not in cols] else: raise ValueError( "usecols must only be list-like of all strings, " "all unicode, all integers or a callable.") if len(missing) > 0: raise ValueError( "usecols do not match columns, columns expected but not found: %s" % missing) if len(cols) > 0: df = df.select(cols) if isinstance(names, list): if not renamed: df = df.selectExpr(*[ "`%s` as `%s`" % (col, name) for col, name in zip(cols, names) ]) # update index map after rename for index, col in index_map.items(): if col in cols: index_map[index] = names[cols.index(col)] if df.rdd.getNumPartitions() < node_num: df = df.repartition(node_num) def to_pandas(columns, squeeze=False, index_col=None): def f(iter): import pandas as pd data = list(iter) pd_df = pd.DataFrame(data, columns=columns) if dtype is not None: if isinstance(dtype, dict): for col, type in dtype.items(): if isinstance(col, str): if col not in pd_df.columns: raise ValueError( "column to be set type is not" " in current dataframe") pd_df[col] = pd_df[col].astype(type) elif isinstance(col, int): if index_map[col] not in pd_df.columns: raise ValueError( "column index to be set type is not" " in current dataframe") pd_df[index_map[col]] = pd_df[ index_map[col]].astype(type) else: pd_df = pd_df.astype(dtype) if squeeze and len(pd_df.columns) == 1: pd_df = pd_df.iloc[:, 0] if index_col: pd_df = pd_df.set_index(index_col) return [pd_df] return f pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns, squeeze, index_col)) try: data_shards = SparkXShards(pd_rdd) except Exception as e: alternative_backend = "pandas" if backend == "spark" else "spark" print( "An error occurred when reading files with '%s' backend, you may switch to '%s' " "backend for another try. You can set the backend using " "OrcaContext.pandas_read_backend" % (backend, alternative_backend)) raise e return data_shards
def __init__(self, *, model_creator, optimizer_creator, loss_creator=None, metrics=None, scheduler_creator=None, training_operator_cls=TrainingOperator, initialization_hook=None, config=None, scheduler_step_freq="batch", use_tqdm=False, workers_per_node=1, sync_stats=True, log_level=logging.INFO): if config is not None and "batch_size" in config: raise Exception( "Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate/predict function of the estimator instead.") self.config = {} if config is None else config sc = OrcaContext.get_spark_context() if not (isinstance(model_creator, types.FunctionType) and isinstance( optimizer_creator, types.FunctionType)): # Torch model is also callable. raise ValueError( "Must provide a function for both model_creator and optimizer_creator" ) if not training_operator_cls and not loss_creator: raise ValueError("If a loss_creator is not provided, you must " "provide a custom training operator.") self.model_creator = model_creator self.initialization_hook = initialization_hook num_nodes, cores_per_node = get_node_and_core_number() self.num_workers = num_nodes * workers_per_node self.total_cores = num_nodes * cores_per_node self.cores_per_worker = cores_per_node // workers_per_node # over partition to cover tasks all over the cluster self.workerRDD = sc.parallelize(list(range(self.total_cores * 4)), self.total_cores * 4).repartition( self.num_workers) self.worker_init_params = dict( model_creator=self.model_creator, optimizer_creator=optimizer_creator, loss_creator=loss_creator, scheduler_creator=scheduler_creator, training_operator_cls=training_operator_cls, scheduler_step_freq=scheduler_step_freq, use_tqdm=use_tqdm, config=self.config.copy(), metrics=metrics, size=self.num_workers, cores_per_worker=self.cores_per_worker, cluster_info=self._get_cluster_info(sc), sync_stats=sync_stats, log_level=log_level) self.driver_runner = PytorchPysparkWorker(**self.worker_init_params, mode='predict') self.state_dict = self.driver_runner.get_state_dict()