コード例 #1
0
def to_dataset(data, batch_size, batch_per_thread, validation_data,
               feature_cols, labels_cols, hard_code_batch_size,
               sequential_order, shuffle, auto_shard_files, memory_type="DRAM"):
    # todo wrap argument into kwargs
    if validation_data:
        if isinstance(data, SparkXShards):
            assert isinstance(validation_data, SparkXShards), \
                "train data and validation data should be both SparkXShards"
        if isinstance(data, Dataset):
            assert isinstance(validation_data, Dataset), \
                "train data and validation data should be both orca.data.tf.Dataset"
        if isinstance(data, DataFrame):
            assert isinstance(validation_data, DataFrame), \
                "train data and validation data should be both Spark DataFrame"
        if isinstance(data, tf.data.Dataset):
            assert isinstance(validation_data, tf.data.Dataset), \
                "train data and validation data should be both tf.data.Dataset"

    if isinstance(data, SparkXShards):
        dataset = xshards_to_tf_dataset(data,
                                        batch_size,
                                        batch_per_thread,
                                        validation_data,
                                        hard_code_batch_size=hard_code_batch_size,
                                        memory_type=memory_type,
                                        sequential_order=sequential_order,
                                        shuffle=shuffle)
    elif isinstance(data, Dataset):
        dataset = TFDataDataset2(data, batch_size=batch_size,
                                 batch_per_thread=batch_per_thread,
                                 validation_dataset=validation_data)
    elif isinstance(data, DataFrame):
        dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols,
                                           batch_size,
                                           batch_per_thread,
                                           hard_code_batch_size,
                                           validation_data,
                                           memory_type,
                                           sequential_order,
                                           shuffle
                                           )
    elif is_tf_data_dataset(data):
        dataset = TFDataset.from_tf_data_dataset(data,
                                                 batch_size,
                                                 batch_per_thread,
                                                 hard_code_batch_size,
                                                 validation_data,
                                                 sequential_order,
                                                 shuffle, auto_shard_files=auto_shard_files)
    else:
        raise ValueError("data must be SparkXShards or orca.data.tf.Dataset or "
                         "Spark DataFrame or tf.data.Dataset")

    return dataset
コード例 #2
0
    def fit(self,
            data,
            steps,
            batch_size=32,
            validation_data=None,
            feed_dict=None,
            session_config=None):

        assert self.labels is not None, \
            "labels is None; it should not be None in training"
        assert self.loss is not None, \
            "loss is None; it should not be None in training"
        assert self.optimizer is not None, \
            "optimizer is None; it not None in training"

        if isinstance(data, SparkXShards):
            dataset = _xshards_to_tf_dataset(
                data,
                batch_size=batch_size,
                validation_data_shard=validation_data)
        elif isinstance(data, Dataset):
            dataset = TFDataDataset2(data,
                                     batch_size=batch_size,
                                     batch_per_thread=-1,
                                     validation_dataset=validation_data)
        else:
            raise ValueError("data type {} is not supported; "
                             "it must be created by zoo.orca.data.package")

        if feed_dict is not None:
            tensor_with_value = {
                key: (value, value)
                for key, value in feed_dict.items()
            }
        else:
            tensor_with_value = None

        optimizer = TFOptimizer.from_train_op(
            train_op=self.train_op,
            loss=self.loss,
            inputs=self.inputs,
            labels=self.labels,
            dataset=dataset,
            metrics=self.metrics,
            updates=self.updates,
            sess=self.sess,
            tensor_with_value=tensor_with_value,
            session_config=session_config,
            model_dir=self.model_dir)

        optimizer.optimize(end_trigger=MaxIteration(steps))
        return self
コード例 #3
0
def _to_dataset(data, batch_size, batch_per_thread):
    if isinstance(data, SparkXShards):
        dataset = _xshards_to_tf_dataset(data,
                                         batch_size=batch_size,
                                         batch_per_thread=batch_per_thread)
    elif isinstance(data, Dataset):
        dataset = TFDataDataset2(data,
                                 batch_size=batch_size,
                                 batch_per_thread=batch_per_thread)
    else:
        raise ValueError(
            "data must be a SparkXShards or an orca.data.tf.Dataset")

    return dataset
コード例 #4
0
ファイル: estimator.py プロジェクト: yanwei-ji/analytics-zoo
    def predict(self, data, batch_size=32):
        assert self.outputs is not None, \
            "output is None, it should not be None in prediction"

        if isinstance(data, SparkXShards):
            dataset = _xshards_to_tf_dataset(data,
                                             batch_per_thread=batch_size)
        elif isinstance(data, Dataset):
            dataset = TFDataDataset2(data, batch_size=-1,
                                     batch_per_thread=batch_size)
        else:
            raise ValueError("data must be a SparkXShards or an orca.data.tf.Dataset")

        flat_inputs = nest.flatten(self.inputs)
        flat_outputs = nest.flatten(self.outputs)
        tfnet = TFNet.from_session(sess=self.sess, inputs=flat_inputs, outputs=flat_outputs)
        return tfnet.predict(dataset)
コード例 #5
0
def to_dataset(data, batch_size, batch_per_thread, validation_data,
               feature_cols, labels_cols, hard_code_batch_size,
               sequential_order, shuffle):
    if validation_data:
        if isinstance(data, SparkXShards):
            assert isinstance(validation_data, SparkXShards), \
                "train data and validation data should be both SparkXShards"
        if isinstance(data, Dataset):
            assert isinstance(validation_data, Dataset), \
                "train data and validation data should be both orca.data.tf.Dataset"
        if isinstance(data, DataFrame):
            assert isinstance(validation_data, DataFrame), \
                "train data and validation data should be both Spark DataFrame"

    if isinstance(data, SparkXShards):
        dataset = xshards_to_tf_dataset(
            data,
            batch_size,
            batch_per_thread,
            validation_data,
            hard_code_batch_size=hard_code_batch_size,
            sequential_order=sequential_order,
            shuffle=shuffle)
    elif isinstance(data, Dataset):
        dataset = TFDataDataset2(data,
                                 batch_size=batch_size,
                                 batch_per_thread=batch_per_thread,
                                 validation_dataset=validation_data)
    elif isinstance(data, DataFrame):
        dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols,
                                           batch_size, batch_per_thread,
                                           hard_code_batch_size,
                                           validation_data, sequential_order,
                                           shuffle)
    else:
        raise ValueError(
            "data must be SparkXShards or orca.data.tf.Dataset or Spark DataFrame"
        )

    return dataset