Ejemplo n.º 1
0
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.
  """
    if cluster_resolver is None:
        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        logging.warning("TPU system %s has already been initialized. "
                        "Reinitializing the TPU can cause previously created "
                        "variables on TPU to be lost.")

    logging.info("Initializing the TPU system.")

    if context.executing_eagerly():
        # This function looks as it is for the following non-intuitive reasons.
        # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
        # DistributedTPURewritePass. This pass actually adds real ops that
        # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
        # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
        # The easiest way to trigger a rewrite is to run the function with
        # TPUPartitionedCallOp.
        @function.defun
        def _tpu_init_fn():
            return tpu.initialize_system()

        # We can't call _tpu_init_fn normally (because it contains just a dummy op,
        # see above) but need to define it to get it added to eager context
        # and get its assigned name.
        # pylint: disable=protected-access
        graph_func = _tpu_init_fn._get_concrete_function_internal()
        func_name = compat.as_str(graph_func._inference_function.name)
        # pylint: enable=protected-access

        with ops.device(get_first_tpu_host_device(cluster_resolver)):
            output = tpu_functional_ops.TPUPartitionedCall(
                args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name)
        serialized_topology = output[0].numpy()
    else:
        master = cluster_resolver.master()
        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                serialized_topology = sess.run(tpu.initialize_system())

    logging.info("Finished initializing TPU system.")
    tpu_topology = topology.Topology(serialized=serialized_topology)
    _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

    return tpu_topology
 def inner_func(*args, **kwargs):
   concrete = tf_func.get_concrete_function(*args, **kwargs)
   # TPUPartitionedCall only accepts list of tensors as input args.
   # Flatten keyword arguments and do some basic ordering:
   # Positional args + Flattened keyword args + Captured args.
   op_args = list(args) + list(kwargs.values()) + concrete.captured_inputs
   return tpu_functional.TPUPartitionedCall(
       args=op_args,
       device_ordinal=tpu_ops.tpu_ordinal_selector(),
       Tout=[o.type for o in concrete.function_def.signature.output_arg],
       f=concrete)
Ejemplo n.º 3
0
    def tpu_call(self, args):
        @function.Defun(capture_resource_var_by_value=False)
        def tpu_subgraph():
            results = tpu.rewrite(functools.partial(model_fn, self.hparams),
                                  args)
            results = tf.reshape(results, [self.hparams.infer_batch_size, -1])
            return self.vocab_table.lookup(tf.to_int64(results))

        return tpu_functional.TPUPartitionedCall(
            args=tpu_subgraph.captured_inputs,
            device_ordinal=tpu_ops.tpu_ordinal_selector(),
            Tout=[
                o.type for o in tpu_subgraph.definition.signature.output_arg
            ],
            f=tpu_subgraph)
Ejemplo n.º 4
0
    def tpu_call(self, *args):
        image, source_id, raw_shape = args[0]
        image = tf.reshape(image, [-1])
        inputs = [[image, source_id, raw_shape]]

        @function.Defun(capture_resource_var_by_value=False)
        def tpu_subgraph():
            return tpu.rewrite(functools.partial(model_fn, self.params),
                               inputs)

        return tpu_functional.TPUPartitionedCall(
            args=tpu_subgraph.captured_inputs,
            device_ordinal=tpu_ops.tpu_ordinal_selector(),
            Tout=[
                o.type for o in tpu_subgraph.definition.signature.output_arg
            ],
            f=tpu_subgraph)
Ejemplo n.º 5
0
def build_graph(pipeline_config,
                shapes_info,
                input_type='encoded_image_string_tensor',
                use_bfloat16=True):
    """Builds serving graph of faster_rcnn to be exported.

  Args:
    pipeline_config: A TrainEvalPipelineConfig proto.
    shapes_info: A python dict of tensors' names and their shapes, returned by
      `get_prediction_tensor_shapes()`.
    input_type: One of
                'encoded_image_string_tensor': a 1d tensor with dtype=tf.string
                'image_tensor': a 4d tensor with dtype=tf.uint8
                'tf_example': a 1d tensor with dtype=tf.string
    use_bfloat16: If true, use tf.bfloat16 on TPU.

  Returns:
    placeholder_tensor: A placeholder tensor, type determined by `input_type`.
    result_tensor_dict: A python dict of tensors' names and tensors.
  """
    pipeline_config = modify_config(pipeline_config)
    detection_model = INPUT_BUILDER_UTIL_MAP['model_build'](
        pipeline_config.model, is_training=False)

    placeholder_tensor, input_tensors = \
        exporter.input_placeholder_fn_map[input_type]()

    # CPU pre-processing
    inputs = tf.cast(input_tensors, dtype=tf.float32)
    preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)

    # Dimshuffle: [b, h, w, c] -> [b, c, h, w]
    preprocessed_inputs = tf.transpose(preprocessed_inputs, perm=[0, 3, 1, 2])
    if use_bfloat16:
        preprocessed_inputs = tf.cast(preprocessed_inputs, dtype=tf.bfloat16)

    # TPU feature extraction
    def tpu_subgraph_predict_fn(preprocessed_inputs, true_image_shapes):
        """Defines the first part of graph on TPU."""
        # [b, c, h, w] -> [b, h, w, c]
        preprocessed_inputs = tf.transpose(preprocessed_inputs,
                                           perm=[0, 2, 3, 1])

        prediction_dict = detection_model.predict(preprocessed_inputs,
                                                  true_image_shapes)

        return (
            # [batch, anchor, depth] -> [depth, batch, anchor]
            tf.transpose(prediction_dict[RPN_BOX_ENCODINGS], perm=[2, 0, 1]),
            # [batch, anchor, depth] -> [depth, batch, anchor]
            tf.transpose(
                prediction_dict[RPN_OBJECTNESS_PREDICTIONS_WITH_BACKGROUND],
                perm=[2, 0, 1]),
            # [anchors, depth]
            tf.transpose(prediction_dict[ANCHORS], perm=[1, 0]),
            # [num_proposals, num_classes, code_size]
            prediction_dict[REFINED_BOX_ENCODINGS],
            prediction_dict[CLASS_PREDICTIONS_WITH_BACKGROUND],
            prediction_dict[NUM_PROPOSALS],
            prediction_dict[PROPOSAL_BOXES])

    @function.Defun(capture_resource_var_by_value=False)
    def tpu_subgraph_predict():
        if use_bfloat16:
            with tf.contrib.tpu.bfloat16_scope():
                return tf.contrib.tpu.rewrite(
                    tpu_subgraph_predict_fn,
                    [preprocessed_inputs, true_image_shapes])
        else:
            return tf.contrib.tpu.rewrite(
                tpu_subgraph_predict_fn,
                [preprocessed_inputs, true_image_shapes])

    (rpn_box_encodings, rpn_objectness_predictions_with_background, anchors,
     refined_box_encodings, class_predictions_with_background, num_proposals,
     proposal_boxes) = tpu_functional.TPUPartitionedCall(
         args=tpu_subgraph_predict.captured_inputs,
         device_ordinal=tpu_ops.tpu_ordinal_selector(),
         Tout=[
             o.type
             for o in tpu_subgraph_predict.definition.signature.output_arg
         ],
         f=tpu_subgraph_predict)

    prediction_dict = {
        RPN_BOX_ENCODINGS:
        tf.transpose(rpn_box_encodings, perm=[1, 2, 0]),
        RPN_OBJECTNESS_PREDICTIONS_WITH_BACKGROUND:
        tf.transpose(rpn_objectness_predictions_with_background,
                     perm=[1, 2, 0]),
        ANCHORS:
        tf.transpose(anchors, perm=[1, 0]),
        REFINED_BOX_ENCODINGS:
        refined_box_encodings,
        CLASS_PREDICTIONS_WITH_BACKGROUND:
        class_predictions_with_background,
        NUM_PROPOSALS:
        num_proposals,
        PROPOSAL_BOXES:
        proposal_boxes
    }

    for k in prediction_dict:
        prediction_dict[k].set_shape(shapes_info[k])

    if use_bfloat16:
        prediction_dict = utils.bfloat16_to_float32_nested(prediction_dict)

    # CPU post-processing (NMS)
    postprocessed_tensors = detection_model.postprocess(
        prediction_dict, true_image_shapes)
    result_tensor_dict = exporter.add_output_tensor_nodes(
        postprocessed_tensors, 'inference_op')

    return placeholder_tensor, result_tensor_dict
Ejemplo n.º 6
0
def build_graph(pipeline_config,
                shapes_info,
                input_type='encoded_image_string_tensor',
                use_bfloat16=True):
  """Builds serving graph of faster_rcnn to be exported.

  Args:
    pipeline_config: A TrainEvalPipelineConfig proto.
    shapes_info: A python dict of tensors' names and their shapes, returned by
      `get_prediction_tensor_shapes()`.
    input_type: One of
                'encoded_image_string_tensor': a 1d tensor with dtype=tf.string
                'image_tensor': a 4d tensor with dtype=tf.uint8
                'tf_example': a 1d tensor with dtype=tf.string
    use_bfloat16: If true, use tf.bfloat16 on TPU.

  Returns:
    placeholder_tensor: A placeholder tensor, type determined by `input_type`.
    result_tensor_dict: A python dict of tensors' names and tensors.
  """
  pipeline_config = modify_config(pipeline_config)
  detection_model = model_builder.build(
      pipeline_config.model, is_training=False)

  placeholder_tensor, input_tensors = \
      exporter.input_placeholder_fn_map[input_type]()

  # CPU pre-processing
  inputs = tf.cast(input_tensors, dtype=tf.float32)
  preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)

  # Dimshuffle: [b, h, w, c] -> [b, c, h, w]
  preprocessed_inputs = tf.transpose(preprocessed_inputs, perm=[0, 3, 1, 2])
  if use_bfloat16:
    preprocessed_inputs = tf.cast(preprocessed_inputs, dtype=tf.bfloat16)

  # TPU feature extraction
  def tpu_subgraph_first_stage_fn(preprocessed_inputs):
    """Defines the first part of graph on TPU."""
    # [b, c, h, w] -> [b, h, w, c]
    preprocessed_inputs = tf.transpose(preprocessed_inputs, perm=[0, 2, 3, 1])

    prediction_dict = detection_model._predict_first_stage(preprocessed_inputs)

    # [b, h, w, c] -> [b, c, h, w]
    rpn_box_predictor_features = tf.transpose(
        prediction_dict[RPN_BOX_PREDICTOR_FEATURES], perm=[0, 3, 1, 2])
    # [b, h, w, c] -> [b, c, h, w]
    rpn_features_to_crop = tf.transpose(
        prediction_dict[RPN_FEATURES_TO_CROP], perm=[0, 3, 1, 2])
    # [batch, anchor, depth] -> [depth, batch, anchor]
    rpn_box_encodings = tf.transpose(
        prediction_dict[RPN_BOX_ENCODINGS], perm=[2, 0, 1])
    # [batch, anchor, depth] -> [depth, batch, anchor]
    rpn_objectness_predictions_with_background = tf.transpose(
        prediction_dict[RPN_OBJECTNESS_PREDICTIONS_WITH_BACKGROUND],
        perm=[2, 0, 1])
    # [anchors, depth]
    anchors = tf.transpose(prediction_dict[ANCHORS], perm=[1, 0])

    return (rpn_box_predictor_features, rpn_features_to_crop,
            prediction_dict['image_shape'], rpn_box_encodings,
            rpn_objectness_predictions_with_background, anchors)

  @function.Defun(capture_resource_var_by_value=False)
  def tpu_subgraph_first_stage():
    if use_bfloat16:
      with tf.contrib.tpu.bfloat16_scope():
        return tf.contrib.tpu.rewrite(tpu_subgraph_first_stage_fn,
                                      [preprocessed_inputs])
    else:
      return tf.contrib.tpu.rewrite(tpu_subgraph_first_stage_fn,
                                    [preprocessed_inputs])

  (rpn_box_predictor_features, rpn_features_to_crop, image_shape,
   rpn_box_encodings, rpn_objectness_predictions_with_background,
   anchors) = \
      tpu_functional.TPUPartitionedCall(
          args=tpu_subgraph_first_stage.captured_inputs,
          device_ordinal=tpu_ops.tpu_ordinal_selector(),
          Tout=[
              o.type
              for o in tpu_subgraph_first_stage.definition.signature.output_arg
          ],
          f=tpu_subgraph_first_stage)

  prediction_dict = {
      RPN_BOX_PREDICTOR_FEATURES:
          tf.transpose(rpn_box_predictor_features, perm=[0, 2, 3, 1]),
      RPN_FEATURES_TO_CROP:
          tf.transpose(rpn_features_to_crop, perm=[0, 2, 3, 1]),
      IMAGE_SHAPE:
          image_shape,
      RPN_BOX_ENCODINGS:
          tf.transpose(rpn_box_encodings, perm=[1, 2, 0]),
      RPN_OBJECTNESS_PREDICTIONS_WITH_BACKGROUND:
          tf.transpose(
              rpn_objectness_predictions_with_background, perm=[1, 2, 0]),
      ANCHORS:
          tf.transpose(anchors, perm=[1, 0]),
  }

  for k in prediction_dict:
    prediction_dict[k].set_shape(shapes_info[k])

  if use_bfloat16:
    prediction_dict = utils.bfloat16_to_float32_nested(prediction_dict)

  # CPU region proposal (NMS)
  proposal_boxes_normalized, num_proposals = \
      detection_model._proposal_postprocess(
          tf.cast(prediction_dict[RPN_BOX_ENCODINGS], dtype=tf.float32),
          tf.cast(
              prediction_dict[RPN_OBJECTNESS_PREDICTIONS_WITH_BACKGROUND],
              dtype=tf.float32), prediction_dict[ANCHORS],
          prediction_dict[IMAGE_SHAPE], true_image_shapes)
  prediction_dict[NUM_PROPOSALS] = num_proposals

  # [b, h, w, c] -> [b, c, h, w]
  prediction_dict[RPN_FEATURES_TO_CROP] = tf.transpose(
      prediction_dict[RPN_FEATURES_TO_CROP], perm=[0, 3, 1, 2])

  if use_bfloat16:
    prediction_dict[RPN_FEATURES_TO_CROP] = tf.cast(
        prediction_dict[RPN_FEATURES_TO_CROP], dtype=tf.bfloat16)
    proposal_boxes_normalized = tf.cast(
        proposal_boxes_normalized, dtype=tf.bfloat16)

  # TPU box prediction
  def tpu_subgraph_second_stage_fn(rpn_features_to_crop,
                                   proposal_boxes_normalized, image_shape):
    """Defines the second part of graph on TPU."""
    rpn_features_to_crop = tf.transpose(rpn_features_to_crop, perm=[0, 2, 3, 1])

    output_dict = detection_model._box_prediction(
        rpn_features_to_crop, proposal_boxes_normalized, image_shape)

    return [
        output_dict[REFINED_BOX_ENCODINGS],
        output_dict[CLASS_PREDICTIONS_WITH_BACKGROUND],
        output_dict[PROPOSAL_BOXES], output_dict[BOX_CLASSIFIER_FEATURES]
    ]

  @function.Defun(capture_resource_var_by_value=False)
  def tpu_subgraph_second_stage():
    """TPU subgraph 2 wrapper."""
    if use_bfloat16:
      with tf.contrib.tpu.bfloat16_scope():
        return tf.contrib.tpu.rewrite(tpu_subgraph_second_stage_fn, [
            prediction_dict[RPN_FEATURES_TO_CROP],
            proposal_boxes_normalized,
            prediction_dict[IMAGE_SHAPE],
        ])
    else:
      return tf.contrib.tpu.rewrite(tpu_subgraph_second_stage_fn, [
          prediction_dict[RPN_FEATURES_TO_CROP],
          proposal_boxes_normalized,
          prediction_dict[IMAGE_SHAPE],
      ])

  (refined_box_encodings, class_predictions_with_background, proposal_boxes,
   box_classifier_features) = tpu_functional.TPUPartitionedCall(
       args=tpu_subgraph_second_stage.captured_inputs,
       device_ordinal=tpu_ops.tpu_ordinal_selector(),
       Tout=[
           o.type
           for o in tpu_subgraph_second_stage.definition.signature.output_arg
       ],
       f=tpu_subgraph_second_stage)

  prediction_dict[RPN_FEATURES_TO_CROP] = tf.transpose(
      prediction_dict[RPN_FEATURES_TO_CROP], perm=[0, 2, 3, 1])

  prediction_dict_updater = {
      REFINED_BOX_ENCODINGS: refined_box_encodings,
      CLASS_PREDICTIONS_WITH_BACKGROUND: class_predictions_with_background,
      PROPOSAL_BOXES: proposal_boxes,
      BOX_CLASSIFIER_FEATURES: box_classifier_features,
      PROPOSAL_BOXES_NORMALIZED: proposal_boxes_normalized,
  }

  for k in prediction_dict_updater:
    prediction_dict_updater[k].set_shape(shapes_info[k])

  prediction_dict.update(prediction_dict_updater)

  if use_bfloat16:
    prediction_dict = utils.bfloat16_to_float32_nested(prediction_dict)

  # CPU post-processing (NMS)
  postprocessed_tensors = detection_model.postprocess(prediction_dict,
                                                      true_image_shapes)
  result_tensor_dict = exporter.add_output_tensor_nodes(postprocessed_tensors,
                                                        'inference_op')

  return placeholder_tensor, result_tensor_dict
Ejemplo n.º 7
0
def build_graph(pipeline_config,
                shapes_info,
                input_type='encoded_image_string_tensor',
                use_bfloat16=False):
    """Builds TPU serving graph of ssd to be exported.

  Args:
    pipeline_config: A TrainEvalPipelineConfig proto.
    shapes_info: A python dict of tensors' names and their shapes, returned by
      `get_prediction_tensor_shapes()`.
    input_type: One of
                'encoded_image_string_tensor': a 1d tensor with dtype=tf.string
                'image_tensor': a 4d tensor with dtype=tf.uint8
                'tf_example': a 1d tensor with dtype=tf.string
    use_bfloat16: If true, use tf.bfloat16 on TPU.

  Returns:
    placeholder_tensor: A placeholder tensor, type determined by `input_type`.
    result_tensor_dict: A python dict of tensors' names and tensors.
  """

    detection_model = model_builder.build(pipeline_config.model,
                                          is_training=False)

    placeholder_tensor, input_tensors = \
        exporter.input_placeholder_fn_map[input_type]()

    inputs = tf.cast(input_tensors, dtype=tf.float32)
    preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)

    # Dimshuffle: (b, h, w, c) -> (b, c, h, w)
    # This is to avoid extra padding due to TPU memory layout:
    # We swap larger dimensions in and smaller dimensions out, so that small
    # dimensions don't get padded tens / hundreds times of its own size.
    # This trick is applied to other similar tensors below.
    preprocessed_inputs = tf.transpose(preprocessed_inputs, perm=[0, 3, 1, 2])
    if use_bfloat16:
        preprocessed_inputs = tf.cast(preprocessed_inputs, dtype=tf.bfloat16)

    def predict_tpu_subgraph(preprocessed_inputs, true_image_shapes):
        """Wraps over the CPU version of `predict()`.

    This builds a same graph as the original `predict()`, manipulates
    result tensors' dimensions to be memory efficient on TPU, and
    returns them as list of tensors.

    Args:
      preprocessed_inputs: A 4D tensor of shape (batch, channels, height, width)
      true_image_shapes: True image shapes tensor.

    Returns:
      A Python list of tensors:
        box_encodings: 3D tensor of shape (code_size, batch_size, num_anchors)
        class_predictions_with_background: 3D tensor,
            shape (num_classes + 1, batch_size, num_anchors)
        anchors: 2D tensor of shape (4, num_anchors)
    """
        # Dimshuffle: (b, c, h, w) -> (b, h, w, c)
        preprocessed_inputs = tf.transpose(preprocessed_inputs,
                                           perm=[0, 2, 3, 1])
        if use_bfloat16:
            with tf.contrib.tpu.bfloat16_scope():
                prediction_dict = detection_model.predict(
                    preprocessed_inputs, true_image_shapes)
        else:
            prediction_dict = detection_model.predict(preprocessed_inputs,
                                                      true_image_shapes)

        # Dimshuffle: (batch, anchors, depth) -> (depth, batch, anchors)
        return [
            tf.transpose(prediction_dict[BOX_ENCODINGS], perm=[2, 0, 1]),
            tf.transpose(prediction_dict[CLASS_PREDICTIONS_WITH_BACKGROUND],
                         perm=[2, 0, 1]),
            tf.transpose(prediction_dict[ANCHORS], perm=[1, 0]),
        ]

    @function.Defun(capture_resource_var_by_value=False)
    def predict_tpu():
        return tf.contrib.tpu.rewrite(predict_tpu_subgraph,
                                      [preprocessed_inputs, true_image_shapes])

    prediction_outputs = tpu_functional.TPUPartitionedCall(
        args=predict_tpu.captured_inputs,
        device_ordinal=tpu_ops.tpu_ordinal_selector(),
        Tout=[o.type for o in predict_tpu.definition.signature.output_arg],
        f=predict_tpu)

    (preprocessed_inputs, box_encodings, class_predictions_with_background,
     anchors) = recover_shape(preprocessed_inputs, prediction_outputs,
                              shapes_info)

    output_tensors = {
        'preprocessed_inputs': preprocessed_inputs,
        BOX_ENCODINGS: box_encodings,
        CLASS_PREDICTIONS_WITH_BACKGROUND: class_predictions_with_background,
        ANCHORS: anchors,
    }

    if use_bfloat16:
        output_tensors = utils.bfloat16_to_float32_nested(output_tensors)

    postprocessed_tensors = detection_model.postprocess(
        output_tensors, true_image_shapes)
    result_tensor_dict = exporter.add_output_tensor_nodes(
        postprocessed_tensors, 'inference_op')

    return placeholder_tensor, result_tensor_dict
Ejemplo n.º 8
0
    def tpu_call(self, *args):
        def model_fn(images):
            """model_fn for Resnet."""
            def build_network(images):
                """Builds the ResNet network architecture."""
                network = resnet_model.resnet_v1(
                    resnet_depth=50,
                    num_classes=1000,
                    data_format='channels_last',
                    conv0_kernel_size=7,
                    conv0_space_to_depth_block_size=self.
                    conv0_space_to_depth_block_size)

                logits = network(inputs=images, is_training=False)
                return logits

            # The followins shapes are w.r.t. byte packing. see packing_utils.py.
            if self.conv0_space_to_depth_block_size != 0:
                if self.tpu_transpose:
                    images = tf.reshape(images, [112, 112, 3, -1])
                else:
                    images = tf.reshape(images, [-1, 3, 112, 112])
            else:
                if self.tpu_transpose:
                    images = tf.reshape(images, [224, 224, 1, -1])
                else:
                    images = tf.reshape(images, [-1, 1, 224, 224])
            images = packing_utils.unpack(
                images,
                self.conv0_space_to_depth_block_size,
                image_format='HWCN' if self.tpu_transpose else 'NCHW')

            if self.tpu_transpose:
                # Transpose from [H, W, C, N] to [N, H, W, C]
                images = tf.transpose(images, [3, 0, 1, 2])
            else:
                # Transpose from [N, C, H, W] to [N, H, W, C]
                images = tf.transpose(images, [0, 2, 3, 1])

            def _normalize(images):
                """Normalize the images."""
                _, _, _, c = images.get_shape().as_list()
                offset = tf.constant(dataset.MEAN_RGB * (c // 3),
                                     shape=[1, 1, 1, c],
                                     dtype=images.dtype)
                images -= offset
                return images

            images = tf.cast(images,
                             tf.bfloat16 if self.use_bfloat16 else tf.float32)
            images = _normalize(images)

            if self.use_bfloat16:
                with tf.contrib.tpu.bfloat16_scope():
                    logits = build_network(images)
                logits = tf.cast(logits, tf.float32)
            else:
                logits = build_network(images)
            return tf.argmax(logits, axis=1) - 1

        @function.Defun(capture_resource_var_by_value=False)
        def tpu_subgraph():
            return tf.tpu.rewrite(model_fn, args)

        return tpu_functional.TPUPartitionedCall(
            args=tpu_subgraph.captured_inputs,
            device_ordinal=tpu_ops.tpu_ordinal_selector(),
            Tout=[
                o.type for o in tpu_subgraph.definition.signature.output_arg
            ],
            f=tpu_subgraph)