Esempio n. 1
0
def model_fn(features, labels, mode, params):
    output_size = params['output_size']
    net = features

    if FLAGS.data_type == 'float32':
        network = resnet_model.resnet_v1(resnet_layers,
                                         block_fn,
                                         num_classes=output_size,
                                         data_format='channels_last',
                                         filters=filters)

        net = network(inputs=features, is_training=True)
    else:
        with tf.variable_scope('cg', custom_getter=get_custom_getter()):
            network = resnet_model.resnet_v1(resnet_layers,
                                             block_fn,
                                             num_classes=output_size,
                                             data_format='channels_last',
                                             filters=filters)

            net = network(inputs=features, is_training=True)
            net = tf.cast(net, tf.float32)

    onehot_labels = tf.one_hot(labels, output_size)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels,
                                           logits=net)

    learning_rate = tf.train.exponential_decay(0.1, tf.train.get_global_step(),
                                               25000, 0.97)
    if opt == 'sgd':
        tf.logging.info('Using SGD optimizer')
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
    elif opt == 'momentum':
        tf.logging.info('Using Momentum optimizer')
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=0.9)
    elif opt == 'rms':
        tf.logging.info('Using RMS optimizer')
        optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                              RMSPROP_DECAY,
                                              momentum=RMSPROP_MOMENTUM,
                                              epsilon=RMSPROP_EPSILON)
    if FLAGS.use_tpu:
        optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

    param_stats = tf.profiler.profile(
        tf.get_default_graph(),
        options=ProfileOptionBuilder.trainable_variables_parameter())
    fl_stats = tf.profiler.profile(
        tf.get_default_graph(),
        options=tf.profiler.ProfileOptionBuilder.float_operation())

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op)
Esempio n. 2
0
 def build_network():
     network = resnet_model.resnet_v1(resnet_depth=50,
                                      num_classes=1000,
                                      dropblock_size=None,
                                      dropblock_keep_probs=[None] * 4,
                                      data_format='channels_last')
     return network(inputs=images, is_training=False)
Esempio n. 3
0
 def build_network():
   network = resnet_model.resnet_v1(
       resnet_depth=FLAGS.resnet_depth,
       num_classes=FLAGS.num_label_classes,
       data_format=FLAGS.data_format)
   return network(
       inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
Esempio n. 4
0
 def build_network(l_features):
   network = resnet_model.resnet_v1(
       resnet_depth=FLAGS.resnet_depth,
       num_classes=FLAGS.num_label_classes,
       dropblock_size=FLAGS.dropblock_size,
       dropblock_keep_probs=dropblock_keep_probs,
       data_format=FLAGS.data_format)
   return network(inputs=l_features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
Esempio n. 5
0
    def test_load_resnet18_v1(self):
        network = resnet_model.resnet_v1(resnet_depth=18,
                                         num_classes=10,
                                         data_format='channels_last')
        input_bhw3 = tf.placeholder(tf.float32, [1, 28, 28, 3])
        resnet_output = network(inputs=input_bhw3, train=True)

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        _ = sess.run(resnet_output,
                     feed_dict={input_bhw3: np.random.randn(1, 28, 28, 3)})
 def create_model():
     """Create the model and compute the logits."""
     if FLAGS.use_keras_model:
         model = tf.keras.applications.resnet50.ResNet50(
             include_top=True,
             weights=None,
             input_tensor=None,
             input_shape=None,
             pooling=None,
             classes=_NUM_CLASSES)
         return model(features, training=is_training)
     else:
         model = resnet_model.resnet_v1(resnet_depth=_RESNET_DEPTH,
                                        num_classes=_NUM_CLASSES,
                                        data_format='channels_last')
         return model(inputs=features, is_training=is_training)
Esempio n. 7
0
def build_network(features, mode, params):
    """ Build ResNet50 Model

    Args:
        features:
        mode:
        params:

    Returns:
        Model function
    """
    network = resnet_v1(
        resnet_depth=50,
        num_classes=params["classes"],
        data_format=params["data_format"],
    )
    return network(inputs=features,
                   is_training=(mode == tf.estimator.ModeKeys.TRAIN))
Esempio n. 8
0
def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU/TPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC.
    if FLAGS.data_format == 'channels_first':
        features = tf.transpose(features, [0, 3, 1, 2])

    network = resnet_model.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                                     num_classes=LABEL_CLASSES,
                                     data_format=FLAGS.data_format)

    logits = network(inputs=features,
                     is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, LABEL_CLASSES)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=one_hot_labels)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate from global_step.
        global_step = tf.train.get_global_step()
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
        learning_rate = learning_rate_schedule(current_epoch)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=MOMENTUM,
                                               use_nesterov=True)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, loss, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                with summary.create_file_writer(FLAGS.model_dir).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)

                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          host_call=host_call,
                                          eval_metrics=eval_metrics)
Esempio n. 9
0
def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        features = features['feature']

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if FLAGS.transpose_input:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    if FLAGS.use_tpu:
        import bfloat16
        scope_fn = lambda: bfloat16.bfloat16_scope()
    else:
        scope_fn = lambda: tf.variable_scope("")

    with scope_fn():
        resnet_size = int(FLAGS.resnet_depth.split("_")[-1])
        if FLAGS.resnet_depth.startswith("v1_"):
            print("\n\n\n\n\nUSING RESNET V1 {}\n\n\n\n\n".format(
                FLAGS.resnet_depth))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention=None,
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("paper-v1_"):
            print("\n\n\n\n\nUSING RESNET V1 (Paper) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention="paper",
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("fc-v1_"):
            print("\n\n\n\n\nUSING RESNET V1 (fc) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention="fc",
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("v2_"):
            print("\n\n\n\n\nUSING RESNET V2 {}\n\n\n\n\n".format(resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention=False,
                                                extra_convs=0,
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        elif FLAGS.resnet_depth.startswith("paper-v2_"):
            print("\n\n\n\n\nUSING RESNET V2 (Paper) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention="paper",
                                                extra_convs=0,
                                                apply_to="output",
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        elif FLAGS.resnet_depth.startswith("fc-v2_"):
            print("\n\n\n\n\nUSING RESNET V2 (fc) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention="fc",
                                                extra_convs=1,
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        else:
            assert False

        logits = network(inputs=features,
                         is_training=(mode == tf.estimator.ModeKeys.TRAIN))
        logits = tf.cast(logits, tf.float32)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, LABEL_CLASSES)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=one_hot_labels)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    # with tf.device("/cpu:0"):
    #   loss = tf.Print(loss, [loss], "loss", summarize=20)

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate from global_step.
        global_step = tf.train.get_global_step()
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
        learning_rate = learning_rate_schedule(current_epoch)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=MOMENTUM,
                                               use_nesterov=True)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):

            if FLAGS.clip_gradients == 0:
                print("\nnot clipping gradients\n")
                train_op = optimizer.minimize(loss, global_step)
            else:
                print("\nclipping gradients\n")
                gradients, variables = zip(*optimizer.compute_gradients(loss))
                gradients, _ = tf.clip_by_global_norm(gradients,
                                                      FLAGS.clip_gradients)
                train_op = optimizer.apply_gradients(zip(gradients, variables),
                                                     global_step=global_step)

            # gvs = optimizer.compute_gradients(loss)
            # gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
            # capped_gvs = [(tf.clip_by_value(grad, -10., 10.), var) for grad, var in gvs]
            # train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, loss, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                with summary.create_file_writer(FLAGS.model_dir).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)

                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

    # logging_hook = tf.train.LoggingTensorHook(
    #   {"logging_hook_loss": loss}, every_n_iter=1)

    return tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        host_call=host_call,
        eval_metrics=eval_metrics,
        # training_hooks=[logging_hook]
    )
Esempio n. 10
0
def model_fn(features, labels, mode):
  """Definition for ResNet model."""
  is_training = mode == tf.estimator.ModeKeys.TRAIN

  features = tf.transpose(features, [3, 0, 1, 2])  # Double-transpose trick

  # Normalize the image to zero mean and unit variance.
  features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
  features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

  with tf.contrib.tpu.bfloat16_scope():
    network = resnet_model.resnet_v1(
        resnet_depth=_RESNET_DEPTH,
        num_classes=_NUM_CLASSES,
        data_format='channels_last')
    logits = network(inputs=features, is_training=is_training)
  logits = tf.cast(logits, tf.float32)

  if mode == tf.estimator.ModeKeys.PREDICT:
    assert False, 'Not implemented correctly right now!'
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(mode, predictions=predictions)

  cross_entropy = tf.losses.sparse_softmax_cross_entropy(
      labels=labels, logits=logits)

  loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
      [tf.nn.l2_loss(v) for v in tf.trainable_variables()
       if 'batch_normalization' not in v.name])

  if mode == tf.estimator.ModeKeys.EVAL:
    predictions = tf.argmax(logits, axis=1)
    top_1_accuracy = tf.metrics.accuracy(labels, predictions)
    # TODO(priyag): Add this back when in_top_k is supported on TPU.
    # in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
    # top_5_accuracy = tf.metrics.mean(in_top_5)

    eval_metric_ops = {
        'top_1_accuracy': top_1_accuracy,
        # 'top_5_accuracy': top_5_accuracy,
    }

    return tf.estimator.EstimatorSpec(
        mode, loss=loss, eval_metric_ops=eval_metric_ops)

  assert mode == tf.estimator.ModeKeys.TRAIN

  global_step = tf.train.get_or_create_global_step()
  batches_per_epoch = (_NUM_TRAIN_IMAGES /
                       (FLAGS.train_batch_size * FLAGS.num_cores))
  current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
  learning_rate = learning_rate_schedule(current_epoch)

  optimizer = tf.train.MomentumOptimizer(
      learning_rate=learning_rate,
      momentum=_MOMENTUM,
      use_nesterov=True)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss, global_step=global_step)
  return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
Esempio n. 11
0
 def resnet_network():
     network = resnet_model.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                                      data_format=FLAGS.data_format)
     return network(inputs=feature_image,
                    is_training=(mode == tf.estimator.ModeKeys.TRAIN))
Esempio n. 12
0
def build_network(features, mode, params):
    network = resnet_v1(resnet_depth=50,
                        num_classes=params["classes"],
                        data_format="channels_first")
    return network(inputs=features,
                   is_training=(mode == tf.estimator.ModeKeys.TRAIN))
Esempio n. 13
0
if load_checkpoint:
  model = load_model(filepath)
else:
  # build the graph
  if version == 2:
      model = resnet_v2(input_shape=input_shape,
                        depth=depth,
                        activation_bits=activation_bits,
                        weight_noise=weight_noise,
                        trainable_conv=not finetune,
                        trainable_dense=True)
  else:
      model = resnet_v1(input_shape=input_shape,
                        depth=depth,
                        activation_bits=activation_bits,
                        weight_noise=weight_noise,
                        trainable_conv=not finetune,
                        trainable_dense=True)

  model.compile(loss='categorical_crossentropy',
                optimizer=Adam(learning_rate=lr_schedule(0)),
                metrics=['accuracy'])

model.summary()
print(model_type)

if finetune:
    weights_path = os.path.join(os.getcwd(), finetune_ckpt_path)
    latest = tf.train.latest_checkpoint(weights_path)
    print(latest)
    model.load_weights(latest)
Esempio n. 14
0
def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        images = features['image']
        hms = features['hm']
        bboxs = features['bbox']
        ccount = features['ccount']
    else:
        images = features
        hms = None
        bboxs = None

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        images = tf.transpose(images, [0, 3, 1, 2])
        if hms is not None:
            hms = tf.transpose(hms, [0, 3, 1, 2])
            bboxs = tf.transpose(bboxs, [0, 3, 1, 2])

    if FLAGS.transpose_input:
        images = tf.transpose(images, [3, 0, 1, 2])  # HWCN to NHWC
        if hms is not None:
            hms = tf.transpose(hms, [3, 0, 1, 2])
            bboxs = tf.transpose(bboxs, [3, 0, 1, 2])

    if FLAGS.use_tpu:
        import bfloat16
        scope_fn = lambda: bfloat16.bfloat16_scope()
    else:
        scope_fn = lambda: tf.variable_scope("")

    with scope_fn():
        resnet_size = int(FLAGS.resnet_depth.split("_")[-1])
        if FLAGS.resnet_depth.startswith("v1_"):
            print("\n\n\n\n\nUSING RESNET V1 {}\n\n\n\n\n".format(
                FLAGS.resnet_depth))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention=None,
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("SE-v1_"):
            print(
                "\n\n\n\n\nUSING RESNET V1 (Squeeze-and-excite) {}\n\n\n\n\n".
                format(resnet_size))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention="se",
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("GALA-v1_"):
            print("\n\n\n\n\nUSING RESNET V1 (GALA) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_model.resnet_v1(resnet_depth=int(resnet_size),
                                             num_classes=LABEL_CLASSES,
                                             attention="gala",
                                             apply_to="outputs",
                                             use_tpu=FLAGS.use_tpu,
                                             data_format=FLAGS.data_format)
        elif FLAGS.resnet_depth.startswith("v2_"):
            print("\n\n\n\n\nUSING RESNET V2 {}\n\n\n\n\n".format(resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention=False,
                                                extra_convs=0,
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        elif FLAGS.resnet_depth.startswith("SE-v2_"):
            print(
                "\n\n\n\n\nUSING RESNET V2 (Squeeze-and-excite) {}\n\n\n\n\n".
                format(resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention="se",
                                                extra_convs=0,
                                                apply_to="output",
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        elif FLAGS.resnet_depth.startswith("GALA-v2_"):
            print("\n\n\n\n\nUSING RESNET V2 (GALA) {}\n\n\n\n\n".format(
                resnet_size))
            network = resnet_v2_model.resnet_v2(resnet_size=resnet_size,
                                                num_classes=LABEL_CLASSES,
                                                feature_attention="gala",
                                                extra_convs=1,
                                                data_format=FLAGS.data_format,
                                                use_tpu=FLAGS.use_tpu)
        else:
            assert False

        logits, attention = network(
            inputs=images, is_training=(mode == tf.estimator.ModeKeys.TRAIN))
        logits = tf.cast(logits, tf.float32)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })
    batch_size = params['batch_size']

    # Calculate softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, LABEL_CLASSES)
    loss = tf.losses.softmax_cross_entropy(logits=logits,
                                           onehot_labels=one_hot_labels)

    # Switch hms/bboxs
    if FLAGS.annotation == 'hms':
        pass
    elif FLAGS.annotation == 'bboxs':
        hms = bboxs
    elif FLAGS.annotation == 'none':
        hms = None
    else:
        raise NotImplementedError(FLAGS.annotation)

    # Add attention losses
    if hms is not None:
        map_loss_list = []
        blur_click_maps = 49  # 0 = no, > 0 blur kernel
        blur_click_maps_sigma = 28  # 14

        # Blur the heatmaps
        hms = blur(hms,
                   kernel=blur_click_maps,
                   sigma=blur_click_maps_sigma,
                   dtype=images.dtype)

        mask = tf.cast(tf.greater(ccount, 0), tf.float32)
        mask = tf.reshape(mask, [int(hms.get_shape()[0]), 1, 1, 1])
        for layer in attention:
            layer_shape = [int(x) for x in layer.get_shape()[1:3]]
            layer = tf.cast(layer, tf.float32)
            hms = tf.cast(hms, tf.float32)
            resized_maps = tf.image.resize_bilinear(hms,
                                                    layer_shape,
                                                    align_corners=True)
            if layer.get_shape().as_list()[-1] > 1:
                layer = tf.reduce_mean(tf.pow(layer, 2),
                                       axis=-1,
                                       keep_dims=True)
            resized_maps = l2_channel_norm(resized_maps)
            layer = l2_channel_norm(layer)
            dist = resized_maps - layer
            d = tf.nn.l2_loss(dist * mask)
            map_loss_list += [d]
        denominator = len(attention)
        if len(map_loss_list):
            denominator = len(attention)
            map_loss = (tf.add_n(map_loss_list) / float(denominator)) * 1e-5
            loss += map_loss
        else:
            assert not FLAGS.resnet_depth.startswith(
                "GALA") or not FLAGS.resnet_depth.startswith(
                    "SE"), "Failed to apply attention."
    loss += (WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name and 'ATTENTION' not in v.name
        and 'block' not in v.name and 'training' not in v.name
    ]))
    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate from global_step.
        global_step = tf.train.get_global_step()
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
        learning_rate = learning_rate_schedule(current_epoch)
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=MOMENTUM,
                                               use_nesterov=True)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            if FLAGS.clip_gradients == 0:
                print("\nnot clipping gradients\n")
                train_op = optimizer.minimize(loss, global_step)
            else:
                print("\nclipping gradients\n")
                gradients, variables = zip(*optimizer.compute_gradients(loss))
                gradients, _ = tf.clip_by_global_norm(gradients,
                                                      FLAGS.clip_gradients)
                train_op = optimizer.apply_gradients(zip(gradients, variables),
                                                     global_step=global_step)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, loss, lr, ce):  # , hm=None, image=None):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                with summary.create_file_writer(FLAGS.model_dir).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)
                        # summary.image('image', hm, step=gs)
                        # summary.image('heatmap', image, step=gs)
                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])
            # im_t = tf.cast(images, tf.float32)
            # hm_t = tf.cast(hms, tf.float32)
            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])
            # host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t, im_t, hm_t])

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])
    # logging_hook = tf.train.LoggingTensorHook(
    #   {"logging_hook_loss": loss}, every_n_iter=1)

    return tpu_estimator.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        host_call=host_call,
        eval_metrics=eval_metrics,
        # training_hooks=[logging_hook]
    )
Esempio n. 15
0
def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
        `params['batch_size']` is always provided and should be used as the
        effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU/TPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC.
    if FLAGS.data_format == 'channels_first':
        features = tf.transpose(features, [0, 3, 1, 2])

    with tf.variable_scope('cg', custom_getter=get_custom_getter()):
        network = resnet_model.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                                         num_classes=LABEL_CLASSES,
                                         data_format=FLAGS.data_format)

        logits = network(inputs=features,
                         is_training=(mode == tf.estimator.ModeKeys.TRAIN))

        logits = tf.cast(logits, tf.float32)

        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, LABEL_CLASSES)
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits, onehot_labels=one_hot_labels)

    tf.identity(cross_entropy, name='cross_entropy')
    tf.summary.scalar('cross_entropy', cross_entropy)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate from global_step.
        global_step = tf.train.get_global_step()
        batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
        current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
        learning_rate = learning_rate_schedule(current_epoch)
        # Create a tensor named learning_rate for logging purposes
        tf.identity(learning_rate, name='learning_rate')
        tf.summary.scalar('learning_rate', learning_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=MOMENTUM,
                                               use_nesterov=True)
        optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            #train_op = optimizer.minimize(loss, global_step)
            train_op = tf.group(optimizer.minimize(loss, global_step),
                                update_ops)

    else:
        train_op = None

    accuracy = tf.metrics.accuracy(tf.argmax(one_hot_labels, axis=1),
                                   predictions['classes'])
    metrics = {'accuracy': accuracy}
    tf.identity(accuracy[1], name='train_accuracy')
    tf.summary.scalar('train_accuracy', accuracy[1])

    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=metrics)