def model_fn(): """Constructs keras model.""" keras_model = tf.keras.models.Sequential([ tf.keras.layers.Dense(1, activation=tf.nn.softmax, kernel_initializer='zeros', input_shape=(2, )) ]) def loss_fn(y_true, y_pred): return tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( y_true, y_pred)) keras_model.compile( loss=loss_fn, optimizer=tf.keras.optimizers.SGD(learning_rate=0.01), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return learning.from_compiled_keras_model(keras_model, sample_batch)
def model_fn(): keras_model = create_compiled_keras_model() return learning.from_compiled_keras_model(keras_model, sample_batch)