Example #1
0
def model_fn(features, labels, mode, params):
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    NUM_CLASSES = len(params['label_vocab'])

    module = hub.Module(TFHUB_CACHE_DIR, trainable=is_training and params['train_module'], name=params['module_name'])
    bottleneck_tensor = module(features['inputs'])

    with tf.name_scope('final_retrain_ops'):
        logits = tf.layers.dense(bottleneck_tensor, units=1, trainable=is_training)

    def train_op_fn(loss):
        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        return optimizer.minimize(loss, global_step=tf.train.get_global_step())

    if NUM_CLASSES == 2:
        head = tf.contrib.estimator.binary_classification_head(label_vocabulary=params['label_vocab'])
    else:
        head = tf.contrib.estimator.multi_class_head(n_classes=NUM_CLASSES, label_vocabulary=params['label_vocab'])

    spec =  head.create_estimator_spec(
        features, mode, logits, labels, train_op_fn=train_op_fn
    )
    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.summary.scalar('accuracy', metrics_lib.accuracy(labels, spec.predictions['classes'])[1])
        logging_hook = logger_hook({"loss": spec.loss,"accuracy":
            metrics_lib.accuracy(labels, spec.predictions['classes'])[1], 
            "step" : tf.train.get_or_create_global_step(), "steps_epoch": steps_epoch, "mode":"train"}, every_n_iter=summary_interval)
        spec = spec._replace(training_hooks = [logging_hook])
    if mode == tf.estimator.ModeKeys.EVAL:
        logging_hook = logger_hook({"loss": spec.loss, "accuracy":
            spec.eval_metric_ops['accuracy'][1], "step" : 
            tf.train.get_or_create_global_step(), "steps_epoch": steps_epoch, "mode": "eval"}, every_n_iter=summary_interval)
        spec = spec._replace(evaluation_hooks = [logging_hook])
    return spec
Example #2
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    model = Model(params['data_format'])
    image = features
    if isinstance(image, dict):
        image = features['image']

    if mode == tf.estimator.ModeKeys.PREDICT:
        logits = model(image, training=False)
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        logits = model(image, training=True)
        loss = tf.losses.softmax_cross_entropy(onehot_labels=labels,
                                               logits=logits)
        accuracy = tf.metrics.accuracy(labels=tf.argmax(labels, axis=1),
                                       predictions=tf.argmax(logits, axis=1))
        # Name the accuracy tensor 'train_accuracy' to demonstrate the
        # LoggingTensorHook.
        tf.identity(accuracy[1], name='train_accuracy')
        tf.summary.scalar('train_accuracy', accuracy[1])
        logging_hook = logger_hook(
            {
                "loss": loss,
                "accuracy": accuracy[1],
                "step": tf.train.get_or_create_global_step(),
                "steps_epoch": steps_epoch,
                "mode": "train"
            },
            every_n_iter=summary_interval)
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=loss,
            train_op=optimizer.minimize(loss,
                                        tf.train.get_or_create_global_step()),
            training_hooks=[logging_hook])
    if mode == tf.estimator.ModeKeys.EVAL:
        logits = model(image, training=False)
        loss = tf.losses.softmax_cross_entropy(onehot_labels=labels,
                                               logits=logits)
        accuracy = tf.metrics.accuracy(labels=tf.argmax(labels, axis=1),
                                       predictions=tf.argmax(logits, axis=1))
        logging_hook = logger_hook(
            {
                "loss": loss,
                "accuracy": accuracy[1],
                "step": tf.train.get_or_create_global_step(),
                "steps_epoch": steps_epoch,
                "mode": "eval"
            },
            every_n_iter=summary_interval)
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=loss,
            eval_metric_ops={
                'accuracy':
                tf.metrics.accuracy(labels=tf.argmax(labels, axis=1),
                                    predictions=tf.argmax(logits, axis=1)),
            },
            evaluation_hooks=[logging_hook])