Ejemplo n.º 1
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()

        def _DecodeFn():
            """Decode call to be compiled for TPU."""
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._model = self._task_params.Instantiate()
                    self._model_task = self._model.GetTask()
                    if py_utils.use_tpu():
                        input_batch = self._model_task.input_generator.CreateTpuFeeds(
                        )
                    else:
                        input_batch = self._model_task.input_generator.SplitInputBatch(
                            self.cluster.num_splits_per_client)
                    metrics_dict = self._model_task.Decode(input_batch)
                    self.metrics_nm = py_utils.NestedMap(metrics_dict)
                    return self.metrics_nm.Flatten()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            _DecodeFn,
            num_shards=self.data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Ejemplo n.º 2
0
  def BuildTpuSubgraph(self):
    tf.logging.info('EvalProgram BuildTpuSubGraph')
    with cluster_factory.SetEval(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism
      with cluster_factory.SetImmediatelyInstantiateVariables(False):
        self._model = self._InstantiateTaskModel(self._task_params)
      self._task = self._model.GetTask()
      self._task.input.InstantiateVariables()
      self._task.input.CreateTpuEnqueueOps()
      self._init_input_ops = self._task.input.InitOps()

      # XLA thinks self.TpuEvalLoop() requires 1 argument due to self
      # Trick it with wrapper function
      def TpuEvalLoopWrapper():
        return self.TpuEvalLoop()

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuEvalLoopWrapper,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference')

      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res]]
      return self.tpu_ops
Ejemplo n.º 3
0
  def _CompileDecodeFn(self):
    """Wrap the DecodeFn with split_compile_and_shard."""
    with cluster_factory.SetImmediatelyInstantiateVariables(False):
      self._model = self._InstantiateTaskModel(self._task_params)
    self._task = self._model.GetTask()
    self._task.input.InstantiateVariables()
    self._task.input.CreateTpuEnqueueOps()
    self._task.input.CreateCpuPassthroughEnqueueOps()

    def _DecodeFn():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model.InstantiateVariables()
        input_batch = self._task.input.TpuDequeueBatch()
        decode_dict = self._task.Decode(input_batch)
      self.decode_nm = py_utils.NestedMap(decode_dict)
      return self.decode_nm.Flatten()

    self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self.cpu_pt = self._task.input.DequeueCpuPassthrough()
    self.decode_tensors = py_utils.NestedMap(self.decode_nm)
    self.decode_tensors = self.decode_tensors.Pack(batch_parallel_res)
Ejemplo n.º 4
0
  def _CompileDecodeLoop(self):
    """Wrap the DecodeLoop with split_compile_and_shard."""
    device_assignment = py_utils.GetTpuDeviceAssignment()
    with cluster_factory.SetImmediatelyInstantiateVariables(False):
      self._model = self._InstantiateTaskModel(self._task_params)
    self._task = self._model.GetTask()
    self._task.input.InstantiateVariables()
    self._task.input.CreateTpuEnqueueOps()
    self._task.input.CreateCpuPassthroughEnqueueOps()

    def _DecodeStep():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model.InstantiateVariables()
        input_batch = self._task.input.TpuDequeueBatch()
        decode_dict = self._task.Decode(input_batch)
      self.decode_nm = py_utils.NestedMap(decode_dict)
      return [self._OutfeedEnqueue(decode_dict)]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._steps_per_loop, _DecodeStep, inputs=[])

    self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
        DecodeLoopFn,
        num_shards=self.data_parallelism,
        device_assignment=device_assignment)

    # Get a list of outfeed ops.
    self.decode_tensors = self._OutfeedDequeue()
    # Pack the list of outfeed ops with structure in self.decode_nm.
    self.decode_tensors = tf.nest.pack_sequence_as(self.decode_nm,
                                                   self.decode_tensors)
    self.cpu_pt = self._task.input.DequeueCpuPassthrough()
Ejemplo n.º 5
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()

        # Instantiate input generator first.
        self._input = self._task_params.input.Instantiate()
        self._input.CreateTpuEnqueueOps()
        self.SkipCreateChild(self._task_params)

        def _DecodeFn():
            """Decode call to be compiled for TPU."""
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._model = self._task_params.Instantiate()
                    self._task = self._model.GetTask()
                    self._task.AddChild('input', self._input)
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                    self.metrics_nm = py_utils.NestedMap(metrics_dict)
                    return self.metrics_nm.Flatten()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            _DecodeFn,
            num_shards=self.data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Ejemplo n.º 6
0
    def BuildTpuSubgraph(self):
        tf.logging.info('TrainProgram BuildTpuSubGraph')

        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            # Instantiate input generator first.
            self._input = self._task_params.input.Instantiate()
            self._input.CreateTpuEnqueueOps()
            self.SkipCreateChild(self._task_params)

            def TpuTrainStep(*args):
                """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
                self._model = self._task_params.Instantiate()
                self._task = self._model.GetTask()
                self._task.AddChild('input', self._input)
                self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._task.eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._task.per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._task.train_op]

            @tpu_function.on_device_training_loop
            def TpuTrain():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuTrainStep,
                    inputs=self._eval_metrics.initial_values,
                    name='train_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuTrain,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())
            outfeed_dequeue_op = self._OutfeedDequeueLoop(
                self._task.per_example_tensors, self._steps_per_loop,
                self.num_splits_per_client)
            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res],
                            outfeed_dequeue_op]

        return self.tpu_ops
Ejemplo n.º 7
0
    def BuildTpuSubgraph(self):
        tf.logging.info('EvalProgram BuildTpuSubGraph')
        with cluster_factory.SetEval(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._InstantiateTaskModel(self._task_params)
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Summed eval metrics.
        """
                with tf.name_scope('tpu_eval'):
                    with py_utils.OpportunisticVariableReuseScope(True):
                        self._model.InstantiateVariables()
                        self._model.ConstructFPropGraph()
                    per_step_eval_metrics = self._eval_metrics.SetMetrics(
                        self._task.eval_metrics, args)
                    summed_metrics = []
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                    return summed_metrics

            @tpu_function.on_device_training_loop
            def TpuEval():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuEvalStep,
                    inputs=self._eval_metrics.initial_values,
                    name='eval_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuEval,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())

            self._task.input.CreateTpuEmbeddingEnqueueOps(
                mode_override='inference')

            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res]]

            return self.tpu_ops
Ejemplo n.º 8
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()
        device_assignment = py_utils.GetTpuDeviceAssignment()
        self.spmd = self._task_params.input.use_partitioned_infeed_queue
        with cluster_factory.SetEval(True):
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._task_params.Instantiate()
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def _DecodeStep():
                """Decode call to be compiled for TPU."""
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                self.metrics_nm = py_utils.NestedMap(metrics_dict)
                device = tpu.core(0) if self.spmd else ''
                with tf.device(device):
                    outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                        self.metrics_nm.Flatten())
                    return [outfeed_enqueue]

        @tpu_function.on_device_training_loop
        def DecodeLoopFn():
            return tpu_training_loop.repeat(self._steps_per_loop,
                                            _DecodeStep,
                                            inputs=[])

        self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
            DecodeLoopFn,
            num_shards=self.data_parallelism,
            device_assignment=device_assignment)
        # Get a list of outfeed ops.
        self.metrics = self._OutfeedDequeue()
        # Pack the list of outfeed ops with structure in self.metrics_nm.
        self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics)
        return
Ejemplo n.º 9
0
    def BuildTpuSubgraph(self):
        if self._ml_perf_log:
            mlp_log.mlperf_print('global_batch_size',
                                 self._ml_perf.global_batch_size)
            mlp_log.mlperf_print('max_sequence_length',
                                 self._ml_perf.max_sequence_length)
            mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
            mlp_log.mlperf_print('opt_base_learning_rate',
                                 self._ml_perf.base_learning_rate)
            mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                                 self._ml_perf.warmup_steps)

        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            def TpuTrainStep():
                """Train a shard of a batch on a single TPU core.

        Do not calculate loss metrics.

        Returns:
         [train_op].
        """
                self._train_model = self._train_task_params.Instantiate()
                self._model = self._train_model
                self._train_model.ConstructFPropBPropGraph()
                return [self._train_model.GetTask().train_op]

            def TpuTrain():
                loop_result = tpu_training_loop.repeat(
                    self._train_steps_per_loop,
                    TpuTrainStep,
                    inputs=[],
                    name='train_loop')
                return loop_result

        py_utils.ResetStepSeed()

        def _DecodeFn():
            """Decode call to be compiled for TPU."""
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._decode_model = self._decode_task_params.Instantiate()
                    self._decode_model_task = self._decode_model.GetTask()
                    if py_utils.use_tpu():
                        input_batch = self._decode_model_task.input_generator.CreateTpuFeeds(
                        )
                    else:
                        input_batch = self._decode_model_task.input_generator.SplitInputBatch(
                            self.cluster.num_splits_per_client)
                    metrics_dict = self._decode_model_task.Decode(input_batch)
                    self.metrics_nm = py_utils.NestedMap(metrics_dict)
                    return self.metrics_nm.Flatten()

        @tpu_function.on_device_training_loop
        def TrainAndDecode():
            with tf.control_dependencies([TpuTrain()]):
                return _DecodeFn()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            TrainAndDecode,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Ejemplo n.º 10
0
    def BuildTpuSubgraph(self):
        tf.logging.info('TrainProgram BuildTpuSubGraph')
        self.spmd = (self.params.spmd
                     or self._task_params.input.use_partitioned_infeed_queue)

        self._eval_metrics = metrics.TpuEvalMetrics()
        data_parallelism = self.data_parallelism

        with cluster_factory.SetImmediatelyInstantiateVariables(False):
            self._model = self._InstantiateTaskModel(self._task_params)
        self._task = self._model.GetTask()
        self._task.input.InstantiateVariables()
        self._task.input.CreateTpuEnqueueOps()

        def TpuTrainStep(*args):
            """Train a shard of a batch on a single TPU core.

      Args:
        *args: metrics values from previous steps.

      Returns:
        New summed metrics values and a train_op.
      """
            with tf.name_scope('tpu_train'):
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._task.eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._task.per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._task.train_op]

        @tpu_function.on_device_training_loop
        def TpuTrain():
            loop_result = tpu_training_loop.repeat(
                self._steps_per_loop,
                TpuTrainStep,
                inputs=self._eval_metrics.initial_values,
                name='train_loop')
            # Final metrics are the avg across self._steps_per_loop steps.
            return self._eval_metrics.FinalizeMetrics(loop_result)

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            TpuTrain,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())
        outfeed_dequeue_op = self._OutfeedDequeueLoop(
            self._task.per_example_tensors, self._steps_per_loop,
            self.num_splits_per_client)

        self._task.input.CreateTpuEmbeddingEnqueueOps()

        # Get metric result from a single replica; they are all same here.

        def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op):
            """Returns the op for tpu training with tail cpu computation."""
            # Adds a tail computation that is run after the tpu_training loop
            # step finishes. This allows us to run certain computation that
            # acts on the variable between tpu_train_loop iterations and
            # amortizing the cost of the operations. Alternative of running
            # tpu.outside_compilation & using tf.cond is expenseive.
            with tf.control_dependencies(train_loop_op):
                self._model.ConstructPostTrainingLoop()
                with tf.control_dependencies(
                    [self._task.post_training_loop_op]):
                    return ([[tf.identity(o) for o in train_loop_op],
                             outfeed_dequeue_op])

        # Get metric result from a single replica; they are all same here.
        all_tpu_ops = [t[0] for t in batch_parallel_res]
        self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops,
                                                   outfeed_dequeue_op))
        self._model_analysis, self._total_num_params = summary_utils.ModelAnalysis(
            self._model)
        try:
            with tf.io.gfile.GFile(
                    os.path.join(self._program_dir, 'model_analysis.txt'),
                    'w') as f:
                f.write(self._model_analysis)
        except tf.errors.NotFoundError as e:
            tf.logging.info('Failed to write model analysis %s', e)

        return self.tpu_ops
Ejemplo n.º 11
0
    def init_graph(self, model_params):
        """Builds moe decode graph.

    Args:
      model_params: the hyperparams of the specified model.
    """
        assert self.graph
        self.model_params = model_params
        batch_size = model_params.task.batch_size
        if (hasattr(model_params.task.builder, 'device_mesh_shape')
                and model_params.task.builder.device_mesh_shape):
            num_partitions = np.prod(
                model_params.task.builder.device_mesh_shape)
        else:
            num_partitions = model_params.task.builder.num_devices

        device_order_mode = (model_params.task.train.tpu_device_order_mode
                             or tpu_device_assignment.DeviceOrderMode.AUTO)
        self._init_tpu(num_partitions, device_order_mode)
        assert self.cluster_params  # configured by init_tpu
        self.cluster = self.cluster_params.Instantiate()

        with self.graph.as_default(), self.cluster, tf.device(
                self.cluster.GetPlacer()):
            _ = py_utils.GetOrCreateGlobalStepVar()
            self.heartbeat = tf.constant(np.pi)

            device_assignment = py_utils.GetTpuDeviceAssignment()

            tf.logging.info('Instantiating model')
            model = model_params.Instantiate()
            xformer = model.GetTask()
            self.task = xformer

            self.init_vars_op = tf.global_variables_initializer()
            self.saver = tf.train.Saver(sharded=True,
                                        reshape=self._saver_reshape)

            infeed = self._config_infeed(num_partitions=num_partitions,
                                         device_assignment=device_assignment,
                                         batch_size=batch_size)

            self.outfeed = []

            def decode_fn(*infeed_batch):  # pylint: disable=missing-docstring
                # Length 6 is passed when there is no tgt_mask (e.g. decoding) and
                # length 7 is passed when there is a tgt_mask (e.g. fprop).

                self.outfeed = self._config_outfeed(xformer, infeed_batch)

                with tf.device(tf.tpu.core(0)):
                    outfeed_op = tpu_ops.outfeed_enqueue_tuple(
                        tf.nest.flatten(self.outfeed))

                return [outfeed_op]

            @tpu_function.on_device_training_loop
            def decode_loop_fn():
                if not self.num_batches:
                    infinite_repeat(decode_fn, infeed)
                else:
                    training_loop.repeat(self.num_batches,
                                         decode_fn,
                                         infeed_queue=infeed)

            self.compile_op, self.decode_loop = tpu_lib.split_compile_and_shard(
                decode_loop_fn,
                num_shards=1,
                device_assignment=device_assignment)

            assert self.outfeed
            with tf.device(device_assignment.tpu_device(0, 0)):
                self.outfeed_op = tpu_ops.outfeed_dequeue_tuple(
                    dtypes=[x.dtype for x in tf.nest.flatten(self.outfeed)],
                    shapes=[x.shape for x in tf.nest.flatten(self.outfeed)])
    def build_model(self, model_fn, eval_model_fn, params, hparams, config):
        """Build the TPU model for training and eval."""
        tf.logging.info(
            "LowLevelRunner: build_model method for training and eval.")

        def tpu_train_step(loss):
            """Generate the TPU graph."""
            del loss
            values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.feature_structure, values)
            features = unflattened_inputs["features"]
            labels = unflattened_inputs["labels"]
            estimator_spec = model_fn(features,
                                      labels,
                                      tf.estimator.ModeKeys.TRAIN,
                                      params=params,
                                      config=config)
            loss, train_op = estimator_spec.loss, estimator_spec.train_op
            self.scaffold_fn = estimator_spec.scaffold_fn
            with tf.control_dependencies([train_op]):
                return tf.identity(loss)

        @tpu_function.on_device_training_loop
        def train_loop():
            return training_loop.repeat(self.train_steps_tensor,
                                        tpu_train_step, [_INITIAL_LOSS])

        def tpu_eval_step():
            """Generate the TPU graph."""
            values = self.eval_infeed_queue[0].generate_dequeue_op(
                tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.eval_feature_structure, values)
            features = unflattened_inputs["features"]
            estimator_spec = eval_model_fn(features,
                                           None,
                                           tf.estimator.ModeKeys.PREDICT,
                                           params=params,
                                           config=config)
            for k, v in six.iteritems(estimator_spec.predictions):
                self.outfeed_names.append(k)
                self.outfeed_tensors.append(v)

            with tf.device(
                    low_level_utils.device_for_tpu_core(self._get_host(0))):
                outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple(
                    self.outfeed_tensors)
            with tf.control_dependencies([outfeed_enqueue_ops]):
                return tf.no_op()

        @tpu_function.on_device_training_loop
        def eval_loop():
            return training_loop.repeat(self.eval_steps_tensor, tpu_eval_step,
                                        [])

        def train_eval_step():
            with tf.control_dependencies(train_loop()):
                return eval_loop()

        @tpu_function.on_device_training_loop
        def train_eval_loop():
            return training_loop.repeat(self.num_epochs_tensor,
                                        train_eval_step, [])

        with self.graph.as_default():
            (
                self.compile_op,
                self.train_eval_op,
            ) = tpu.split_compile_and_shard(
                train_eval_loop,
                inputs=[],
                num_shards=FLAGS.tpu_num_shards,
                outputs_from_all_shards=False,
            )
            if self.scaffold_fn:
                self.scaffold_fn()
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.local_variables_initializer())

            graph_io.write_graph(self.graph.as_graph_def(add_shapes=True),
                                 FLAGS.output_dir, "graph.pbtxt")

        def create_dequeue_ops(host_id):
            """Create outfeed dequeue ops."""
            dequeue_ops = []
            tensor_dtypes = []
            tensor_shapes = []
            for v in self.outfeed_tensors:
                dequeue_ops.append([])
                tensor_dtypes.append(v.dtype)
                tensor_shapes.append(v.shape)
            for i in range(FLAGS.tpu_num_shards_per_host):
                with tf.device(
                        low_level_utils.device_for_host(
                            self._get_host(host_id))):
                    outfeed = tpu_ops.outfeed_dequeue_tuple(
                        dtypes=tensor_dtypes,
                        shapes=tensor_shapes,
                        device_ordinal=i)
                    for j, item in enumerate(outfeed):
                        dequeue_ops[j].append(item)
            for j in range(len(outfeed)):
                dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
            return dequeue_ops

        with self.output_graph.as_default():
            # Get dequeue ops from each hosts.
            for i in range(self.num_hosts):
                tf.logging.info(
                    "LowLevelRunner: get dequeue ops for host: %d.", i)
                local_batch_size = hparams.batch_size // self.num_hosts
                local_dequeue_ops = []
                for n in range(local_batch_size):
                    local_dequeue_ops.append({})
                for j, dequeue_tensor in enumerate(create_dequeue_ops(i)):
                    if self.outfeed_names[j] in ("inputs", "targets",
                                                 "outputs"):
                        dequeue_tensors = tf.split(dequeue_tensor,
                                                   local_batch_size,
                                                   axis=0)
                        for n in range(local_batch_size):
                            local_dequeue_ops[n][
                                self.outfeed_names[j]] = dequeue_tensors[n]
                for j, dequeue_dict in enumerate(local_dequeue_ops):
                    self.dequeue_ops.append(dequeue_dict)
    def build_eval_model(self, model_fn, params):
        """Build the Eval TPU model and infeed enqueue ops."""
        tf.logging.info("TrainAndEvalLowLevelRunner: build_model method")

        # TODO(wangtao): refactor to extract common logic with tpu_train_step.
        def tpu_eval_step():
            """Generate the TPU graph."""
            values = self.eval_infeed_queue[0].generate_dequeue_op(
                tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.eval_feature_structure, values)
            features = unflattened_inputs["features"]
            estimator_spec = model_fn(features, None,
                                      tf.estimator.ModeKeys.PREDICT, params)
            for k, v in six.iteritems(estimator_spec.predictions):
                self.outfeed_names.append(k)
                self.outfeed_tensors.append(v)

            with tf.device(utils.device_for_tpu_core(self._get_host(0))):
                outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple(
                    self.outfeed_tensors)
            with tf.control_dependencies([outfeed_enqueue_ops]):
                return tf.no_op()

        def eval_loop():
            return training_loop.repeat(self.eval_steps, tpu_eval_step, [])

        def train_eval_step(iteration):
            with tf.control_dependencies(self.train_loop()):
                should_eval = tf.reduce_any(
                    tf.equal(tf.constant(self.eval_iterations), iteration))
                should_eval = tf.logical_or(
                    should_eval,
                    tf.constant(self.params["eval_every_checkpoint"]))
                ops = tf.cond(should_eval, lambda: eval_loop(),
                              lambda: tf.no_op())  # pylint: disable=unnecessary-lambda
                with tf.control_dependencies([ops]):
                    return iteration + 1

        @on_device_train_and_eval_loops
        def train_eval_loop():
            return training_loop.repeat(self.max_train_iterations,
                                        train_eval_step, [0])

        self.eval_epochs = [
            steps * ssd_constants.DEFAULT_BATCH_SIZE /
            FLAGS.train_batch_size // params["steps_per_epoch"]
            for steps in self.eval_at_steps
        ]

        self.log_epochs = dict(
            zip(self.eval_epochs, [False for _ in self.eval_epochs]))

        self.epoch_count = dict(
            zip(self.eval_epochs,
                [self.eval_epochs[0]] + np.diff(self.eval_epochs).tolist()))

        # TODO(wangtao): refactor to extract common logic
        # with train create_dequeu_ops.
        def create_dequeue_ops(host_id):
            """Create outfeed dequeue ops."""
            dequeue_ops = []
            tensor_dtypes = []
            tensor_shapes = []
            for v in self.outfeed_tensors:
                tensor_dtypes.append(v.dtype)
                tensor_shapes.append(v.shape)
            with tf.device(utils.device_for_host(self._get_host(host_id))):
                for i in range(self.replicas_per_worker):
                    if self.use_spatial_partition:
                        replica_id = self.device_assignment.lookup_replicas(
                            host_id, 0)[i]
                        ordinal = self.device_assignment.tpu_ordinal(
                            replica=replica_id, logical_core=0)
                    else:
                        ordinal = i
                    outfeed = tpu_ops.outfeed_dequeue_tuple(
                        dtypes=tensor_dtypes,
                        shapes=tensor_shapes,
                        device_ordinal=ordinal)
                    if len(outfeed) == 2:
                        # 2 outfeed tensors
                        #   is_pad: [batch]
                        #   detections: [batch, 200, 7]
                        if outfeed[0].shape.ndims == 3:
                            detections, is_pad = outfeed
                        else:
                            is_pad, detections = outfeed
                        num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum(
                            tf.cast(is_pad, tf.int32))
                        dequeue_ops.append(
                            tf.slice(detections, [0, 0, 0],
                                     [num_non_pad, -1, -1]))
                    else:
                        # no padding, only detections are in the outfeed
                        dequeue_ops.append(outfeed)
                dequeue_ops = tf.concat(dequeue_ops, axis=0)
            return dequeue_ops

        with self.graph.as_default():
            (
                self.train_eval_compile_op,
                self.train_eval_op,
            ) = tpu.split_compile_and_shard(
                train_eval_loop,
                inputs=[],
                num_shards=self.num_shards,
                outputs_from_all_shards=False,
                device_assignment=self.device_assignment,
            )

            # Get dequeue ops from each hosts.
            for i in range(self.num_hosts):
                self.dequeue_ops.append(create_dequeue_ops(i))
Ejemplo n.º 14
0
    def BuildTpuSubgraph(self):
        tf.logging.info('TrainProgram BuildTpuSubGraph')

        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            # Instantiate input generator first.
            self._input = self._task_params.input.Instantiate()
            self._input.CreateTpuEnqueueOps()
            self.SkipCreateChild(self._task_params)

            def TpuTrainStep(*args):
                """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
                self._model = self._task_params.Instantiate()
                self._task = self._model.GetTask()
                self._task.AddChild('input', self._input)
                self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._task.eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._task.per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._task.train_op]

            @tpu_function.on_device_training_loop
            def TpuTrain():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuTrainStep,
                    inputs=self._eval_metrics.initial_values,
                    name='train_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuTrain,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())
            outfeed_dequeue_op = self._OutfeedDequeueLoop(
                self._task.per_example_tensors, self._steps_per_loop,
                self.num_splits_per_client)

            # Get metric result from a single replica; they are all same here.

            def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op):
                """Returns the op for tpu training with tail cpu computation."""
                # Adds a tail computation that is run after the tpu_training loop
                # step finishes. This allows us to run certain computation that
                # acts on the variable between tpu_train_loop iterations and
                # amortizing the cost of the operations. Alternative of running
                # tpu.outside_compilation & using tf.cond is expenseive.
                with tf.control_dependencies(train_loop_op):
                    self._model.ConstructPostTrainingLoop()
                    with tf.control_dependencies(
                        [self._task.post_training_loop_op]):
                        return ([[tf.identity(o) for o in train_loop_op],
                                 outfeed_dequeue_op])

            # Get metric result from a single replica; they are all same here.
            all_tpu_ops = [t[0] for t in batch_parallel_res]
            self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops,
                                                       outfeed_dequeue_op))

        return self.tpu_ops
  def build_model(self, model_fn, params):
    """Build the TPU model and infeed enqueue ops."""
    tf.logging.info("TrainLowLevelRunner: build_model method")

    def tpu_train_step(loss):
      """Generate the TPU graph."""
      del loss
      values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
      unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure,
                                                      values)
      features = unflattened_inputs["features"]
      core_id = unflattened_inputs["core_id"]
      new_features = {}
      for k in features:
        s = features[k].shape.as_list()
        s = [self.hparams.num_shards, s[0] // self.hparams.num_shards] + s[1:]
        new_features[k] = tf.squeeze(
            tf.gather(
                tf.reshape(tpu_ops.cross_replica_sum(features[k]), s), core_id),
            [0])

      estimator_spec = model_fn(new_features, None, tf.estimator.ModeKeys.TRAIN,
                                params)
      loss, train_op = estimator_spec.loss, estimator_spec.train_op
      with tf.control_dependencies([train_op]):
        return tf.identity(loss)

    @tpu_function.on_device_training_loop
    def train_loop():
      return training_loop.repeat(self.iterations, tpu_train_step,
                                  [_INITIAL_LOSS])

    def tpu_eval_step():
      """Generate the TPU graph."""
      values = self.eval_infeed_queue[0].generate_dequeue_op(tpu_device=0)
      unflattened_inputs = data_nest.pack_sequence_as(
          self.eval_feature_structure, values)
      features = unflattened_inputs["features"]
      estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT,
                                params)
      for k, v in six.iteritems(estimator_spec.predictions):
        self.outfeed_names.append(k)
        self.outfeed_tensors.append(v)

      with tf.device(
          device_for_tpu_core(get_host(self.resolver, self.hparams))):
        outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple(
            self.outfeed_tensors)
      with tf.control_dependencies([outfeed_enqueue_ops]):
        return tf.no_op()

    @tpu_function.on_device_training_loop
    def eval_loop():
      if self.eval_steps > 0:
        return training_loop.repeat(self.eval_steps, tpu_eval_step, [])
      else:
        return tf.no_op()

    def train_eval_step():
      with tf.control_dependencies(train_loop()):
        return eval_loop()

    def train_eval_loop():
      return training_loop.repeat(self.hparams.max_train_epochs,
                                  train_eval_step, [])

    def create_dequeue_ops(host_id):
      """Create outfeed dequeue ops."""
      dequeue_ops = []
      tensor_dtypes = []
      tensor_shapes = []
      for v in self.outfeed_tensors:
        dequeue_ops.append([])
        tensor_dtypes.append(v.dtype)
        tensor_shapes.append(v.shape)
      for i in range(self.hparams.num_shards_per_host):
        with tf.device(
            device_for_host(get_host(self.resolver, self.hparams, host_id))):
          outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
              dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i)
          for j, item in enumerate(outfeed_tensors):
            dequeue_ops[j].append(item)
      for j in range(len(outfeed_tensors)):
        dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
      return dequeue_ops

    with self.graph.as_default():
      if self.eval_steps <= 0:
        (self.loss,) = tpu.shard(
            train_loop,
            inputs=[],
            num_shards=self.hparams.num_shards,
            outputs_from_all_shards=False,
        )

      else:
        (
            self.compile_op,
            self.train_eval_op,
        ) = tpu.split_compile_and_shard(
            train_eval_loop,
            inputs=[],
            num_shards=self.hparams.num_shards,
            outputs_from_all_shards=False)

      if self.eval_steps > 0:
        for i in range(0, self.num_hosts):
          self.dequeue_ops.append({})
          host_dequeue_ops = create_dequeue_ops(i)
          for j, dequeue_tenor in enumerate(host_dequeue_ops):
            self.dequeue_ops[i][self.outfeed_names[j]] = dequeue_tenor

      global_initializer = tf.global_variables_initializer()
      local_initializer = tf.local_variables_initializer()
      self.sess.run(global_initializer)
      self.sess.run(local_initializer)

      graph_io.write_graph(
          self.graph.as_graph_def(add_shapes=True), self.hparams.out_dir,
          "graph.pbtxt")
      self.saver = tf.train.Saver()