def test_tfdataset_with_tfrecord(self): model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(10, activation='softmax'), ]) model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', metrics=['accuracy']) keras_model = KerasModel(model) def parse_fn(example): keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'), 'image/class/label': tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)), } items_to_handlers = { 'image': tf.contrib.slim.tfexample_decoder.Image(shape=[28, 28, 1], channels=1), 'label': tf.contrib.slim.tfexample_decoder.Tensor('image/class/label', shape=[]), } decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) results = decoder.decode(example) if len(results[0].shape) > 0: feature = results[0] label = results[1] else: feature = results[1] label = results[0] return feature, label train_path = os.path.join(resource_path, "tfrecord/mnist_train.tfrecord") test_path = os.path.join(resource_path, "tfrecord/mnist_test.tfrecord") dataset = TFDataset.from_tfrecord(train_path, parse_fn=parse_fn, batch_size=8, validation_file_path=test_path) keras_model.fit(dataset) predict_dataset = TFDataset.from_tfrecord(test_path, parse_fn=lambda x: (parse_fn(x)[0], ), batch_per_thread=1) result = keras_model.predict(predict_dataset) result.collect()
class TFKerasWrapper(Estimator): def __init__(self, keras_model, model_dir): self.model = KerasModel(keras_model, model_dir) def fit(self, data, epochs=1, batch_size=32, feature_cols=None, labels_cols=None, validation_data=None, hard_code_batch_size=False, session_config=None ): if isinstance(data, DataFrame): assert feature_cols is not None, \ "feature columns is None; it should not be None in training" assert labels_cols is not None, \ "label columns is None; it should not be None in training" dataset = to_dataset(data, batch_size=batch_size, batch_per_thread=-1, validation_data=validation_data, feature_cols=feature_cols, labels_cols=labels_cols, hard_code_batch_size=hard_code_batch_size, sequential_order=False, shuffle=True ) self.model.fit(dataset, batch_size=batch_size, epochs=epochs, session_config=session_config ) return self def predict(self, data, batch_size=4, feature_cols=None, hard_code_batch_size=False ): if isinstance(data, DataFrame): assert feature_cols is not None, \ "feature columns is None; it should not be None in prediction" dataset = to_dataset(data, batch_size=-1, batch_per_thread=batch_size, validation_data=None, feature_cols=feature_cols, labels_cols=None, hard_code_batch_size=hard_code_batch_size, sequential_order=True, shuffle=False ) predicted_rdd = self.model.predict(dataset, batch_size) if isinstance(data, DataFrame): return convert_predict_to_dataframe(data, predicted_rdd) else: return predicted_rdd def evaluate(self, data, batch_size=4, feature_cols=None, labels_cols=None, hard_code_batch_size=False ): if isinstance(data, DataFrame): assert feature_cols is not None, \ "feature columns is None; it should not be None in evaluation" assert labels_cols is not None, \ "label columns is None; it should not be None in evaluation" dataset = to_dataset(data, batch_size=-1, batch_per_thread=batch_size, validation_data=None, feature_cols=feature_cols, labels_cols=labels_cols, hard_code_batch_size=hard_code_batch_size, sequential_order=True, shuffle=False ) return self.model.evaluate(dataset, batch_per_thread=batch_size)