Exemple #1
0
def build_model(model_name: Text, inputs: tf.Tensor, **kwargs):
    """Build model for a given model name.

  Args:
    model_name: the name of the model.
    inputs: an image tensor or a numpy array.
    **kwargs: extra parameters for model builder.

  Returns:
    (cls_outputs, box_outputs): the outputs for class and box predictions.
    Each is a dictionary with key as feature level and value as predictions.
  """
    mixed_precision = kwargs.get('mixed_precision', None)
    precision = utils.get_precision(kwargs.get('strategy', None),
                                    mixed_precision)

    if kwargs.get('use_keras_model', None):

        def model_arch(feats, model_name=None, **kwargs):
            """Construct a model arch for keras models."""
            config = hparams_config.get_efficientdet_config(model_name)
            config.override(kwargs)
            model = efficientdet_keras.EfficientDetNet(config=config)
            cls_out_list, box_out_list = model(feats)
            # convert the list of model outputs to a dictionary with key=level.
            assert len(cls_out_list) == config.max_level - config.min_level + 1
            assert len(box_out_list) == config.max_level - config.min_level + 1
            cls_outputs, box_outputs = {}, {}
            for i in range(config.min_level, config.max_level + 1):
                cls_outputs[i] = cls_out_list[i - config.min_level]
                box_outputs[i] = box_out_list[i - config.min_level]
            return cls_outputs, box_outputs

        cls_outputs, box_outputs = utils.build_model_with_precision(
            precision, model_arch, inputs, False, model_name, **kwargs)
    else:
        model_arch = det_model_fn.get_model_arch(model_name)
        cls_outputs, box_outputs = utils.build_model_with_precision(
            precision, model_arch, inputs, False, model_name, **kwargs)

    if mixed_precision:
        # Post-processing has multiple places with hard-coded float32.
        # TODO(tanmingxing): Remove them once post-process can adpat to dtypes.
        cls_outputs = {
            k: tf.cast(v, tf.float32)
            for k, v in cls_outputs.items()
        }
        box_outputs = {
            k: tf.cast(v, tf.float32)
            for k, v in box_outputs.items()
        }

    return cls_outputs, box_outputs
Exemple #2
0
def build_model(model_name: Text, inputs: tf.Tensor, **kwargs):
    """Build model for a given model name.

  Args:
    model_name: the name of the model.
    inputs: an image tensor or a numpy array.
    **kwargs: extra parameters for model builder.

  Returns:
    (cls_outputs, box_outputs): the outputs for class and box predictions.
    Each is a dictionary with key as feature level and value as predictions.
  """
    model_arch = det_model_fn.get_model_arch(model_name)
    mixed_precision = kwargs.get('mixed_precision', None)
    precision = utils.get_precision(kwargs.get('strategy', None),
                                    mixed_precision)
    cls_outputs, box_outputs = utils.build_model_with_precision(
        precision, model_arch, inputs, False, model_name, **kwargs)
    if mixed_precision:
        # Post-processing has multiple places with hard-coded float32.
        # TODO(tanmingxing): Remove them once post-process can adpat to dtypes.
        cls_outputs = {
            k: tf.cast(v, tf.float32)
            for k, v in cls_outputs.items()
        }
        box_outputs = {
            k: tf.cast(v, tf.float32)
            for k, v in box_outputs.items()
        }
    return cls_outputs, box_outputs
Exemple #3
0
    def build_model(self, features, is_training, precision):
        def _model_fn(inputs):
            return efficientdet_arch.efficientdet(inputs,
                                                  model_name='efficientdet-d0',
                                                  is_training_bn=is_training,
                                                  image_size=512)

        return utils.build_model_with_precision(precision, _model_fn, features)
Exemple #4
0
    def test_precision_float16(self):
        def _model(inputs):
            x = tf.ones((4, 4, 4, 4), dtype='float32')
            conv = tf.keras.layers.Conv2D(filters=4, kernel_size=2, use_bias=False)
            a = tf.Variable(1.0)
            return tf.cast(a, inputs.dtype) * conv(x) * inputs

        x = tf.constant(2.0, dtype=tf.float32)  # input can be any type.
        out = utils.build_model_with_precision('mixed_float16', _model, x, False)
        # Variables should be float32.
        for v in tf.global_variables():
            self.assertIn(v.dtype, (tf.float32, tf.dtypes.as_dtype('float32_ref')))
        self.assertIs(out.dtype, tf.float16)  # output should be float16.
Exemple #5
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
    """Model definition entry.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.

  Raises:
    RuntimeError: if both ckpt and backbone_ckpt are set.
  """
    utils.image('input_image', features)
    training_hooks = []
    if params['data_format'] == 'channels_first':
        features = tf.transpose(features, [0, 3, 1, 2])

    def _model_outputs(inputs):
        # Convert params (dict) to Config for easier access.
        return model(inputs, config=hparams_config.Config(params))

    cls_outputs, box_outputs = utils.build_model_with_precision(
        params['precision'], _model_outputs, features,
        params['is_training_bn'])

    levels = cls_outputs.keys()
    for level in levels:
        cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
        box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'image': features,
        }
        for level in levels:
            predictions['cls_outputs_%d' % level] = cls_outputs[level]
            predictions['box_outputs_%d' % level] = box_outputs[level]
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params, global_step)

    # cls_loss and box_loss are for logging. only total_loss is optimized.
    det_loss, cls_loss, box_loss, box_iou_loss = detection_loss(
        cls_outputs, box_outputs, labels, params)
    reg_l2loss = reg_l2_loss(params['weight_decay'])
    total_loss = det_loss + reg_l2loss

    if mode == tf.estimator.ModeKeys.TRAIN:
        utils.scalar('lrn_rate', learning_rate)
        utils.scalar('trainloss/cls_loss', cls_loss)
        utils.scalar('trainloss/box_loss', box_loss)
        utils.scalar('trainloss/det_loss', det_loss)
        utils.scalar('trainloss/reg_l2_loss', reg_l2loss)
        utils.scalar('trainloss/loss', total_loss)
        if box_iou_loss:
            utils.scalar('trainloss/box_iou_loss', box_iou_loss)

    moving_average_decay = params['moving_average_decay']
    if moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(decay=moving_average_decay,
                                                num_updates=global_step)
        ema_vars = utils.get_ema_vars()
    if params['strategy'] == 'horovod':
        import horovod.tensorflow as hvd  # pylint: disable=g-import-not-at-top
        learning_rate = learning_rate * hvd.size()
    if mode == tf.estimator.ModeKeys.TRAIN:
        if params['optimizer'].lower() == 'sgd':
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   momentum=params['momentum'])
        elif params['optimizer'].lower() == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate)
        else:
            raise ValueError('optimizers should be adam or sgd')

        if params['strategy'] == 'tpu':
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)
        elif params['strategy'] == 'horovod':
            optimizer = hvd.DistributedOptimizer(optimizer)
            training_hooks = [hvd.BroadcastGlobalVariablesHook(0)]

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = tf.trainable_variables()
        if variable_filter_fn:
            var_list = variable_filter_fn(var_list)

        if params.get('clip_gradients_norm', 0) > 0:
            logging.info('clip gradients norm by %f',
                         params['clip_gradients_norm'])
            grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
            with tf.name_scope('clip'):
                grads = [gv[0] for gv in grads_and_vars]
                tvars = [gv[1] for gv in grads_and_vars]
                clipped_grads, gnorm = tf.clip_by_global_norm(
                    grads, params['clip_gradients_norm'])
                utils.scalar('gnorm', gnorm)
                grads_and_vars = list(zip(clipped_grads, tvars))

            with tf.control_dependencies(update_ops):
                train_op = optimizer.apply_gradients(grads_and_vars,
                                                     global_step)
        else:
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss,
                                              global_step,
                                              var_list=var_list)

        if moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(**kwargs):
            """Returns a dictionary that has the evaluation metrics."""
            batch_size = params['batch_size']
            if params['strategy'] == 'tpu':
                batch_size = params['batch_size'] * params['num_shards']
            eval_anchors = anchors.Anchors(params['min_level'],
                                           params['max_level'],
                                           params['num_scales'],
                                           params['aspect_ratios'],
                                           params['anchor_scale'],
                                           params['image_size'])
            anchor_labeler = anchors.AnchorLabeler(eval_anchors,
                                                   params['num_classes'])
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])

            if params.get('testdev_dir', None):
                logging.info('Eval testdev_dir %s', params['testdev_dir'])
                coco_metrics = coco_metric_fn(
                    batch_size,
                    anchor_labeler,
                    params['val_json_file'],
                    testdev_dir=params['testdev_dir'],
                    disable_pyfun=params.get('disable_pyfun', None),
                    **kwargs)
            else:
                logging.info('Eval val with groudtruths %s.',
                             params['val_json_file'])
                coco_metrics = coco_metric_fn(batch_size, anchor_labeler,
                                              params['val_json_file'],
                                              **kwargs)

            # Add metrics to output.
            output_metrics = {
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            }
            output_metrics.update(coco_metrics)
            return output_metrics

        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        box_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(box_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'box_loss_repeat': box_loss_repeat,
            'source_ids': labels['source_ids'],
            'groundtruth_data': labels['groundtruth_data'],
            'image_scales': labels['image_scales'],
        }
        add_metric_fn_inputs(params, cls_outputs, box_outputs,
                             metric_fn_inputs)
        eval_metrics = (metric_fn, metric_fn_inputs)

    checkpoint = params.get('ckpt') or params.get('backbone_ckpt')

    if checkpoint and mode == tf.estimator.ModeKeys.TRAIN:
        # Initialize the model from an EfficientDet or backbone checkpoint.
        if params.get('ckpt') and params.get('backbone_ckpt'):
            raise RuntimeError(
                '--backbone_ckpt and --checkpoint are mutually exclusive')

        if params.get('backbone_ckpt'):
            var_scope = params['backbone_name'] + '/'
            if params['ckpt_var_scope'] is None:
                # Use backbone name as default checkpoint scope.
                ckpt_scope = params['backbone_name'] + '/'
            else:
                ckpt_scope = params['ckpt_var_scope'] + '/'
        else:
            # Load every var in the given checkpoint
            var_scope = ckpt_scope = '/'

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            logging.info('restore variables from %s', checkpoint)

            var_map = utils.get_ckpt_var_map(ckpt_path=checkpoint,
                                             ckpt_scope=ckpt_scope,
                                             var_scope=var_scope,
                                             var_exclude_expr=params.get(
                                                 'var_exclude_expr', None))

            tf.train.init_from_checkpoint(checkpoint, var_map)

            return tf.train.Scaffold()
    elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:

        def scaffold_fn():
            """Load moving average variables for eval."""
            logging.info('Load EMA vars with ema_decay=%f',
                         moving_average_decay)
            restore_vars_dict = ema.variables_to_restore(ema_vars)
            saver = tf.train.Saver(restore_vars_dict)
            return tf.train.Scaffold(saver=saver)
    else:
        scaffold_fn = None

    if params['strategy'] != 'tpu':
        # Profile every 1K steps.
        profile_hook = tf.train.ProfilerHook(save_steps=1000,
                                             output_dir=params['model_dir'])
        training_hooks.append(profile_hook)

        # Report memory allocation if OOM
        class OomReportingHook(tf.estimator.SessionRunHook):
            def before_run(self, run_context):
                return tf.estimator.SessionRunArgs(
                    fetches=[],
                    options=tf.RunOptions(
                        report_tensor_allocations_upon_oom=True))

        training_hooks.append(OomReportingHook())

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=total_loss,
                                             train_op=train_op,
                                             eval_metrics=eval_metrics,
                                             host_call=utils.get_tpu_host_call(
                                                 global_step, params),
                                             scaffold_fn=scaffold_fn,
                                             training_hooks=training_hooks)
Exemple #6
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
    """Model definition entry.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN and EVAL.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.

  Raises:
    RuntimeError: if both ckpt and backbone_ckpt are set.
  """
    is_tpu = params['strategy'] == 'tpu'
    if params['img_summary_steps']:
        utils.image('input_image', features, is_tpu)
    training_hooks = []
    params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)

    if params['use_keras_model']:

        def model_fn(inputs):
            model = efficientdet_keras.EfficientDetNet(
                config=hparams_config.Config(params))
            cls_out_list, box_out_list = model(inputs,
                                               params['is_training_bn'])
            cls_outputs, box_outputs = {}, {}
            for i in range(params['min_level'], params['max_level'] + 1):
                cls_outputs[i] = cls_out_list[i - params['min_level']]
                box_outputs[i] = box_out_list[i - params['min_level']]
            return cls_outputs, box_outputs
    else:
        model_fn = functools.partial(model,
                                     config=hparams_config.Config(params))

    precision = utils.get_precision(params['strategy'],
                                    params['mixed_precision'])
    cls_outputs, box_outputs = utils.build_model_with_precision(
        precision, model_fn, features, params['is_training_bn'])

    levels = cls_outputs.keys()
    for level in levels:
        cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
        box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params, global_step)

    # cls_loss and box_loss are for logging. only total_loss is optimized.
    det_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs,
                                                  labels, params)
    reg_l2loss = reg_l2_loss(params['weight_decay'])
    total_loss = det_loss + reg_l2loss

    if mode == tf.estimator.ModeKeys.TRAIN:
        utils.scalar('lrn_rate', learning_rate, is_tpu)
        utils.scalar('trainloss/cls_loss', cls_loss, is_tpu)
        utils.scalar('trainloss/box_loss', box_loss, is_tpu)
        utils.scalar('trainloss/det_loss', det_loss, is_tpu)
        utils.scalar('trainloss/reg_l2_loss', reg_l2loss, is_tpu)
        utils.scalar('trainloss/loss', total_loss, is_tpu)
        train_epochs = tf.cast(global_step,
                               tf.float32) / params['steps_per_epoch']
        utils.scalar('train_epochs', train_epochs, is_tpu)

    moving_average_decay = params['moving_average_decay']
    if moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(decay=moving_average_decay,
                                                num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    if mode == tf.estimator.ModeKeys.TRAIN:
        if params['optimizer'].lower() == 'sgd':
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   momentum=params['momentum'])
        elif params['optimizer'].lower() == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate)
        else:
            raise ValueError('optimizers should be adam or sgd')

        if is_tpu:
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = tf.trainable_variables()
        if variable_filter_fn:
            var_list = variable_filter_fn(var_list)

        if params.get('clip_gradients_norm', None):
            logging.info('clip gradients norm by %f',
                         params['clip_gradients_norm'])
            grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
            with tf.name_scope('clip'):
                grads = [gv[0] for gv in grads_and_vars]
                tvars = [gv[1] for gv in grads_and_vars]
                # First clip each variable's norm, then clip global norm.
                clip_norm = abs(params['clip_gradients_norm'])
                clipped_grads = [
                    tf.clip_by_norm(g, clip_norm) if g is not None else None
                    for g in grads
                ]
                clipped_grads, _ = tf.clip_by_global_norm(
                    clipped_grads, clip_norm)
                utils.scalar('gradient_norm',
                             tf.linalg.global_norm(clipped_grads), is_tpu)
                grads_and_vars = list(zip(clipped_grads, tvars))

            with tf.control_dependencies(update_ops):
                train_op = optimizer.apply_gradients(grads_and_vars,
                                                     global_step)
        else:
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss,
                                              global_step,
                                              var_list=var_list)

        if moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(**kwargs):
            """Returns a dictionary that has the evaluation metrics."""
            if params['nms_configs'].get('pyfunc', True):
                detections_bs = []
                nms_configs = params['nms_configs']
                for index in range(kwargs['boxes'].shape[0]):
                    detections = tf.numpy_function(
                        functools.partial(nms_np.per_class_nms,
                                          nms_configs=nms_configs),
                        [
                            kwargs['boxes'][index],
                            kwargs['scores'][index],
                            kwargs['classes'][index],
                            tf.slice(kwargs['image_ids'], [index], [1]),
                            tf.slice(kwargs['image_scales'], [index], [1]),
                            params['num_classes'],
                            nms_configs['max_output_size'],
                        ], tf.float32)
                    detections_bs.append(detections)
                detections_bs = postprocess.transform_detections(
                    tf.stack(detections_bs))
            else:
                # These two branches should be equivalent, but currently they are not.
                # TODO(tanmingxing): enable the non_pyfun path after bug fix.
                nms_boxes, nms_scores, nms_classes, _ = postprocess.per_class_nms(
                    params, kwargs['boxes'], kwargs['scores'],
                    kwargs['classes'], kwargs['image_scales'])
                img_ids = tf.cast(tf.expand_dims(kwargs['image_ids'], -1),
                                  nms_scores.dtype)
                detections_bs = [
                    img_ids * tf.ones_like(nms_scores),
                    nms_boxes[:, :, 1],
                    nms_boxes[:, :, 0],
                    nms_boxes[:, :, 3] - nms_boxes[:, :, 1],
                    nms_boxes[:, :, 2] - nms_boxes[:, :, 0],
                    nms_scores,
                    nms_classes,
                ]
                detections_bs = tf.stack(detections_bs,
                                         axis=-1,
                                         name='detnections')

            if params.get('testdev_dir', None):
                logging.info('Eval testdev_dir %s', params['testdev_dir'])
                eval_metric = coco_metric.EvaluationMetric(
                    testdev_dir=params['testdev_dir'])
                coco_metrics = eval_metric.estimator_metric_fn(
                    detections_bs, tf.zeros([1]))
            else:
                logging.info('Eval val with groudtruths %s.',
                             params['val_json_file'])
                eval_metric = coco_metric.EvaluationMetric(
                    filename=params['val_json_file'],
                    label_map=params['label_map'])
                coco_metrics = eval_metric.estimator_metric_fn(
                    detections_bs, kwargs['groundtruth_data'])

            # Add metrics to output.
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
            output_metrics = {
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            }
            output_metrics.update(coco_metrics)
            return output_metrics

        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        box_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(box_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])

        cls_outputs = postprocess.to_list(cls_outputs)
        box_outputs = postprocess.to_list(box_outputs)
        params['nms_configs']['max_nms_inputs'] = anchors.MAX_DETECTION_POINTS
        boxes, scores, classes = postprocess.pre_nms(params, cls_outputs,
                                                     box_outputs)
        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'box_loss_repeat': box_loss_repeat,
            'image_ids': labels['source_ids'],
            'groundtruth_data': labels['groundtruth_data'],
            'image_scales': labels['image_scales'],
            'boxes': boxes,
            'scores': scores,
            'classes': classes,
        }
        eval_metrics = (metric_fn, metric_fn_inputs)

    checkpoint = params.get('ckpt') or params.get('backbone_ckpt')

    if checkpoint and mode == tf.estimator.ModeKeys.TRAIN:
        # Initialize the model from an EfficientDet or backbone checkpoint.
        if params.get('ckpt') and params.get('backbone_ckpt'):
            raise RuntimeError(
                '--backbone_ckpt and --checkpoint are mutually exclusive')

        if params.get('backbone_ckpt'):
            var_scope = params['backbone_name'] + '/'
            if params['ckpt_var_scope'] is None:
                # Use backbone name as default checkpoint scope.
                ckpt_scope = params['backbone_name'] + '/'
            else:
                ckpt_scope = params['ckpt_var_scope'] + '/'
        else:
            # Load every var in the given checkpoint
            var_scope = ckpt_scope = '/'

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            logging.info('restore variables from %s', checkpoint)

            var_map = utils.get_ckpt_var_map(
                ckpt_path=checkpoint,
                ckpt_scope=ckpt_scope,
                var_scope=var_scope,
                skip_mismatch=params['skip_mismatch'])

            tf.train.init_from_checkpoint(checkpoint, var_map)
            return tf.train.Scaffold()
    elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:

        def scaffold_fn():
            """Load moving average variables for eval."""
            logging.info('Load EMA vars with ema_decay=%f',
                         moving_average_decay)
            restore_vars_dict = ema.variables_to_restore(ema_vars)
            saver = tf.train.Saver(restore_vars_dict)
            return tf.train.Scaffold(saver=saver)
    else:
        scaffold_fn = None

    if is_tpu:
        return tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            loss=total_loss,
            train_op=train_op,
            eval_metrics=eval_metrics,
            host_call=utils.get_tpu_host_call(global_step, params),
            scaffold_fn=scaffold_fn,
            training_hooks=training_hooks)
    else:
        # Profile every 1K steps.
        if params.get('profile', False):
            profile_hook = tf.estimator.ProfilerHook(
                save_steps=1000,
                output_dir=params['model_dir'],
                show_memory=True)
            training_hooks.append(profile_hook)

            # Report memory allocation if OOM; it will slow down the running.
            class OomReportingHook(tf.estimator.SessionRunHook):
                def before_run(self, run_context):
                    return tf.estimator.SessionRunArgs(
                        fetches=[],
                        options=tf.RunOptions(
                            report_tensor_allocations_upon_oom=True))

            training_hooks.append(OomReportingHook())

        logging_hook = tf.estimator.LoggingTensorHook(
            {
                'step': global_step,
                'det_loss': det_loss,
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            },
            every_n_iter=params.get('iterations_per_loop', 100),
        )
        training_hooks.append(logging_hook)

        eval_metric_ops = (eval_metrics[0](
            **eval_metrics[1]) if eval_metrics else None)
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=total_loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            scaffold=scaffold_fn() if scaffold_fn else None,
            training_hooks=training_hooks)
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
    """Model definition entry.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.

  Raises:
    RuntimeError: if both ckpt and backbone_ckpt are set.
  """
    utils.image('input_image', features)
    training_hooks = []
    params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)

    if params['use_keras_model']:

        def model_fn(inputs):
            model = efficientdet_keras.EfficientDetNet(
                config=hparams_config.Config(params))
            cls_out_list, box_out_list = model(inputs,
                                               params['is_training_bn'])
            cls_outputs, box_outputs = {}, {}
            for i in range(params['min_level'], params['max_level'] + 1):
                cls_outputs[i] = cls_out_list[i - params['min_level']]
                box_outputs[i] = box_out_list[i - params['min_level']]
            return cls_outputs, box_outputs
    else:
        model_fn = functools.partial(model,
                                     config=hparams_config.Config(params))

    precision = utils.get_precision(params['strategy'],
                                    params['mixed_precision'])
    cls_outputs, box_outputs = utils.build_model_with_precision(
        precision, model_fn, features, params['is_training_bn'])

    levels = cls_outputs.keys()
    for level in levels:
        cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
        box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'image': features,
        }
        for level in levels:
            predictions['cls_outputs_%d' % level] = cls_outputs[level]
            predictions['box_outputs_%d' % level] = box_outputs[level]
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params, global_step)

    # cls_loss and box_loss are for logging. only total_loss is optimized.
    det_loss, cls_loss, box_loss, box_iou_loss = detection_loss(
        cls_outputs, box_outputs, labels, params)
    reg_l2loss = reg_l2_loss(params['weight_decay'])
    total_loss = det_loss + reg_l2loss

    if mode == tf.estimator.ModeKeys.TRAIN:
        utils.scalar('lrn_rate', learning_rate)
        utils.scalar('trainloss/cls_loss', cls_loss)
        utils.scalar('trainloss/box_loss', box_loss)
        utils.scalar('trainloss/det_loss', det_loss)
        utils.scalar('trainloss/reg_l2_loss', reg_l2loss)
        utils.scalar('trainloss/loss', total_loss)
        if params['iou_loss_type']:
            utils.scalar('trainloss/box_iou_loss', box_iou_loss)
        train_epochs = tf.cast(global_step,
                               tf.float32) / params['steps_per_epoch']
        utils.scalar('train_epochs', train_epochs)

    moving_average_decay = params['moving_average_decay']
    if moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(decay=moving_average_decay,
                                                num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    if mode == tf.estimator.ModeKeys.TRAIN:
        if params['optimizer'].lower() == 'sgd':
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   momentum=params['momentum'])
        elif params['optimizer'].lower() == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate)
        else:
            raise ValueError('optimizers should be adam or sgd')

        if params['strategy'] == 'tpu':
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)
        if params['gradient_checkpointing']:
            from third_party.grad_checkpoint \
                import memory_saving_gradients  # pylint: disable=g-import-not-at-top
            from tensorflow.python.ops \
                import gradients  # pylint: disable=g-import-not-at-top

            # monkey patch tf.gradients to point to our custom version,
            # with automatic checkpoint selection
            def gradients_(ys, xs, grad_ys=None, **kwargs):
                return memory_saving_gradients.gradients(
                    ys,
                    xs,
                    grad_ys,
                    checkpoints=params['gradient_checkpointing_list'],
                    **kwargs)

            gradients.__dict__["gradients"] = gradients_

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = tf.trainable_variables()
        if variable_filter_fn:
            var_list = variable_filter_fn(var_list)

        if params.get('clip_gradients_norm', None):
            logging.info('clip gradients norm by %f',
                         params['clip_gradients_norm'])
            grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
            with tf.name_scope('clip'):
                grads = [gv[0] for gv in grads_and_vars]
                tvars = [gv[1] for gv in grads_and_vars]
                # First clip each variable's norm, then clip global norm.
                clip_norm = abs(params['clip_gradients_norm'])
                clipped_grads = [tf.clip_by_norm(g, clip_norm) for g in grads]
                clipped_grads, _ = tf.clip_by_global_norm(
                    clipped_grads, clip_norm)
                utils.scalar('gradient_norm',
                             tf.linalg.global_norm(clipped_grads))
                grads_and_vars = list(zip(clipped_grads, tvars))

            with tf.control_dependencies(update_ops):
                train_op = optimizer.apply_gradients(grads_and_vars,
                                                     global_step)
        else:
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss,
                                              global_step,
                                              var_list=var_list)

        if moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(**kwargs):
            """Returns a dictionary that has the evaluation metrics."""
            if params['nms_configs'].get('pyfunc', True):
                detections_bs = []
                for index in range(kwargs['boxes'].shape[0]):
                    nms_configs = params['nms_configs']
                    detections = tf.numpy_function(
                        functools.partial(nms_np.per_class_nms,
                                          nms_configs=nms_configs),
                        [
                            kwargs['boxes'][index],
                            kwargs['scores'][index],
                            kwargs['classes'][index],
                            tf.slice(kwargs['image_ids'], [index], [1]),
                            tf.slice(kwargs['image_scales'], [index], [1]),
                            params['num_classes'],
                            nms_configs['max_output_size'],
                        ], tf.float32)
                    detections_bs.append(detections)
                detections_bs = postprocess.transform_detections(
                    tf.stack(detections_bs))
            else:
                # These two branches should be equivalent, but currently they are not.
                # TODO(tanmingxing): enable the non_pyfun path after bug fix.
                nms_boxes, nms_scores, nms_classes, _ = postprocess.per_class_nms(
                    params, kwargs['boxes'], kwargs['scores'],
                    kwargs['classes'], kwargs['image_scales'])
                img_ids = tf.cast(tf.expand_dims(kwargs['image_ids'], -1),
                                  nms_scores.dtype)
                detections_bs = [
                    img_ids * tf.ones_like(nms_scores),
                    nms_boxes[:, :, 1],
                    nms_boxes[:, :, 0],
                    nms_boxes[:, :, 3] - nms_boxes[:, :, 1],
                    nms_boxes[:, :, 2] - nms_boxes[:, :, 0],
                    nms_scores,
                    nms_classes,
                ]
                detections_bs = tf.stack(detections_bs,
                                         axis=-1,
                                         name='detnections')

            if params.get('testdev_dir', None):
                logging.info('Eval testdev_dir %s', params['testdev_dir'])
                eval_metric = coco_metric.EvaluationMetric(
                    testdev_dir=params['testdev_dir'])
                coco_metrics = eval_metric.estimator_metric_fn(
                    detections_bs, tf.zeros([1]))
            else:
                logging.info('Eval val with groudtruths %s.',
                             params['val_json_file'])
                eval_metric = coco_metric.EvaluationMetric(
                    filename=params['val_json_file'])
                coco_metrics = eval_metric.estimator_metric_fn(
                    detections_bs, kwargs['groundtruth_data'],
                    params['label_map'])

            # Add metrics to output.
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
            output_metrics = {
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            }
            output_metrics.update(coco_metrics)
            return output_metrics

        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        box_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(box_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])

        cls_outputs = postprocess.to_list(cls_outputs)
        box_outputs = postprocess.to_list(box_outputs)
        params['nms_configs']['max_nms_inputs'] = anchors.MAX_DETECTION_POINTS
        boxes, scores, classes = postprocess.pre_nms(params, cls_outputs,
                                                     box_outputs)
        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'box_loss_repeat': box_loss_repeat,
            'image_ids': labels['source_ids'],
            'groundtruth_data': labels['groundtruth_data'],
            'image_scales': labels['image_scales'],
            'boxes': boxes,
            'scores': scores,
            'classes': classes,
        }
        eval_metrics = (metric_fn, metric_fn_inputs)

    checkpoint = params.get('ckpt') or params.get('backbone_ckpt')

    if checkpoint and mode == tf.estimator.ModeKeys.TRAIN:
        # Initialize the model from an EfficientDet or backbone checkpoint.
        if params.get('ckpt') and params.get('backbone_ckpt'):
            raise RuntimeError(
                '--backbone_ckpt and --checkpoint are mutually exclusive')

        if params.get('backbone_ckpt'):
            var_scope = params['backbone_name'] + '/'
            if params['ckpt_var_scope'] is None:
                # Use backbone name as default checkpoint scope.
                ckpt_scope = params['backbone_name'] + '/'
            else:
                ckpt_scope = params['ckpt_var_scope'] + '/'
        else:
            # Load every var in the given checkpoint
            var_scope = ckpt_scope = '/'

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            logging.info('restore variables from %s', checkpoint)

            var_map = utils.get_ckpt_var_map(
                ckpt_path=checkpoint,
                ckpt_scope=ckpt_scope,
                var_scope=var_scope,
                skip_mismatch=params['skip_mismatch'])

            tf.train.init_from_checkpoint(checkpoint, var_map)
            return tf.train.Scaffold()
    elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:

        def scaffold_fn():
            """Load moving average variables for eval."""
            logging.info('Load EMA vars with ema_decay=%f',
                         moving_average_decay)
            restore_vars_dict = ema.variables_to_restore(ema_vars)
            saver = tf.train.Saver(restore_vars_dict)
            return tf.train.Scaffold(saver=saver)
    else:
        scaffold_fn = None

    if params['strategy'] != 'tpu':
        # Profile every 1K steps.
        if params.get('profile', False):
            profile_hook = tf.estimator.ProfilerHook(
                save_steps=1000,
                output_dir=params['model_dir'],
                show_memory=True)
            training_hooks.append(profile_hook)

            # Report memory allocation if OOM
            class OomReportingHook(tf.estimator.SessionRunHook):
                def before_run(self, run_context):
                    return tf.estimator.SessionRunArgs(
                        fetches=[],
                        options=tf.RunOptions(
                            report_tensor_allocations_upon_oom=True))

            training_hooks.append(OomReportingHook())

        logging_hook = tf.estimator.LoggingTensorHook(
            {
                'step': global_step,
                'det_loss': det_loss,
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            },
            every_n_iter=params.get('iterations_per_loop', 100),
        )
        training_hooks.append(logging_hook)

        if params["nvgpu_logging"]:
            try:
                from third_party import nvgpu  # pylint: disable=g-import-not-at-top
                from functools import reduce  # pylint: disable=g-import-not-at-top

                def get_nested_value(d, path):
                    return reduce(dict.get, path, d)

                def nvgpu_gpu_info(inp):
                    inp = inp.decode("utf-8")
                    inp = inp.split(",")
                    inp = [x.strip() for x in inp]
                    value = get_nested_value(nvgpu.gpu_info(), inp)
                    return np.str(value)

                def commonsize(inp):
                    const_sizes = {
                        'B': 1,
                        'KB': 1e3,
                        'MB': 1e6,
                        'GB': 1e9,
                        'TB': 1e12,
                        'PB': 1e15,
                        'KiB': 1024,
                        'MiB': 1048576,
                        'GiB': 1073741824
                    }
                    inp = inp.split(" ")
                    # convert all to MiB
                    if inp[1] != 'MiB':
                        inp_ = float(
                            inp[0]) * (const_sizes[inp[1]] / 1048576.0)
                    else:
                        inp_ = float(inp[0])

                    return inp_

                def formatter_log(tensors):
                    """Format the output."""
                    mem_used = tensors["memory used"].decode("utf-8")
                    mem_total = tensors["memory total"].decode("utf-8")
                    mem_util = commonsize(mem_used) / commonsize(mem_total)
                    logstring = "GPU memory used: {} = {:.1%} of total GPU memory: {}".format(
                        mem_used, mem_util, mem_total)
                    return logstring

                mem_used = tf.py_func(nvgpu_gpu_info,
                                      ['gpu, fb_memory_usage, used'],
                                      [tf.string])[0]
                mem_total = tf.py_func(nvgpu_gpu_info,
                                       ['gpu, fb_memory_usage, total'],
                                       [tf.string])[0]

                logging_hook3 = tf.estimator.LoggingTensorHook(
                    tensors={
                        "memory used": mem_used,
                        "memory total": mem_total,
                    },
                    every_n_iter=params.get('iterations_per_loop', 100),
                    formatter=formatter_log,
                )
                training_hooks.append(logging_hook3)
            except:
                logging.error("nvgpu error: nvidia-smi format not recognized")

    if params['strategy'] == 'tpu':
        return tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            loss=total_loss,
            train_op=train_op,
            eval_metrics=eval_metrics,
            host_call=utils.get_tpu_host_call(global_step, params),
            scaffold_fn=scaffold_fn,
            training_hooks=training_hooks)
    else:
        eval_metric_ops = eval_metrics[0](
            **eval_metrics[1]) if eval_metrics else None
        utils.get_tpu_host_call(global_step, params)
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=total_loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops,
                                          scaffold=scaffold_fn(),
                                          training_hooks=training_hooks)
Exemple #8
0
def model_fn(features, labels, mode, params):
    """The model_fn to be used with TPUEstimator.

  Args:
    features: A dict of `Tensor` of batched images and other features.
    labels: a Tensor or a dict of Tensor representing the batched labels.
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    logging.info('params=%s', params)
    images = features['image'] if isinstance(features, dict) else features
    labels = labels['label'] if isinstance(labels, dict) else labels
    config = params['config']
    image_size = params['image_size']
    utils.scalar('model/resolution', image_size)

    if config.model.data_format == 'channels_first':
        images = tf.transpose(images, [0, 3, 1, 2])

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    has_moving_average_decay = (config.train.ema_decay > 0)
    if FLAGS.use_tpu and not config.model.bn_type:
        config.model.bn_type = 'tpu_bn'
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)

    def build_model(in_images):
        """Build model using the model_name given through the command line."""
        config.model.num_classes = config.data.num_classes
        model = effnetv2_model.EffNetV2Model(config.model.model_name,
                                             config.model)
        logits = model(in_images, training=is_training)[0]
        return logits

    pre_num_params, pre_num_flops = utils.num_params_flops(
        readable_format=True)

    if config.runtime.mixed_precision:
        precision = 'mixed_bfloat16' if FLAGS.use_tpu else 'mixed_float16'
        logits = utils.build_model_with_precision(precision, build_model,
                                                  images, is_training)
        logits = tf.cast(logits, tf.float32)
    else:
        logits = build_model(images)

    num_params, num_flops = utils.num_params_flops(readable_format=True)
    num_params = num_params - pre_num_params
    num_flops = (num_flops - pre_num_flops) / params['batch_size']
    logging.info('backbone params/flops = %.4f M / %.4f B', num_params,
                 num_flops)
    utils.scalar('model/params', num_params)
    utils.scalar('model/flops', num_flops)

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    if config.train.loss_type == 'sigmoid':
        cross_entropy = tf.losses.sigmoid_cross_entropy(
            multi_class_labels=tf.cast(labels, dtype=logits.dtype),
            logits=logits,
            label_smoothing=config.train.label_smoothing)
    elif config.train.loss_type == 'custom':
        xent = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(
            labels, dtype=logits.dtype),
                                                       logits=logits)
        cross_entropy = tf.reduce_mean(tf.reduce_sum(xent, axis=-1))
    else:
        if config.data.multiclass:
            logging.info('use multi-class loss: %s', config.data.multiclass)
            labels /= tf.reshape(tf.reduce_sum(labels, axis=1), (-1, 1))
        cross_entropy = tf.losses.softmax_cross_entropy(
            onehot_labels=labels,
            logits=logits,
            label_smoothing=config.train.label_smoothing)

    train_steps = max(config.train.min_steps,
                      config.train.epochs * params['steps_per_epoch'])
    global_step = tf.train.get_global_step()
    weight_decay_inc = config.train.weight_decay_inc * (
        tf.cast(global_step, tf.float32) / tf.cast(train_steps, tf.float32))
    weight_decay = (1 + weight_decay_inc) * config.train.weight_decay
    utils.scalar('train/weight_decay', weight_decay)
    # Add weight decay to the loss for non-batch-normalization variables.
    matcher = re.compile(config.train.weight_decay_exclude)
    l2loss = weight_decay * tf.add_n([
        tf.nn.l2_loss(v)
        for v in tf.trainable_variables() if not matcher.match(v.name)
    ])
    loss = cross_entropy + l2loss
    utils.scalar('loss/l2reg', l2loss)
    utils.scalar('loss/xent', cross_entropy)

    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(decay=config.train.ema_decay,
                                                num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    host_call = None
    restore_vars_dict = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])
        utils.scalar('train/epoch', current_epoch)

        scaled_lr = config.train.lr_base * (config.train.batch_size / 256.0)
        scaled_lr_min = config.train.lr_min * (config.train.batch_size / 256.0)
        learning_rate = utils.WarmupLearningRateSchedule(
            scaled_lr,
            steps_per_epoch=params['steps_per_epoch'],
            decay_epochs=config.train.lr_decay_epoch,
            warmup_epochs=config.train.lr_warmup_epoch,
            decay_factor=config.train.lr_decay_factor,
            lr_decay_type=config.train.lr_sched,
            total_steps=train_steps,
            minimal_lr=scaled_lr_min)(global_step)
        utils.scalar('train/lr', learning_rate)
        optimizer = utils.build_optimizer(
            learning_rate, optimizer_name=config.train.optimizer)
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # filter trainable variables if needed.
        var_list = tf.trainable_variables()
        if config.train.varsexp:
            vars2 = [
                v for v in var_list if re.match(config.train.varsexp, v.name)
            ]
            if len(vars2) == len(var_list):
                logging.warning('%s has no match.', config.train.freeze)
            logging.info('Filter variables: orig=%d, final=%d, delta=%d',
                         len(var_list), len(vars2),
                         len(var_list) - len(vars2))
            var_list = vars2

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if config.train.gclip and is_training:
            logging.info('clip gradients norm by %f', config.train.gclip)
            grads_and_vars = optimizer.compute_gradients(loss, var_list)
            with tf.name_scope('gclip'):
                grads = [gv[0] for gv in grads_and_vars]
                tvars = [gv[1] for gv in grads_and_vars]
                utils.scalar('train/gnorm', tf.linalg.global_norm(grads))
                utils.scalar('train/gnormmax',
                             tf.math.reduce_max([tf.norm(g) for g in grads]))
                # First clip each variable's norm, then clip global norm.
                clip_norm = abs(config.train.gclip)
                clipped_grads = [
                    tf.clip_by_norm(g, clip_norm) if g is not None else None
                    for g in grads
                ]
                clipped_grads, _ = tf.clip_by_global_norm(
                    clipped_grads, clip_norm)
                grads_and_vars = list(zip(clipped_grads, tvars))

            with tf.control_dependencies(update_ops):
                train_op = optimizer.apply_gradients(grads_and_vars,
                                                     global_step)
        else:
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(loss,
                                              global_step,
                                              var_list=var_list)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not config.runtime.skip_host_call:
            host_call = utils.get_tpu_host_call(
                global_step, FLAGS.model_dir,
                config.runtime.iterations_per_loop)
    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch, num_classes]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            metrics = {}
            if config.data.multiclass:
                metrics['eval/global_ap'] = tf.metrics.auc(
                    labels,
                    tf.nn.sigmoid(logits),
                    curve='PR',
                    num_thresholds=200,
                    summation_method='careful_interpolation',
                    name='global_ap')

                # Convert labels to set: be careful, tf.metrics.xx_at_k are horrible.
                labels = tf.cast(labels, dtype=tf.int64)
                label_to_repeat = tf.expand_dims(tf.argmax(labels, axis=-1),
                                                 axis=-1)
                all_labels_set = tf.range(0, labels.shape[-1], dtype=tf.int64)
                all_labels_set = tf.expand_dims(all_labels_set, axis=0)
                labels_set = labels * all_labels_set + (
                    1 - labels) * label_to_repeat

                metrics['eval/precision@1'] = tf.metrics.precision_at_k(
                    labels_set, logits, k=1)
                metrics['eval/recall@1'] = tf.metrics.recall_at_k(labels_set,
                                                                  logits,
                                                                  k=1)
                metrics['eval/precision@5'] = tf.metrics.precision_at_k(
                    labels_set, logits, k=5)
                metrics['eval/recall@5'] = tf.metrics.recall_at_k(labels_set,
                                                                  logits,
                                                                  k=5)

            # always add accuracy.
            labels = tf.argmax(labels, axis=1)
            predictions = tf.argmax(logits, axis=1)
            metrics['eval/acc_top1'] = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            metrics['eval/acc_top5'] = tf.metrics.mean(in_top_5)
            metrics['model/resolution'] = tf.metrics.mean(image_size)
            metrics['model/flops'] = tf.metrics.mean(num_flops)
            metrics['model/params'] = tf.metrics.mean(num_params)
            return metrics

        eval_metrics = (metric_fn, [labels, logits])

    if has_moving_average_decay and not is_training:

        def scaffold_fn():  # read ema for eval jobs.
            saver = tf.train.Saver(restore_vars_dict)
            return tf.train.Scaffold(saver=saver)
    elif config.train.ft_init_ckpt and is_training:

        def scaffold_fn():
            logging.info('restore variables from %s',
                         config.train.ft_init_ckpt)
            var_map = utils.get_ckpt_var_map(
                ckpt_path=config.train.ft_init_ckpt,
                skip_mismatch=True,
                init_ema=config.train.ft_init_ema)
            tf.train.init_from_checkpoint(config.train.ft_init_ckpt, var_map)
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             host_call=host_call,
                                             eval_metrics=eval_metrics,
                                             scaffold_fn=scaffold_fn)