def BuildTpuSubgraph(self):
    tf.logging.info('DecodeProgram BuildTpuSubGraph')
    py_utils.ResetStepSeed()

    # Instantiate input generator first.
    self._input = self._task_params.input.Instantiate()
    self._input.CreateTpuEnqueueOps()
    self._task_params.input.Define('skip_create_child', True, '')

    def _DecodeFn():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        with cluster_factory.SetEval(True):
          self._model = self._task_params.Instantiate()
          self._model_task = self._model.GetTask()
          self._model_task.AddChild('input', self._input)
          input_batch = self._model_task.input_generator.TpuDequeueBatch()
          metrics_dict = self._model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          return self.metrics_nm.Flatten()

    self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self.metrics = py_utils.NestedMap(self.metrics_nm)
    self.metrics = self.metrics.Pack(batch_parallel_res)
    return None
        def Bak(inputs, outputs, d_outputs):
            """Backward step."""
            del inputs  # unused
            output_acts, step_seeds = outputs
            d_outputs = d_outputs[0]

            d_layer_thetas = []
            for layer_idx in reversed(range(num_layers)):
                f_seed, g_seed = step_seeds[layer_idx]
                layer = self.sub_layers[layer_idx]
                layer_theta = theta.sub_layers[layer_idx]

                input_acts, d_inputs, d_theta = layer.ReverseAndGrad(
                    layer_theta, output_acts, d_outputs, f_seed, g_seed,
                    *extra_inputs)

                d_layer_thetas.append(d_theta)
                # Passes reconstructed inputs to the previous layer.
                output_acts = input_acts
                d_outputs = d_inputs
            py_utils.ResetStepSeed(final_step_seed)
            d_theta = py_utils.NestedMap(
                global_step=tf.zeros_like(initial_step_seed))
            d_theta.sub_layers = list(reversed(d_layer_thetas))

            extra_grads = [tf.zeros_like(t) for t in extra_inputs]
            return [
                tf.zeros_like(initial_step_seed), d_theta, d_inputs,
                extra_grads
            ]
        def Fwd(xs):
            """Forward pass."""
            initial_step_seed, theta, acts, extra_inputs = xs

            py_utils.ResetStepSeed(initial_step_seed)
            layer_step_seeds = []

            for layer_theta, layer in zip(theta.sub_layers, self.sub_layers):
                acts, f_seed, g_seed = layer.FProp(layer_theta, acts,
                                                   *extra_inputs)
                layer_step_seeds += [(f_seed, g_seed)]
            return [acts, layer_step_seeds]
  def BuildTpuSubgraph(self):
    tf.logging.info('DecodeProgram BuildTpuSubGraph')
    py_utils.ResetStepSeed()
    device_assignment = py_utils.GetTpuDeviceAssignment()
    self.spmd = self._task_params.input.use_partitioned_infeed_queue
    with py_utils.OpportunisticVariableReuseScope(True):
      with cluster_factory.SetEval(True):
        self._model = self._task_params.Instantiate()
        self._model_task = self._model.GetTask()
        self._model_task.input.CreateTpuEnqueueOps()

        def _DecodeStep():
          """Decode call to be compiled for TPU."""
          input_batch = self._model_task.input_generator.TpuDequeueBatch()
          metrics_dict = self._model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          device = tpu.core(0) if self.spmd else ''
          with tf.device(device):
            outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                self.metrics_nm.Flatten())
            return [outfeed_enqueue]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._steps_per_loop, _DecodeStep, inputs=[])

    self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
        DecodeLoopFn,
        num_shards=self.data_parallelism,
        device_assignment=device_assignment)
    # Get a list of outfeed ops.
    self.metrics = self._OutfeedDequeue()
    # Pack the list of outfeed ops with structure in self.metrics_nm.
    self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics)
    return
  def BuildTpuSubgraph(self):
    if self._ml_perf_log:
      mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size)
      mlp_log.mlperf_print(
          'max_sequence_length',
          self._ml_perf.max_sequence_length,
          metadata={'method': 'discard'})
      mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
      mlp_log.mlperf_print('opt_base_learning_rate',
                           self._ml_perf.base_learning_rate)
      mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                           self._ml_perf.warmup_steps)
      mlp_log.mlperf_print('opt_adam_beta_1', self._ml_perf.opt_adam_beta_1)
      mlp_log.mlperf_print('opt_adam_beta_2', self._ml_perf.opt_adam_beta_2)
      mlp_log.mlperf_print('opt_adam_epsilon', self._ml_perf.opt_adam_epsilon)
      mlp_log.mlperf_print('train_samples', self._ml_perf.train_samples)
      mlp_log.mlperf_print('eval_samples', self._ml_perf.eval_samples)

    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism
      self._train_task_params.input.Define('skip_create_child', True, '')
      self._train_input = self._train_task_params.input.Instantiate()
      self._train_input.CreateTpuEnqueueOps()

      self._decode_input = self._decode_task_params.input.Instantiate()
      self._decode_input.CreateTpuEnqueueOps()
      self._decode_task_params.input.Define('skip_create_child', True, '')

      def wrap_computation_in_while_loop(op_fn, n, host_device):
        """Wraps the ops generated by `op_fn` in tf.while_loop."""

        def computation(i):
          ops = op_fn()
          if not isinstance(ops, list):
            ops = [ops]
          with tf.control_dependencies(ops):
            return tf.Print(i + 1, [i], 'while_loop:')

        with tf.device(host_device):
          return tf.while_loop(
              lambda i: tf.less(i, n),
              computation, [tf.constant(0)],
              parallel_iterations=1)

      def TrainAndDecodeEpoch(i, host_device):
        """Train and decode infeed for an epoch.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          Decode with control deps on train node.
        """
        train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i)
        decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i)
        tf.logging.info('self._train_steps_per_loop: %d',
                        self._train_steps_per_loop)
        tf.logging.info('self._decode_steps_per_loop: %d',
                        self._decode_steps_per_loop)
        train = wrap_computation_in_while_loop(train_infeed_fn,
                                               self._train_steps_per_loop,
                                               host_device)
        with tf.device(host_device):
          with tf.control_dependencies([train]):
            decode = wrap_computation_in_while_loop(decode_infeed_fn,
                                                    self._decode_steps_per_loop,
                                                    host_device)
        return decode

      def TrainAndDecodeEpochLoop(i, host_device):
        """Train and decode infeed for num_epochs_per_session_run.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          tf.while_loop result.
        """
        train_and_decode_epoch_fn = lambda: TrainAndDecodeEpoch(i, host_device)
        epoch = wrap_computation_in_while_loop(train_and_decode_epoch_fn,
                                               self.num_epochs_per_session_run,
                                               host_device)
        return epoch

      num_infeed_hosts = len(self._train_input.per_host_device)
      tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)
      self.infeed_ops = []
      for i in range(num_infeed_hosts):
        host_device = self._train_input.per_host_device[i]
        self.infeed_ops.append(TrainAndDecodeEpochLoop(i, host_device))

      def TpuTrainStep():
        """Train a shard of a batch on a single TPU core.

        Do not calculate loss metrics.

        Returns:
         [train_op].
        """
        self._train_model = self._train_task_params.Instantiate()
        self._task = self._train_model.GetTask()
        self._task.AddChild('input', self._train_input)
        self._model = self._train_model
        self._train_model.ConstructFPropBPropGraph()
        return [self._train_model.GetTask().train_op]

      @tpu_function.on_device_training_loop
      def TpuTrain():
        loop_result = tpu_training_loop.repeat(
            self._train_steps_per_loop,
            TpuTrainStep,
            inputs=[],
            name='train_loop')
        return loop_result

    py_utils.ResetStepSeed()

    def _DecodeStep():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        with cluster_factory.SetEval(True):
          self._decode_model = self._decode_task_params.Instantiate()
          self._decode_model_task = self._decode_model.GetTask()
          self._decode_model_task.AddChild('input', self._decode_input)
          input_batch = self._decode_model_task.input_generator.TpuDequeueBatch(
          )
          metrics_dict = self._decode_model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          device = tpu.core(0) if self.spmd else ''
          with tf.device(device):
            outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                self.metrics_nm.Flatten())
            return [outfeed_enqueue]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._decode_steps_per_loop, _DecodeStep, inputs=[])

    def TrainAndDecode():
      with tf.control_dependencies([TpuTrain()]):
        return DecodeLoopFn()

    @OnDeviceTrainAndEvalLoops
    def TrainAndDecodeLoop():
      tpu_training_loop.repeat(
          self.num_epochs_per_session_run, TrainAndDecode, inputs=[])

    self._compile_op, self.train_and_decode_loop = tpu.split_compile_and_shard(
        TrainAndDecodeLoop,
        num_shards=data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())


    # Get a list of outfeed ops.
    self.metric_dicts = self._OutfeedDequeue()
    # Saves the graph def.
    tf.io.write_graph(tf.get_default_graph().as_graph_def(), self._logdir,
                      'train.pbtxt')
    return
    def ReverseAndGrad(self, theta, outputs, d_outputs, f_seed, g_seed,
                       *extra_inputs):
        """Implements Algorithm 1 in the revnet paper.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      outputs: A NestedMap: .split1 and .split2 corresponding to y1 and y2.
      d_outputs: A NestedMap: .split1 and .split2 corresponding to dy1 and dy2,
        the total derivatives.
      f_seed: Scalar tensor. The step seed used in forward for the f block.
      g_seed: Scalar tensor. The step seed used in forward for the g block. The
        step seeds are needed for deterministic randomness, e.g. to ensure
        dropout generate the same random mask in forward and reverse_grad.
      *extra_inputs: additional inputs that will be passed to both f and g. No
        gradient will be computed for these inputs.

    Returns:
      A tuple of NestedMaps

      - inputs: .split1 and .split2 corresponding to x1 and x2.
      - d_inputs: .split1 and .split2 corresponding to dx1 and dx2, the total
        derivatives with respect to inputs.
      - d_theta: has the same structure as theta. The total derivatives with
        respect to weights.

    """

        # Stop gradient on the outputs to avoid circular symbolic dependency.
        y1 = tf.stop_gradient(outputs.split1)
        y2 = tf.stop_gradient(outputs.split2)
        dy1 = d_outputs.split1
        dy2 = d_outputs.split2

        # Computes the reverse.
        z1 = y1
        py_utils.ResetStepSeed(g_seed)
        gz1 = self.g_block.FProp(theta.g_block, z1, *extra_inputs)
        x2 = y2 - gz1
        py_utils.ResetStepSeed(f_seed)
        fx2 = self.f_block.FProp(theta.f_block, x2, *extra_inputs)
        x1 = z1 - fx2

        # Computes the gradients.
        dz1 = dy1 + tf.gradients(gz1, z1, dy2)[0]
        dx2 = dy2 + tf.gradients(fx2, x2, dz1)[0]

        dgw = tf.gradients(gz1,
                           theta.g_block.Flatten(),
                           dy2,
                           unconnected_gradients=tf.UnconnectedGradients.ZERO)
        dgw = theta.g_block.Pack(dgw)

        dfw = tf.gradients(fx2,
                           theta.f_block.Flatten(),
                           dz1,
                           unconnected_gradients=tf.UnconnectedGradients.ZERO)
        dfw = theta.f_block.Pack(dfw)

        return (py_utils.NestedMap(split1=x1, split2=x2),
                py_utils.NestedMap(split1=dz1, split2=dx2),
                py_utils.NestedMap(f_block=dfw,
                                   g_block=dgw,
                                   global_step=tf.zeros_like(
                                       theta.global_step)))
    def FProp(self, theta, inputs, *extra_inputs):

        initial_step_seed = py_utils.GetStepSeed()
        final_step_seed = py_utils.GenerateSeedFromName(
            tf.no_op(name='new_step_seed').name)
        num_layers = len(self.sub_layers)

        def Bak(inputs, outputs, d_outputs):
            """Backward step."""
            del inputs  # unused
            output_acts, step_seeds = outputs
            d_outputs = d_outputs[0]

            d_layer_thetas = []
            for layer_idx in reversed(range(num_layers)):
                f_seed, g_seed = step_seeds[layer_idx]
                layer = self.sub_layers[layer_idx]
                layer_theta = theta.sub_layers[layer_idx]

                input_acts, d_inputs, d_theta = layer.ReverseAndGrad(
                    layer_theta, output_acts, d_outputs, f_seed, g_seed,
                    *extra_inputs)

                d_layer_thetas.append(d_theta)
                # Passes reconstructed inputs to the previous layer.
                output_acts = input_acts
                d_outputs = d_inputs
            py_utils.ResetStepSeed(final_step_seed)
            d_theta = py_utils.NestedMap(
                global_step=tf.zeros_like(initial_step_seed))
            d_theta.sub_layers = list(reversed(d_layer_thetas))

            extra_grads = [tf.zeros_like(t) for t in extra_inputs]
            return [
                tf.zeros_like(initial_step_seed), d_theta, d_inputs,
                extra_grads
            ]

        def Fwd(xs):
            """Forward pass."""
            initial_step_seed, theta, acts, extra_inputs = xs

            py_utils.ResetStepSeed(initial_step_seed)
            layer_step_seeds = []

            for layer_theta, layer in zip(theta.sub_layers, self.sub_layers):
                acts, f_seed, g_seed = layer.FProp(layer_theta, acts,
                                                   *extra_inputs)
                layer_step_seeds += [(f_seed, g_seed)]
            return [acts, layer_step_seeds]

        if self.params.custom_gradient:
            acts, _ = py_utils.CallDefun(
                Fwd, Bak, [initial_step_seed, theta, inputs, extra_inputs])
            py_utils.ResetStepSeed(final_step_seed)
            return acts
        else:
            acts = inputs
            for layer_theta, layer in zip(theta.sub_layers, self.sub_layers):
                acts, _, _ = layer.FProp(layer_theta, acts, *extra_inputs)
            return acts