Example #1
0
    def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step):
      """Add all summaries."""
      for k in scalars.keys():
        if not isinstance(scalars[k], tf.Tensor):
          scalars[k] = tf.cast(
              lowering.export_to_tf_tensor(scalars[k]), tf.float32)

      def _host_loss_summary(global_step, tf_loss, **scalars):
        """Add summary.scalar in host side."""
        gs = tf.cast(global_step, tf.int64)
        sum_loss = contrib_summary.scalar(
            '{}_loss'.format(train_or_eval), tf_loss, step=gs)
        sum_ops = [sum_loss.op]
        for description, tf_metric in scalars.iteritems():
          sum_metric = contrib_summary.scalar(
              '{}_{}'.format(train_or_eval, description), tf_metric, step=gs)
          sum_ops.append(sum_metric)
        with tf.control_dependencies(sum_ops):
          return tf.identity(tf_loss)

      if FLAGS.use_tpu:
        # Cast the global step to tf.int32, since
        # outside_compilation does not support tf.int64.
        tf_loss = tpu.outside_compilation(
            _host_loss_summary,
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)
      else:
        tf_loss = _host_loss_summary(
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)

      return tf_loss
Example #2
0
    def compute_eval_dict(features, labels):
        """Compute the evaluation result on an image."""
        # For evaling on train data, it is necessary to check whether groundtruth
        # must be unpadded.
        boxes_shape = (labels[
            fields.InputDataFields.groundtruth_boxes].get_shape().as_list())
        unpad_groundtruth_tensors = (boxes_shape[1] is not None and not use_tpu
                                     and batch_size == 1)
        labels = model_lib.unstack_batch(
            labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

        losses_dict, prediction_dict = _compute_losses_and_predictions_dicts(
            detection_model, features, labels, add_regularization_loss)

        def postprocess_wrapper(args):
            return detection_model.postprocess(args[0], args[1])

        # TODO(kaftan): Depending on how postprocessing will work for TPUS w/
        ## TPUStrategy, may be good to move wrapping to a utility method
        if use_tpu and postprocess_on_cpu:
            detections = contrib_tpu.outside_compilation(
                postprocess_wrapper,
                (prediction_dict,
                 features[fields.InputDataFields.true_image_shape]))
        else:
            detections = postprocess_wrapper(
                (prediction_dict,
                 features[fields.InputDataFields.true_image_shape]))

        class_agnostic = (fields.DetectionResultFields.detection_classes
                          not in detections)
        # TODO(kaftan) (or anyone): move `_prepare_groundtruth_for_eval to eval_util
        ## and call this from there.
        groundtruth = model_lib._prepare_groundtruth_for_eval(  # pylint: disable=protected-access
            detection_model, class_agnostic,
            eval_input_config.max_number_of_boxes)
        use_original_images = fields.InputDataFields.original_image in features
        if use_original_images:
            eval_images = features[fields.InputDataFields.original_image]
            true_image_shapes = tf.slice(
                features[fields.InputDataFields.true_image_shape], [0, 0],
                [-1, 3])
            original_image_spatial_shapes = features[
                fields.InputDataFields.original_image_spatial_shape]
        else:
            eval_images = features[fields.InputDataFields.image]
            true_image_shapes = None
            original_image_spatial_shapes = None

        eval_dict = eval_util.result_dict_for_batched_example(
            eval_images,
            features[inputs.HASH_KEY],
            detections,
            groundtruth,
            class_agnostic=class_agnostic,
            scale_to_absolute=True,
            original_image_spatial_shapes=original_image_spatial_shapes,
            true_image_shapes=true_image_shapes)

        return eval_dict, losses_dict, class_agnostic
Example #3
0
def _model_fn(features,
              labels,
              mode,
              params,
              model,
              use_tpu_estimator_spec,
              variable_filter_fn=None):
    """Model defination for the RetinaNet model based on ResNet.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in dataloader.py
    mode: the mode of TPUEstimator/Estimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the RetinaNet model outputs class logits and box regression outputs.
    use_tpu_estimator_spec: Whether to use TPUEstimatorSpec or EstimatorSpec.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
  """

    # In predict mode features is a dict with input as value of the 'inputs'.
    image_info = None
    if (mode == tf.estimator.ModeKeys.PREDICT and isinstance(features, dict)
            and 'inputs' in features):
        image_info = features['image_info']
        labels = None
        if 'labels' in features:
            labels = features['labels']
        features = features['inputs']

    def _model_outputs():
        return model(features,
                     min_level=params['min_level'],
                     max_level=params['max_level'],
                     num_classes=params['num_classes'],
                     num_anchors=len(params['aspect_ratios'] *
                                     params['num_scales']),
                     resnet_depth=params['resnet_depth'],
                     is_training_bn=params['is_training_bn'])

    if params['use_bfloat16']:
        with contrib_tpu.bfloat16_scope():
            cls_outputs, box_outputs = _model_outputs()
            levels = cls_outputs.keys()
            for level in levels:
                cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
                box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
    else:
        cls_outputs, box_outputs = _model_outputs()
        levels = cls_outputs.keys()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        # Postprocess on host; memory layout for NMS on TPU is very inefficient.
        def _predict_postprocess_wrapper(args):
            return _predict_postprocess(*args)

        predictions = contrib_tpu.outside_compilation(
            _predict_postprocess_wrapper,
            (cls_outputs, box_outputs, labels, params))

        # Include resizing information on prediction output to help bbox drawing.
        if image_info is not None:
            predictions.update({
                'image_info':
                tf.identity(image_info, 'ImageInfo'),
            })

        return contrib_tpu.TPUEstimatorSpec(mode=tf.estimator.ModeKeys.PREDICT,
                                            predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(
                params['resnet_checkpoint'], {
                    '/': 'resnet%s/' % params['resnet_depth'],
                })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_global_step()
    learning_rate = learning_rate_schedule(params['adjusted_learning_rate'],
                                           params['lr_warmup_init'],
                                           params['lr_warmup_step'],
                                           params['first_lr_drop_step'],
                                           params['second_lr_drop_step'],
                                           global_step)
    # cls_loss and box_loss are for logging. only total_loss is optimized.
    total_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs,
                                                    labels, params)
    total_loss += _WEIGHT_DECAY * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=params['momentum'])
        if params['use_tpu']:
            optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
        else:
            if params['auto_mixed_precision']:
                optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                    optimizer)

        # Batch norm requires `update_ops` to be executed alongside `train_op`.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = variable_filter_fn(
            tf.trainable_variables(),
            params['resnet_depth']) if variable_filter_fn else None

        minimize_op = optimizer.minimize(total_loss,
                                         global_step,
                                         var_list=var_list)
        train_op = tf.group(minimize_op, update_ops)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(**kwargs):
            """Returns a dictionary that has the evaluation metrics."""
            batch_size = params['batch_size']
            eval_anchors = anchors.Anchors(params['min_level'],
                                           params['max_level'],
                                           params['num_scales'],
                                           params['aspect_ratios'],
                                           params['anchor_scale'],
                                           params['image_size'])
            anchor_labeler = anchors.AnchorLabeler(eval_anchors,
                                                   params['num_classes'])
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
            coco_metrics = coco_metric_fn(batch_size, anchor_labeler,
                                          params['val_json_file'], **kwargs)

            # Add metrics to output.
            output_metrics = {
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            }
            output_metrics.update(coco_metrics)
            return output_metrics

        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        box_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(box_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'box_loss_repeat': box_loss_repeat,
            'source_ids': labels['source_ids'],
            'groundtruth_data': labels['groundtruth_data'],
            'image_scales': labels['image_scales'],
        }
        add_metric_fn_inputs(params, cls_outputs, box_outputs,
                             metric_fn_inputs)
        eval_metrics = (metric_fn, metric_fn_inputs)

    if use_tpu_estimator_spec:
        return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                            loss=total_loss,
                                            train_op=train_op,
                                            eval_metrics=eval_metrics,
                                            scaffold_fn=scaffold_fn)
    else:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            loss=total_loss,
            # TODO(rostam): Fix bug to get scaffold working.
            # scaffold=scaffold_fn(),
            train_op=train_op)
Example #4
0
    def model_fn(features, labels, mode, config, params):
        """Estimator model function."""

        # Not sure why it does this?
        del labels
        del config
        del params

        tf.get_variable_scope().set_initializer(
            tf.variance_scaling_initializer(1.0,
                                            mode="fan_avg",
                                            distribution="uniform"))

        # PREDICTION (e.g. evaluate)
        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions, _, _ = model_params.estimator_prediction_fn(features)

            if include_features_in_predictions:
                predictions.update(features)

            if decode_keys:
                # Decode the raw ids into strings in prediction.
                def decode_host_call(tensor_dict):
                    for key in decode_keys:
                        predictions[key] = public_parsing_ops.decode(
                            tensor_dict[key], model_params.vocab_filename,
                            model_params.encoder_type)
                    return tensor_dict

                contrib_tpu.outside_compilation(decode_host_call, predictions)
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  predictions=predictions)

        # TRAINING
        training = mode == tf.estimator.ModeKeys.TRAIN
        # use_tpu is false by default so this skips
        if use_tpu and model_params.use_bfloat16:
            with contrib_tpu.bfloat16_scope():
                loss, outputs = model_params.model()(features, training)
        else:
            XENT_loss, outputs = model_params.model()(features, training)
            # XENT_loss, outputs = model_params.model().double_sampling(features, training, model_params.batch_size,
            #                                                           features["targets"].get_shape().as_list()[1],
            #                                                           mixed=True)

        # TPU requires outputs all have batch dimension and doesn't handle scalar.
        # Tile all scalars to 1 dimension vector.
        outputs = _tile_scalar_to_batch_size(outputs, model_params.batch_size)

        # Create optimizer and define learning rate
        if mode == tf.estimator.ModeKeys.TRAIN:
            init_lr = model_params.learning_rate
            global_step = tf.train.get_global_step()
            lr = init_lr / 0.01 * tf.rsqrt(
                tf.maximum(tf.to_float(global_step), 10000))
            if train_init_checkpoint:
                lr = tf.minimum(
                    tf.to_float(global_step + 1) / train_warmup_steps *
                    init_lr, lr)

            optimizer = adafactor.AdafactorOptimizer(
                learning_rate=lr,
                decay_rate=adafactor.adafactor_decay_rate_pow(0.8),
                beta1=0.0)
            if use_tpu:
                optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

            ###############################################################################################################
            ##### VARIABLES ###############################################################################################
            # Create index tensors to stack and get corresponding probabilities from logp
            # max_seq_len = outputs["targets"].get_shape().as_list()[1]
            # sequence_index = tf.constant(np.arange(0, max_seq_len))
            # batch_index = tf.constant(np.zeros(sequence_index.get_shape().as_list()[0]), dtype=tf.int64)

            ##### I.I.D SAMPLING ##########################################################################################
            """ Here we sample the tokens that are produced by teacher forcing. """
            # Normalise logits to log-prob, and compute Gumbel samples with location
            # logit_probs = tf.math.softmax(outputs["logits"], axis=2)  # should not be x <= 0
            # clipped_logit_probs = tf.clip_by_value(logit_probs, 1e-8, 1.0)
            # logp = tf.log(clipped_logit_probs)

            # RETURNS TEACHER FORCING SAMPLED TOKEN VARIATIONS
            # argmax_logp_index, soft_logp_index, topk_out, z = iid_sampling(logp, max_seq_len, greedy=True, soft=False,
            #                                                                topk=False, k=2)
            # topk_probs, topk_indices = topk_out
            # TEST SAMPLING METHODS PROVIDED BY PEGASUS
            # sampled_BxT = iid_process_logits(outputs["logits"], max_seq_len, model_params.batch_size,
            #                                  outputs["logits"].get_shape().as_list()[-1],
            #                                  top_k=0, top_p=0.9, temperature=1.0)

            ##### DECODER SAMPLING ########################################################################################
            """ Here we sample the tokens using the decoder. Beam size == 1. 
            PREDS: IDs
            LOGP: transformed logits
            SCORE: scalar score using RISK trick
            LOGP: [BxTxV] beam logp
            LOGITS: [BxTxV] beam logits
            the dictionary contains the following keys: {ids, logp_BxT, sent_score, logp_BxTxV}
      # Note: the logp_BxTxV are analogous to z -> should be used for RELAX, preds are the BxT of these -> b=H(z), and
      # logp are the corresponding values (score is normalised to sentence score).
      """
            # greedy_beam_params = {"_beam": 3, "top_k": 0, "top_p": 0.0, "temperature": 0.0}
            # random_beam_params = {"_beam": 3, "top_k": 0, "top_p": 0.0, "temperature": 1.0}
            # topk_beam_params = {"_beam": 3, "top_k": 10000, "top_p": 0.0, "temperature": 1.0}
            # topp_beam_params = {"_beam": 3, "top_k": 0, "top_p": 0.9, "temperature": 1.0}

            # greedy_dict = non_beam_sampling(model_params, features, max_seq_len,
            #                                 beam_params=greedy_beam_params, sentence_score=False)
            # random_dict = non_beam_sampling(model_params, features, max_seq_len,
            #                                 beam_params=random_beam_params, sentence_score=False)
            # topk_dict = non_beam_sampling(model_params, features, max_seq_len,
            #                               beam_params=topk_beam_params, sentence_score=False)
            # topp_dict = non_beam_sampling(model_params, features, max_seq_len,
            #                               beam_params=topp_beam_params, sentence_score=False)

            # BEAM SEARCH
            # greedy_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index,
            #                             beam_params=greedy_beam_params)
            # random_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index,
            #                             beam_params=random_beam_params)
            # topk_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index,
            #                           beam_params=topk_beam_params)
            # topp_dict = beam_sampling(model_params, features, max_seq_len, batch_index, sequence_index,
            #                           beam_params=topp_beam_params)

            ##### RELAX VARIABLES #########################################################################################
            """ Here we create the variables for RELAX. Pass in the logp, logits, and z that has already been 
      sampled/created from manipulation. Will return z_tilde [BxTxV] and logp(b) [BxT]. """
            # TEACHER FORCING SAMPLING
            # z_tilde, logp_b = create_variables(z, logp, batch_index, sequence_index, clipped_logit_probs)

            # DECODER SAMPLING -> sample_b is already argmaxed in decode loop
            # z_tilde, logp_b = create_variables_from_samples(random_dict["logits_BxTxV"], random_dict["logp_BxTxV"],
            #                                                 random_dict["ids"], batch_index, sequence_index)

            ##### TEXT AND ROUGE ##########################################################################################
            """ Here we first convert sequences to text, and calculate corresponding rouge scores/losses. """
            # target_text = rouge_decoding(outputs["targets"], model_params)  # TARGET SAMPLES
            # argmax_pred_text = rouge_decoding(argmax_logp_index, model_params)  # ARGMAX SAMPLES
            # soft_pred_text = rouge_decoding(soft_logp_index, model_params)  # SOFTMAX SAMPLES
            # additional_pred_text = rouge_decoding(sampled_BxT, model_params)  # ADDITIONAL SAMPLES

            # Token-level ROUGE
            # ROUGE_token = tf.py_function(rouge_token,(outputs["targets"], random_dict["ids"], 0, 0), tf.float32)

            # CALCULATE ROUGE LOSS: ROUGE score -> ROUGE loss = -ROUGE score
            # NOTE: for ROUGE variant, change value (0: precision, 1: recall, 2: f1)
            # rouge_loss_argmax = -tf.py_function(evaluate_rl, (target_text, argmax_pred_text, 2), tf.float32)
            # rouge_loss_soft = -tf.py_function(evaluate_rl, (target_text, soft_pred_text, 2), tf.float32)
            # rouge_loss_extra = -tf.py_function(evaluate_rl, (target_text, additional_pred_text, 2), tf.float32)

            ##### REINFORCE LOSS ##########################################################################################
            """ Calculate standard REINFORCE loss. Can be document-level (score using RISK trick), or token-level [BxT]. """
            # FIND CORRESPONDING LOG_PROBS OF THE I.I.D SAMPLED TOKENS
            # ARGMAX -> logp(argmax(y))
            # argmax_logp = iid_log_probs(argmax_logp_index, batch_index, sequence_index, logp)
            # SOFTMAX -> logp(sample_y)
            # softmax_logp = iid_log_probs(soft_logp_index, batch_index, sequence_index, logp)
            # ADDITIONAL
            # additional_logp = iid_log_probs(sampled_BxT, batch_index, sequence_index, logp)

            # CHANGE BELOW IF USING DECODER SAMPLED TOKENS/SCORES
            # weight the logp by ROUGE score (neg ROUGE_loss), sum values
            # reinforce_loss = tf.reduce_sum(tf.multiply(rouge_loss_argmax, argmax_logp))

            ##### REINFORCE w/ BASELINE ###################################################################################
            """ Calculate RwB using Socher's loss function (2017). Optional: use a Q_func as baseline. """
            # improve the probs of the SOFT labels (soft - hard)*soft_logp
            # improve the probs of the HARD labels (hard - soft)*hard_logp

            # BASELINE: CONTROL VARIATE
            # ffn_output = control_variate(source, targets)
            # with tf.variable_scope("Q_func"):
            #   cv = rwb_Q_func(tf.reshape(softmax_logp, [1, 32]), tf.reshape(additional_logp, [1, 32]))

            # cv_loss = tf.reduce_mean(tf.square(tf.subtract(rouge_loss_argmax, cv)))

            # loss_difference = tf.subtract(rouge_loss_soft, rouge_loss_argmax)
            # reinforce_baseline = tf.reduce_sum(tf.multiply(loss_difference, softmax_logp))

            # BASELINE: HINGE LOSS
            # rouge_soft = -rouge_loss_soft
            # rouge_hard = -rouge_loss_argmax
            # hinge = -tf.maximum((rouge_soft - rouge_hard), 0)
            # hinge_baseline = tf.reduce_sum(tf.multiply(hinge, softmax_logp))

            ##### REINFORCE w/ THRESHOLD ##################################################################################
            """ Calculate REINFORCE with a constant threshold as the baseline. """
            # we take output of ROUGE score as ROUGE_loss = -ROUGE score
            # intermediate_loss = tf.reduce_sum(tf.multiply(tf.subtract(0.3, -rouge_loss_argmax), argmax_logp))

            ##### EXPECTED RISK MINIMISATION ##############################################################################
            """ Calculate the RISK loss using n sequences from sampling process. """
            # L_risk = risk_loss(model_params.batch_size, max_seq_len,
            #                    rouge_losses=[rouge_loss_argmax, rouge_loss_soft, rouge_loss_extra],
            #                    logps=[topk_dict["logp1"], topk_dict["logp2"], topk_dict["logp3"]], n=3)

            ##### MIXED LOSS ##############################################################################################
            """ Implement a mixed loss function that is weighted by an alpha term. """
            # combined_loss = tf.math.add(tf.multiply(tf.constant(0.3, dtype=tf.float32), XENT_loss),
            #                             tf.multiply(tf.constant(0.7, dtype=tf.float32), L_risk))

            # OR conditional loss switch
            # constraint = tf.random_uniform(shape=(), minval=0, maxval=1, dtype=tf.float32)
            # combined_loss = tf.cond(constraint > 0.8, lambda: hard_reinforce_loss, lambda: XENT_loss)

            ##### RELAX CONTROL VARIATE ###################################################################################
            """ Prepare the target sequence for use in the control variate. """
            # z = random_dict["logp_BxTxV"]
            # z_target, zt_target = create_cv_target(outputs, batch_index, sequence_index, z, z_tilde)

            ##### RELAX LOSS ##############################################################################################
            """ Manipulate z and z_tilde using the Q_func to mimic ROUGE loss. """
            # with tf.variable_scope("Q_func"):
            #     c_z = Q_func(z, z_target)

            # with tf.variable_scope("Q_func", reuse=True):
            #     c_z_tilde = Q_func(z_tilde, zt_target)

            # Formulate RELAX as a loss function
            # f_y = rouge_loss_soft  # negative for loss (defined above)
            # c_z_tilde1 = tf.stop_gradient(tf.identity(c_z_tilde))  # clone, detach, stop grad
            # L_relax = tf.reduce_sum(((f_y - c_z_tilde1)*logp_b) - c_z_tilde + c_z)

            # OR construct gradient estimator
            # theta = [tv for tv in tf.trainable_variables() if "Q_func" not in tv.name]
            # d_logp_d_theta = tf.gradients(logp_b, theta)[0]  # logp
            # d_c_z_tilde_d_theta = tf.gradients(c_z_tilde, theta)[0]
            # d_c_z_d_theta = tf.gradients(c_z, theta)[0]
            # relax = tf.reduce_sum(f_y - c_z_tilde)*d_logp_d_theta - d_c_z_tilde_d_theta + d_c_z_d_theta

            # relax = tf.gradients(L_relax, theta)[0]

            # Calculate the first optimization step with loss
            # list_of_gradient_variable_pairs = optimizer.compute_gradients(L_relax)
            # train_op = optimizer.apply_gradients(list_of_gradient_variable_pairs, global_step=global_step)

            # Variance reduction objective
            # variance_loss = tf.reduce_mean(tf.square(relax), name="variance_loss")

            # initialise adafactor again for variance optimiser
            # var_opt = adafactor.AdafactorOptimizer(
            #           learning_rate=lr,
            #           decay_rate=adafactor.adafactor_decay_rate_pow(0.8),
            #           beta1=0.0)

            # est_params = [eta, log_temperature]  # TODO: REBAR implementation

            # Adds the parameters of the FFNN
            # nn_params = [tv for tv in tf.trainable_variables() if "Q_func" in tv.name]
            # est_params = nn_params
            # est_params = est_params + nn_params  # TODO: REBAR implementation

            # Additional optimization step
            # var_gradvars = var_opt.compute_gradients(variance_loss, var_list=est_params)
            # var_train_op = var_opt.apply_gradients(var_gradvars)

            # This may allow for both train ops to be passed in the return statement below?
            # with tf.control_dependencies([train_op, var_train_op]):
            #     train_op = tf.no_op()

            ###############################################################################################################
            # Calculate gradients
            # If freezing layers, only optimise wrt certain layers (find names) - speeds up, worsens performance
            # last_params = [tv for tv in tf.trainable_variables() if "decoder/LayerNorm/" in tv.name]
            # list_of_gradient_variable_pairs = optimizer.compute_gradients(combined_loss, var_list=last_params)

            list_of_gradient_variable_pairs = optimizer.compute_gradients(
                XENT_loss)
            train_op = optimizer.apply_gradients(
                list_of_gradient_variable_pairs, global_step=global_step)

            tf.logging.set_verbosity(tf.logging.INFO)
            # Debugging steps - add into logging hook directly if needed
            # tf.debugging.check_numerics(sum_logp, "DEBUG: sum_logp has a NaN")

            logging_hook = tf.train.LoggingTensorHook(
                {
                    "loss": XENT_loss,
                    # "variance_loss": variance_loss,
                    # "cv_loss": cv_loss,
                    "learning_rate": lr,
                    "global_step": global_step,
                },
                every_n_iter=5)

            # This is the configured estimator function that is returned to train the model
            return tpu_estimator.TPUEstimatorSpec(
                mode=mode,
                loss=XENT_loss,
                train_op=train_op,
                training_hooks=[logging_hook],
                scaffold_fn=_load_vars_from_checkpoint(use_tpu,
                                                       train_init_checkpoint),
                host_call=add_scalars_to_summary(
                    model_dir,
                    {
                        "learning_rate": lr,
                        # "rouge_loss_hard": rouge_loss_argmax,
                        # "rouge_loss_soft": rouge_loss_soft,
                        # "rouge_loss_extra": rouge_loss_extra,
                        # "reinforce_loss": reinforce_loss,
                        # "risk_loss": L_risk,
                        # "XENT_loss": XENT_loss,
                    }))

        # EVALUATION (evaluating the performance)
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = model_params.estimator_eval_metrics_fn(
                features, outputs)
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  loss=XENT_loss,
                                                  eval_metrics=eval_metrics)
Example #5
0
    def model_fn(features, labels, mode, config, params):
        """Estimator model function."""

        del labels
        del config
        del params

        tf.get_variable_scope().set_initializer(
            tf.variance_scaling_initializer(1.0,
                                            mode="fan_avg",
                                            distribution="uniform"))

        if mode == tf.estimator.ModeKeys.PREDICT:
            predictions = model_params.estimator_prediction_fn(features)

            if include_features_in_predictions:
                predictions.update(features)

            if decode_keys:
                # Decode the raw ids into strings in prediction.
                def decode_host_call(tensor_dict):
                    for key in decode_keys:
                        predictions[key] = public_parsing_ops.decode(
                            tensor_dict[key], model_params.vocab_filename,
                            model_params.encoder_type)
                    return tensor_dict

                contrib_tpu.outside_compilation(decode_host_call, predictions)
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  predictions=predictions)

        training = mode == tf.estimator.ModeKeys.TRAIN
        if use_tpu and model_params.use_bfloat16:
            with contrib_tpu.bfloat16_scope():
                loss, outputs = model_params.model()(features, training)
        else:
            loss, outputs = model_params.model()(features, training)

        # TPU requires ouputs all have batch dimension and doesn't handle scalar.
        # Tile all scalars to 1 dimension vector.
        outputs = _tile_scalar_to_batch_size(outputs, model_params.batch_size)

        if mode == tf.estimator.ModeKeys.TRAIN:
            init_lr = model_params.learning_rate
            global_step = tf.train.get_global_step()
            lr = init_lr / 0.01 * tf.rsqrt(
                tf.maximum(tf.to_float(global_step), 10000))
            if train_init_checkpoint:
                lr = tf.minimum(
                    tf.to_float(global_step + 1) / train_warmup_steps *
                    init_lr, lr)

            optimizer = adafactor.AdafactorOptimizer(
                learning_rate=lr,
                decay_rate=adafactor.adafactor_decay_rate_pow(0.8),
                beta1=0.0)
            if use_tpu:
                optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
            train_op = optimizer.minimize(loss, global_step=global_step)

            return tpu_estimator.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                scaffold_fn=_load_vars_from_checkpoint(use_tpu,
                                                       train_init_checkpoint),
                host_call=add_scalars_to_summary(model_dir,
                                                 {"learning_rate": lr}))
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metrics = model_params.estimator_eval_metrics_fn(
                features, outputs)
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  loss=loss,
                                                  eval_metrics=eval_metrics)
Example #6
0
    def model_fn(features, labels, mode, params=None):
        """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
        params = params or {}
        total_loss, train_op, detections, export_outputs = None, None, None, None
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Make sure to set the Keras learning phase. True during training,
        # False for inference.
        tf.keras.backend.set_learning_phase(is_training)
        # Set policy for mixed-precision training with Keras-based models.
        if use_tpu and train_config.use_bfloat16:
            from tensorflow.python.keras.engine import base_layer_utils  # pylint: disable=g-import-not-at-top
            # Enable v2 behavior, as `mixed_bfloat16` is only supported in TF 2.0.
            base_layer_utils.enable_v2_dtype_behavior()
            tf.compat.v2.keras.mixed_precision.experimental.set_policy(
                'mixed_bfloat16')
        detection_model = detection_model_fn(is_training=is_training,
                                             add_summaries=(not use_tpu))
        scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            labels = unstack_batch(labels,
                                   unpad_groundtruth_tensors=train_config.
                                   unpad_groundtruth_tensors)
        elif mode == tf.estimator.ModeKeys.EVAL:
            # For evaling on train data, it is necessary to check whether groundtruth
            # must be unpadded.
            boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes].
                           get_shape().as_list())
            unpad_groundtruth_tensors = boxes_shape[
                1] is not None and not use_tpu
            labels = unstack_batch(
                labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            provide_groundtruth(detection_model, labels)

        preprocessed_images = features[fields.InputDataFields.image]

        side_inputs = detection_model.get_side_inputs(features)

        if use_tpu and train_config.use_bfloat16:
            with contrib_tpu.bfloat16_scope():
                prediction_dict = detection_model.predict(
                    preprocessed_images,
                    features[fields.InputDataFields.true_image_shape],
                    **side_inputs)
                prediction_dict = ops.bfloat16_to_float32_nested(
                    prediction_dict)
        else:
            prediction_dict = detection_model.predict(
                preprocessed_images,
                features[fields.InputDataFields.true_image_shape],
                **side_inputs)

        def postprocess_wrapper(args):
            return detection_model.postprocess(args[0], args[1])

        if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
            if use_tpu and postprocess_on_cpu:
                detections = contrib_tpu.outside_compilation(
                    postprocess_wrapper,
                    (prediction_dict,
                     features[fields.InputDataFields.true_image_shape]))
            else:
                detections = postprocess_wrapper(
                    (prediction_dict,
                     features[fields.InputDataFields.true_image_shape]))

        if mode == tf.estimator.ModeKeys.TRAIN:
            load_pretrained = hparams.load_pretrained if hparams else False
            if train_config.fine_tune_checkpoint and load_pretrained:
                if not train_config.fine_tune_checkpoint_type:
                    # train_config.from_detection_checkpoint field is deprecated. For
                    # backward compatibility, set train_config.fine_tune_checkpoint_type
                    # based on train_config.from_detection_checkpoint.
                    if train_config.from_detection_checkpoint:
                        train_config.fine_tune_checkpoint_type = 'detection'
                    else:
                        train_config.fine_tune_checkpoint_type = 'classification'
                asg_map = detection_model.restore_map(
                    fine_tune_checkpoint_type=train_config.
                    fine_tune_checkpoint_type,
                    load_all_detection_checkpoint_vars=(
                        train_config.load_all_detection_checkpoint_vars))
                available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        asg_map,
                        train_config.fine_tune_checkpoint,
                        include_global_step=False))
                if use_tpu:

                    def tpu_scaffold():
                        tf.train.init_from_checkpoint(
                            train_config.fine_tune_checkpoint,
                            available_var_map)
                        return tf.train.Scaffold()

                    scaffold_fn = tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(
                        train_config.fine_tune_checkpoint, available_var_map)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            if (mode == tf.estimator.ModeKeys.EVAL
                    and eval_config.use_dummy_loss_in_eval):
                total_loss = tf.constant(1.0)
                losses_dict = {'Loss/total_loss': total_loss}
            else:
                losses_dict = detection_model.loss(
                    prediction_dict,
                    features[fields.InputDataFields.true_image_shape])
                losses = [loss_tensor for loss_tensor in losses_dict.values()]
                if train_config.add_regularization_loss:
                    regularization_losses = detection_model.regularization_losses(
                    )
                    if use_tpu and train_config.use_bfloat16:
                        regularization_losses = ops.bfloat16_to_float32_nested(
                            regularization_losses)
                    if regularization_losses:
                        regularization_loss = tf.add_n(
                            regularization_losses, name='regularization_loss')
                        losses.append(regularization_loss)
                        losses_dict[
                            'Loss/regularization_loss'] = regularization_loss
                total_loss = tf.add_n(losses, name='total_loss')
                losses_dict['Loss/total_loss'] = total_loss

            if 'graph_rewriter_config' in configs:
                graph_rewriter_fn = graph_rewriter_builder.build(
                    configs['graph_rewriter_config'], is_training=is_training)
                graph_rewriter_fn()

            # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
            # can write learning rate summaries on TPU without host calls.
            global_step = tf.train.get_or_create_global_step()
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)

        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                training_optimizer = contrib_tpu.CrossShardOptimizer(
                    training_optimizer)

            # Optionally freeze some layers by setting their gradients to be zero.
            trainable_variables = None
            include_variables = (train_config.update_trainable_variables
                                 if train_config.update_trainable_variables
                                 else None)
            exclude_variables = (train_config.freeze_variables
                                 if train_config.freeze_variables else None)
            trainable_variables = contrib_framework.filter_variables(
                tf.trainable_variables(),
                include_patterns=include_variables,
                exclude_patterns=exclude_variables)

            clip_gradients_value = None
            if train_config.gradient_clipping_by_norm > 0:
                clip_gradients_value = train_config.gradient_clipping_by_norm

            if not use_tpu:
                for var in optimizer_summary_vars:
                    tf.summary.scalar(var.op.name, var)
            summaries = [] if use_tpu else None
            if train_config.summarize_gradients:
                summaries = [
                    'gradients', 'gradient_norm', 'global_gradient_norm'
                ]
            train_op = contrib_layers.optimize_loss(
                loss=total_loss,
                global_step=global_step,
                learning_rate=None,
                clip_gradients=clip_gradients_value,
                optimizer=training_optimizer,
                update_ops=detection_model.updates(),
                variables=trainable_variables,
                summaries=summaries,
                name='')  # Preventing scope prefix on all variables.

        if mode == tf.estimator.ModeKeys.PREDICT:
            exported_output = exporter_lib.add_output_tensor_nodes(detections)
            export_outputs = {
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
                tf.estimator.export.PredictOutput(exported_output)
            }

        eval_metric_ops = None
        scaffold = None
        if mode == tf.estimator.ModeKeys.EVAL:
            class_agnostic = (fields.DetectionResultFields.detection_classes
                              not in detections)
            groundtruth = _prepare_groundtruth_for_eval(
                detection_model, class_agnostic,
                eval_input_config.max_number_of_boxes)
            use_original_images = fields.InputDataFields.original_image in features
            if use_original_images:
                eval_images = features[fields.InputDataFields.original_image]
                true_image_shapes = tf.slice(
                    features[fields.InputDataFields.true_image_shape], [0, 0],
                    [-1, 3])
                original_image_spatial_shapes = features[
                    fields.InputDataFields.original_image_spatial_shape]
            else:
                eval_images = features[fields.InputDataFields.image]
                true_image_shapes = None
                original_image_spatial_shapes = None

            eval_dict = eval_util.result_dict_for_batched_example(
                eval_images,
                features[inputs.HASH_KEY],
                detections,
                groundtruth,
                class_agnostic=class_agnostic,
                scale_to_absolute=True,
                original_image_spatial_shapes=original_image_spatial_shapes,
                true_image_shapes=true_image_shapes)

            if fields.InputDataFields.image_additional_channels in features:
                eval_dict[fields.InputDataFields.
                          image_additional_channels] = features[
                              fields.InputDataFields.image_additional_channels]

            if class_agnostic:
                category_index = label_map_util.create_class_agnostic_category_index(
                )
            else:
                category_index = label_map_util.create_category_index_from_labelmap(
                    eval_input_config.label_map_path)
            vis_metric_ops = None
            if not use_tpu and use_original_images:
                eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
                    category_index,
                    max_examples_to_draw=eval_config.num_visualizations,
                    max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
                    min_score_thresh=eval_config.min_score_threshold,
                    use_normalized_coordinates=False)
                vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
                    eval_dict)

            # Eval metrics on a single example.
            eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
                eval_config, list(category_index.values()), eval_dict)
            for loss_key, loss_tensor in iter(losses_dict.items()):
                eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
            for var in optimizer_summary_vars:
                eval_metric_ops[var.op.name] = (var, tf.no_op())
            if vis_metric_ops is not None:
                eval_metric_ops.update(vis_metric_ops)
            eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

            if eval_config.use_moving_averages:
                variable_averages = tf.train.ExponentialMovingAverage(0.0)
                variables_to_restore = variable_averages.variables_to_restore()
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    variables_to_restore,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours
                )
                scaffold = tf.train.Scaffold(saver=saver)

        # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
        if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
            return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                scaffold_fn=scaffold_fn,
                                                predictions=detections,
                                                loss=total_loss,
                                                train_op=train_op,
                                                eval_metrics=eval_metric_ops,
                                                export_outputs=export_outputs)
        else:
            if scaffold is None:
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    sharded=True,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                    save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                scaffold = tf.train.Scaffold(saver=saver)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=detections,
                                              loss=total_loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metric_ops,
                                              export_outputs=export_outputs,
                                              scaffold=scaffold)