Esempio n. 1
0
    def test_tfdataset_with_tfrecord(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
            tf.keras.layers.Dense(10, activation='softmax'),
        ])

        model.compile(optimizer='rmsprop',
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])

        keras_model = KerasModel(model)

        def parse_fn(example):
            keys_to_features = {
                'image/encoded':
                tf.FixedLenFeature((), tf.string, default_value=''),
                'image/format':
                tf.FixedLenFeature((), tf.string, default_value='raw'),
                'image/class/label':
                tf.FixedLenFeature([1],
                                   tf.int64,
                                   default_value=tf.zeros([1],
                                                          dtype=tf.int64)),
            }

            items_to_handlers = {
                'image':
                tf.contrib.slim.tfexample_decoder.Image(shape=[28, 28, 1],
                                                        channels=1),
                'label':
                tf.contrib.slim.tfexample_decoder.Tensor('image/class/label',
                                                         shape=[]),
            }

            decoder = tf.contrib.slim.tfexample_decoder.TFExampleDecoder(
                keys_to_features, items_to_handlers)
            results = decoder.decode(example)

            if len(results[0].shape) > 0:
                feature = results[0]
                label = results[1]
            else:
                feature = results[1]
                label = results[0]

            return feature, label

        train_path = os.path.join(resource_path,
                                  "tfrecord/mnist_train.tfrecord")
        test_path = os.path.join(resource_path, "tfrecord/mnist_test.tfrecord")
        dataset = TFDataset.from_tfrecord(train_path,
                                          parse_fn=parse_fn,
                                          batch_size=8,
                                          validation_file_path=test_path)

        keras_model.fit(dataset)

        predict_dataset = TFDataset.from_tfrecord(test_path,
                                                  parse_fn=lambda x:
                                                  (parse_fn(x)[0], ),
                                                  batch_per_thread=1)
        result = keras_model.predict(predict_dataset)
        result.collect()
Esempio n. 2
0
class TFKerasWrapper(Estimator):

    def __init__(self, keras_model, model_dir):
        self.model = KerasModel(keras_model, model_dir)

    def fit(self, data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            labels_cols=None,
            validation_data=None,
            hard_code_batch_size=False,
            session_config=None
            ):

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in training"
            assert labels_cols is not None, \
                "label columns is None; it should not be None in training"

        dataset = to_dataset(data, batch_size=batch_size, batch_per_thread=-1,
                             validation_data=validation_data,
                             feature_cols=feature_cols, labels_cols=labels_cols,
                             hard_code_batch_size=hard_code_batch_size,
                             sequential_order=False, shuffle=True
                             )

        self.model.fit(dataset, batch_size=batch_size, epochs=epochs,
                       session_config=session_config
                       )
        return self

    def predict(self, data, batch_size=4,
                feature_cols=None,
                hard_code_batch_size=False
                ):

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in prediction"

        dataset = to_dataset(data, batch_size=-1, batch_per_thread=batch_size,
                             validation_data=None,
                             feature_cols=feature_cols, labels_cols=None,
                             hard_code_batch_size=hard_code_batch_size,
                             sequential_order=True, shuffle=False
                             )

        predicted_rdd = self.model.predict(dataset, batch_size)
        if isinstance(data, DataFrame):
            return convert_predict_to_dataframe(data, predicted_rdd)
        else:
            return predicted_rdd

    def evaluate(self, data, batch_size=4,
                 feature_cols=None,
                 labels_cols=None,
                 hard_code_batch_size=False
                 ):

        if isinstance(data, DataFrame):
            assert feature_cols is not None, \
                "feature columns is None; it should not be None in evaluation"
            assert labels_cols is not None, \
                "label columns is None; it should not be None in evaluation"

        dataset = to_dataset(data, batch_size=-1, batch_per_thread=batch_size,
                             validation_data=None,
                             feature_cols=feature_cols, labels_cols=labels_cols,
                             hard_code_batch_size=hard_code_batch_size,
                             sequential_order=True, shuffle=False
                             )

        return self.model.evaluate(dataset, batch_per_thread=batch_size)