def BuildTpuSubgraph(self):
    tf.logging.info('DecodeProgram BuildTpuSubGraph')
    py_utils.ResetStepSeed()

    # Instantiate input generator first.
    self._input = self._task_params.input.Instantiate()
    self._input.CreateTpuEnqueueOps()
    self._task_params.input.Define('skip_create_child', True, '')

    def _DecodeFn():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        with cluster_factory.SetEval(True):
          self._model = self._task_params.Instantiate()
          self._model_task = self._model.GetTask()
          self._model_task.AddChild('input', self._input)
          input_batch = self._model_task.input_generator.TpuDequeueBatch()
          metrics_dict = self._model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          return self.metrics_nm.Flatten()

    self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self.metrics = py_utils.NestedMap(self.metrics_nm)
    self.metrics = self.metrics.Pack(batch_parallel_res)
    return None
  def _OutfeedDequeue(self):
    """Collect outfeed dequeue from all devices."""
    num_outfeeds = len(self.metrics_nm.Flatten())
    outfeed_dicts = []
    concat_lists = {}
    # Hard-coding for Transformer/MLPerf.
    keys = ['target_ids', 'eval_weight', 'tlen', 'top_ids', 'top_lens']
    concat_dict = {}
    for key in keys:
      concat_lists[key] = []

    device_assignment = py_utils.GetTpuDeviceAssignment()
    assert device_assignment
    for replica in range(device_assignment.num_replicas):
      num_cores_per_replica = 1 if self.spmd else (
          device_assignment.num_cores_per_replica)
      for core in range(num_cores_per_replica):
        with tf.device(device_assignment.host_device(replica, core)):
          outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
              dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
              shapes=[x.shape for x in self.metrics_nm.Flatten()],
              device_ordinal=device_assignment.tpu_ordinal(replica, core))
          packed = tf.nest.pack_sequence_as(self.metrics_nm, outfeeds_per_core)
          outfeed_dict = self._decode_model_task.PostProcessDecodeHost(packed)
          for key in keys:
            concat_lists[key].append(outfeed_dict[key])

    for key in keys:
      concat_dict[key] = tf.concat(concat_lists[key], 0)
    return concat_dict
    def LoopBody(i, *input_arrays):
      """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
      # Outfeed ops execute on each JF node, so they must be located on the
      # nodes.
      outfeed_devices = []
      device_assignment = py_utils.GetTpuDeviceAssignment()
      assert device_assignment
      for replica in range(device_assignment.num_replicas):
        for core in range(device_assignment.num_cores_per_replica):
          with tf.device(device_assignment.host_device(replica, core)):
            outfeed_devices.append(
                tpu_ops.outfeed_dequeue_tuple(
                    tensor_types,
                    tensor_shapes,
                    device_ordinal=device_assignment.tpu_ordinal(replica,
                                                                 core)))
      offset = i * num_devices
      output_arrays = list(input_arrays)
      # Each output_array holds a different per-example tensor. We get results
      # for each tensor from each TPU for each TpuTrainStep call.
      for j in range(len(output_arrays)):
        for k in range(len(outfeed_devices)):
          output_arrays[j] = output_arrays[j].write(offset + k,
                                                    outfeed_devices[k][j])

      return tuple([i + 1] + output_arrays)
 def TPUOrdinalFunction(shard_index_in_host):
     device_assignment = py_utils.GetTpuDeviceAssignment()
     if device_assignment:
         # We put both enqueue/dequeue ops at core 0 in each replica.
         replica = device_assignment.lookup_replicas(
             task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
         return device_assignment.tpu_ordinal(replica=replica)
     else:
         return shard_index_in_host
  def BuildTpuSubgraph(self):
    tf.logging.info('TrainProgram BuildTpuSubGraph')

    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism

      # Instantiate input generator first.
      self._input = self._task_params.input.Instantiate()
      self._input.CreateTpuEnqueueOps()
      self._task_params.input.Define('skip_create_child', True, '')

      def TpuTrainStep(*args):
        """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
        self._model = self._task_params.Instantiate()
        self._task = self._model.GetTask()
        self._task.AddChild('input', self._input)
        self._model.ConstructFPropBPropGraph()
        per_step_eval_metrics = self._eval_metrics.SetMetrics(
            self._task.eval_metrics, args)
        outfeed_op = self._OutfeedEnqueue(self._task.per_example_tensors)
        summed_metrics = []
        assert len(per_step_eval_metrics) == len(args)
        with tf.control_dependencies([outfeed_op]):
          for x, y in zip(per_step_eval_metrics, args):
            summed_metrics.append(x + y)
        return summed_metrics + [self._model.GetTask().train_op]

      @tpu_function.on_device_training_loop
      def TpuTrain():
        loop_result = tpu_training_loop.repeat(
            self._steps_per_loop,
            TpuTrainStep,
            inputs=self._eval_metrics.initial_values,
            name='train_loop')
        # Final metrics are the avg across self._steps_per_loop steps.
        return self._eval_metrics.FinalizeMetrics(loop_result)

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuTrain,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      outfeed_dequeue_op = self._OutfeedDequeueLoop(
          self._model.GetTask().per_example_tensors, self._steps_per_loop,
          self.num_splits_per_client)
      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op]

    return self.tpu_ops
  def BuildTpuSubgraph(self):
    tf.logging.info('EvalProgram BuildTpuSubGraph')
    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism

      self._input = self._task_params.input.Instantiate()
      self._input.CreateTpuEnqueueOps()
      self._task_params.input.Define('skip_create_child', True, '')

      def TpuEvalStep(*args):
        """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Summed eval metrics.
        """
        with cluster_factory.SetEval(True):
          self._model = self._task_params.Instantiate()
          self._task = self._model.GetTask()
          self._task.AddChild('input', self._input)

          self._model.ConstructFPropGraph()
          per_step_eval_metrics = self._eval_metrics.SetMetrics(
              self._model.GetTask().eval_metrics, args)
          summed_metrics = []
          for x, y in zip(per_step_eval_metrics, args):
            summed_metrics.append(x + y)
          return summed_metrics

      @tpu_function.on_device_training_loop
      def TpuEval():
        loop_result = tpu_training_loop.repeat(
            self._steps_per_loop,
            TpuEvalStep,
            inputs=self._eval_metrics.initial_values,
            name='eval_loop')
        # Final metrics are the avg across self._steps_per_loop steps.
        return self._eval_metrics.FinalizeMetrics(loop_result)

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuEval,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res]]

      return self.tpu_ops
 def _OutfeedDequeue(self):
   """Collect outfeed dequeue from all devices."""
   num_outfeeds = len(self.metrics_nm.Flatten())
   outfeed_ops = [[]] * num_outfeeds
   device_assignment = py_utils.GetTpuDeviceAssignment()
   assert device_assignment
   for replica in range(device_assignment.num_replicas):
     num_cores_per_replica = 1 if self.spmd else (
         device_assignment.num_cores_per_replica)
     for core in range(num_cores_per_replica):
       with tf.device(device_assignment.host_device(replica, core)):
         outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
             dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
             shapes=[x.shape for x in self.metrics_nm.Flatten()],
             device_ordinal=device_assignment.tpu_ordinal(replica, core))
         for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
           outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed]
   return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
  def BuildTpuSubgraph(self):
    tf.logging.info('DecodeProgram BuildTpuSubGraph')
    py_utils.ResetStepSeed()
    device_assignment = py_utils.GetTpuDeviceAssignment()
    self.spmd = self._task_params.input.use_partitioned_infeed_queue
    with py_utils.OpportunisticVariableReuseScope(True):
      with cluster_factory.SetEval(True):
        self._model = self._task_params.Instantiate()
        self._model_task = self._model.GetTask()
        self._model_task.input.CreateTpuEnqueueOps()

        def _DecodeStep():
          """Decode call to be compiled for TPU."""
          input_batch = self._model_task.input_generator.TpuDequeueBatch()
          metrics_dict = self._model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          device = tpu.core(0) if self.spmd else ''
          with tf.device(device):
            outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                self.metrics_nm.Flatten())
            return [outfeed_enqueue]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._steps_per_loop, _DecodeStep, inputs=[])

    self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
        DecodeLoopFn,
        num_shards=self.data_parallelism,
        device_assignment=device_assignment)
    # Get a list of outfeed ops.
    self.metrics = self._OutfeedDequeue()
    # Pack the list of outfeed ops with structure in self.metrics_nm.
    self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics)
    return
  def BuildTpuSubgraph(self):
    if self._ml_perf_log:
      mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size)
      mlp_log.mlperf_print(
          'max_sequence_length',
          self._ml_perf.max_sequence_length,
          metadata={'method': 'discard'})
      mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
      mlp_log.mlperf_print('opt_base_learning_rate',
                           self._ml_perf.base_learning_rate)
      mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                           self._ml_perf.warmup_steps)
      mlp_log.mlperf_print('opt_adam_beta_1', self._ml_perf.opt_adam_beta_1)
      mlp_log.mlperf_print('opt_adam_beta_2', self._ml_perf.opt_adam_beta_2)
      mlp_log.mlperf_print('opt_adam_epsilon', self._ml_perf.opt_adam_epsilon)
      mlp_log.mlperf_print('train_samples', self._ml_perf.train_samples)
      mlp_log.mlperf_print('eval_samples', self._ml_perf.eval_samples)

    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism
      self._train_task_params.input.Define('skip_create_child', True, '')
      self._train_input = self._train_task_params.input.Instantiate()
      self._train_input.CreateTpuEnqueueOps()

      self._decode_input = self._decode_task_params.input.Instantiate()
      self._decode_input.CreateTpuEnqueueOps()
      self._decode_task_params.input.Define('skip_create_child', True, '')

      def wrap_computation_in_while_loop(op_fn, n, host_device):
        """Wraps the ops generated by `op_fn` in tf.while_loop."""

        def computation(i):
          ops = op_fn()
          if not isinstance(ops, list):
            ops = [ops]
          with tf.control_dependencies(ops):
            return tf.Print(i + 1, [i], 'while_loop:')

        with tf.device(host_device):
          return tf.while_loop(
              lambda i: tf.less(i, n),
              computation, [tf.constant(0)],
              parallel_iterations=1)

      def TrainAndDecodeEpoch(i, host_device):
        """Train and decode infeed for an epoch.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          Decode with control deps on train node.
        """
        train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i)
        decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i)
        tf.logging.info('self._train_steps_per_loop: %d',
                        self._train_steps_per_loop)
        tf.logging.info('self._decode_steps_per_loop: %d',
                        self._decode_steps_per_loop)
        train = wrap_computation_in_while_loop(train_infeed_fn,
                                               self._train_steps_per_loop,
                                               host_device)
        with tf.device(host_device):
          with tf.control_dependencies([train]):
            decode = wrap_computation_in_while_loop(decode_infeed_fn,
                                                    self._decode_steps_per_loop,
                                                    host_device)
        return decode

      def TrainAndDecodeEpochLoop(i, host_device):
        """Train and decode infeed for num_epochs_per_session_run.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          tf.while_loop result.
        """
        train_and_decode_epoch_fn = lambda: TrainAndDecodeEpoch(i, host_device)
        epoch = wrap_computation_in_while_loop(train_and_decode_epoch_fn,
                                               self.num_epochs_per_session_run,
                                               host_device)
        return epoch

      num_infeed_hosts = len(self._train_input.per_host_device)
      tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)
      self.infeed_ops = []
      for i in range(num_infeed_hosts):
        host_device = self._train_input.per_host_device[i]
        self.infeed_ops.append(TrainAndDecodeEpochLoop(i, host_device))

      def TpuTrainStep():
        """Train a shard of a batch on a single TPU core.

        Do not calculate loss metrics.

        Returns:
         [train_op].
        """
        self._train_model = self._train_task_params.Instantiate()
        self._task = self._train_model.GetTask()
        self._task.AddChild('input', self._train_input)
        self._model = self._train_model
        self._train_model.ConstructFPropBPropGraph()
        return [self._train_model.GetTask().train_op]

      @tpu_function.on_device_training_loop
      def TpuTrain():
        loop_result = tpu_training_loop.repeat(
            self._train_steps_per_loop,
            TpuTrainStep,
            inputs=[],
            name='train_loop')
        return loop_result

    py_utils.ResetStepSeed()

    def _DecodeStep():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        with cluster_factory.SetEval(True):
          self._decode_model = self._decode_task_params.Instantiate()
          self._decode_model_task = self._decode_model.GetTask()
          self._decode_model_task.AddChild('input', self._decode_input)
          input_batch = self._decode_model_task.input_generator.TpuDequeueBatch(
          )
          metrics_dict = self._decode_model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          device = tpu.core(0) if self.spmd else ''
          with tf.device(device):
            outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                self.metrics_nm.Flatten())
            return [outfeed_enqueue]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._decode_steps_per_loop, _DecodeStep, inputs=[])

    def TrainAndDecode():
      with tf.control_dependencies([TpuTrain()]):
        return DecodeLoopFn()

    @OnDeviceTrainAndEvalLoops
    def TrainAndDecodeLoop():
      tpu_training_loop.repeat(
          self.num_epochs_per_session_run, TrainAndDecode, inputs=[])

    self._compile_op, self.train_and_decode_loop = tpu.split_compile_and_shard(
        TrainAndDecodeLoop,
        num_shards=data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())


    # Get a list of outfeed ops.
    self.metric_dicts = self._OutfeedDequeue()
    # Saves the graph def.
    tf.io.write_graph(tf.get_default_graph().as_graph_def(), self._logdir,
                      'train.pbtxt')
    return