Exemplo n.º 1
0
    def _specialize_model(self, input_specs, infeed_manager):
        """Specialize `self.model` (a Keras model) for the given input shapes."""
        # Re-create our input and output layers inside our subgraph.  They will be
        # attached to the true computation when we clone our model in `tpu_fn`.
        K.set_learning_phase(
            self.execution_mode == model_fn_lib.ModeKeys.TRAIN)

        # functools.partial and callable objects are not supported by tpu.rewrite
        def _model_fn():
            """Compute fit/eval/predict for the TPU."""
            is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
            is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
            is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT

            # During train/eval, we infeed our features as well as labels.
            if is_training or is_test:
                infeed_layers = self.model._input_layers + self.model._output_layers
            else:
                infeed_layers = self.model._input_layers

            # Generate our infeed operation to read features & labels.
            infeed_tensors = tpu_ops.infeed_dequeue_tuple(
                dtypes=[spec.dtype for spec in input_specs],
                shapes=[spec.shape for spec in input_specs],
                name='infeed-%s' % self.execution_mode)

            assert len(infeed_tensors) == len(infeed_layers), (
                'Infeed inputs did not match model: %s vs %s' %
                (infeed_layers, infeed_tensors))

            tpu_targets = []
            tpu_input_map = {}

            # Sort infeed outputs into inputs and labels for calling our Keras model.
            for tensor, layer in zip(infeed_tensors, infeed_layers):
                if layer in self.model._input_layers:
                    tpu_input_map[layer.name] = tensor
                if layer in self.model._output_layers:
                    tpu_targets.append(tensor)

            # Clone our CPU model, running within the TPU device context.
            with TPURewriteContext(tpu_input_map):
                # TODO(power): Replicate variables.
                with ops.device('/device:TPU:0'):
                    self._cloned_model = models.clone_model(self.model)

            # Create a copy of the optimizer for this graph.
            if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
                cloned_optimizer = keras_optimizers.TFOptimizer(
                    self.model.optimizer.optimizer)
            else:
                logging.info('Cloning %s %s',
                             self.model.optimizer.__class__.__name__,
                             self._optimizer_config)
                cloned_optimizer = self.model.optimizer.__class__.from_config(
                    self._optimizer_config)

            if is_training or is_test:
                self._cloned_model.compile(
                    optimizer=_replicated_optimizer(cloned_optimizer),
                    loss=self.model.loss,
                    loss_weights=self.model.loss_weights,
                    metrics=self.model.metrics,
                    weighted_metrics=self.model.weighted_metrics,
                    target_tensors=tpu_targets,
                )

            # Compute our outfeed depending on the execution mode
            if is_training:
                self._cloned_model._make_train_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in self._cloned_model.train_function.outputs
                ]
                return [
                    self._cloned_model.train_function.updates_op,
                    tpu_ops.outfeed_enqueue_tuple(
                        self._cloned_model.train_function.outputs,
                        name='outfeed-enqueue-train')
                ]
            elif is_test:
                self._cloned_model._make_test_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in self._cloned_model.test_function.outputs
                ]
                return [
                    tpu_ops.outfeed_enqueue_tuple(
                        self._cloned_model.test_function.outputs,
                        name='outfeed-enqueue-test')
                ]
            elif is_predict:
                self._cloned_model._make_predict_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in self._cloned_model.predict_function.outputs
                ]
                return [
                    tpu_ops.outfeed_enqueue_tuple(
                        self._cloned_model.predict_function.outputs,
                        name='outfeed-enqueue-predict',
                    )
                ]
            else:
                assert False, 'Unexpected execution mode: %s' % self.execution_mode

        # Capture outfeed metadata computed during the rewrite.
        self._outfeed_spec = None

        # Generate out TPU operations using `tpu.split_compile_and_replicate`.
        # `compile_op` can be used to test the TPU model compiles before execution.
        # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
        # running on a different logical core.
        compile_op, execute_op = tpu.split_compile_and_replicate(
            _model_fn, inputs=[[]] * self._strategy.num_towers)

        # Generate CPU side operations to enqueue features/labels and dequeue
        # outputs from the model call.
        sized_infeed = infeed_manager.build_infeed_from_input_specs(
            input_specs, self.execution_mode)
        # Build output ops.
        outfeed_op = []
        for shard_id in range(self._strategy.num_towers):
            with ops.device('/device:CPU:0'):
                outfeed_op.extend(
                    tpu_ops.outfeed_dequeue_tuple(
                        dtypes=[spec.dtype for spec in self._outfeed_spec],
                        shapes=[spec.shape for spec in self._outfeed_spec],
                        name='outfeed-dequeue-%s-%d' %
                        (self.execution_mode, shard_id),
                        device_ordinal=shard_id))

        return TPUModelOp(compile_op,
                          execute_op,
                          infeed_tensors=sized_infeed.sharded_infeed_tensors,
                          infeed_op=sized_infeed.infeed_ops,
                          outfeed_op=outfeed_op)
Exemplo n.º 2
0
    def _specialize_model(self, input_specs):
        """Specialize `self.model` (a Keras model) for the given input shapes."""
        # Re-create our input and output layers inside our subgraph.  They will be
        # attached to the true computation when we clone our model in `tpu_fn`.
        K.set_learning_phase(
            self.execution_mode == model_fn_lib.ModeKeys.TRAIN)

        # functools.partial and callable objects are not supported by tpu.rewrite
        def _model_fn():
            """Compute fit/eval/predict for the TPU."""
            is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
            is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
            is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT

            # During train/eval, we infeed our features as well as labels.
            if is_training or is_test:
                infeed_layers = self.model._input_layers + self.model._output_layers
            else:
                infeed_layers = self.model._input_layers

            # Generate our infeed operation to read features & labels.
            infeed_tensors = tpu_ops.infeed_dequeue_tuple(
                dtypes=[spec.dtype for spec in input_specs],
                shapes=[spec.shape for spec in input_specs],
                name='infeed-%s' % self.execution_mode)

            assert len(infeed_tensors) == len(infeed_layers), (
                'Infeed inputs did not match model: %s vs %s',
                (infeed_layers, infeed_tensors))

            tpu_targets = []
            tpu_inputs = []

            # Sort infeed outputs into inputs and labels for calling our Keras model.
            for tensor, layer in zip(infeed_tensors, infeed_layers):
                if layer in self.model._input_layers:
                    tpu_inputs.append(
                        layers.Input(name=layer.name, tensor=tensor))
                if layer in self.model._output_layers:
                    tpu_targets.append(tensor)

            # Call our model with our infeed inputs (re-using the weights).
            model_outputs = self.model(tpu_inputs)
            child_model = models.Model(inputs=tpu_inputs,
                                       outputs=model_outputs)

            if is_training or is_test:
                child_model.compile(
                    optimizer=_replicated_optimizer(self.model.optimizer,
                                                    self.num_replicas),
                    loss=self.model.loss,
                    loss_weights=self.model.loss_weights,
                    metrics=self.model.metrics,
                    weighted_metrics=self.model.weighted_metrics,
                    target_tensors=tpu_targets,
                )

            # Compute our outfeed depending on the execution mode
            if is_training:
                child_model._make_train_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in child_model.train_function.outputs
                ]
                return [
                    child_model.train_function.updates_op,
                    tpu_ops.outfeed_enqueue_tuple(
                        child_model.train_function.outputs,
                        name='outfeed-enqueue-train')
                ]
            elif is_test:
                child_model._make_test_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in child_model.test_function.outputs
                ]
                return [
                    tpu_ops.outfeed_enqueue_tuple(
                        child_model.test_function.outputs,
                        name='outfeed-enqueue-test')
                ]
            elif is_predict:
                child_model._make_predict_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in child_model.predict_function.outputs
                ]
                return [
                    tpu_ops.outfeed_enqueue_tuple(
                        child_model.predict_function.outputs,
                        name='outfeed-enqueue-predict',
                    )
                ]
            else:
                assert False, 'Unexpected execution mode: %s' % self.execution_mode

        # Capture outfeed metadata computed during the rewrite.
        self._outfeed_spec = None

        # Generate out TPU operations using `tpu.split_compile_and_replicate`.
        # `compile_op` can be used to test the TPU model compiles before execution.
        # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
        # running on a different logical core.
        compile_op, execute_op = tpu.split_compile_and_replicate(
            _model_fn, inputs=[[]] * self.num_replicas)

        # Generate CPU side operations to enqueue features/labels and dequeue
        # outputs from the model call.
        infeed_op = []
        outfeed_op = []
        shard_infeed_tensors = []

        for shard_id in range(self.num_replicas):
            with ops.device('/device:TPU:%d' % shard_id):
                infeed_tensors = []
                for spec in input_specs:
                    infeed_tensors.append(
                        array_ops.placeholder(dtype=spec.dtype,
                                              shape=spec.shape,
                                              name='infeed-enqueue-%s-%d' %
                                              (spec.name, shard_id)))
                shard_infeed_tensors.append(infeed_tensors)

                infeed_op.append(
                    tpu_ops.infeed_enqueue_tuple(
                        infeed_tensors, [spec.shape for spec in input_specs],
                        name='infeed-enqueue-%s-%d' %
                        (self.execution_mode, shard_id)))

                outfeed_op.extend(
                    tpu_ops.outfeed_dequeue_tuple(
                        dtypes=[spec.dtype for spec in self._outfeed_spec],
                        shapes=[spec.shape for spec in self._outfeed_spec],
                        name='outfeed-dequeue-%s-%d' %
                        (self.execution_mode, shard_id)))

        return TPUModelOp(compile_op,
                          execute_op,
                          infeed_tensors=shard_infeed_tensors,
                          infeed_op=infeed_op,
                          outfeed_op=outfeed_op)
Exemplo n.º 3
0
    def _specialize_model(self, input_specs):
        """Specialize `self.model` (a Keras model) for the given input shapes."""
        # Re-create our input and output layers inside our subgraph.  They will be
        # attached to the true computation when we clone our model in `tpu_fn`.
        K.set_learning_phase(
            self.execution_mode == model_fn_lib.ModeKeys.TRAIN)

        # functools.partial and callable objects are not supported by tpu.rewrite
        def _model_fn():
            """Compute fit/eval/predict for the TPU."""
            is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
            is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
            is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT

            # During train/eval, we infeed our features as well as labels.
            if is_training or is_test:
                infeed_layers = self.model._input_layers + self.model._output_layers
            else:
                infeed_layers = self.model._input_layers

            # Generate our infeed operation to read features & labels.
            infeed_tensors = tpu_ops.infeed_dequeue_tuple(
                dtypes=[spec.dtype for spec in input_specs],
                shapes=[spec.shape for spec in input_specs],
                name='infeed-%s' % self.execution_mode)

            assert len(infeed_tensors) == len(infeed_layers), (
                'Infeed inputs did not match model: %s vs %s',
                (infeed_layers, infeed_tensors))

            tpu_targets = []
            tpu_inputs = []

            # Sort infeed outputs into inputs and labels for calling our Keras model.
            for tensor, layer in zip(infeed_tensors, infeed_layers):
                if layer in self.model._input_layers:
                    tpu_inputs.append(
                        layers.Input(name=layer.name, tensor=tensor))
                if layer in self.model._output_layers:
                    tpu_targets.append(tensor)

            optimizer = self.model.optimizer
            optimizer.iterations = training_util.get_or_create_global_step()

            # Call our model with our infeed inputs (re-using the weights).
            model_outputs = self.model(tpu_inputs)
            child_model = models.Model(inputs=tpu_inputs,
                                       outputs=model_outputs)
            if is_training or is_test:
                child_model.compile(
                    optimizer=self.model.optimizer,
                    loss=self.model.loss,
                    loss_weights=self.model.loss_weights,
                    metrics=self.model.metrics,
                    weighted_metrics=self.model.weighted_metrics,
                    target_tensors=tpu_targets,
                )

            # Compute our outfeed depending on the execution mode
            if is_training:
                child_model._make_train_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in child_model.train_function.outputs
                ]
                return [
                    child_model.train_function.updates_op,
                    tpu_ops.outfeed_enqueue_tuple(
                        child_model.train_function.outputs,
                        name='oufeed-enqueue-train')
                ]
            elif is_test:
                child_model._make_test_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in child_model.test_function.outputs
                ]
                return [
                    tpu_ops.outfeed_enqueue_tuple(
                        child_model.test_function.outputs,
                        name='outfeed-enqueue-test')
                ]
            elif is_predict:
                child_model._make_predict_function()
                self._outfeed_spec = [
                    tensor_spec.TensorSpec(tensor.shape, tensor.dtype,
                                           tensor.name)
                    for tensor in child_model.predict_function.outputs
                ]
                return [
                    tpu_ops.outfeed_enqueue_tuple(
                        child_model.predict_function.outputs,
                        name='outfeed-enqueue-predict',
                    )
                ]
            else:
                assert False, 'Unexpected execution mode: %s' % self.execution_mode

        # Capture outfeed metadata computed during the rewrite.
        self._outfeed_spec = None

        tpu_execute_op = tpu.rewrite(_model_fn)

        K._initialize_variables(
            K.get_session())  # pylint-disable: protected-access

        # Generate CPU side operations to enqueue features/labels and dequeue
        # outputs from the model call.
        with ops.device('/device:TPU:0'):
            infeed_tensors = []
            for spec in input_specs:
                infeed_tensors.append(
                    array_ops.placeholder(dtype=spec.dtype,
                                          shape=spec.shape,
                                          name='infeed-enqueue-%s' %
                                          spec.name))

            infeed_op = tpu_ops.infeed_enqueue_tuple(
                infeed_tensors, [spec.shape for spec in input_specs],
                name='infeed-enqueue-%s' % self.execution_mode)

            outfeed_op = tpu_ops.outfeed_dequeue_tuple(
                dtypes=[spec.dtype for spec in self._outfeed_spec],
                shapes=[spec.shape for spec in self._outfeed_spec],
                name='outfeed-dequeue-%s' % self.execution_mode)

        return CompiledTPUOp(tpu_execute_op, infeed_tensors, infeed_op,
                             outfeed_op)
Exemplo n.º 4
0
  def _specialize_model(self, input_specs):
    """Specialize `self.model` (a Keras model) for the given input shapes."""
    # Re-create our input and output layers inside our subgraph.  They will be
    # attached to the true computation when we clone our model in `tpu_fn`.
    K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN)

    # functools.partial and callable objects are not supported by tpu.rewrite
    def _model_fn():
      """Compute fit/eval/predict for the TPU."""
      is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
      is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
      is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT

      # During train/eval, we infeed our features as well as labels.
      if is_training or is_test:
        infeed_layers = self.model._input_layers + self.model._output_layers
      else:
        infeed_layers = self.model._input_layers

      # Generate our infeed operation to read features & labels.
      infeed_tensors = tpu_ops.infeed_dequeue_tuple(
          dtypes=[spec.dtype for spec in input_specs],
          shapes=[spec.shape for spec in input_specs],
          name='infeed-%s' % self.execution_mode)

      assert len(infeed_tensors) == len(infeed_layers), (
          'Infeed inputs did not match model: %s vs %s', (infeed_layers,
                                                          infeed_tensors))

      tpu_targets = []
      tpu_input_map = {}

      # Sort infeed outputs into inputs and labels for calling our Keras model.
      for tensor, layer in zip(infeed_tensors, infeed_layers):
        if layer in self.model._input_layers:
          tpu_input_map[layer.name] = tensor
        if layer in self.model._output_layers:
          tpu_targets.append(tensor)

      # Clone our CPU model, running within the TPU device context.
      with TPURewriteContext(tpu_input_map):
        self._cloned_model = models.clone_model(self.model)

      # Create a copy of the optimizer for this graph.
      if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
        cloned_optimizer = keras_optimizers.TFOptimizer(
            self.model.optimizer.optimizer)
      else:
        logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
                     self._optimizer_config)
        cloned_optimizer = self.model.optimizer.__class__.from_config(
            self._optimizer_config)

      if is_training or is_test:
        self._cloned_model.compile(
            optimizer=_replicated_optimizer(cloned_optimizer),
            loss=self.model.loss,
            loss_weights=self.model.loss_weights,
            metrics=self.model.metrics,
            weighted_metrics=self.model.weighted_metrics,
            target_tensors=tpu_targets,
        )

      # Compute our outfeed depending on the execution mode
      if is_training:
        self._cloned_model._make_train_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in self._cloned_model.train_function.outputs
        ]
        return [
            self._cloned_model.train_function.updates_op,
            tpu_ops.outfeed_enqueue_tuple(
                self._cloned_model.train_function.outputs,
                name='outfeed-enqueue-train')
        ]
      elif is_test:
        self._cloned_model._make_test_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in self._cloned_model.test_function.outputs
        ]
        return [
            tpu_ops.outfeed_enqueue_tuple(
                self._cloned_model.test_function.outputs,
                name='outfeed-enqueue-test')
        ]
      elif is_predict:
        self._cloned_model._make_predict_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in self._cloned_model.predict_function.outputs
        ]
        return [
            tpu_ops.outfeed_enqueue_tuple(
                self._cloned_model.predict_function.outputs,
                name='outfeed-enqueue-predict',
            )
        ]
      else:
        assert False, 'Unexpected execution mode: %s' % self.execution_mode

    # Capture outfeed metadata computed during the rewrite.
    self._outfeed_spec = None

    # Generate out TPU operations using `tpu.split_compile_and_replicate`.
    # `compile_op` can be used to test the TPU model compiles before execution.
    # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
    # running on a different logical core.
    compile_op, execute_op = tpu.split_compile_and_replicate(
        _model_fn, inputs=[[]] * self._strategy.num_towers)

    # Generate CPU side operations to enqueue features/labels and dequeue
    # outputs from the model call.
    infeed_op = []
    outfeed_op = []
    shard_infeed_tensors = []

    for shard_id in range(self._strategy.num_towers):
      with ops.device('/device:TPU:%d' % shard_id):
        infeed_tensors = []
        for spec in input_specs:
          infeed_tensors.append(
              array_ops.placeholder(
                  dtype=spec.dtype,
                  shape=spec.shape,
                  name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
        shard_infeed_tensors.append(infeed_tensors)

        infeed_op.append(
            tpu_ops.infeed_enqueue_tuple(
                infeed_tensors, [spec.shape for spec in input_specs],
                name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id)))

        outfeed_op.extend(
            tpu_ops.outfeed_dequeue_tuple(
                dtypes=[spec.dtype for spec in self._outfeed_spec],
                shapes=[spec.shape for spec in self._outfeed_spec],
                name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id)))

    return TPUModelOp(
        compile_op,
        execute_op,
        infeed_tensors=shard_infeed_tensors,
        infeed_op=infeed_op,
        outfeed_op=outfeed_op)
Exemplo n.º 5
0
  def _specialize_model(self, input_specs):
    """Specialize `self.model` (a Keras model) for the given input shapes."""
    # Re-create our input and output layers inside our subgraph.  They will be
    # attached to the true computation when we clone our model in `tpu_fn`.
    K.set_learning_phase(
        self.execution_mode == model_fn_lib.ModeKeys.TRAIN
    )

    # functools.partial and callable objects are not supported by tpu.rewrite
    def _model_fn():
      """Compute fit/eval/predict for the TPU."""
      is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
      is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
      is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT

      # During train/eval, we infeed our features as well as labels.
      if is_training or is_test:
        infeed_layers = self.model._input_layers + self.model._output_layers
      else:
        infeed_layers = self.model._input_layers

      # Generate our infeed operation to read features & labels.
      infeed_tensors = tpu_ops.infeed_dequeue_tuple(
          dtypes=[spec.dtype for spec in input_specs],
          shapes=[spec.shape for spec in input_specs],
          name='infeed-%s' % self.execution_mode)

      assert len(infeed_tensors) == len(infeed_layers), (
          'Infeed inputs did not match model: %s vs %s', (infeed_layers,
                                                          infeed_tensors))

      tpu_targets = []
      tpu_inputs = []

      # Sort infeed outputs into inputs and labels for calling our Keras model.
      for tensor, layer in zip(infeed_tensors, infeed_layers):
        if layer in self.model._input_layers:
          tpu_inputs.append(layers.Input(name=layer.name, tensor=tensor))
        if layer in self.model._output_layers:
          tpu_targets.append(tensor)

      # Call our model with our infeed inputs (re-using the weights).
      model_outputs = self.model(tpu_inputs)
      child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs)
      if is_training or is_test:
        child_model.compile(
            optimizer=self.model.optimizer,
            loss=self.model.loss,
            loss_weights=self.model.loss_weights,
            metrics=self.model.metrics,
            weighted_metrics=self.model.weighted_metrics,
            target_tensors=tpu_targets,
        )

      # Compute our outfeed depending on the execution mode
      if is_training:
        child_model._make_train_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in child_model.train_function.outputs
        ]
        return [
            child_model.train_function.updates_op,
            tpu_ops.outfeed_enqueue_tuple(
                child_model.train_function.outputs, name='oufeed-enqueue-train')
        ]
      elif is_test:
        child_model._make_test_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in child_model.test_function.outputs
        ]
        return [
            tpu_ops.outfeed_enqueue_tuple(
                child_model.test_function.outputs, name='outfeed-enqueue-test')
        ]
      elif is_predict:
        child_model._make_predict_function()
        self._outfeed_spec = [
            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
            for tensor in child_model.predict_function.outputs
        ]
        return [
            tpu_ops.outfeed_enqueue_tuple(
                child_model.predict_function.outputs,
                name='outfeed-enqueue-predict',
            )
        ]
      else:
        assert False, 'Unexpected execution mode: %s' % self.execution_mode

    # Capture outfeed metadata computed during the rewrite.
    self._outfeed_spec = None

    tpu_execute_op = tpu.rewrite(_model_fn)

    # Generate CPU side operations to enqueue features/labels and dequeue
    # outputs from the model call.
    with ops.device('/device:TPU:0'):
      infeed_tensors = []
      for spec in input_specs:
        infeed_tensors.append(
            array_ops.placeholder(
                dtype=spec.dtype,
                shape=spec.shape,
                name='infeed-enqueue-%s' % spec.name))

      infeed_op = tpu_ops.infeed_enqueue_tuple(
          infeed_tensors, [spec.shape for spec in input_specs],
          name='infeed-enqueue-%s' % self.execution_mode)

      outfeed_op = tpu_ops.outfeed_dequeue_tuple(
          dtypes=[spec.dtype for spec in self._outfeed_spec],
          shapes=[spec.shape for spec in self._outfeed_spec],
          name='outfeed-dequeue-%s' % self.execution_mode)

    return CompiledTPUOp(tpu_execute_op, infeed_tensors, infeed_op, outfeed_op)