def model_fn(features, labels, mode, params): is_training = mode == tf.estimator.ModeKeys.TRAIN NUM_CLASSES = len(params['label_vocab']) module = hub.Module(TFHUB_CACHE_DIR, trainable=is_training and params['train_module'], name=params['module_name']) bottleneck_tensor = module(features['inputs']) with tf.name_scope('final_retrain_ops'): logits = tf.layers.dense(bottleneck_tensor, units=1, trainable=is_training) def train_op_fn(loss): optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) return optimizer.minimize(loss, global_step=tf.train.get_global_step()) if NUM_CLASSES == 2: head = tf.contrib.estimator.binary_classification_head(label_vocabulary=params['label_vocab']) else: head = tf.contrib.estimator.multi_class_head(n_classes=NUM_CLASSES, label_vocabulary=params['label_vocab']) spec = head.create_estimator_spec( features, mode, logits, labels, train_op_fn=train_op_fn ) if mode == tf.estimator.ModeKeys.TRAIN: tf.summary.scalar('accuracy', metrics_lib.accuracy(labels, spec.predictions['classes'])[1]) logging_hook = logger_hook({"loss": spec.loss,"accuracy": metrics_lib.accuracy(labels, spec.predictions['classes'])[1], "step" : tf.train.get_or_create_global_step(), "steps_epoch": steps_epoch, "mode":"train"}, every_n_iter=summary_interval) spec = spec._replace(training_hooks = [logging_hook]) if mode == tf.estimator.ModeKeys.EVAL: logging_hook = logger_hook({"loss": spec.loss, "accuracy": spec.eval_metric_ops['accuracy'][1], "step" : tf.train.get_or_create_global_step(), "steps_epoch": steps_epoch, "mode": "eval"}, every_n_iter=summary_interval) spec = spec._replace(evaluation_hooks = [logging_hook]) return spec
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" model = Model(params['data_format']) image = features if isinstance(image, dict): image = features['image'] if mode == tf.estimator.ModeKeys.PREDICT: logits = model(image, training=False) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, export_outputs={ 'classify': tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) logits = model(image, training=True) loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) accuracy = tf.metrics.accuracy(labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1)) # Name the accuracy tensor 'train_accuracy' to demonstrate the # LoggingTensorHook. tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) logging_hook = logger_hook( { "loss": loss, "accuracy": accuracy[1], "step": tf.train.get_or_create_global_step(), "steps_epoch": steps_epoch, "mode": "train" }, every_n_iter=summary_interval) return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()), training_hooks=[logging_hook]) if mode == tf.estimator.ModeKeys.EVAL: logits = model(image, training=False) loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) accuracy = tf.metrics.accuracy(labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1)) logging_hook = logger_hook( { "loss": loss, "accuracy": accuracy[1], "step": tf.train.get_or_create_global_step(), "steps_epoch": steps_epoch, "mode": "eval" }, every_n_iter=summary_interval) return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops={ 'accuracy': tf.metrics.accuracy(labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1)), }, evaluation_hooks=[logging_hook])