def BuildTpuSubgraph(self):
    tf.logging.info('DecodeProgram BuildTpuSubGraph')
    py_utils.ResetStepSeed()

    # Instantiate input generator first.
    self._input = self._task_params.input.Instantiate()
    self._input.CreateTpuEnqueueOps()
    self._task_params.input.Define('skip_create_child', True, '')

    def _DecodeFn():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        with cluster_factory.SetEval(True):
          self._model = self._task_params.Instantiate()
          self._model_task = self._model.GetTask()
          self._model_task.AddChild('input', self._input)
          input_batch = self._model_task.input_generator.TpuDequeueBatch()
          metrics_dict = self._model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          return self.metrics_nm.Flatten()

    self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self.metrics = py_utils.NestedMap(self.metrics_nm)
    self.metrics = self.metrics.Pack(batch_parallel_res)
    return None
  def BuildTpuSubgraph(self):
    tf.logging.info('TrainProgram BuildTpuSubGraph')

    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism

      # Instantiate input generator first.
      self._input = self._task_params.input.Instantiate()
      self._input.CreateTpuEnqueueOps()
      self._task_params.input.Define('skip_create_child', True, '')

      def TpuTrainStep(*args):
        """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
        self._model = self._task_params.Instantiate()
        self._task = self._model.GetTask()
        self._task.AddChild('input', self._input)
        self._model.ConstructFPropBPropGraph()
        per_step_eval_metrics = self._eval_metrics.SetMetrics(
            self._task.eval_metrics, args)
        outfeed_op = self._OutfeedEnqueue(self._task.per_example_tensors)
        summed_metrics = []
        assert len(per_step_eval_metrics) == len(args)
        with tf.control_dependencies([outfeed_op]):
          for x, y in zip(per_step_eval_metrics, args):
            summed_metrics.append(x + y)
        return summed_metrics + [self._model.GetTask().train_op]

      @tpu_function.on_device_training_loop
      def TpuTrain():
        loop_result = tpu_training_loop.repeat(
            self._steps_per_loop,
            TpuTrainStep,
            inputs=self._eval_metrics.initial_values,
            name='train_loop')
        # Final metrics are the avg across self._steps_per_loop steps.
        return self._eval_metrics.FinalizeMetrics(loop_result)

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuTrain,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      outfeed_dequeue_op = self._OutfeedDequeueLoop(
          self._model.GetTask().per_example_tensors, self._steps_per_loop,
          self.num_splits_per_client)
      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op]

    return self.tpu_ops
  def BuildTpuSubgraph(self):
    tf.logging.info('EvalProgram BuildTpuSubGraph')
    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism

      self._input = self._task_params.input.Instantiate()
      self._input.CreateTpuEnqueueOps()
      self._task_params.input.Define('skip_create_child', True, '')

      def TpuEvalStep(*args):
        """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Summed eval metrics.
        """
        with cluster_factory.SetEval(True):
          self._model = self._task_params.Instantiate()
          self._task = self._model.GetTask()
          self._task.AddChild('input', self._input)

          self._model.ConstructFPropGraph()
          per_step_eval_metrics = self._eval_metrics.SetMetrics(
              self._model.GetTask().eval_metrics, args)
          summed_metrics = []
          for x, y in zip(per_step_eval_metrics, args):
            summed_metrics.append(x + y)
          return summed_metrics

      @tpu_function.on_device_training_loop
      def TpuEval():
        loop_result = tpu_training_loop.repeat(
            self._steps_per_loop,
            TpuEvalStep,
            inputs=self._eval_metrics.initial_values,
            name='eval_loop')
        # Final metrics are the avg across self._steps_per_loop steps.
        return self._eval_metrics.FinalizeMetrics(loop_result)

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuEval,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res]]

      return self.tpu_ops
  def BuildTpuSubgraph(self):
    tf.logging.info('DecodeProgram BuildTpuSubGraph')
    py_utils.ResetStepSeed()
    device_assignment = py_utils.GetTpuDeviceAssignment()
    self.spmd = self._task_params.input.use_partitioned_infeed_queue
    with py_utils.OpportunisticVariableReuseScope(True):
      with cluster_factory.SetEval(True):
        self._model = self._task_params.Instantiate()
        self._model_task = self._model.GetTask()
        self._model_task.input.CreateTpuEnqueueOps()

        def _DecodeStep():
          """Decode call to be compiled for TPU."""
          input_batch = self._model_task.input_generator.TpuDequeueBatch()
          metrics_dict = self._model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          device = tpu.core(0) if self.spmd else ''
          with tf.device(device):
            outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                self.metrics_nm.Flatten())
            return [outfeed_enqueue]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._steps_per_loop, _DecodeStep, inputs=[])

    self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
        DecodeLoopFn,
        num_shards=self.data_parallelism,
        device_assignment=device_assignment)
    # Get a list of outfeed ops.
    self.metrics = self._OutfeedDequeue()
    # Pack the list of outfeed ops with structure in self.metrics_nm.
    self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics)
    return
    def build_model(self, model_fn, params):
        """Build the TPU model and infeed enqueue ops."""
        tf.logging.info("TrainLowLevelRunner: build_model method")

        def tpu_train_step(loss):
            """Generate the TPU graph."""
            del loss
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            core_id = unflattened_inputs["core_id"]
            new_features = {}
            for k in features:
                s = features[k].shape.as_list()
                s = [self.hparams.num_shards, s[0] // self.hparams.num_shards
                     ] + s[1:]
                new_features[k] = tf.squeeze(
                    tf.gather(
                        tf.reshape(tpu_ops.cross_replica_sum(features[k]), s),
                        core_id), [0])

            estimator_spec = model_fn(new_features, None,
                                      tf.estimator.ModeKeys.TRAIN, params)
            loss, train_op = estimator_spec.loss, estimator_spec.train_op
            with tf.control_dependencies([train_op]):
                return tf.identity(loss)

        @tpu_function.on_device_training_loop
        def train_loop():
            return training_loop.repeat(self.iterations, tpu_train_step,
                                        [_INITIAL_LOSS])

        def tpu_eval_step():
            """Generate the TPU graph."""
            values = self.eval_infeed_queue[0].generate_dequeue_op(
                tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.eval_feature_structure, values)
            features = unflattened_inputs["features"]
            estimator_spec = model_fn(features, None,
                                      tf.estimator.ModeKeys.PREDICT, params)
            for k, v in six.iteritems(estimator_spec.predictions):
                self.outfeed_names.append(k)
                self.outfeed_tensors.append(v)

            with tf.device(
                    device_for_tpu_core(get_host(self.resolver,
                                                 self.hparams))):
                outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple(
                    self.outfeed_tensors)
            with tf.control_dependencies([outfeed_enqueue_ops]):
                return tf.no_op()

        @tpu_function.on_device_training_loop
        def eval_loop():
            if self.eval_steps > 0:
                return training_loop.repeat(self.eval_steps, tpu_eval_step, [])
            else:
                return tf.no_op()

        def train_eval_step():
            with tf.control_dependencies(train_loop()):
                return eval_loop()

        def train_eval_loop():
            return training_loop.repeat(self.hparams.max_train_epochs,
                                        train_eval_step, [])

        def create_dequeue_ops(host_id):
            """Create outfeed dequeue ops."""
            dequeue_ops = []
            tensor_dtypes = []
            tensor_shapes = []
            for v in self.outfeed_tensors:
                dequeue_ops.append([])
                tensor_dtypes.append(v.dtype)
                tensor_shapes.append(v.shape)
            for i in range(self.hparams.num_shards_per_host):
                with tf.device(
                        device_for_host(
                            get_host(self.resolver, self.hparams, host_id))):
                    outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
                        dtypes=tensor_dtypes,
                        shapes=tensor_shapes,
                        device_ordinal=i)
                    for j, item in enumerate(outfeed_tensors):
                        dequeue_ops[j].append(item)
            for j in range(len(outfeed_tensors)):
                dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
            return dequeue_ops

        with self.graph.as_default():
            if self.eval_steps <= 0:
                (self.loss, ) = tpu.shard(
                    train_loop,
                    inputs=[],
                    num_shards=self.hparams.num_shards,
                    outputs_from_all_shards=False,
                )

            else:
                (
                    self.compile_op,
                    self.train_eval_op,
                ) = tpu.split_compile_and_shard(
                    train_eval_loop,
                    inputs=[],
                    num_shards=self.hparams.num_shards,
                    outputs_from_all_shards=False)

            if self.eval_steps > 0:
                for i in range(0, self.num_hosts):
                    self.dequeue_ops.append({})
                    host_dequeue_ops = create_dequeue_ops(i)
                    for j, dequeue_tenor in enumerate(host_dequeue_ops):
                        self.dequeue_ops[i][
                            self.outfeed_names[j]] = dequeue_tenor

            global_initializer = tf.global_variables_initializer()
            local_initializer = tf.local_variables_initializer()
            self.sess.run(global_initializer)
            self.sess.run(local_initializer)

            graph_io.write_graph(self.graph.as_graph_def(add_shapes=True),
                                 self.hparams.out_dir, "graph.pbtxt")
            self.saver = tf.train.Saver()
  def BuildTpuSubgraph(self):
    if self._ml_perf_log:
      mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size)
      mlp_log.mlperf_print(
          'max_sequence_length',
          self._ml_perf.max_sequence_length,
          metadata={'method': 'discard'})
      mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
      mlp_log.mlperf_print('opt_base_learning_rate',
                           self._ml_perf.base_learning_rate)
      mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                           self._ml_perf.warmup_steps)
      mlp_log.mlperf_print('opt_adam_beta_1', self._ml_perf.opt_adam_beta_1)
      mlp_log.mlperf_print('opt_adam_beta_2', self._ml_perf.opt_adam_beta_2)
      mlp_log.mlperf_print('opt_adam_epsilon', self._ml_perf.opt_adam_epsilon)
      mlp_log.mlperf_print('train_samples', self._ml_perf.train_samples)
      mlp_log.mlperf_print('eval_samples', self._ml_perf.eval_samples)

    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism
      self._train_task_params.input.Define('skip_create_child', True, '')
      self._train_input = self._train_task_params.input.Instantiate()
      self._train_input.CreateTpuEnqueueOps()

      self._decode_input = self._decode_task_params.input.Instantiate()
      self._decode_input.CreateTpuEnqueueOps()
      self._decode_task_params.input.Define('skip_create_child', True, '')

      def wrap_computation_in_while_loop(op_fn, n, host_device):
        """Wraps the ops generated by `op_fn` in tf.while_loop."""

        def computation(i):
          ops = op_fn()
          if not isinstance(ops, list):
            ops = [ops]
          with tf.control_dependencies(ops):
            return tf.Print(i + 1, [i], 'while_loop:')

        with tf.device(host_device):
          return tf.while_loop(
              lambda i: tf.less(i, n),
              computation, [tf.constant(0)],
              parallel_iterations=1)

      def TrainAndDecodeEpoch(i, host_device):
        """Train and decode infeed for an epoch.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          Decode with control deps on train node.
        """
        train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i)
        decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i)
        tf.logging.info('self._train_steps_per_loop: %d',
                        self._train_steps_per_loop)
        tf.logging.info('self._decode_steps_per_loop: %d',
                        self._decode_steps_per_loop)
        train = wrap_computation_in_while_loop(train_infeed_fn,
                                               self._train_steps_per_loop,
                                               host_device)
        with tf.device(host_device):
          with tf.control_dependencies([train]):
            decode = wrap_computation_in_while_loop(decode_infeed_fn,
                                                    self._decode_steps_per_loop,
                                                    host_device)
        return decode

      def TrainAndDecodeEpochLoop(i, host_device):
        """Train and decode infeed for num_epochs_per_session_run.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          tf.while_loop result.
        """
        train_and_decode_epoch_fn = lambda: TrainAndDecodeEpoch(i, host_device)
        epoch = wrap_computation_in_while_loop(train_and_decode_epoch_fn,
                                               self.num_epochs_per_session_run,
                                               host_device)
        return epoch

      num_infeed_hosts = len(self._train_input.per_host_device)
      tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)
      self.infeed_ops = []
      for i in range(num_infeed_hosts):
        host_device = self._train_input.per_host_device[i]
        self.infeed_ops.append(TrainAndDecodeEpochLoop(i, host_device))

      def TpuTrainStep():
        """Train a shard of a batch on a single TPU core.

        Do not calculate loss metrics.

        Returns:
         [train_op].
        """
        self._train_model = self._train_task_params.Instantiate()
        self._task = self._train_model.GetTask()
        self._task.AddChild('input', self._train_input)
        self._model = self._train_model
        self._train_model.ConstructFPropBPropGraph()
        return [self._train_model.GetTask().train_op]

      @tpu_function.on_device_training_loop
      def TpuTrain():
        loop_result = tpu_training_loop.repeat(
            self._train_steps_per_loop,
            TpuTrainStep,
            inputs=[],
            name='train_loop')
        return loop_result

    py_utils.ResetStepSeed()

    def _DecodeStep():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        with cluster_factory.SetEval(True):
          self._decode_model = self._decode_task_params.Instantiate()
          self._decode_model_task = self._decode_model.GetTask()
          self._decode_model_task.AddChild('input', self._decode_input)
          input_batch = self._decode_model_task.input_generator.TpuDequeueBatch(
          )
          metrics_dict = self._decode_model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          device = tpu.core(0) if self.spmd else ''
          with tf.device(device):
            outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                self.metrics_nm.Flatten())
            return [outfeed_enqueue]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._decode_steps_per_loop, _DecodeStep, inputs=[])

    def TrainAndDecode():
      with tf.control_dependencies([TpuTrain()]):
        return DecodeLoopFn()

    @OnDeviceTrainAndEvalLoops
    def TrainAndDecodeLoop():
      tpu_training_loop.repeat(
          self.num_epochs_per_session_run, TrainAndDecode, inputs=[])

    self._compile_op, self.train_and_decode_loop = tpu.split_compile_and_shard(
        TrainAndDecodeLoop,
        num_shards=data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())


    # Get a list of outfeed ops.
    self.metric_dicts = self._OutfeedDequeue()
    # Saves the graph def.
    tf.io.write_graph(tf.get_default_graph().as_graph_def(), self._logdir,
                      'train.pbtxt')
    return