Esempio n. 1
0
  def _CompileDecodeFn(self):
    """Wrap the DecodeFn with split_compile_and_shard."""
    with cluster_factory.SetImmediatelyInstantiateVariables(False):
      self._model = self._InstantiateTaskModel(self._task_params)
    self._task = self._model.GetTask()
    self._task.input.InstantiateVariables()
    self._task.input.CreateTpuEnqueueOps()
    self._task.input.CreateCpuPassthroughEnqueueOps()

    def _DecodeFn():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model.InstantiateVariables()
        input_batch = self._task.input.TpuDequeueBatch()
        decode_dict = self._task.Decode(input_batch)
      self.decode_nm = py_utils.NestedMap(decode_dict)
      return self.decode_nm.Flatten()

    self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self.cpu_pt = self._task.input.DequeueCpuPassthrough()
    self.decode_tensors = py_utils.NestedMap(self.decode_nm)
    self.decode_tensors = self.decode_tensors.Pack(batch_parallel_res)
Esempio n. 2
0
  def BuildTpuSubgraph(self):
    tf.logging.info('EvalProgram BuildTpuSubGraph')
    with cluster_factory.SetEval(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism
      with cluster_factory.SetImmediatelyInstantiateVariables(False):
        self._model = self._InstantiateTaskModel(self._task_params)
      self._task = self._model.GetTask()
      self._task.input.InstantiateVariables()
      self._task.input.CreateTpuEnqueueOps()
      self._init_input_ops = self._task.input.InitOps()

      # XLA thinks self.TpuEvalLoop() requires 1 argument due to self
      # Trick it with wrapper function
      def TpuEvalLoopWrapper():
        return self.TpuEvalLoop()

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuEvalLoopWrapper,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference')

      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res]]
      return self.tpu_ops
Esempio n. 3
0
  def _CompileDecodeLoop(self):
    """Wrap the DecodeLoop with split_compile_and_shard."""
    device_assignment = py_utils.GetTpuDeviceAssignment()
    with cluster_factory.SetImmediatelyInstantiateVariables(False):
      self._model = self._InstantiateTaskModel(self._task_params)
    self._task = self._model.GetTask()
    self._task.input.InstantiateVariables()
    self._task.input.CreateTpuEnqueueOps()
    self._task.input.CreateCpuPassthroughEnqueueOps()

    def _DecodeStep():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model.InstantiateVariables()
        input_batch = self._task.input.TpuDequeueBatch()
        decode_dict = self._task.Decode(input_batch)
      self.decode_nm = py_utils.NestedMap(decode_dict)
      return [self._OutfeedEnqueue(decode_dict)]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._steps_per_loop, _DecodeStep, inputs=[])

    self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
        DecodeLoopFn,
        num_shards=self.data_parallelism,
        device_assignment=device_assignment)

    # Get a list of outfeed ops.
    self.decode_tensors = self._OutfeedDequeue()
    # Pack the list of outfeed ops with structure in self.decode_nm.
    self.decode_tensors = tf.nest.pack_sequence_as(self.decode_nm,
                                                   self.decode_tensors)
    self.cpu_pt = self._task.input.DequeueCpuPassthrough()
Esempio n. 4
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()

        with cluster_factory.SetEval(True):
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._task_params.Instantiate()
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def _DecodeFn():
                """Decode call to be compiled for TPU."""
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                self.metrics_nm = py_utils.NestedMap(metrics_dict)
                return self.metrics_nm.Flatten()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            _DecodeFn,
            num_shards=self.data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Esempio n. 5
0
    def BuildTpuSubgraph(self):
        tf.logging.info('EvalProgram BuildTpuSubGraph')
        with cluster_factory.SetEval(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._InstantiateTaskModel(self._task_params)
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Summed eval metrics.
        """
                with tf.name_scope('tpu_eval'):
                    with py_utils.OpportunisticVariableReuseScope(True):
                        self._model.InstantiateVariables()
                        self._model.ConstructFPropGraph()
                    per_step_eval_metrics = self._eval_metrics.SetMetrics(
                        self._task.eval_metrics, args)
                    summed_metrics = []
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                    return summed_metrics

            @tpu_function.on_device_training_loop
            def TpuEval():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuEvalStep,
                    inputs=self._eval_metrics.initial_values,
                    name='eval_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuEval,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())

            self._task.input.CreateTpuEmbeddingEnqueueOps(
                mode_override='inference')

            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res]]

            return self.tpu_ops
Esempio n. 6
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()
        device_assignment = py_utils.GetTpuDeviceAssignment()
        self.spmd = self._task_params.input.use_partitioned_infeed_queue
        with cluster_factory.SetEval(True):
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._task_params.Instantiate()
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def _DecodeStep():
                """Decode call to be compiled for TPU."""
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                self.metrics_nm = py_utils.NestedMap(metrics_dict)
                device = tpu.core(0) if self.spmd else ''
                with tf.device(device):
                    outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                        self.metrics_nm.Flatten())
                    return [outfeed_enqueue]

        @tpu_function.on_device_training_loop
        def DecodeLoopFn():
            return tpu_training_loop.repeat(self._steps_per_loop,
                                            _DecodeStep,
                                            inputs=[])

        self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
            DecodeLoopFn,
            num_shards=self.data_parallelism,
            device_assignment=device_assignment)
        # Get a list of outfeed ops.
        self.metrics = self._OutfeedDequeue()
        # Pack the list of outfeed ops with structure in self.metrics_nm.
        self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics)
        return
Esempio n. 7
0
    def BuildTpuSubgraph(self):
        if self._ml_perf_log:
            mlp_log.mlperf_print('global_batch_size',
                                 self._ml_perf.global_batch_size)
            mlp_log.mlperf_print('max_sequence_length',
                                 self._ml_perf.max_sequence_length)
            mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
            mlp_log.mlperf_print('opt_base_learning_rate',
                                 self._ml_perf.base_learning_rate)
            mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                                 self._ml_perf.warmup_steps)

        self._eval_metrics = metrics.TpuEvalMetrics()
        data_parallelism = self.data_parallelism
        with cluster_factory.SetImmediatelyInstantiateVariables(False):
            self._train_model = self._train_task_params.Instantiate()
        self._train_task = self._train_model.GetTask()
        self._train_task.input.InstantiateVariables()
        self._train_task.input.CreateTpuEnqueueOps()
        self._model = self._train_model

        def TpuTrainStep():
            """Train a shard of a batch on a single TPU core.

      Do not calculate loss metrics.

      Returns:
       [train_op].
      """
            with py_utils.OpportunisticVariableReuseScope(True):
                self._train_model.InstantiateVariables()
                self._train_model.ConstructFPropBPropGraph()
            return [self._train_task.train_op]

        def TpuTrain():
            loop_result = tpu_training_loop.repeat(self._train_steps_per_loop,
                                                   TpuTrainStep,
                                                   inputs=[],
                                                   name='train_loop')
            return loop_result

        py_utils.ResetStepSeed()

        with cluster_factory.SetImmediatelyInstantiateVariables(False):
            self._decode_model = self._InstantiateTaskModel(
                self._decode_task_params)
        self._decode_task = self._decode_model.GetTask()
        self._decode_task.input.InstantiateVariables()
        self._decode_task.input.CreateTpuEnqueueOps()

        def _DecodeFn():
            """Decode call to be compiled for TPU."""
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._decode_model.InstantiateVariables()
                    input_batch = self._decode_task.input.TpuDequeueBatch()
                    metrics_dict = self._decode_task.Decode(input_batch)
            self.metrics_nm = py_utils.NestedMap(metrics_dict)
            return self.metrics_nm.Flatten()

        @tpu_function.on_device_training_loop
        def TrainAndDecode():
            with tf.control_dependencies([TpuTrain()]):
                return _DecodeFn()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            TrainAndDecode,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Esempio n. 8
0
    def BuildTpuSubgraph(self):
        tf.logging.info('TrainProgram BuildTpuSubGraph')
        self.spmd = (self.params.spmd
                     or self._task_params.input.use_partitioned_infeed_queue)

        self._eval_metrics = metrics.TpuEvalMetrics()
        data_parallelism = self.data_parallelism

        with cluster_factory.SetImmediatelyInstantiateVariables(False):
            self._model = self._InstantiateTaskModel(self._task_params)
        self._task = self._model.GetTask()
        self._task.input.InstantiateVariables()
        self._task.input.CreateTpuEnqueueOps()

        def TpuTrainStep(*args):
            """Train a shard of a batch on a single TPU core.

      Args:
        *args: metrics values from previous steps.

      Returns:
        New summed metrics values and a train_op.
      """
            with tf.name_scope('tpu_train'):
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._task.eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._task.per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._task.train_op]

        @tpu_function.on_device_training_loop
        def TpuTrain():
            loop_result = tpu_training_loop.repeat(
                self._steps_per_loop,
                TpuTrainStep,
                inputs=self._eval_metrics.initial_values,
                name='train_loop')
            # Final metrics are the avg across self._steps_per_loop steps.
            return self._eval_metrics.FinalizeMetrics(loop_result)

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            TpuTrain,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())
        outfeed_dequeue_op = self._OutfeedDequeueLoop(
            self._task.per_example_tensors, self._steps_per_loop,
            self.num_splits_per_client)

        self._task.input.CreateTpuEmbeddingEnqueueOps()

        # Get metric result from a single replica; they are all same here.

        def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op):
            """Returns the op for tpu training with tail cpu computation."""
            # Adds a tail computation that is run after the tpu_training loop
            # step finishes. This allows us to run certain computation that
            # acts on the variable between tpu_train_loop iterations and
            # amortizing the cost of the operations. Alternative of running
            # tpu.outside_compilation & using tf.cond is expenseive.
            with tf.control_dependencies(train_loop_op):
                self._model.ConstructPostTrainingLoop()
                with tf.control_dependencies(
                    [self._task.post_training_loop_op]):
                    return ([[tf.identity(o) for o in train_loop_op],
                             outfeed_dequeue_op])

        # Get metric result from a single replica; they are all same here.
        all_tpu_ops = [t[0] for t in batch_parallel_res]
        self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops,
                                                   outfeed_dequeue_op))
        self._model_analysis, self._total_num_params = summary_utils.ModelAnalysis(
            self._model)
        try:
            with tf.io.gfile.GFile(
                    os.path.join(self._program_dir, 'model_analysis.txt'),
                    'w') as f:
                f.write(self._model_analysis)
        except tf.errors.NotFoundError as e:
            tf.logging.info('Failed to write model analysis %s', e)

        return self.tpu_ops