Example #1
0
  def __init__(self, anchors, num_classes, match_threshold=0.7,
               unmatched_threshold=0.3, rpn_batch_size_per_im=256,
               rpn_fg_fraction=0.5):
    """Constructs anchor labeler to assign labels to anchors.

    Args:
      anchors: an instance of class Anchors.
      num_classes: integer number representing number of classes in the dataset.
      match_threshold: a float number between 0 and 1 representing the
        lower-bound threshold to assign positive labels for anchors. An anchor
        with a score over the threshold is labeled positive.
      unmatched_threshold: a float number between 0 and 1 representing the
        upper-bound threshold to assign negative labels for anchors. An anchor
        with a score below the threshold is labeled negative.
      rpn_batch_size_per_im: an integer number that represents the number of
        sampled anchors per image in the first stage (region proposal network).
      rpn_fg_fraction: a float number between 0 and 1 representing the fraction
        of positive anchors (foreground) in the first stage.
    """
    similarity_calc = region_similarity_calculator.IouSimilarity()
    matcher = argmax_matcher.ArgMaxMatcher(
        match_threshold,
        unmatched_threshold=unmatched_threshold,
        negatives_lower_than_unmatched=True,
        force_match_for_each_row=True)
    box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()

    self._target_assigner = target_assigner.TargetAssigner(
        similarity_calc, matcher, box_coder)
    self._anchors = anchors
    self._match_threshold = match_threshold
    self._unmatched_threshold = unmatched_threshold
    self._rpn_batch_size_per_im = rpn_batch_size_per_im
    self._rpn_fg_fraction = rpn_fg_fraction
    self._num_classes = num_classes
Example #2
0
def encode_labels(gt_boxes, gt_labels):
    """Labels anchors with ground truth inputs.

  Args:
    gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
      For each row, it stores [y0, x0, y1, x1] for four corners of a box.
    gt_labels: A integer tensor with shape [N, 1] representing groundtruth
      classes.
  Returns:
    encoded_classes: a tensor with shape [num_anchors, 1].
    encoded_boxes: a tensor with shape [num_anchors, 4].
    num_positives: scalar tensor storing number of positives in an image.
  """
    similarity_calc = region_similarity_calculator.IouSimilarity()
    matcher = argmax_matcher.ArgMaxMatcher(
        matched_threshold=ssd_constants.MATCH_THRESHOLD,
        unmatched_threshold=ssd_constants.MATCH_THRESHOLD,
        negatives_lower_than_unmatched=True,
        force_match_for_each_row=True)

    box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
        scale_factors=ssd_constants.BOX_CODER_SCALES)

    default_boxes = box_list.BoxList(
        tf.convert_to_tensor(DefaultBoxes()('ltrb')))
    target_boxes = box_list.BoxList(gt_boxes)

    assigner = target_assigner.TargetAssigner(similarity_calc, matcher,
                                              box_coder)

    encoded_classes, _, encoded_boxes, _, matches = assigner.assign(
        default_boxes, target_boxes, gt_labels)
    num_matched_boxes = tf.reduce_sum(
        tf.cast(tf.not_equal(matches.match_results, -1), tf.float32))
    return encoded_classes, encoded_boxes, num_matched_boxes
Example #3
0
 def __init__(self, is_training, num_classes, params=DEFAULT_PARAMS):
     """
     Args:
         is_training: indicate training or not
         num_classes: number of classes for prediction
         params: parameters for model definition
                 resnet_arch: name of which resnet architecture used
     """
     self._is_training = is_training
     self._num_classes = num_classes
     self._nms_fn = post_processing.batch_multiclass_non_max_suppression
     self._score_convert_fn = tf.sigmoid
     self._params = params
     # self._unmatched_class_label = tf.constant([1] + (self._num_classes) * [0], tf.float32)
     self._unmatched_class_label = tf.constant(
         (self._num_classes + 1) * [0], tf.float32)
     self._target_assigner = create_target_assigner(
         unmatched_cls_target=self._unmatched_class_label)
     self._anchors = None
     self._anchor_generator = None
     self._box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
Example #4
0
  def __init__(self, anchors, num_classes, match_threshold=0.5):
    """Constructs anchor labeler to assign labels to anchors.

    Args:
      anchors: an instance of class Anchors.
      num_classes: integer number representing number of classes in the dataset.
      match_threshold: float number between 0 and 1 representing the threshold
        to assign positive labels for anchors.
    """
    similarity_calc = region_similarity_calculator.IouSimilarity()
    matcher = argmax_matcher.ArgMaxMatcher(
        match_threshold,
        unmatched_threshold=match_threshold,
        negatives_lower_than_unmatched=True,
        force_match_for_each_row=True)
    box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()

    self._target_assigner = target_assigner.TargetAssigner(
        similarity_calc, matcher, box_coder)
    self._anchors = anchors
    self._match_threshold = match_threshold
    self._num_classes = num_classes
Example #5
0
def _model_fn(features, labels, mode, params, model):
    """Model defination for the SSD model based on ResNet-50.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the SSD model outputs class logits and box regression outputs.

  Returns:
    spec: the EstimatorSpec or TPUEstimatorSpec to run training, evaluation,
      or prediction.
  """
    if mode == tf.estimator.ModeKeys.PREDICT:
        labels = features
        features = labels.pop('image')

    features -= tf.constant(constants.NORMALIZATION_MEAN,
                            shape=[1, 1, 3],
                            dtype=features.dtype)
    COEF_STD = 1.0 / tf.constant(
        constants.NORMALIZATION_STD, shape=[1, 1, 3], dtype=features.dtype)
    features *= COEF_STD

    def _model_outputs():
        return model(features,
                     params,
                     is_training_bn=(mode == tf.estimator.ModeKeys.TRAIN))

    if params['dtype'] == 'bf16':
        with tf.compat.v1.tpu.bfloat16_scope():
            cls_outputs, box_outputs = _model_outputs()
            levels = cls_outputs.keys()
            for level in levels:
                cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
                box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
    else:
        cls_outputs, box_outputs = _model_outputs()
        levels = cls_outputs.keys()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        flattened_cls, flattened_box = concat_outputs(cls_outputs, box_outputs,
                                                      True)
        ssd_box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
            scale_factors=constants.BOX_CODER_SCALES)

        anchors = box_list.BoxList(
            tf.convert_to_tensor(dataloader.DefaultBoxes()('ltrb')))

        decoded_boxes = box_coder.batch_decode(encoded_boxes=flattened_box,
                                               box_coder=ssd_box_coder,
                                               anchors=anchors)

        pred_scores = tf.nn.softmax(flattened_cls, axis=2)

        pred_scores, indices = select_top_k_scores(
            pred_scores, constants.MAX_NUM_EVAL_BOXES)
        predictions = dict(
            labels,
            indices=indices,
            pred_scores=pred_scores,
            pred_box=decoded_boxes,
        )

        if params['visualize_dataloader']:
            # this is for inference visualization.
            predictions['image'] = features

        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(
                params['resnet_checkpoint'], {
                    '/': 'resnet%s/' % constants.RESNET_DEPTH,
                })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params, global_step)
    # cls_loss and box_loss are for logging. only total_loss is optimized.
    loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs, labels)

    total_loss = loss + params['weight_decay'] * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=constants.MOMENTUM)

        if params['distributed_optimizer']:
            optimizer = params['distributed_optimizer'](optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.group(optimizer.minimize(total_loss, global_step),
                            update_ops)
        return model_fn_lib.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          scaffold=scaffold_fn())

    if mode == tf.estimator.ModeKeys.EVAL:
        raise NotImplementedError
Example #6
0
        def _parse_example(data):
            with tf.name_scope('augmentation'):
                source_id = data['source_id']
                image = tf.image.convert_image_dtype(data['image'],
                                                     dtype=tf.float32)
                raw_shape = tf.shape(image)
                boxes = data['groundtruth_boxes']
                classes = tf.reshape(data['groundtruth_classes'], [-1, 1])

                # Only 80 of the 90 COCO classes are used.
                class_map = tf.convert_to_tensor(ssd_constants.CLASS_MAP)
                classes = tf.gather(class_map, classes)
                classes = tf.cast(classes, dtype=tf.float32)

                if self._is_training:
                    image, boxes, classes = ssd_crop(image, boxes, classes)

                    # random_horizontal_flip() is hard coded to flip with 50% chance.
                    mlperf_log.ssd_print(
                        key=mlperf_log.RANDOM_FLIP_PROBABILITY, value=0.5)
                    image, boxes = preprocessor.random_horizontal_flip(
                        image=image, boxes=boxes)

                    # TODO(shibow): Investigate the parameters for color jitter.
                    image = color_jitter(image,
                                         brightness=0.125,
                                         contrast=0.5,
                                         saturation=0.5,
                                         hue=0.05)
                    image = normalize_image(image)

                    if params['use_bfloat16']:
                        image = tf.cast(image, dtype=tf.bfloat16)

                    encoded_classes, encoded_boxes, num_matched_boxes = encode_labels(
                        boxes, classes)

                    # TODO(taylorrobie): Check that this cast is valid.
                    encoded_classes = tf.cast(encoded_classes, tf.int32)

                    labels = {
                        ssd_constants.NUM_MATCHED_BOXES: num_matched_boxes,
                        ssd_constants.BOXES: encoded_boxes,
                        ssd_constants.CLASSES: encoded_classes,
                    }
                    # This is for dataloader visualization; actual model doesn't use this.
                    if params['visualize_dataloader']:
                        box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
                            scale_factors=ssd_constants.BOX_CODER_SCALES)
                        decoded_boxes = tf.expand_dims(box_coder.decode(
                            rel_codes=tf.squeeze(encoded_boxes),
                            anchors=box_list.BoxList(
                                tf.convert_to_tensor(
                                    DefaultBoxes()('ltrb')))).get(),
                                                       axis=0)
                        labels['decoded_boxes'] = tf.squeeze(decoded_boxes)

                    return image, labels

                else:
                    mlperf_log.ssd_print(key=mlperf_log.INPUT_SIZE,
                                         value=ssd_constants.IMAGE_SIZE)
                    image = tf.image.resize_images(
                        image[tf.newaxis, :, :, :],
                        size=(ssd_constants.IMAGE_SIZE,
                              ssd_constants.IMAGE_SIZE))[0, :, :, :]

                    image = normalize_image(image)

                    if params['use_bfloat16']:
                        image = tf.cast(image, dtype=tf.bfloat16)

                    def trim_and_pad(inp_tensor, dim_1):
                        """Limit the number of boxes, and pad if necessary."""
                        inp_tensor = inp_tensor[:ssd_constants.
                                                MAX_NUM_EVAL_BOXES]
                        num_pad = ssd_constants.MAX_NUM_EVAL_BOXES - tf.shape(
                            inp_tensor)[0]
                        inp_tensor = tf.pad(inp_tensor, [[0, num_pad], [0, 0]])
                        return tf.reshape(
                            inp_tensor,
                            [ssd_constants.MAX_NUM_EVAL_BOXES, dim_1])

                    boxes, classes = trim_and_pad(boxes,
                                                  4), trim_and_pad(classes, 1)

                    return {
                        ssd_constants.IMAGE:
                        image,
                        ssd_constants.BOXES:
                        boxes,
                        ssd_constants.CLASSES:
                        classes,
                        ssd_constants.SOURCE_ID:
                        tf.string_to_number(source_id, tf.int32),
                        ssd_constants.RAW_SHAPE:
                        raw_shape,
                    }
Example #7
0
def _model_fn(features, labels, mode, params, model):
    """Model defination for the SSD model based on ResNet-50.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the SSD model outputs class logits and box regression outputs.

  Returns:
    spec: the EstimatorSpec or TPUEstimatorSpec to run training, evaluation,
      or prediction.
  """
    if mode == tf.estimator.ModeKeys.PREDICT:
        labels = features
        features = labels.pop('image')

    # Manually apply the double transpose trick for training data.
    if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])
        labels[ssd_constants.BOXES] = tf.transpose(labels[ssd_constants.BOXES],
                                                   [2, 0, 1])
        labels[ssd_constants.CLASSES] = tf.transpose(
            labels[ssd_constants.CLASSES], [2, 0, 1])

    # Normalize the image to zero mean and unit variance.
    mlperf_log.ssd_print(key=mlperf_log.DATA_NORMALIZATION_MEAN,
                         value=ssd_constants.NORMALIZATION_MEAN)
    mlperf_log.ssd_print(key=mlperf_log.DATA_NORMALIZATION_STD,
                         value=ssd_constants.NORMALIZATION_STD)

    features -= tf.constant(ssd_constants.NORMALIZATION_MEAN,
                            shape=[1, 1, 3],
                            dtype=features.dtype)

    features /= tf.constant(ssd_constants.NORMALIZATION_STD,
                            shape=[1, 1, 3],
                            dtype=features.dtype)

    def _model_outputs():
        return model(features,
                     params,
                     is_training_bn=(mode == tf.estimator.ModeKeys.TRAIN))

    if params['use_bfloat16']:
        with bfloat16.bfloat16_scope():
            cls_outputs, box_outputs = _model_outputs()
            levels = cls_outputs.keys()
            for level in levels:
                cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
                box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
    else:
        cls_outputs, box_outputs = _model_outputs()
        levels = cls_outputs.keys()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        flattened_cls, flattened_box = concat_outputs(cls_outputs, box_outputs)
        mlperf_log.ssd_print(key=mlperf_log.SCALES,
                             value=ssd_constants.BOX_CODER_SCALES)
        ssd_box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
            scale_factors=ssd_constants.BOX_CODER_SCALES)

        anchors = box_list.BoxList(
            tf.convert_to_tensor(dataloader.DefaultBoxes()('ltrb')))

        decoded_boxes = box_coder.batch_decode(encoded_boxes=flattened_box,
                                               box_coder=ssd_box_coder,
                                               anchors=anchors)

        pred_scores = tf.nn.softmax(flattened_cls, axis=2)

        pred_scores, indices = select_top_k_scores(
            pred_scores, ssd_constants.MAX_NUM_EVAL_BOXES)

        predictions = dict(
            labels,
            indices=indices,
            pred_scores=pred_scores,
            pred_box=decoded_boxes,
        )

        if params['visualize_dataloader']:
            # this is for inference visualization.
            predictions['image'] = features

        if params['use_tpu']:
            return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                                  predictions=predictions)

        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(
                params['resnet_checkpoint'], {
                    '/': 'resnet%s/' % ssd_constants.RESNET_DEPTH,
                })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params, global_step)
    mlperf_log.ssd_print(key=mlperf_log.OPT_LR, deferred=True)
    # cls_loss and box_loss are for logging. only total_loss is optimized.
    total_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs,
                                                    labels)

    total_loss += params['weight_decay'] * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=ssd_constants.MOMENTUM)
        if params['use_tpu']:
            optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)

        mlperf_log.ssd_print(key=mlperf_log.OPT_NAME,
                             value='tf.train.MomentumOptimizer')
        # TODO(wangtao): figure out how to log learning rate.
        # mlperf_log.ssd_print(key=mlperf_log.OPT_LR, value=learning_rate)
        mlperf_log.ssd_print(key=mlperf_log.OPT_MOMENTUM,
                             value=ssd_constants.MOMENTUM)
        mlperf_log.ssd_print(key=mlperf_log.OPT_WEIGHT_DECAY,
                             value=params['weight_decay'])

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if params['device'] == 'gpu':
            # GPU uses tf.group to avoid dependency overhead on update_ops; also,
            # multi-GPU requires a different EstimatorSpec class object
            train_op = tf.group(optimizer.minimize(total_loss, global_step),
                                update_ops)
            return model_fn_lib.EstimatorSpec(mode=mode,
                                              loss=total_loss,
                                              train_op=train_op,
                                              scaffold=scaffold_fn())
        else:
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(total_loss, global_step)

        if params['use_host_call']:

            def host_call_fn(global_step, total_loss, cls_loss, box_loss,
                             learning_rate):
                """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          global_step: `Tensor with shape `[batch, ]` for the global_step.
          total_loss: `Tensor` with shape `[batch, ]` for the training loss.
          cls_loss: `Tensor` with shape `[batch, ]` for the training cls loss.
          box_loss: `Tensor` with shape `[batch, ]` for the training box loss.
          learning_rate: `Tensor` with shape `[batch, ]` for the learning_rate.

        Returns:
          List of summary ops to run on the CPU host.
        """
                # Outfeed supports int32 but global_step is expected to be int64.
                global_step = tf.reduce_mean(global_step)
                # Host call fns are executed FLAGS.iterations_per_loop times after one
                # TPU loop is finished, setting max_queue value to the same as number of
                # iterations will make the summary writer only flush the data to storage
                # once per loop.
                with (tf.contrib.summary.create_file_writer(
                        params['model_dir'],
                        max_queue=params['iterations_per_loop']).as_default()):
                    with tf.contrib.summary.always_record_summaries():
                        tf.contrib.summary.scalar('total_loss',
                                                  tf.reduce_mean(total_loss),
                                                  step=global_step)
                        tf.contrib.summary.scalar('cls_loss',
                                                  tf.reduce_mean(cls_loss),
                                                  step=global_step)
                        tf.contrib.summary.scalar('box_loss',
                                                  tf.reduce_mean(box_loss),
                                                  step=global_step)
                        tf.contrib.summary.scalar(
                            'learning_rate',
                            tf.reduce_mean(learning_rate),
                            step=global_step)

                        return tf.contrib.summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            global_step_t = tf.reshape(global_step, [1])
            total_loss_t = tf.reshape(total_loss, [1])
            cls_loss_t = tf.reshape(cls_loss, [1])
            box_loss_t = tf.reshape(box_loss, [1])
            learning_rate_t = tf.reshape(learning_rate, [1])
            host_call = (host_call_fn, [
                global_step_t, total_loss_t, cls_loss_t, box_loss_t,
                learning_rate_t
            ])
    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
        raise NotImplementedError

    return tpu_estimator.TPUEstimatorSpec(mode=mode,
                                          loss=total_loss,
                                          train_op=train_op,
                                          host_call=host_call,
                                          eval_metrics=eval_metrics,
                                          scaffold_fn=scaffold_fn)
Example #8
0
        def _parse_example(data):
            with tf.name_scope('augmentation'):
                source_id = data['source_id']
                image = data['image']  # dtype uint8
                raw_shape = tf.shape(image)
                boxes = data['groundtruth_boxes']
                classes = tf.reshape(data['groundtruth_classes'], [-1, 1])

                # Only 80 of the 90 COCO classes are used.
                class_map = tf.convert_to_tensor(ssd_constants.CLASS_MAP)
                classes = tf.gather(class_map, classes)
                classes = tf.cast(classes, dtype=tf.float32)

                if self._is_training:
                    image, boxes, classes = ssd_crop(image, boxes, classes)
                    # ssd_crop resizes and returns image of dtype float32 and does not
                    # change its range (i.e., value in between 0--255). Divide by 255.
                    # converts it to [0, 1] range. Not doing this before cropping to
                    # avoid dtype cast (which incurs additional memory copy).
                    image /= 255.0

                    # random_horizontal_flip() is hard coded to flip with 50% chance.
                    image, boxes = preprocessor.random_horizontal_flip(
                        image=image, boxes=boxes)

                    # TODO(shibow): Investigate the parameters for color jitter.
                    image = color_jitter(image,
                                         brightness=0.125,
                                         contrast=0.5,
                                         saturation=0.5,
                                         hue=0.05)

                    if params['use_bfloat16']:
                        image = tf.cast(image, dtype=tf.bfloat16)

                    encoded_classes, encoded_boxes, num_matched_boxes = encode_labels(
                        boxes, classes)

                    # TODO(taylorrobie): Check that this cast is valid.
                    encoded_classes = tf.cast(encoded_classes, tf.int32)

                    labels = {
                        ssd_constants.NUM_MATCHED_BOXES: num_matched_boxes,
                        ssd_constants.BOXES: encoded_boxes,
                        ssd_constants.CLASSES: tf.squeeze(encoded_classes,
                                                          axis=1),
                    }
                    # This is for dataloader visualization; actual model doesn't use this.
                    if params['visualize_dataloader']:
                        box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
                            scale_factors=ssd_constants.BOX_CODER_SCALES)
                        decoded_boxes = tf.expand_dims(box_coder.decode(
                            rel_codes=tf.squeeze(encoded_boxes),
                            anchors=box_list.BoxList(
                                tf.convert_to_tensor(
                                    DefaultBoxes()('ltrb')))).get(),
                                                       axis=0)
                        labels['decoded_boxes'] = tf.squeeze(decoded_boxes)

                    return image, labels

                else:
                    image = tf.image.resize_images(
                        image,
                        size=(ssd_constants.IMAGE_SIZE,
                              ssd_constants.IMAGE_SIZE))
                    # resize_image returns image of dtype float32 and does not change its
                    # range. Divide by 255 to convert image to [0, 1] range.
                    image /= 255.

                    if params['use_bfloat16']:
                        image = tf.cast(image, dtype=tf.bfloat16)

                    def trim_and_pad(inp_tensor, dim_1):
                        """Limit the number of boxes, and pad if necessary."""
                        inp_tensor = inp_tensor[:ssd_constants.
                                                MAX_NUM_EVAL_BOXES]
                        num_pad = ssd_constants.MAX_NUM_EVAL_BOXES - tf.shape(
                            inp_tensor)[0]
                        inp_tensor = tf.pad(inp_tensor, [[0, num_pad], [0, 0]])
                        return tf.reshape(
                            inp_tensor,
                            [ssd_constants.MAX_NUM_EVAL_BOXES, dim_1])

                    boxes, classes = trim_and_pad(boxes,
                                                  4), trim_and_pad(classes, 1)

                    sample = {
                        ssd_constants.IMAGE:
                        image,
                        ssd_constants.BOXES:
                        boxes,
                        ssd_constants.CLASSES:
                        classes,
                        ssd_constants.SOURCE_ID:
                        tf.string_to_number(source_id, tf.int32),
                        ssd_constants.RAW_SHAPE:
                        raw_shape,
                    }

                    if not self._is_training and self._count > params[
                            'eval_samples']:
                        sample[ssd_constants.IS_PADDED] = data[
                            ssd_constants.IS_PADDED]
                    return sample