Beispiel #1
0
  def _CompileDecodeFn(self):
    """Wrap the DecodeFn with split_compile_and_shard."""
    with cluster_factory.SetImmediatelyInstantiateVariables(False):
      self._model = self._InstantiateTaskModel(self._task_params)
    self._task = self._model.GetTask()
    self._task.input.InstantiateVariables()
    self._task.input.CreateTpuEnqueueOps()
    self._task.input.CreateCpuPassthroughEnqueueOps()

    def _DecodeFn():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model.InstantiateVariables()
        input_batch = self._task.input.TpuDequeueBatch()
        decode_dict = self._task.Decode(input_batch)
      self.decode_nm = py_utils.NestedMap(decode_dict)
      return self.decode_nm.Flatten()

    self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self.cpu_pt = self._task.input.DequeueCpuPassthrough()
    self.decode_tensors = py_utils.NestedMap(self.decode_nm)
    self.decode_tensors = self.decode_tensors.Pack(batch_parallel_res)
Beispiel #2
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()

        # Instantiate input generator first.
        self._input = self._task_params.input.Instantiate()
        self._input.CreateTpuEnqueueOps()
        self.SkipCreateChild(self._task_params)

        def _DecodeFn():
            """Decode call to be compiled for TPU."""
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._model = self._task_params.Instantiate()
                    self._task = self._model.GetTask()
                    self._task.AddChild('input', self._input)
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                    self.metrics_nm = py_utils.NestedMap(metrics_dict)
                    return self.metrics_nm.Flatten()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            _DecodeFn,
            num_shards=self.data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Beispiel #3
0
        def LoopBody(i, *input_arrays):
            """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
            # Outfeed ops execute on each JF node, so they must be located on the
            # nodes.
            outfeed_devices = []
            device_assignment = py_utils.GetTpuDeviceAssignment()
            assert device_assignment
            for replica in range(device_assignment.num_replicas):
                for core in range(device_assignment.num_cores_per_replica):
                    with tf.device(device_assignment.host_device(
                            replica, core)):
                        outfeed_devices.append(
                            tpu_ops.outfeed_dequeue_tuple(
                                tensor_types,
                                tensor_shapes,
                                device_ordinal=device_assignment.tpu_ordinal(
                                    replica, core)))
            offset = i * num_devices
            output_arrays = list(input_arrays)
            # Each output_array holds a different per-example tensor. We get results
            # for each tensor from each TPU for each TpuTrainStep call.
            for j in range(len(output_arrays)):
                for k in range(len(outfeed_devices)):
                    output_arrays[j] = output_arrays[j].write(
                        offset + k, outfeed_devices[k][j])

            return tuple([i + 1] + output_arrays)
Beispiel #4
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()

        def _DecodeFn():
            """Decode call to be compiled for TPU."""
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._model = self._task_params.Instantiate()
                    self._model_task = self._model.GetTask()
                    if py_utils.use_tpu():
                        input_batch = self._model_task.input_generator.CreateTpuFeeds(
                        )
                    else:
                        input_batch = self._model_task.input_generator.SplitInputBatch(
                            self.cluster.num_splits_per_client)
                    metrics_dict = self._model_task.Decode(input_batch)
                    self.metrics_nm = py_utils.NestedMap(metrics_dict)
                    return self.metrics_nm.Flatten()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            _DecodeFn,
            num_shards=self.data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Beispiel #5
0
  def _CompileDecodeLoop(self):
    """Wrap the DecodeLoop with split_compile_and_shard."""
    device_assignment = py_utils.GetTpuDeviceAssignment()
    with cluster_factory.SetImmediatelyInstantiateVariables(False):
      self._model = self._InstantiateTaskModel(self._task_params)
    self._task = self._model.GetTask()
    self._task.input.InstantiateVariables()
    self._task.input.CreateTpuEnqueueOps()
    self._task.input.CreateCpuPassthroughEnqueueOps()

    def _DecodeStep():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model.InstantiateVariables()
        input_batch = self._task.input.TpuDequeueBatch()
        decode_dict = self._task.Decode(input_batch)
      self.decode_nm = py_utils.NestedMap(decode_dict)
      return [self._OutfeedEnqueue(decode_dict)]

    @tpu_function.on_device_training_loop
    def DecodeLoopFn():
      return tpu_training_loop.repeat(
          self._steps_per_loop, _DecodeStep, inputs=[])

    self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
        DecodeLoopFn,
        num_shards=self.data_parallelism,
        device_assignment=device_assignment)

    # Get a list of outfeed ops.
    self.decode_tensors = self._OutfeedDequeue()
    # Pack the list of outfeed ops with structure in self.decode_nm.
    self.decode_tensors = tf.nest.pack_sequence_as(self.decode_nm,
                                                   self.decode_tensors)
    self.cpu_pt = self._task.input.DequeueCpuPassthrough()
Beispiel #6
0
  def BuildTpuSubgraph(self):
    tf.logging.info('EvalProgram BuildTpuSubGraph')
    with cluster_factory.SetEval(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism
      with cluster_factory.SetImmediatelyInstantiateVariables(False):
        self._model = self._InstantiateTaskModel(self._task_params)
      self._task = self._model.GetTask()
      self._task.input.InstantiateVariables()
      self._task.input.CreateTpuEnqueueOps()
      self._init_input_ops = self._task.input.InitOps()

      # XLA thinks self.TpuEvalLoop() requires 1 argument due to self
      # Trick it with wrapper function
      def TpuEvalLoopWrapper():
        return self.TpuEvalLoop()

      self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
          TpuEvalLoopWrapper,
          num_shards=data_parallelism,
          device_assignment=py_utils.GetTpuDeviceAssignment())
      self._task.input.CreateTpuEmbeddingEnqueueOps(mode_override='inference')

      # Get metric result from a single replica; they are all same here.
      self.tpu_ops = [[t[0] for t in batch_parallel_res]]
      return self.tpu_ops
Beispiel #7
0
    def BuildTpuSubgraph(self):
        tf.logging.info('TrainProgram BuildTpuSubGraph')

        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            # Instantiate input generator first.
            self._input = self._task_params.input.Instantiate()
            self._input.CreateTpuEnqueueOps()
            self.SkipCreateChild(self._task_params)

            def TpuTrainStep(*args):
                """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
                self._model = self._task_params.Instantiate()
                self._task = self._model.GetTask()
                self._task.AddChild('input', self._input)
                self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._task.eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._task.per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._task.train_op]

            @tpu_function.on_device_training_loop
            def TpuTrain():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuTrainStep,
                    inputs=self._eval_metrics.initial_values,
                    name='train_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuTrain,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())
            outfeed_dequeue_op = self._OutfeedDequeueLoop(
                self._task.per_example_tensors, self._steps_per_loop,
                self.num_splits_per_client)
            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res],
                            outfeed_dequeue_op]

        return self.tpu_ops
 def TPUOrdinalFunction(shard_index_in_host):
   device_assignment = py_utils.GetTpuDeviceAssignment()
   if device_assignment:
     # We put both enqueue/dequeue ops at core 0 in each replica.
     replica = device_assignment.lookup_replicas(
         task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
     return device_assignment.tpu_ordinal(replica=replica)
   else:
     return shard_index_in_host
Beispiel #9
0
    def BuildTpuSubgraph(self):
        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            def TpuTrainStep(*args):
                """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
                self._model = self._task_params.Instantiate()
                self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._model.GetTask().eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._model.GetTask().per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._model.GetTask().train_op]

            @tpu_function.on_device_training_loop
            def TpuTrain():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuTrainStep,
                    inputs=self._eval_metrics.initial_values,
                    name='train_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            batch_parallel_res = tf.tpu.batch_parallel(
                TpuTrain,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())
            outfeed_dequeue_op = self._OutfeedDequeueLoop(
                self._model.GetTask().per_example_tensors,
                self._steps_per_loop, self.num_splits_per_client)
            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res],
                            outfeed_dequeue_op]

            # TODO(blee): This is going to need to be fixed for multiple-model
            # execution. Need to get only the vars associated with the model.
            self._checkpointer = self._CreateCheckpointer(
                self._checkpoint_dir, self._model)
        return self.tpu_ops
Beispiel #10
0
    def BuildTpuSubgraph(self):
        tf.logging.info('EvalProgram BuildTpuSubGraph')
        with cluster_factory.SetEval(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._InstantiateTaskModel(self._task_params)
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Summed eval metrics.
        """
                with tf.name_scope('tpu_eval'):
                    with py_utils.OpportunisticVariableReuseScope(True):
                        self._model.InstantiateVariables()
                        self._model.ConstructFPropGraph()
                    per_step_eval_metrics = self._eval_metrics.SetMetrics(
                        self._task.eval_metrics, args)
                    summed_metrics = []
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                    return summed_metrics

            @tpu_function.on_device_training_loop
            def TpuEval():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuEvalStep,
                    inputs=self._eval_metrics.initial_values,
                    name='eval_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuEval,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())

            self._task.input.CreateTpuEmbeddingEnqueueOps(
                mode_override='inference')

            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res]]

            return self.tpu_ops
Beispiel #11
0
    def BuildTpuSubgraph(self):
        tf.logging.info('EvalProgram BuildTpuSubGraph')
        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            def TpuEvalStep(*args):
                """Eval a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          Per-step eval metrics.
        """
                self._model = self._task_params.Instantiate()
                self._model.ConstructFPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._model.GetTask().eval_metrics, args)
                return per_step_eval_metrics

            @tpu_function.on_device_training_loop
            def TpuEval():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuEvalStep,
                    inputs=self._eval_metrics.initial_values,
                    name='eval_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            batch_parallel_res = tf.tpu.batch_parallel(
                TpuEval,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())
            # Get metric result from a single replica; they are all same here.
            self.tpu_ops = [[t[0] for t in batch_parallel_res]]
            self._checkpointer = checkpointer.Checkpointer(
                self._checkpoint_dir, self._model)

            return self.tpu_ops
Beispiel #12
0
 def _OutfeedDequeue(self):
     """Collect outfeed dequeue from all devices."""
     num_outfeeds = len(self.metrics_nm.Flatten())
     outfeed_ops = [[]] * num_outfeeds
     device_assignment = py_utils.GetTpuDeviceAssignment()
     assert device_assignment
     for replica in range(device_assignment.num_replicas):
         num_cores_per_replica = 1 if self.spmd else (
             device_assignment.num_cores_per_replica)
         for core in range(num_cores_per_replica):
             with tf.device(device_assignment.host_device(replica, core)):
                 outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
                     dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
                     shapes=[x.shape for x in self.metrics_nm.Flatten()],
                     device_ordinal=device_assignment.tpu_ordinal(
                         replica, core))
                 for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
                     outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [
                         out_feed
                     ]
     return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
Beispiel #13
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()
        device_assignment = py_utils.GetTpuDeviceAssignment()
        self.spmd = self._task_params.input.use_partitioned_infeed_queue
        with cluster_factory.SetEval(True):
            with cluster_factory.SetImmediatelyInstantiateVariables(False):
                self._model = self._task_params.Instantiate()
            self._task = self._model.GetTask()
            self._task.input.InstantiateVariables()
            self._task.input.CreateTpuEnqueueOps()

            def _DecodeStep():
                """Decode call to be compiled for TPU."""
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    input_batch = self._task.input.TpuDequeueBatch()
                    metrics_dict = self._task.Decode(input_batch)
                self.metrics_nm = py_utils.NestedMap(metrics_dict)
                device = tpu.core(0) if self.spmd else ''
                with tf.device(device):
                    outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
                        self.metrics_nm.Flatten())
                    return [outfeed_enqueue]

        @tpu_function.on_device_training_loop
        def DecodeLoopFn():
            return tpu_training_loop.repeat(self._steps_per_loop,
                                            _DecodeStep,
                                            inputs=[])

        self._compile_op, self.decode_loop = tpu.split_compile_and_shard(
            DecodeLoopFn,
            num_shards=self.data_parallelism,
            device_assignment=device_assignment)
        # Get a list of outfeed ops.
        self.metrics = self._OutfeedDequeue()
        # Pack the list of outfeed ops with structure in self.metrics_nm.
        self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics)
        return
Beispiel #14
0
    def BuildTpuSubgraph(self):
        tf.logging.info('DecodeProgram BuildTpuSubGraph')
        py_utils.ResetStepSeed()

        def _DecodeFn():
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._model = self._task_params.Instantiate()
                    self._model_task = self._model.GetTask()
                    input_batch = self._model_task.GetInputBatch()
                    metrics_dict = self._model_task.Decode(input_batch)
                    self.metrics_nm = py_utils.NestedMap(metrics_dict)
                    return self.metrics_nm.Flatten()

        batch_parallel_res = tf.tpu.batch_parallel(
            _DecodeFn,
            num_shards=self.data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Beispiel #15
0
  def BuildTpuSubgraph(self):
    py_utils.ResetStepSeed()

    def _DecodeFn():
      with py_utils.OpportunisticVariableReuseScope(True):
        self._model = self._task_params.Instantiate()
        self._model_task = self._model.GetTask()
        input_batch = self._model_task.GetInputBatch()
        metrics_dict = self._model_task.Decode(input_batch)
        self.metrics_nm = py_utils.NestedMap(metrics_dict)
        return self.metrics_nm.Flatten()

    batch_parallel_res = tf.tpu.batch_parallel(
        _DecodeFn,
        num_shards=self.data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                   self._model)

    self.metrics = py_utils.NestedMap(self.metrics_nm)
    self.metrics = self.metrics.Pack(batch_parallel_res)
    return None
Beispiel #16
0
  def _OutfeedDequeue(self):
    """Collect outfeed dequeue from all devices.

    Returns:
      A list of tensors corresponding to stacked decoded outputs. The decoder
      outputs are stacked on the first dimension (usually corresponds to
      batch size).
    """
    num_decode_tensors = len(self.decode_nm.Flatten())
    outfeed_ops = [[]] * num_decode_tensors
    device_assignment = py_utils.GetTpuDeviceAssignment()
    assert device_assignment
    num_cores_per_replica = (1 if self.spmd else
                             (device_assignment.num_cores_per_replica))
    for replica in range(device_assignment.num_replicas):
      for core in range(num_cores_per_replica):
        with tf.device(device_assignment.host_device(replica, core)):
          outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
              dtypes=[x.dtype for x in self.decode_nm.Flatten()],
              shapes=[x.shape for x in self.decode_nm.Flatten()],
              device_ordinal=device_assignment.tpu_ordinal(replica, core))
          for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
            outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed]
    return [tf.concat(per_outfeed, axis=0) for per_outfeed in outfeed_ops]
Beispiel #17
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Beispiel #18
0
    def BuildTpuSubgraph(self):
        tf.logging.info('TrainProgram BuildTpuSubGraph')
        self.spmd = (self.params.spmd
                     or self._task_params.input.use_partitioned_infeed_queue)

        self._eval_metrics = metrics.TpuEvalMetrics()
        data_parallelism = self.data_parallelism

        with cluster_factory.SetImmediatelyInstantiateVariables(False):
            self._model = self._InstantiateTaskModel(self._task_params)
        self._task = self._model.GetTask()
        self._task.input.InstantiateVariables()
        self._task.input.CreateTpuEnqueueOps()

        def TpuTrainStep(*args):
            """Train a shard of a batch on a single TPU core.

      Args:
        *args: metrics values from previous steps.

      Returns:
        New summed metrics values and a train_op.
      """
            with tf.name_scope('tpu_train'):
                with py_utils.OpportunisticVariableReuseScope(True):
                    self._model.InstantiateVariables()
                    self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._task.eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._task.per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._task.train_op]

        @tpu_function.on_device_training_loop
        def TpuTrain():
            loop_result = tpu_training_loop.repeat(
                self._steps_per_loop,
                TpuTrainStep,
                inputs=self._eval_metrics.initial_values,
                name='train_loop')
            # Final metrics are the avg across self._steps_per_loop steps.
            return self._eval_metrics.FinalizeMetrics(loop_result)

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            TpuTrain,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())
        outfeed_dequeue_op = self._OutfeedDequeueLoop(
            self._task.per_example_tensors, self._steps_per_loop,
            self.num_splits_per_client)

        self._task.input.CreateTpuEmbeddingEnqueueOps()

        # Get metric result from a single replica; they are all same here.

        def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op):
            """Returns the op for tpu training with tail cpu computation."""
            # Adds a tail computation that is run after the tpu_training loop
            # step finishes. This allows us to run certain computation that
            # acts on the variable between tpu_train_loop iterations and
            # amortizing the cost of the operations. Alternative of running
            # tpu.outside_compilation & using tf.cond is expenseive.
            with tf.control_dependencies(train_loop_op):
                self._model.ConstructPostTrainingLoop()
                with tf.control_dependencies(
                    [self._task.post_training_loop_op]):
                    return ([[tf.identity(o) for o in train_loop_op],
                             outfeed_dequeue_op])

        # Get metric result from a single replica; they are all same here.
        all_tpu_ops = [t[0] for t in batch_parallel_res]
        self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops,
                                                   outfeed_dequeue_op))
        self._model_analysis, self._total_num_params = summary_utils.ModelAnalysis(
            self._model)
        try:
            with tf.io.gfile.GFile(
                    os.path.join(self._program_dir, 'model_analysis.txt'),
                    'w') as f:
                f.write(self._model_analysis)
        except tf.errors.NotFoundError as e:
            tf.logging.info('Failed to write model analysis %s', e)

        return self.tpu_ops
Beispiel #19
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            first_batch = None
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_embedding_input_keys = (
                tpu_embedding.feature_to_config_dict.keys()
                if tpu_embedding is not None else [])

            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    tpu_embedding_features = []
                    for tpu_embedding_input_key in tpu_embedding_input_keys:
                        tpu_embedding_feature = batch.pop(
                            tpu_embedding_input_key)
                        tpu_embedding_features.append(
                            (tpu_embedding_input_key, tpu_embedding_feature))

                    if first_batch is None:
                        first_batch = batch
                    flat_batch = batch.FlattenItems()

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {}
                        ] * tpu_embedding.num_cores_per_host
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features:
                            tpu_embedding_feature_splitted = tf.split(
                                tpu_embedding_feature, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_embedding_feature_splitted):
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    tf.squeeze(split, axis=[1]))
                                enqueue_dict_per_core[core][
                                    tpu_embedding_input_key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    shapes, types = [], []
                    for k, x in flat_batch:
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                        shapes.append(x.shape)
                        types.append(x.dtype)
                    q = tf.contrib.tpu.InfeedQueue(tuple_types=types,
                                                   tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def _tpu_ordinal_function(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=_tpu_ordinal_function)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        with tf.device(tf.compat.v1.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return first_batch.Pack(tensors)
Beispiel #20
0
  def CreateTpuFeeds(self):
    """Creates the TPU infeed queue from preprocessed batch."""
    p = self.params
    cluster = cluster_factory.Current()
    num_tpu_hosts = cluster.num_tpu_hosts
    assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
    num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

    with py_utils.outside_all_rewrites():
      assert py_utils.use_tpu()
      assert not self._made_tpu_infeed

      shards = tpu_function.get_tpu_context(
      ).number_of_shards // num_infeed_hosts
      input_ops_list = []
      queues = []
      first_batch = None
      for task_id in range(num_infeed_hosts):
        host_device = '/task:{}/device:CPU:0'.format(task_id)
        with tf.device(host_device):
          batch = self.GetPreprocessedInputBatch()
          if first_batch is None:
            first_batch = batch
          flat_batch = batch.FlattenItems()

          shapes, types = [], []
          for k, x in flat_batch:
            assert x.shape.is_fully_defined(), (
                'Shape must be fully defined: %s: %s' % (k, x))
            # TODO(cwhipkey): if it's a string (or other type not supported on
            # TPU), drop it from feeding and on the other end add in an op that
            # fails if used.
            shapes.append(x.shape)
            types.append(x.dtype)
          q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes)
          queues.append(q)
          assert shards is not None
          q.set_number_of_shards(shards)

          if p.use_per_host_infeed:

            # TODO(ylc/zhifengc): Add this to a policy module and test it.
            def _tpu_ordinal_function(shard_index_in_host):
              device_assignment = py_utils.GetTpuDeviceAssignment()
              if device_assignment:
                # We put both enqueue/dequeue ops at core 0 in each replica.
                replica = device_assignment.lookup_replicas(
                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                return device_assignment.tpu_ordinal(replica=replica)
              else:
                return shard_index_in_host

            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                tpu_ordinal_function=_tpu_ordinal_function)
          else:
            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                device_assignment=py_utils.GetTpuDeviceAssignment())

          input_ops_list += input_ops
      tf.logging.info('input_ops_list %s', input_ops_list)
      tpu_infeed_op = tf.group(*input_ops_list)
    self._made_tpu_infeed = True
    # Let trainer.py use multiple threads to drive the infeed op.
    for _ in range(p.tpu_infeed_parallism):
      tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

    with tf.device(tf.contrib.tpu.core(0)):
      tensors = queues[0].generate_dequeue_op()
    return first_batch.Pack(tensors)
Beispiel #21
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Beispiel #22
0
    def BuildTpuSubgraph(self):
        tf.logging.info('TrainProgram BuildTpuSubGraph')

        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            # Instantiate input generator first.
            self._input = self._task_params.input.Instantiate()
            self._input.CreateTpuEnqueueOps()
            self.SkipCreateChild(self._task_params)

            def TpuTrainStep(*args):
                """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
                self._model = self._task_params.Instantiate()
                self._task = self._model.GetTask()
                self._task.AddChild('input', self._input)
                self._model.ConstructFPropBPropGraph()
                per_step_eval_metrics = self._eval_metrics.SetMetrics(
                    self._task.eval_metrics, args)
                outfeed_op = self._OutfeedEnqueue(
                    self._task.per_example_tensors)
                summed_metrics = []
                assert len(per_step_eval_metrics) == len(args)
                with tf.control_dependencies([outfeed_op]):
                    for x, y in zip(per_step_eval_metrics, args):
                        summed_metrics.append(x + y)
                return summed_metrics + [self._task.train_op]

            @tpu_function.on_device_training_loop
            def TpuTrain():
                loop_result = tpu_training_loop.repeat(
                    self._steps_per_loop,
                    TpuTrainStep,
                    inputs=self._eval_metrics.initial_values,
                    name='train_loop')
                # Final metrics are the avg across self._steps_per_loop steps.
                return self._eval_metrics.FinalizeMetrics(loop_result)

            self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
                TpuTrain,
                num_shards=data_parallelism,
                device_assignment=py_utils.GetTpuDeviceAssignment())
            outfeed_dequeue_op = self._OutfeedDequeueLoop(
                self._task.per_example_tensors, self._steps_per_loop,
                self.num_splits_per_client)

            # Get metric result from a single replica; they are all same here.

            def _ConstructPostTrainingLoop(train_loop_op, outfeed_dequeue_op):
                """Returns the op for tpu training with tail cpu computation."""
                # Adds a tail computation that is run after the tpu_training loop
                # step finishes. This allows us to run certain computation that
                # acts on the variable between tpu_train_loop iterations and
                # amortizing the cost of the operations. Alternative of running
                # tpu.outside_compilation & using tf.cond is expenseive.
                with tf.control_dependencies(train_loop_op):
                    self._model.ConstructPostTrainingLoop()
                    with tf.control_dependencies(
                        [self._task.post_training_loop_op]):
                        return ([[tf.identity(o) for o in train_loop_op],
                                 outfeed_dequeue_op])

            # Get metric result from a single replica; they are all same here.
            all_tpu_ops = [t[0] for t in batch_parallel_res]
            self.tpu_ops = (_ConstructPostTrainingLoop(all_tpu_ops,
                                                       outfeed_dequeue_op))

        return self.tpu_ops
Beispiel #23
0
  def __init__(self, *args, **kwargs):
    super(TrainerTpu, self).__init__(*args, **kwargs)

    # Multiple TPU trainer tasks not tested/implemented.
    assert self._cluster.num_replicas == 1
    data_parallelism = self._cluster.num_splits_per_client
    assert data_parallelism
    num_devices_per_split = self._cluster.num_devices_per_split
    tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                    data_parallelism, num_devices_per_split)

    def ComputationShape(split_size):
      """Decides the computation shape based on the split_size."""
      computation_shape = None
      if split_size == 1:
        computation_shape = [1, 1, 1]
      elif split_size == 2:
        computation_shape = [1, 1, 2]
      elif split_size == 4:
        computation_shape = [1, 2, 2]
      elif split_size == 8:
        computation_shape = [2, 2, 2]
      elif split_size == 16:
        computation_shape = [4, 2, 2]
      else:
        assert False, ('Model parallelism with %d devices is currently not'
                       ' supported.' % split_size)
      assert computation_shape is not None
      return computation_shape

    self._steps_per_loop = min(self.params.train.tpu_steps_per_loop,
                               self.params.train.max_steps)

    tf.logging.info(
        'Creating TrainerTpu using data parallelism %s '
        'and %s steps_per_loop', data_parallelism, self._steps_per_loop)

    @py_utils.RetryOnTransientTfError()
    def _WaitTillInit():
      """Wait until the model is ready."""
      try:
        with self._GetSession() as sess:
          topology = sess.run(
              tf.contrib.tpu.initialize_system(embedding_config=None, job=None))
          device_assignment = tf.contrib.tpu.device_assignment(
              topology,
              computation_shape=ComputationShape(num_devices_per_split),
              num_replicas=data_parallelism)
          py_utils.SetTpuDeviceAssignment(device_assignment)
          tf.logging.info('device_assignment.core_assignment: %s',
                          str(device_assignment.core_assignment))
          tf.logging.info('device_assignment.topology.device_coordinates: %s',
                          str(device_assignment.topology.device_coordinates))
      except py_utils.transient_tf_errors as e:
        tf.logging.info('TPU initialization failed: %s', e)
        raise

    _WaitTillInit()

    with self._graph.as_default(), tf.container(self._container_id):
      with self._cluster, tf.device(self._cluster.job_spec.name):
        self._eval_metrics = metrics.TpuEvalMetrics()

        def TpuTrainStep(*args):
          self._model = self.params.cls(self.params)
          self._model.ConstructFPropBPropGraph()
          per_step_eval_metrics = self._eval_metrics.SetMetrics(
              self._model.GetTask().eval_metrics, args)
          summed_metrics = []
          assert len(per_step_eval_metrics) == len(args)
          for x, y in zip(per_step_eval_metrics, args):
            summed_metrics.append(x + y)
          return summed_metrics + [self._model.GetTask().train_op]

        def TpuTrain():
          loop_result = tf.contrib.tpu.repeat(
              self._steps_per_loop,
              TpuTrainStep,
              inputs=self._eval_metrics.initial_values,
              name='train_loop')
          # Final metrics are the avg across self._steps_per_loop steps.
          return self._eval_metrics.FinalizeMetrics(loop_result)

        batch_parallel_res = tf.contrib.tpu.batch_parallel(
            TpuTrain,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())
        # Get metric result from a single replica; they are all same here.
        self._tpu_train_ops = [t[0] for t in batch_parallel_res]

      self.initialize_tables = tf.tables_initializer()
      self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
      assert not tf.get_collection(py_utils.CLOSE_QUEUE_OPS)
      tf.logging.info('Trainer number of enqueue ops: %d',
                      len(self.enqueue_ops))

    self._summary_writer = self._CreateSummaryWriter(self._train_dir)

    # Saves the graph def.
    tf.train.write_graph(self._graph.as_graph_def(), self._train_dir,
                         'train.pbtxt')
    def CreateTpuEnqueueOps(self):
        """Create the host-side enqueue ops.

    This should be called in an outer non-TPU context.
    """
        assert not self._tpu_queues, (
            'CreateTpuEnqueueOps should only be called '
            'once.')
        self._tpu_queues = []
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTpuEnqueueOps num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        shards = (cluster.total_worker_devices //
                  num_infeed_hosts) // cluster.num_devices_per_split
        tf.logging.info('shards {}'.format(shards))

        input_ops_list = []
        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')

        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                self._batch = self.GetPreprocessedInputBatch()
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                for k, x in self._batch.FlattenItems():
                    assert x.shape.is_fully_defined(), (
                        'Shape must be fully defined: %s: %s' % (k, x))
                    # TODO(cwhipkey): if it's a string (or other type not supported on
                    # TPU), drop it from feeding and on the other end add in an op that
                    # fails if used.
                shapes = self._batch.Transform(lambda x: x.shape).Flatten()
                dtypes = self._batch.Transform(lambda x: x.dtype).Flatten()

                tf.logging.info('host_device: %s infeed shapes: %r',
                                host_device, shapes)
                tf.logging.info('host_device: %s infeed dtypes: %r',
                                host_device, dtypes)

                if p.use_partitioned_infeed_queue:
                    device_assignment = py_utils.GetTpuDeviceAssignment()

                    host_device = device_assignment.host_device(
                        replica=0, job=tf.flags.FLAGS.tf_master)
                    host_id = int(
                        host_device.split('/task:')[1].split('/device:')[0])
                    tf.logging.info('host_id: {} host_device: {}'.format(
                        host_id, host_device))
                    q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                        number_of_tuple_elements=len(dtypes),
                        device_assignment=device_assignment,
                        host_id=host_id,
                        input_partition_dims=[[p.num_partitions] + [1] *
                                              (len(s) - 1) for s in shapes],
                        tuple_types=dtypes,
                        tuple_shapes=shapes)
                else:
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                self._tpu_queues.append(q)

                if p.use_partitioned_infeed_queue:
                    input_ops = q.generate_enqueue_ops([self._batch.Flatten()])
                elif p.use_per_host_infeed:
                    # TODO(ylc/zhifengc): Add this to a policy module and test it.
                    def TPUOrdinalFunction(shard_index_in_host):
                        device_assignment = py_utils.GetTpuDeviceAssignment()
                        if device_assignment:
                            # We put both enqueue/dequeue ops at core 0 in each replica.
                            replica = device_assignment.lookup_replicas(
                                task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                            return device_assignment.tpu_ordinal(
                                replica=replica)
                        else:
                            return shard_index_in_host

                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                        tpu_ordinal_function=TPUOrdinalFunction)
                else:
                    input_ops = q.split_inputs_and_generate_enqueue_ops(
                        self._batch.Flatten(),
                        device_assignment=py_utils.GetTpuDeviceAssignment())
                input_ops_list += input_ops

        tf.logging.info('input_ops_list %s', input_ops_list)
        grouped_infeed_op = tf.group(*input_ops_list)
        self._tpu_infeed_op = []
        for _ in range(p.tpu_infeed_parallelism):
            self._tpu_infeed_op.append(grouped_infeed_op)
Beispiel #25
0
    def BuildTpuSubgraph(self):
        if self._ml_perf_log:
            mlp_log.mlperf_print('global_batch_size',
                                 self._ml_perf.global_batch_size)
            mlp_log.mlperf_print('max_sequence_length',
                                 self._ml_perf.max_sequence_length)
            mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
            mlp_log.mlperf_print('opt_base_learning_rate',
                                 self._ml_perf.base_learning_rate)
            mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                                 self._ml_perf.warmup_steps)

        with py_utils.OpportunisticVariableReuseScope(True):
            self._eval_metrics = metrics.TpuEvalMetrics()
            data_parallelism = self.data_parallelism

            def TpuTrainStep():
                """Train a shard of a batch on a single TPU core.

        Do not calculate loss metrics.

        Returns:
         [train_op].
        """
                self._train_model = self._train_task_params.Instantiate()
                self._model = self._train_model
                self._train_model.ConstructFPropBPropGraph()
                return [self._train_model.GetTask().train_op]

            def TpuTrain():
                loop_result = tpu_training_loop.repeat(
                    self._train_steps_per_loop,
                    TpuTrainStep,
                    inputs=[],
                    name='train_loop')
                return loop_result

        py_utils.ResetStepSeed()

        def _DecodeFn():
            """Decode call to be compiled for TPU."""
            with py_utils.OpportunisticVariableReuseScope(True):
                with cluster_factory.SetEval(True):
                    self._decode_model = self._decode_task_params.Instantiate()
                    self._decode_model_task = self._decode_model.GetTask()
                    if py_utils.use_tpu():
                        input_batch = self._decode_model_task.input_generator.CreateTpuFeeds(
                        )
                    else:
                        input_batch = self._decode_model_task.input_generator.SplitInputBatch(
                            self.cluster.num_splits_per_client)
                    metrics_dict = self._decode_model_task.Decode(input_batch)
                    self.metrics_nm = py_utils.NestedMap(metrics_dict)
                    return self.metrics_nm.Flatten()

        @tpu_function.on_device_training_loop
        def TrainAndDecode():
            with tf.control_dependencies([TpuTrain()]):
                return _DecodeFn()

        self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
            TrainAndDecode,
            num_shards=data_parallelism,
            device_assignment=py_utils.GetTpuDeviceAssignment())

        self.metrics = py_utils.NestedMap(self.metrics_nm)
        self.metrics = self.metrics.Pack(batch_parallel_res)
        return None
Beispiel #26
0
    def init_graph(self, model_params):
        """Builds moe decode graph.

    Args:
      model_params: the hyperparams of the specified model.
    """
        assert self.graph
        self.model_params = model_params
        batch_size = model_params.task.batch_size
        if (hasattr(model_params.task.builder, 'device_mesh_shape')
                and model_params.task.builder.device_mesh_shape):
            num_partitions = np.prod(
                model_params.task.builder.device_mesh_shape)
        else:
            num_partitions = model_params.task.builder.num_devices

        device_order_mode = (model_params.task.train.tpu_device_order_mode
                             or tpu_device_assignment.DeviceOrderMode.AUTO)
        self._init_tpu(num_partitions, device_order_mode)
        assert self.cluster_params  # configured by init_tpu
        self.cluster = self.cluster_params.Instantiate()

        with self.graph.as_default(), self.cluster, tf.device(
                self.cluster.GetPlacer()):
            _ = py_utils.GetOrCreateGlobalStepVar()
            self.heartbeat = tf.constant(np.pi)

            device_assignment = py_utils.GetTpuDeviceAssignment()

            tf.logging.info('Instantiating model')
            model = model_params.Instantiate()
            xformer = model.GetTask()
            self.task = xformer

            self.init_vars_op = tf.global_variables_initializer()
            self.saver = tf.train.Saver(sharded=True,
                                        reshape=self._saver_reshape)

            infeed = self._config_infeed(num_partitions=num_partitions,
                                         device_assignment=device_assignment,
                                         batch_size=batch_size)

            self.outfeed = []

            def decode_fn(*infeed_batch):  # pylint: disable=missing-docstring
                # Length 6 is passed when there is no tgt_mask (e.g. decoding) and
                # length 7 is passed when there is a tgt_mask (e.g. fprop).

                self.outfeed = self._config_outfeed(xformer, infeed_batch)

                with tf.device(tf.tpu.core(0)):
                    outfeed_op = tpu_ops.outfeed_enqueue_tuple(
                        tf.nest.flatten(self.outfeed))

                return [outfeed_op]

            @tpu_function.on_device_training_loop
            def decode_loop_fn():
                if not self.num_batches:
                    infinite_repeat(decode_fn, infeed)
                else:
                    training_loop.repeat(self.num_batches,
                                         decode_fn,
                                         infeed_queue=infeed)

            self.compile_op, self.decode_loop = tpu_lib.split_compile_and_shard(
                decode_loop_fn,
                num_shards=1,
                device_assignment=device_assignment)

            assert self.outfeed
            with tf.device(device_assignment.tpu_device(0, 0)):
                self.outfeed_op = tpu_ops.outfeed_dequeue_tuple(
                    dtypes=[x.dtype for x in tf.nest.flatten(self.outfeed)],
                    shapes=[x.shape for x in tf.nest.flatten(self.outfeed)])