Ejemplo n.º 1
0
def model_fn(features, labels, mode, params):
    """Mobilenet v1 model using Estimator API."""
    num_classes = FLAGS.num_classes
    training_active = (mode == tf.estimator.ModeKeys.TRAIN)
    eval_active = (mode == tf.estimator.ModeKeys.EVAL)

    features = tensor_transform_fn(features, params['input_perm'])

    with bfloat16.bfloat16_scope():
        if FLAGS.clear_update_collections:
            # updates_collections must be set to None in order to use fused batchnorm
            with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
                logits, end_points = mobilenet_v1.mobilenet_v1(
                    features,
                    num_classes,
                    is_training=training_active,
                    depth_multiplier=FLAGS.depth_multiplier)
        else:
            with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
                logits, end_points = mobilenet_v1.mobilenet_v1(
                    features,
                    num_classes,
                    is_training=training_active,
                    depth_multiplier=FLAGS.depth_multiplier)

        logits = tf.cast(logits, tf.float32)
        for k in end_points.keys():
            end_points[k] = tf.cast(end_points[k], tf.float32)

    predictions = {
        'classes': tf.argmax(input=logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and (
            not FLAGS.use_tpu):
        with tf.control_dependencies([
                tf.Print(predictions['classes'], [predictions['classes']],
                         summarize=FLAGS.eval_batch_size,
                         message='prediction: ')
        ]):
            labels = tf.Print(labels, [labels],
                              summarize=FLAGS.eval_batch_size,
                              message='label: ')

    one_hot_labels = tf.one_hot(labels, FLAGS.num_classes, dtype=tf.int32)

    loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels,
                                           logits=logits,
                                           weights=1.0,
                                           label_smoothing=0.1)
    #loss = tf.losses.get_total_loss(add_regularization_losses=True)
    loss += WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    initial_learning_rate = FLAGS.learning_rate * FLAGS.train_batch_size / 256
    final_learning_rate = 0.0001 * initial_learning_rate

    train_op = None
    if training_active:
        batches_per_epoch = _NUM_TRAIN_IMAGES // FLAGS.train_batch_size
        global_step = tf.train.get_or_create_global_step()

        learning_rate = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=global_step,
            decay_steps=FLAGS.learning_rate_decay_epochs * batches_per_epoch,
            decay_rate=FLAGS.learning_rate_decay,
            staircase=True)

        # Set a minimum boundary for the learning rate.
        learning_rate = tf.maximum(learning_rate,
                                   final_learning_rate,
                                   name='learning_rate')

        if FLAGS.optimizer == 'sgd':
            tf.logging.info('Using SGD optimizer')
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        elif FLAGS.optimizer == 'momentum':
            tf.logging.info('Using Momentum optimizer')
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.9)
        elif FLAGS.optimizer == 'RMS':
            tf.logging.info('Using RMS optimizer')
            optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                                  RMSPROP_DECAY,
                                                  momentum=RMSPROP_MOMENTUM,
                                                  epsilon=RMSPROP_EPSILON)
        else:
            tf.logging.fatal('Unknown optimizer:', FLAGS.optimizer)

        if FLAGS.use_tpu:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step=global_step)
        if FLAGS.moving_average:
            ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY,
                                                    num_updates=global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            with tf.control_dependencies([train_op
                                          ]), tf.name_scope('moving_average'):
                train_op = ema.apply(variables_to_average)

    eval_metrics = None
    if eval_active:

        def metric_fn(labels, predictions):
            accuracy = tf.metrics.accuracy(
                labels, tf.argmax(input=predictions, axis=1))
            return {'accuracy': accuracy}

        if FLAGS.use_logits:
            eval_predictions = logits
        else:
            eval_predictions = end_points['Predictions']

        eval_metrics = (metric_fn, [labels, eval_predictions])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metrics=eval_metrics)
Ejemplo n.º 2
0
def model_fn(features, labels, mode, params):
    """Mobilenet v1 model using Estimator API."""
    num_classes = params['num_classes']
    training_active = (mode == tf.estimator.ModeKeys.TRAIN)
    eval_active = (mode == tf.estimator.ModeKeys.EVAL)

    if isinstance(features, dict):
        features = features['feature']

    features = supervised_images.tensor_transform_fn(features,
                                                     params['input_perm'])

    if params['clear_update_collections']:
        # updates_collections must be set to None in order to use fused batchnorm
        with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
            logits, end_points = mobilenet_v1.mobilenet_v1(
                features,
                num_classes,
                is_training=training_active,
                depth_multiplier=params['depth_multiplier'])
    else:
        with arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
            logits, end_points = mobilenet_v1.mobilenet_v1(
                features,
                num_classes,
                is_training=training_active,
                depth_multiplier=params['depth_multiplier'])

    predictions = {
        'classes': tf.argmax(input=logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    if mode == tf.estimator.ModeKeys.EVAL and FLAGS.display_tensors and (
            not params['use_tpu']):
        with tf.control_dependencies([
                tf.Print(predictions['classes'], [predictions['classes']],
                         summarize=params['eval_batch_size'],
                         message='prediction: ')
        ]):
            labels = tf.Print(labels, [labels],
                              summarize=params['eval_batch_size'],
                              message='label: ')

    one_hot_labels = tf.one_hot(labels, params['num_classes'], dtype=tf.int32)

    tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels,
                                    logits=logits,
                                    weights=1.0,
                                    label_smoothing=0.1)
    loss = tf.losses.get_total_loss(add_regularization_losses=True)

    initial_learning_rate = params['learning_rate'] * params['train_batch_size'] / 256  # pylint: disable=line-too-long
    final_learning_rate = 0.0001 * initial_learning_rate

    train_op = None
    if training_active:
        batches_per_epoch = params['num_train_images'] // params[
            'train_batch_size']
        global_step = tf.train.get_or_create_global_step()

        learning_rate = tf.train.exponential_decay(
            learning_rate=initial_learning_rate,
            global_step=global_step,
            decay_steps=params['learning_rate_decay_epochs'] *
            batches_per_epoch,
            decay_rate=params['learning_rate_decay'],
            staircase=True)

        # Set a minimum boundary for the learning rate.
        learning_rate = tf.maximum(learning_rate,
                                   final_learning_rate,
                                   name='learning_rate')

        if params['optimizer'] == 'sgd':
            tf.logging.info('Using SGD optimizer')
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        elif params['optimizer'] == 'momentum':
            tf.logging.info('Using Momentum optimizer')
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.9)
        elif params['optimizer'] == 'RMS':
            tf.logging.info('Using RMS optimizer')
            optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                                  RMSPROP_DECAY,
                                                  momentum=RMSPROP_MOMENTUM,
                                                  epsilon=RMSPROP_EPSILON)
        else:
            tf.logging.fatal('Unknown optimizer:', params['optimizer'])

        if params['use_tpu']:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step=global_step)
        if params['moving_average']:
            ema = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY,
                                                    num_updates=global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            with tf.control_dependencies([train_op
                                          ]), tf.name_scope('moving_average'):
                train_op = ema.apply(variables_to_average)

    eval_metrics = None
    if eval_active:

        def metric_fn(labels, predictions):
            accuracy = tf.metrics.accuracy(
                labels, tf.argmax(input=predictions, axis=1))
            return {'accuracy': accuracy}

        if params['use_logits']:
            eval_predictions = logits
        else:
            eval_predictions = end_points['Predictions']

        eval_metrics = (metric_fn, [labels, eval_predictions])

    return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                           loss=loss,
                                           train_op=train_op,
                                           eval_metrics=eval_metrics)