Exemplo n.º 1
0
def train(fine_tuning):
    """Build a Keras model and train with mock data."""
    features = np.array(["my first sentence", "my second sentence"])
    labels = np.array([1, 0])
    dataset = tf.data.Dataset.from_tensor_slices((features, labels))

    module = tf.saved_model.load(FLAGS.model_dir)

    # Create the sequential keras model.
    l = tf.keras.layers
    model = tf.keras.Sequential()
    model.add(l.Reshape((), batch_input_shape=[None, 1], dtype=tf.string))
    model.add(
        util.CustomLayer(module, output_shape=[10], trainable=fine_tuning))
    model.add(l.Dense(100, activation="relu"))
    model.add(l.Dense(50, activation="relu"))
    model.add(l.Dense(1, activation="sigmoid"))

    model.compile(
        optimizer="adam",
        loss="binary_crossentropy",
        metrics=["accuracy"],
        # TODO(b/124446120): Remove after fixed.
        run_eagerly=True)

    model.fit_generator(generator=dataset.batch(1), epochs=5)
def main(argv):
    del argv

    features = np.array(["my first sentence", "my second sentence"])
    labels = np.array([1, 0])

    dataset = tf.data.Dataset.from_tensor_slices((features, labels))

    embed = tf.saved_model.load(FLAGS.model_dir)

    # Create the sequential keras model.
    model = tf.keras.Sequential()
    model.add(
        util.CustomLayer(embed,
                         batch_input_shape=[None],
                         output_shape=[10],
                         dtype=tf.string))
    model.add(tf.keras.layers.Dense(100, activation="relu"))
    model.add(tf.keras.layers.Dense(50, activation="relu"))
    model.add(tf.keras.layers.Dense(1, activation="sigmoid"))
    model.compile(optimizer="adam",
                  loss="binary_crossentropy",
                  metrics=["accuracy"])

    model.fit_generator(generator=dataset.batch(1), epochs=5)
Exemplo n.º 3
0
def main(argv):
    del argv

    # Load a pre-trained feature extractor and wrap it for use in Keras.
    obj = tf.saved_model.load(FLAGS.export_dir)
    scale_regularization_losses(obj, FLAGS.regularization_loss_multiplier)
    arguments = {}
    if FLAGS.dropout_rate is not None:
        arguments['dropout_rate'] = FLAGS.dropout_rate
    feature_extractor = util.CustomLayer(obj,
                                         output_shape=[10],
                                         trainable=FLAGS.retrain,
                                         arguments=arguments)

    # Build a classifier with it.
    model = make_classifier(feature_extractor)

    # Train the classifier (possibly on a different dataset).
    (x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
        use_fashion_mnist=FLAGS.use_fashion_mnist,
        fake_tiny_data=FLAGS.fast_test_mode)
    model.compile(loss=tf.keras.losses.categorical_crossentropy,
                  optimizer=tf.keras.optimizers.SGD(),
                  metrics=['accuracy'])
    print('Training on %s with %d trainable and %d untrainable variables.' %
          ('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST',
           len(model.trainable_variables), len(model.non_trainable_variables)))
    model.fit(x_train,
              y_train,
              batch_size=128,
              epochs=FLAGS.epochs,
              verbose=1,
              validation_data=(x_test, y_test))
Exemplo n.º 4
0
def main(argv):
    del argv

    # Load a pre-trained feature extractor and wrap it for use in Keras.
    obj = tf.saved_model.load(FLAGS.export_dir)
    feature_extractor = util.CustomLayer(obj,
                                         output_shape=[128],
                                         trainable=FLAGS.retrain)

    # Build a classifier with it.
    model = make_classifier(feature_extractor)

    # Train the classifier (possibly on a different dataset).
    (x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
        use_fashion_mnist=FLAGS.use_fashion_mnist,
        fake_tiny_data=FLAGS.fast_test_mode)
    model.compile(
        loss=tf.keras.losses.categorical_crossentropy,
        optimizer=tf.keras.optimizers.SGD(),
        metrics=['accuracy'],
        # TODO(arnoegw): Remove after investigating huge allocs.
        run_eagerly=True)
    print('Training on %s with %d trainable and %d untrainable variables.' %
          ('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST',
           len(model.trainable_variables), len(model.non_trainable_variables)))
    model.fit(x_train,
              y_train,
              batch_size=128,
              epochs=FLAGS.epochs,
              steps_per_epoch=3,
              verbose=1,
              validation_data=(x_test, y_test))