def to_dataset(data, batch_size, batch_per_thread, validation_data, feature_cols, labels_cols, hard_code_batch_size, sequential_order, shuffle, auto_shard_files, memory_type="DRAM"): # todo wrap argument into kwargs if validation_data: if isinstance(data, SparkXShards): assert isinstance(validation_data, SparkXShards), \ "train data and validation data should be both SparkXShards" if isinstance(data, Dataset): assert isinstance(validation_data, Dataset), \ "train data and validation data should be both orca.data.tf.Dataset" if isinstance(data, DataFrame): assert isinstance(validation_data, DataFrame), \ "train data and validation data should be both Spark DataFrame" if isinstance(data, tf.data.Dataset): assert isinstance(validation_data, tf.data.Dataset), \ "train data and validation data should be both tf.data.Dataset" if isinstance(data, SparkXShards): dataset = xshards_to_tf_dataset(data, batch_size, batch_per_thread, validation_data, hard_code_batch_size=hard_code_batch_size, memory_type=memory_type, sequential_order=sequential_order, shuffle=shuffle) elif isinstance(data, Dataset): dataset = TFDataDataset2(data, batch_size=batch_size, batch_per_thread=batch_per_thread, validation_dataset=validation_data) elif isinstance(data, DataFrame): dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols, batch_size, batch_per_thread, hard_code_batch_size, validation_data, memory_type, sequential_order, shuffle ) elif is_tf_data_dataset(data): dataset = TFDataset.from_tf_data_dataset(data, batch_size, batch_per_thread, hard_code_batch_size, validation_data, sequential_order, shuffle, auto_shard_files=auto_shard_files) else: raise ValueError("data must be SparkXShards or orca.data.tf.Dataset or " "Spark DataFrame or tf.data.Dataset") return dataset
def fit(self, data, steps, batch_size=32, validation_data=None, feed_dict=None, session_config=None): assert self.labels is not None, \ "labels is None; it should not be None in training" assert self.loss is not None, \ "loss is None; it should not be None in training" assert self.optimizer is not None, \ "optimizer is None; it not None in training" if isinstance(data, SparkXShards): dataset = _xshards_to_tf_dataset( data, batch_size=batch_size, validation_data_shard=validation_data) elif isinstance(data, Dataset): dataset = TFDataDataset2(data, batch_size=batch_size, batch_per_thread=-1, validation_dataset=validation_data) else: raise ValueError("data type {} is not supported; " "it must be created by zoo.orca.data.package") if feed_dict is not None: tensor_with_value = { key: (value, value) for key, value in feed_dict.items() } else: tensor_with_value = None optimizer = TFOptimizer.from_train_op( train_op=self.train_op, loss=self.loss, inputs=self.inputs, labels=self.labels, dataset=dataset, metrics=self.metrics, updates=self.updates, sess=self.sess, tensor_with_value=tensor_with_value, session_config=session_config, model_dir=self.model_dir) optimizer.optimize(end_trigger=MaxIteration(steps)) return self
def _to_dataset(data, batch_size, batch_per_thread): if isinstance(data, SparkXShards): dataset = _xshards_to_tf_dataset(data, batch_size=batch_size, batch_per_thread=batch_per_thread) elif isinstance(data, Dataset): dataset = TFDataDataset2(data, batch_size=batch_size, batch_per_thread=batch_per_thread) else: raise ValueError( "data must be a SparkXShards or an orca.data.tf.Dataset") return dataset
def predict(self, data, batch_size=32): assert self.outputs is not None, \ "output is None, it should not be None in prediction" if isinstance(data, SparkXShards): dataset = _xshards_to_tf_dataset(data, batch_per_thread=batch_size) elif isinstance(data, Dataset): dataset = TFDataDataset2(data, batch_size=-1, batch_per_thread=batch_size) else: raise ValueError("data must be a SparkXShards or an orca.data.tf.Dataset") flat_inputs = nest.flatten(self.inputs) flat_outputs = nest.flatten(self.outputs) tfnet = TFNet.from_session(sess=self.sess, inputs=flat_inputs, outputs=flat_outputs) return tfnet.predict(dataset)
def to_dataset(data, batch_size, batch_per_thread, validation_data, feature_cols, labels_cols, hard_code_batch_size, sequential_order, shuffle): if validation_data: if isinstance(data, SparkXShards): assert isinstance(validation_data, SparkXShards), \ "train data and validation data should be both SparkXShards" if isinstance(data, Dataset): assert isinstance(validation_data, Dataset), \ "train data and validation data should be both orca.data.tf.Dataset" if isinstance(data, DataFrame): assert isinstance(validation_data, DataFrame), \ "train data and validation data should be both Spark DataFrame" if isinstance(data, SparkXShards): dataset = xshards_to_tf_dataset( data, batch_size, batch_per_thread, validation_data, hard_code_batch_size=hard_code_batch_size, sequential_order=sequential_order, shuffle=shuffle) elif isinstance(data, Dataset): dataset = TFDataDataset2(data, batch_size=batch_size, batch_per_thread=batch_per_thread, validation_dataset=validation_data) elif isinstance(data, DataFrame): dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols, batch_size, batch_per_thread, hard_code_batch_size, validation_data, sequential_order, shuffle) else: raise ValueError( "data must be SparkXShards or orca.data.tf.Dataset or Spark DataFrame" ) return dataset