Ejemplo n.º 1
0
    def model_fn(features, labels, mode, config, params):
        """Estimator model function."""

        del labels
        del config
        del params

        tf.get_variable_scope().set_initializer(
            tf.variance_scaling_initializer(1.0,
                                            mode="fan_avg",
                                            distribution="uniform"))

        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions = model_params.estimator_prediction_fn(features)

            if include_features_in_predictions:
                predictions.update(features)

            if decode_keys:
                # Decode the raw ids into strings in prediction.
                def decode_host_call(tensor_dict):
                    for key in decode_keys:
                        predictions[key] = public_parsing_ops.decode(
                            tensor_dict[key], model_params.vocab_filename,
                            model_params.encoder_type)
                    return tensor_dict

                contrib_tpu.outside_compilation(decode_host_call, predictions)
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  predictions=predictions)

        training = mode == tf.estimator.ModeKeys.TRAIN
        if use_tpu and model_params.use_bfloat16:
            with contrib_tpu.bfloat16_scope():
                loss, outputs = model_params.model()(features, training)
        else:
            loss, outputs = model_params.model()(features, training)

        # TPU requires ouputs all have batch dimension and doesn't handle scalar.
        # Tile all scalars to 1 dimension vector.
        outputs = _tile_scalar_to_batch_size(outputs, model_params.batch_size)

        if mode == tf.estimator.ModeKeys.TRAIN:
            init_lr = model_params.learning_rate
            global_step = tf.train.get_global_step()
            lr = init_lr / 0.01 * tf.rsqrt(
                tf.maximum(tf.to_float(global_step), 10000))
            if train_init_checkpoint:
                lr = tf.minimum(
                    tf.to_float(global_step + 1) / train_warmup_steps *
                    init_lr, lr)

            optimizer = adafactor.AdafactorOptimizer(
                learning_rate=lr,
                decay_rate=adafactor.adafactor_decay_rate_pow(0.8),
                beta1=0.0)
            if use_tpu:
                optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
            train_op = optimizer.minimize(loss, global_step=global_step)

            return tpu_estimator.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                scaffold_fn=_load_vars_from_checkpoint(use_tpu,
                                                       train_init_checkpoint),
                host_call=add_scalars_to_summary(model_dir,
                                                 {"learning_rate": lr}))
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = model_params.estimator_eval_metrics_fn(
                features, outputs)
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  loss=loss,
                                                  eval_metrics=eval_metrics)
Ejemplo n.º 2
0
def resnet_model_fn_w_pruning(features, labels, mode, params):
    """The model_fn for ResNet-50 with pruning.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    A TPUEstimatorSpec for the model
  """

    width = 1. if FLAGS.width <= 0 else FLAGS.width
    if isinstance(features, dict):
        features = features['feature']

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if FLAGS.transpose_input and mode != tf_estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    pruning_method = params['pruning_method']
    use_tpu = params['use_tpu']
    log_alpha_threshold = params['log_alpha_threshold']

    def build_network():
        """Construct the network in the graph."""
        model_pruning_method = pruning_method
        if pruning_method == 'scratch':
            model_pruning_method = 'threshold'

        network = resnet_model.resnet_v1_(
            resnet_depth=FLAGS.resnet_depth,
            num_classes=FLAGS.num_label_classes,
            # we need to construct the model with the pruning masks, but they won't
            # be updated if we're doing scratch training
            pruning_method=model_pruning_method,
            init_method=FLAGS.init_method,
            width=width,
            prune_first_layer=FLAGS.prune_first_layer,
            prune_last_layer=FLAGS.prune_last_layer,
            data_format=FLAGS.data_format,
            end_sparsity=FLAGS.end_sparsity,
            clip_log_alpha=FLAGS.clip_log_alpha,
            log_alpha_threshold=log_alpha_threshold,
            weight_decay=FLAGS.weight_decay)
        return network(inputs=features,
                       is_training=(mode == tf_estimator.ModeKeys.TRAIN))

    if FLAGS.precision == 'bfloat16':
        with contrib_tpu.bfloat16_scope():
            logits = build_network()
        logits = tf.cast(logits, tf.float32)
    elif FLAGS.precision == 'float32':
        logits = build_network()

    if mode == tf_estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf_estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf_estimator.export.PredictOutput(predictions)
            })

    output_dir = params['output_dir']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)

    # make sure we reuse the same label smoothing parameter is we're doing
    # scratch / lottery ticket experiments.
    label_smoothing = FLAGS.label_smoothing
    if FLAGS.pruning_method == 'scratch':
        label_smoothing = float(FLAGS.load_mask_dir.split('/')[15])
    loss = tf.losses.softmax_cross_entropy(logits=logits,
                                           onehot_labels=one_hot_labels,
                                           label_smoothing=label_smoothing)
    # Add regularization loss term
    loss += tf.losses.get_regularization_loss()

    if pruning_method == 'variational_dropout':
        reg_loss = utils.variational_dropout_dkl_loss(
            reg_scalar=FLAGS.reg_scalar,
            start_reg_ramp_up=FLAGS.sparsity_begin_step,
            end_reg_ramp_up=FLAGS.sparsity_end_step,
            warm_up=FLAGS.is_warm_up,
            use_tpu=use_tpu)
        loss += reg_loss
        tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES)
    elif pruning_method == 'l0_regularization':
        reg_loss = utils.l0_regularization_loss(
            reg_scalar=FLAGS.reg_scalar,
            start_reg_ramp_up=FLAGS.sparsity_begin_step,
            end_reg_ramp_up=FLAGS.sparsity_end_step,
            warm_up=FLAGS.is_warm_up,
            use_tpu=use_tpu)
        loss += reg_loss
        tf.losses.add_loss(reg_loss, loss_collection=tf.GraphKeys.LOSSES)

    host_call = None
    if mode == tf_estimator.ModeKeys.TRAIN:
        host_call, train_op = train_function(pruning_method, loss, output_dir,
                                             use_tpu)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf_estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Calculate eval metrics."""
            logging.info('In metric function')
            eval_metrics = {}
            predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)
            eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
                labels=labels, predictions=predictions)

            return eval_metrics

        def vd_metric_fn(labels, logits, global_sparsity):
            eval_metrics = metric_fn(labels, logits)
            eval_metrics['global_sparsity'] = tf.metrics.mean(global_sparsity)
            return eval_metrics

        tensors = [labels, logits]
        metric_function = metric_fn

        if FLAGS.pruning_method == 'variational_dropout':
            batch_size = labels.shape[0]
            ones = tf.ones([batch_size, 1])
            mask_metrics = utils.add_vd_pruning_summaries(
                threshold=FLAGS.log_alpha_threshold)
            tensors.append(mask_metrics['global_sparsity'] * ones)
            metric_function = vd_metric_fn

        eval_metrics = (metric_function, tensors)

    # define a custom scaffold function to enable initializing the mask from an
    # already trained checkpoint.
    def initialize_mask_from_ckpt(ckpt_path):
        """Load mask from an existing checkpoint."""
        model_dir = FLAGS.output_dir
        already_has_ckpt = model_dir and tf.train.latest_checkpoint(
            model_dir) is not None
        if already_has_ckpt:
            tf.logging.info(
                'Training already started on this model, not loading masks from'
                'previously trained model')
            return

        reader = tf.train.NewCheckpointReader(ckpt_path)
        mask_names = reader.get_variable_to_shape_map().keys()
        mask_names = [x for x in mask_names if x.endswith('mask')]

        variable_map = {}
        for var in tf.global_variables():
            var_name = var.name.split(':')[0]
            if var_name in mask_names:
                tf.logging.info('Loading mask variable from checkpoint: %s',
                                var_name)
                variable_map[var_name] = var
            elif 'mask' in var_name:
                tf.logging.info(
                    'Cannot find mask variable in checkpoint, skipping: %s',
                    var_name)
        tf.train.init_from_checkpoint(ckpt_path, variable_map)

    def initialize_parameters_from_ckpt(ckpt_path):
        """Load parameters from an existing checkpoint."""
        model_dir = FLAGS.output_dir
        already_has_ckpt = model_dir and tf.train.latest_checkpoint(
            model_dir) is not None
        if already_has_ckpt:
            tf.logging.info(
                'Training already started on this model, not loading masks from'
                'previously trained model')
            return

        reader = tf.train.NewCheckpointReader(ckpt_path)
        param_names = reader.get_variable_to_shape_map().keys()
        param_names = [x for x in param_names if not x.endswith('mask')]

        variable_map = {}
        for var in tf.global_variables():
            var_name = var.name.split(':')[0]
            if var_name in param_names:
                tf.logging.info(
                    'Loading parameter variable from checkpoint: %s', var_name)
                variable_map[var_name] = var
            elif 'mask' not in var_name:
                tf.logging.info(
                    'Cannot find parameter variable in checkpoint, skipping: %s',
                    var_name)
        tf.train.init_from_checkpoint(ckpt_path, variable_map)

    if FLAGS.pruning_method == 'scratch':
        if FLAGS.load_mask_dir:

            def scaffold_fn():
                initialize_mask_from_ckpt(FLAGS.load_mask_dir)
                if FLAGS.initial_value_checkpoint:
                    initialize_parameters_from_ckpt(
                        FLAGS.initial_value_checkpoint)
                return tf.train.Scaffold()
        else:
            raise ValueError(
                'Must supply a mask directory to use scratch method')
    else:
        scaffold_fn = None

    return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                        loss=loss,
                                        train_op=train_op,
                                        host_call=host_call,
                                        eval_metrics=eval_metrics,
                                        scaffold_fn=scaffold_fn)
Ejemplo n.º 3
0
def resnet_model_fn_w_pruning(features, labels, mode, params):
    """The model_fn for ResNet-50 with pruning.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    A TPUEstimatorSpec for the model
  """

    width = 1. if FLAGS.width <= 0 else FLAGS.width

    if isinstance(features, dict):
        features = features['feature']

    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    training_method = params['training_method']
    use_tpu = params['use_tpu']

    def build_network():
        """Construct the network in the graph."""
        if FLAGS.model_architecture == 'mobilenet_v2':
            network_func = functools.partial(
                mobilenetv2_model.mobilenet_v2,
                expansion_factor=FLAGS.expansion_factor)
        elif FLAGS.model_architecture == 'mobilenet_v1':
            network_func = functools.partial(mobilenetv1_model.mobilenet_v1)
        elif FLAGS.model_architecture == 'resnet':
            prune_first_layer = FLAGS.first_layer_sparsity != 0.
            network_func = functools.partial(
                resnet_model.resnet_v1_,
                resnet_depth=FLAGS.resnet_depth,
                init_method=FLAGS.init_method,
                end_sparsity=FLAGS.end_sparsity,
                prune_first_layer=prune_first_layer)
        elif FLAGS.model_architecture.startswith('vgg'):
            network_func = functools.partial(vgg.vgg,
                                             vgg_type=FLAGS.model_architecture,
                                             init_method=FLAGS.init_method,
                                             end_sparsity=FLAGS.end_sparsity)
        else:
            raise ValueError('Unknown archiecture ' + FLAGS.archiecture)
        prune_last_layer = FLAGS.last_layer_sparsity != 0.
        network = network_func(
            num_classes=FLAGS.num_label_classes,
            # TODO remove the pruning_method option.
            pruning_method='threshold',
            width=width,
            prune_last_layer=prune_last_layer,
            data_format=FLAGS.data_format,
            weight_decay=FLAGS.weight_decay)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        if FLAGS.use_batch_statistics:
            is_training = True
        return network(inputs=features, is_training=is_training)

    if FLAGS.precision == 'bfloat16':
        with contrib_tpu.bfloat16_scope():
            logits = build_network()
        logits = tf.cast(logits, tf.float32)
    elif FLAGS.precision == 'float32':
        logits = build_network()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })
    output_dir = params['output_dir']
    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)

    # make sure we reuse the same label smoothing parameter is we're doing
    # scratch / lottery ticket experiments.
    label_smoothing = FLAGS.label_smoothing
    if FLAGS.training_method == 'scratch' and FLAGS.load_mask_dir:
        scratch_stripped = FLAGS.load_mask_dir.replace('/scratch', '')
        label_smoothing = float(scratch_stripped.split('/')[15])
        tf.logging.info('LABEL SMOOTHING USED: %.2f' % label_smoothing)
    cross_loss = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=label_smoothing)
    # Add regularization loss term
    reg_loss = tf.losses.get_regularization_loss()
    loss = cross_loss + reg_loss

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        host_call, train_op = train_function(training_method, loss, cross_loss,
                                             reg_loss, output_dir, use_tpu)
    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits, cross_loss, reg_loss):
            """Calculate eval metrics."""
            logging.info('In metric function')
            eval_metrics = {}
            predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)
            eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss)
            eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss)
            eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
                labels=labels, predictions=predictions)

            # If evaluating once lets also calculate sparsities.
            if FLAGS.mode == 'eval_once':
                sparsity_summaries = utils.mask_summaries(pruning.get_masks())
                # We call mean on a scalar to create tensor, update_op pairs.
                sparsity_summaries = {
                    k: tf.metrics.mean(v)
                    for k, v in sparsity_summaries.items()
                }
                eval_metrics.update(sparsity_summaries)
            return eval_metrics

        tensors = [
            labels, logits,
            tf.broadcast_to(cross_loss, tf.shape(labels)),
            tf.broadcast_to(reg_loss, tf.shape(labels))
        ]

        eval_metrics = (metric_fn, tensors)

    if (FLAGS.load_mask_dir
            and FLAGS.training_method not in NO_MASK_INIT_METHODS):

        def scaffold_fn():
            """For initialization, passed to the estimator."""
            utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir,
                                                  FLAGS.output_dir,
                                                  MASK_SUFFIX)
            if FLAGS.initial_value_checkpoint:
                utils.initialize_parameters_from_ckpt(
                    FLAGS.initial_value_checkpoint, FLAGS.output_dir,
                    PARAM_SUFFIXES)
            return tf.train.Scaffold()
    elif (FLAGS.mask_init_method
          and FLAGS.training_method not in NO_MASK_INIT_METHODS):

        def scaffold_fn():
            """For initialization, passed to the estimator."""
            if FLAGS.initial_value_checkpoint:
                utils.initialize_parameters_from_ckpt(
                    FLAGS.initial_value_checkpoint, FLAGS.output_dir,
                    PARAM_SUFFIXES)
            all_masks = pruning.get_masks()
            assigner = sparse_utils.get_mask_init_fn(
                all_masks,
                FLAGS.mask_init_method,
                FLAGS.end_sparsity,
                CUSTOM_SPARSITY_MAP,
                erk_power_scale=FLAGS.erk_power_scale)

            def init_fn(scaffold, session):
                """A callable for restoring variable from a checkpoint."""
                del scaffold  # Unused.
                session.run(assigner)

            return tf.train.Scaffold(init_fn=init_fn)
    else:
        assert FLAGS.training_method in NO_MASK_INIT_METHODS
        scaffold_fn = None
        tf.logging.info('No mask is set, starting dense.')

    return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                        loss=loss,
                                        train_op=train_op,
                                        host_call=host_call,
                                        eval_metrics=eval_metrics,
                                        scaffold_fn=scaffold_fn)
Ejemplo n.º 4
0
    def model_fn(features, labels, mode, params=None):
        """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
        params = params or {}
        total_loss, train_op, detections, export_outputs = None, None, None, None
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Make sure to set the Keras learning phase. True during training,
        # False for inference.
        tf.keras.backend.set_learning_phase(is_training)
        # Set policy for mixed-precision training with Keras-based models.
        if use_tpu and train_config.use_bfloat16:
            from tensorflow.python.keras.engine import base_layer_utils  # pylint: disable=g-import-not-at-top
            # Enable v2 behavior, as `mixed_bfloat16` is only supported in TF 2.0.
            base_layer_utils.enable_v2_dtype_behavior()
            tf.compat.v2.keras.mixed_precision.experimental.set_policy(
                'mixed_bfloat16')
        detection_model = detection_model_fn(is_training=is_training,
                                             add_summaries=(not use_tpu))
        scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            labels = unstack_batch(labels,
                                   unpad_groundtruth_tensors=train_config.
                                   unpad_groundtruth_tensors)
        elif mode == tf.estimator.ModeKeys.EVAL:
            # For evaling on train data, it is necessary to check whether groundtruth
            # must be unpadded.
            boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes].
                           get_shape().as_list())
            unpad_groundtruth_tensors = boxes_shape[
                1] is not None and not use_tpu
            labels = unstack_batch(
                labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            provide_groundtruth(detection_model, labels)

        preprocessed_images = features[fields.InputDataFields.image]

        side_inputs = detection_model.get_side_inputs(features)

        if use_tpu and train_config.use_bfloat16:
            with contrib_tpu.bfloat16_scope():
                prediction_dict = detection_model.predict(
                    preprocessed_images,
                    features[fields.InputDataFields.true_image_shape],
                    **side_inputs)
                prediction_dict = ops.bfloat16_to_float32_nested(
                    prediction_dict)
        else:
            prediction_dict = detection_model.predict(
                preprocessed_images,
                features[fields.InputDataFields.true_image_shape],
                **side_inputs)

        def postprocess_wrapper(args):
            return detection_model.postprocess(args[0], args[1])

        if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
            if use_tpu and postprocess_on_cpu:
                detections = contrib_tpu.outside_compilation(
                    postprocess_wrapper,
                    (prediction_dict,
                     features[fields.InputDataFields.true_image_shape]))
            else:
                detections = postprocess_wrapper(
                    (prediction_dict,
                     features[fields.InputDataFields.true_image_shape]))

        if mode == tf.estimator.ModeKeys.TRAIN:
            load_pretrained = hparams.load_pretrained if hparams else False
            if train_config.fine_tune_checkpoint and load_pretrained:
                if not train_config.fine_tune_checkpoint_type:
                    # train_config.from_detection_checkpoint field is deprecated. For
                    # backward compatibility, set train_config.fine_tune_checkpoint_type
                    # based on train_config.from_detection_checkpoint.
                    if train_config.from_detection_checkpoint:
                        train_config.fine_tune_checkpoint_type = 'detection'
                    else:
                        train_config.fine_tune_checkpoint_type = 'classification'
                asg_map = detection_model.restore_map(
                    fine_tune_checkpoint_type=train_config.
                    fine_tune_checkpoint_type,
                    load_all_detection_checkpoint_vars=(
                        train_config.load_all_detection_checkpoint_vars))
                available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        asg_map,
                        train_config.fine_tune_checkpoint,
                        include_global_step=False))
                if use_tpu:

                    def tpu_scaffold():
                        tf.train.init_from_checkpoint(
                            train_config.fine_tune_checkpoint,
                            available_var_map)
                        return tf.train.Scaffold()

                    scaffold_fn = tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(
                        train_config.fine_tune_checkpoint, available_var_map)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            if (mode == tf.estimator.ModeKeys.EVAL
                    and eval_config.use_dummy_loss_in_eval):
                total_loss = tf.constant(1.0)
                losses_dict = {'Loss/total_loss': total_loss}
            else:
                losses_dict = detection_model.loss(
                    prediction_dict,
                    features[fields.InputDataFields.true_image_shape])
                losses = [loss_tensor for loss_tensor in losses_dict.values()]
                if train_config.add_regularization_loss:
                    regularization_losses = detection_model.regularization_losses(
                    )
                    if use_tpu and train_config.use_bfloat16:
                        regularization_losses = ops.bfloat16_to_float32_nested(
                            regularization_losses)
                    if regularization_losses:
                        regularization_loss = tf.add_n(
                            regularization_losses, name='regularization_loss')
                        losses.append(regularization_loss)
                        losses_dict[
                            'Loss/regularization_loss'] = regularization_loss
                total_loss = tf.add_n(losses, name='total_loss')
                losses_dict['Loss/total_loss'] = total_loss

            if 'graph_rewriter_config' in configs:
                graph_rewriter_fn = graph_rewriter_builder.build(
                    configs['graph_rewriter_config'], is_training=is_training)
                graph_rewriter_fn()

            # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
            # can write learning rate summaries on TPU without host calls.
            global_step = tf.train.get_or_create_global_step()
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)

        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                training_optimizer = contrib_tpu.CrossShardOptimizer(
                    training_optimizer)

            # Optionally freeze some layers by setting their gradients to be zero.
            trainable_variables = None
            include_variables = (train_config.update_trainable_variables
                                 if train_config.update_trainable_variables
                                 else None)
            exclude_variables = (train_config.freeze_variables
                                 if train_config.freeze_variables else None)
            trainable_variables = contrib_framework.filter_variables(
                tf.trainable_variables(),
                include_patterns=include_variables,
                exclude_patterns=exclude_variables)

            clip_gradients_value = None
            if train_config.gradient_clipping_by_norm > 0:
                clip_gradients_value = train_config.gradient_clipping_by_norm

            if not use_tpu:
                for var in optimizer_summary_vars:
                    tf.summary.scalar(var.op.name, var)
            summaries = [] if use_tpu else None
            if train_config.summarize_gradients:
                summaries = [
                    'gradients', 'gradient_norm', 'global_gradient_norm'
                ]
            train_op = contrib_layers.optimize_loss(
                loss=total_loss,
                global_step=global_step,
                learning_rate=None,
                clip_gradients=clip_gradients_value,
                optimizer=training_optimizer,
                update_ops=detection_model.updates(),
                variables=trainable_variables,
                summaries=summaries,
                name='')  # Preventing scope prefix on all variables.

        if mode == tf.estimator.ModeKeys.PREDICT:
            exported_output = exporter_lib.add_output_tensor_nodes(detections)
            export_outputs = {
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
                tf.estimator.export.PredictOutput(exported_output)
            }

        eval_metric_ops = None
        scaffold = None
        if mode == tf.estimator.ModeKeys.EVAL:
            class_agnostic = (fields.DetectionResultFields.detection_classes
                              not in detections)
            groundtruth = _prepare_groundtruth_for_eval(
                detection_model, class_agnostic,
                eval_input_config.max_number_of_boxes)
            use_original_images = fields.InputDataFields.original_image in features
            if use_original_images:
                eval_images = features[fields.InputDataFields.original_image]
                true_image_shapes = tf.slice(
                    features[fields.InputDataFields.true_image_shape], [0, 0],
                    [-1, 3])
                original_image_spatial_shapes = features[
                    fields.InputDataFields.original_image_spatial_shape]
            else:
                eval_images = features[fields.InputDataFields.image]
                true_image_shapes = None
                original_image_spatial_shapes = None

            eval_dict = eval_util.result_dict_for_batched_example(
                eval_images,
                features[inputs.HASH_KEY],
                detections,
                groundtruth,
                class_agnostic=class_agnostic,
                scale_to_absolute=True,
                original_image_spatial_shapes=original_image_spatial_shapes,
                true_image_shapes=true_image_shapes)

            if fields.InputDataFields.image_additional_channels in features:
                eval_dict[fields.InputDataFields.
                          image_additional_channels] = features[
                              fields.InputDataFields.image_additional_channels]

            if class_agnostic:
                category_index = label_map_util.create_class_agnostic_category_index(
                )
            else:
                category_index = label_map_util.create_category_index_from_labelmap(
                    eval_input_config.label_map_path)
            vis_metric_ops = None
            if not use_tpu and use_original_images:
                eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
                    category_index,
                    max_examples_to_draw=eval_config.num_visualizations,
                    max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
                    min_score_thresh=eval_config.min_score_threshold,
                    use_normalized_coordinates=False)
                vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
                    eval_dict)

            # Eval metrics on a single example.
            eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
                eval_config, list(category_index.values()), eval_dict)
            for loss_key, loss_tensor in iter(losses_dict.items()):
                eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
            for var in optimizer_summary_vars:
                eval_metric_ops[var.op.name] = (var, tf.no_op())
            if vis_metric_ops is not None:
                eval_metric_ops.update(vis_metric_ops)
            eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

            if eval_config.use_moving_averages:
                variable_averages = tf.train.ExponentialMovingAverage(0.0)
                variables_to_restore = variable_averages.variables_to_restore()
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    variables_to_restore,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours
                )
                scaffold = tf.train.Scaffold(saver=saver)

        # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
        if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
            return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                scaffold_fn=scaffold_fn,
                                                predictions=detections,
                                                loss=total_loss,
                                                train_op=train_op,
                                                eval_metrics=eval_metric_ops,
                                                export_outputs=export_outputs)
        else:
            if scaffold is None:
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    sharded=True,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                    save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                scaffold = tf.train.Scaffold(saver=saver)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=detections,
                                              loss=total_loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metric_ops,
                                              export_outputs=export_outputs,
                                              scaffold=scaffold)
Ejemplo n.º 5
0
    def model_fn(self, features, labels, mode, config=None, params=None):
        """Estimator model_fn.

    Note, this function overwrites the model_fn of the wrapped t2r_model since
    is replaces specifications with their TPU corresponding calls and introduces
    additional casting conversion after the specification has been verified.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      labels: This is the second item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_specification.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: (Optional tf.estimator.RunConfig or contrib_tpu.RunConfig) Will
        receive what is passed to Estimator in config parameter, or the default
        config (tf.estimator.RunConfig). Allows updating things in your model_fn
        based on  configuration such as num_ps_replicas, or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.

    Raises:
      ValueError: If the mode key is not supported, not in [PREDICT, TRAIN,
        EVAL].

    Returns:
      A TPUEstimatorSpec.
    """

        features = tensorspec_utils.validate_and_pack(
            expected_spec=self.get_feature_specification(mode),
            actual_tensors_or_spec=features,
            ignore_batch=True)
        if labels:
            labels = tensorspec_utils.validate_and_pack(
                expected_spec=self.get_label_specification(mode),
                actual_tensors_or_spec=labels,
                ignore_batch=True)

        # In order to support both TPU and CPU for inference, tensors
        # with dtype=bfloat16 will be casted to float32.
        # Note, despite casting the benefit of bfloat16 are still maintained
        # for TPUs since this operation is a noop on this platform.
        # See http://shortn/_TTg3ZyATRo for rationale.
        if not self._train_in_bfloat16 or (
                mode == tf.estimator.ModeKeys.PREDICT
                or mode == tf.estimator.ModeKeys.EVAL):
            features = tensorspec_utils.cast_bfloat16_to_float32(features)
            if labels is not None:
                labels = tensorspec_utils.cast_bfloat16_to_float32(labels)

        if self._train_in_bfloat16 and mode == tf.estimator.ModeKeys.TRAIN:
            with contrib_tpu.bfloat16_scope():
                inference_outputs = self._t2r_model.inference_network_fn(
                    features, labels, mode, config, params)
        else:
            inference_outputs = self._t2r_model.inference_network_fn(
                features, labels, mode, config, params)

        update_ops = None
        if isinstance(inference_outputs, tuple):
            update_ops = inference_outputs[1]
            inference_outputs = inference_outputs[0]

        if mode == tf.estimator.ModeKeys.PREDICT:
            model_fn_results = self._t2r_model.create_export_outputs_fn(
                features, inference_outputs, mode, config, params)
            export_outputs = None
            if isinstance(model_fn_results, tuple):
                predictions = model_fn_results[0]
                export_outputs = model_fn_results[1]
            elif isinstance(model_fn_results, dict):
                export_outputs = {}
                if len(model_fn_results) == 1:
                    name, output = list(model_fn_results.items())[0]
                    export_outputs[
                        name] = tf.estimator.export.RegressionOutput(output)
                export_outputs[tf.saved_model.signature_constants.
                               DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
                                   tf.estimator.export.PredictOutput(
                                       model_fn_results))
                predictions = model_fn_results
            else:
                raise ValueError(
                    'The create_export_outputs_fn should return a '
                    'tuple(predictions, export_outputs) or predictions.')

            return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                predictions=predictions,
                                                export_outputs=export_outputs)

        train_fn_result = self._t2r_model.model_train_fn(
            features, labels, inference_outputs, mode, config, params)
        if isinstance(train_fn_result, tf.Tensor):
            train_loss = train_fn_result
            train_outputs = {}
        elif isinstance(train_fn_result, tuple):
            train_loss = train_fn_result[0]
            train_outputs = train_fn_result[1]
        else:
            raise ValueError('The model_train_fn should return a '
                             'tuple(loss, train_outputs) or loss.')

        if mode == tf.estimator.ModeKeys.TRAIN:
            # Create the tf.train.Optimizer.
            optimizer = get_cross_shard_optimizer(
                self._t2r_model.create_optimizer(params))

            train_op = self._t2r_model.create_train_op(train_loss, optimizer,
                                                       update_ops,
                                                       train_outputs)

            self._t2r_model.add_summaries(features, labels, inference_outputs,
                                          train_loss, train_outputs, mode,
                                          config, params)

            # For TPUs the init has to happen in a scaffold function. Since the model
            # already contains one implementation which is internal to the model
            # this call is simply wrapped.
            # No new variables are allowed to be added, otherwise
            # we would not initialize these variables.
            # Note, this feature is only available for train to bootstrap a model
            # (partially) from a different model. As soon as this checkpoint is
            # written all other modes will use the local checkpoint within
            # model_dir.

            def create_scaffold_fn():
                """Creates a scaffold instance."""
                self._t2r_model.maybe_init_from_checkpoint()
                # Return the value of the property first since it might be changed.
                scaffold_fn = self._t2r_model.scaffold_fn
                scaffold = scaffold_fn()
                # In order to export asynchronously the saver has to be registered
                # in the graph collection. The scaffold function might register a
                # saver already which is why it is checked here and a saver only
                # added it has none has been added.
                if not tf.get_collection(tf.GraphKeys.SAVERS):
                    # TODO(T2R_CONTRIBUTORS): Switch to using gin config for all saver params.
                    keep_checkpoint_every_n_hours = None
                    max_to_keep = None
                    if config is not None:
                        keep_checkpoint_every_n_hours = config.keep_checkpoint_every_n_hours
                        max_to_keep = config.keep_checkpoint_max
                    saver = abstract_model.gin_configurable_saver(
                        keep_checkpoint_every_n_hours=
                        keep_checkpoint_every_n_hours,
                        max_to_keep=max_to_keep,
                    )
                    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                return scaffold

            training_hooks = []

            # EstimatorSpec has training_chief_hooks, but TPUEstimatorSpec does not,
            # so we have to use training_hooks here and check is_chief.
            if config and config.is_chief:  # pytype: disable=attribute-error
                training_hooks.append(
                    gin_utils.GinConfigSaverHook(config.model_dir,
                                                 summarize_config=True))

            return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                loss=train_loss,
                                                train_op=train_op,
                                                training_hooks=training_hooks,
                                                scaffold_fn=create_scaffold_fn)

        if mode == tf.estimator.ModeKeys.EVAL:
            self._t2r_model.add_summaries(features, labels, inference_outputs,
                                          train_loss, train_outputs, mode,
                                          config, params)
            eval_metrics = self._t2r_model.model_eval_fn(
                features, labels, inference_outputs, train_loss, train_outputs,
                mode, config, params)
            evaluation_hooks = self._t2r_model.get_eval_hooks(config, params)
            if config and config.is_chief:  # pytype: disable=attribute-error
                eval_name = params.get('eval_name', 'eval')  # pytype: disable=attribute-error
                evaluation_hooks.append(
                    gin_utils.GinConfigSaverHook(os.path.join(
                        config.model_dir, eval_name),
                                                 summarize_config=True))

            return contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                loss=train_loss,
                eval_metrics=eval_metrics,
                evaluation_hooks=evaluation_hooks)

        raise ValueError('The mode {} is not supported yet.'.format(mode))
Ejemplo n.º 6
0
def _model_fn(features,
              labels,
              mode,
              params,
              model,
              use_tpu_estimator_spec,
              variable_filter_fn=None):
    """Model defination for the RetinaNet model based on ResNet.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in dataloader.py
    mode: the mode of TPUEstimator/Estimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the RetinaNet model outputs class logits and box regression outputs.
    use_tpu_estimator_spec: Whether to use TPUEstimatorSpec or EstimatorSpec.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
  """

    # In predict mode features is a dict with input as value of the 'inputs'.
    image_info = None
    if (mode == tf.estimator.ModeKeys.PREDICT and isinstance(features, dict)
            and 'inputs' in features):
        image_info = features['image_info']
        labels = None
        if 'labels' in features:
            labels = features['labels']
        features = features['inputs']

    def _model_outputs():
        return model(features,
                     min_level=params['min_level'],
                     max_level=params['max_level'],
                     num_classes=params['num_classes'],
                     num_anchors=len(params['aspect_ratios'] *
                                     params['num_scales']),
                     resnet_depth=params['resnet_depth'],
                     is_training_bn=params['is_training_bn'])

    if params['use_bfloat16']:
        with contrib_tpu.bfloat16_scope():
            cls_outputs, box_outputs = _model_outputs()
            levels = cls_outputs.keys()
            for level in levels:
                cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
                box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
    else:
        cls_outputs, box_outputs = _model_outputs()
        levels = cls_outputs.keys()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        # Postprocess on host; memory layout for NMS on TPU is very inefficient.
        def _predict_postprocess_wrapper(args):
            return _predict_postprocess(*args)

        predictions = contrib_tpu.outside_compilation(
            _predict_postprocess_wrapper,
            (cls_outputs, box_outputs, labels, params))

        # Include resizing information on prediction output to help bbox drawing.
        if image_info is not None:
            predictions.update({
                'image_info':
                tf.identity(image_info, 'ImageInfo'),
            })

        return contrib_tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT,
                                            predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(
                params['resnet_checkpoint'], {
                    '/': 'resnet%s/' % params['resnet_depth'],
                })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_global_step()
    learning_rate = learning_rate_schedule(params['adjusted_learning_rate'],
                                           params['lr_warmup_init'],
                                           params['lr_warmup_step'],
                                           params['first_lr_drop_step'],
                                           params['second_lr_drop_step'],
                                           global_step)
    # cls_loss and box_loss are for logging. only total_loss is optimized.
    total_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs,
                                                    labels, params)
    total_loss += _WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=params['momentum'])
        if params['use_tpu']:
            optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
        else:
            if params['auto_mixed_precision']:
                optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    optimizer)

        # Batch norm requires `update_ops` to be executed alongside `train_op`.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = variable_filter_fn(
            tf.trainable_variables(),
            params['resnet_depth']) if variable_filter_fn else None

        minimize_op = optimizer.minimize(total_loss,
                                         global_step,
                                         var_list=var_list)
        train_op = tf.group(minimize_op, update_ops)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(**kwargs):
            """Returns a dictionary that has the evaluation metrics."""
            batch_size = params['batch_size']
            eval_anchors = anchors.Anchors(params['min_level'],
                                           params['max_level'],
                                           params['num_scales'],
                                           params['aspect_ratios'],
                                           params['anchor_scale'],
                                           params['image_size'])
            anchor_labeler = anchors.AnchorLabeler(eval_anchors,
                                                   params['num_classes'])
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
            coco_metrics = coco_metric_fn(batch_size, anchor_labeler,
                                          params['val_json_file'], **kwargs)

            # Add metrics to output.
            output_metrics = {
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            }
            output_metrics.update(coco_metrics)
            return output_metrics

        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        box_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(box_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'box_loss_repeat': box_loss_repeat,
            'source_ids': labels['source_ids'],
            'groundtruth_data': labels['groundtruth_data'],
            'image_scales': labels['image_scales'],
        }
        add_metric_fn_inputs(params, cls_outputs, box_outputs,
                             metric_fn_inputs)
        eval_metrics = (metric_fn, metric_fn_inputs)

    if use_tpu_estimator_spec:
        return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                            loss=total_loss,
                                            train_op=train_op,
                                            eval_metrics=eval_metrics,
                                            scaffold_fn=scaffold_fn)
    else:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=total_loss,
            # TODO(rostam): Fix bug to get scaffold working.
            # scaffold=scaffold_fn(),
            train_op=train_op)
Ejemplo n.º 7
0
def model_fn(features, labels, mode):
    """Definition for ResNet model."""
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    if FLAGS.transpose_input:
        features = tf.transpose(features,
                                [3, 0, 1, 2])  # Double-transpose trick

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    def create_model():
        """Create the model and compute the logits."""
        if FLAGS.use_keras_model:
            model = tf.keras.applications.resnet50.ResNet50(
                include_top=True,
                weights=None,
                input_tensor=None,
                input_shape=None,
                pooling=None,
                classes=_NUM_CLASSES)
            return model(features, training=is_training)
        else:
            model = resnet_model.resnet_v1(resnet_depth=_RESNET_DEPTH,
                                           num_classes=_NUM_CLASSES,
                                           data_format='channels_last')
            return model(inputs=features, is_training=is_training)

    if FLAGS.precision == 'bfloat16':
        with contrib_tpu.bfloat16_scope():
            logits = create_model()
    else:
        logits = create_model()

    logits = tf.cast(logits, tf.float32)

    if mode == tf.estimator.ModeKeys.PREDICT:
        assert False, 'Not implemented correctly right now!'
        predictions = {'logits': logits}
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                           logits=logits)

    loss = cross_entropy + _WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    if mode == tf.estimator.ModeKeys.EVAL:
        predictions = tf.argmax(logits, axis=1)
        top_1_accuracy = tf.metrics.accuracy(labels, predictions)
        # TODO(priyag): Add this back when in_top_k is supported on TPU.
        # in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
        # top_5_accuracy = tf.metrics.mean(in_top_5)

        eval_metric_ops = {
            'top_1_accuracy': top_1_accuracy,
            # 'top_5_accuracy': top_5_accuracy,
        }

        return tf.estimator.EstimatorSpec(mode,
                                          loss=loss,
                                          eval_metric_ops=eval_metric_ops)

    assert mode == tf.estimator.ModeKeys.TRAIN

    global_step = tf.train.get_or_create_global_step()
    batches_per_epoch = (_NUM_TRAIN_IMAGES /
                         (FLAGS.train_batch_size * FLAGS.num_cores))
    current_epoch = (tf.cast(global_step, tf.float32) / batches_per_epoch)
    learning_rate = learning_rate_schedule(current_epoch)

    if FLAGS.optimizer == 'sgd':
        optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=learning_rate)
    else:
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=_MOMENTUM,
                                               use_nesterov=True)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step=global_step)
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
Ejemplo n.º 8
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
    """Model defination for the RetinaNet model based on ResNet-50.

  Args:
    features: The input images tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: The input labels in a tensor with the same shape as input images.
    mode: The mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: The dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: The FPN segmentation model outputs class logits.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
  """
    def _model_outputs():
        return model(features,
                     min_level=params['min_level'],
                     max_level=params['max_level'],
                     num_classes=params['num_classes'],
                     resnet_depth=params['resnet_depth'],
                     is_training_bn=params['is_training_bn'])

    if params['use_bfloat16']:
        with contrib_tpu.bfloat16_scope():
            cls_outputs = _model_outputs()
            cls_outputs = tf.cast(cls_outputs, tf.float32)
    else:
        cls_outputs = _model_outputs()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'image': features, 'cls_outputs': cls_outputs}
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(
                params['resnet_checkpoint'], {
                    '/': 'resnet%s/' % params['resnet_depth'],
                })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    retinanet_model.update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_global_step()
    learning_rate = retinanet_model.learning_rate_schedule(
        params['adjusted_learning_rate'], params['lr_warmup_init'],
        params['lr_warmup_step'], params['first_lr_drop_step'],
        params['second_lr_drop_step'], global_step)

    cls_loss = _segmentation_loss(cls_outputs, labels, params)
    weight_decay_loss = params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    # Add L2 regularization loss
    total_loss = cls_loss + weight_decay_loss

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=params['momentum'])
        if params['use_tpu']:
            optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = variable_filter_fn(
            tf.trainable_variables(),
            params['resnet_depth']) if variable_filter_fn else None

        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(total_loss,
                                          global_step,
                                          var_list=var_list)
    else:
        train_op = None

    # Evaluation only works on GPU/CPU host and batch_size=1
    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        batch_size = params['batch_size']

        def metric_fn(**kwargs):
            """Creates metric_fn for TPUEstimatorSpec."""
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            total_loss = tf.metrics.mean(kwargs['total_loss_repeat'])
            logits = tf.image.resize_bilinear(kwargs['prediction'],
                                              tf.shape(kwargs['labels'])[1:3],
                                              align_corners=True)
            predictions_with_shape = tf.argmax(logits, 3, output_type=tf.int32)
            predictions = tf.reshape(predictions_with_shape, shape=[-1])

            labels = tf.reshape(kwargs['labels'], shape=[-1])
            # Background class is considered as a class. Not ignored.
            weights = tf.to_float(tf.not_equal(labels, params['ignore_label']))

            # Set ignore_label regions to label 0, because metrics.mean_iou requires
            # range of labels = [0, dataset.num_classes).
            # Note the ignore_lable regions are not evaluated since the corresponding
            # regions contain weights = 0.
            labels = tf.where(tf.equal(labels, params['ignore_label']),
                              tf.zeros_like(labels), labels)

            return {
                'total_loss':
                total_loss,
                'cls_loss':
                cls_loss,
                'miou':
                tf.metrics.mean_iou(predictions,
                                    labels,
                                    params['num_classes'],
                                    weights=weights),
            }

        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                batch_size,
            ]), [batch_size, 1])

        total_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(total_loss, 0), [
                batch_size,
            ]), [batch_size, 1])

        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'total_loss_repeat': total_loss_repeat,
            'prediction': cls_outputs,
            'labels': labels,
        }

        eval_metrics = (metric_fn, metric_fn_inputs)

    return contrib_tpu.TPUEstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        scaffold_fn=scaffold_fn,
    )
Ejemplo n.º 9
0
def model_fn(features, labels, mode, params):
    """Defines how to train, evaluate and predict from the transformer model."""
    with tf.variable_scope("model"):
        inputs, targets = features, labels

        # Create model and get output logits.
        with bfloat16_scope():
            model = transformer.Transformer(
                params, mode == tf.estimator.ModeKeys.TRAIN)

        logits = model(inputs, targets)

        # When in prediction mode, the labels/targets is None. The model output
        # is the prediction
        if mode == tf.estimator.ModeKeys.PREDICT:
            if params["use_tpu"]:
                raise NotImplementedError(
                    "Prediction is not yet supported on TPUs.")
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.PREDICT,
                predictions=logits,
                export_outputs={
                    "translate": tf.estimator.export.PredictOutput(logits)
                })

        # Explicitly set the shape of the logits for XLA (TPU). This is needed
        # because the logits are passed back to the host VM CPU for metric
        # evaluation, and the shape of [?, ?, vocab_size] is too vague. However
        # it is known from Transformer that the first two dimensions of logits
        # are the dimensions of targets. Note that the ambiguous shape of logits is
        # not a problem when computing xentropy, because padded_cross_entropy_loss
        # resolves the shape on the TPU.
        logits.set_shape(targets.shape.as_list() + logits.shape.as_list()[2:])

        # Calculate model loss.
        # xentropy contains the cross entropy loss of every nonpadding token in the
        # targets.
        xentropy, weights = metrics.padded_cross_entropy_loss(
            logits, targets, params["label_smoothing"], params["vocab_size"])
        loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

        # Save loss as named tensor that will be logged with the logging hook.
        tf.identity(loss, "cross_entropy")

        if mode == tf.estimator.ModeKeys.EVAL:
            if params["use_tpu"]:
                # host call functions should only have tensors as arguments.
                # This lambda pre-populates params so that metric_fn is
                # TPUEstimator compliant.
                metric_fn = lambda logits, labels: (metrics.get_eval_metrics(
                    logits, labels, params=params))
                eval_metrics = (metric_fn, [logits, labels])
                return tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    predictions={"predictions": logits},
                    eval_metrics=eval_metrics)
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                predictions={"predictions": logits},
                eval_metric_ops=metrics.get_eval_metrics(
                    logits, labels, params))
        else:
            train_op, metric_dict = get_train_op_and_metrics(loss, params)

            # Epochs can be quite long. This gives some intermediate information
            # in TensorBoard.
            metric_dict["minibatch_loss"] = loss
            if params["use_tpu"]:
                return tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    host_call=tpu_util.construct_scalar_host_call(
                        metric_dict=metric_dict,
                        model_dir=params["model_dir"],
                        prefix="training/"))
            record_scalars(metric_dict)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=loss,
                                              train_op=train_op)