Example #1
0
        def LoopBody(i, *input_arrays):
            """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
            # Outfeed ops execute on each JF node, so they must be located on the
            # nodes.
            outfeed_devices = []
            device_assignment = py_utils.GetTpuDeviceAssignment()
            assert device_assignment
            for replica in range(device_assignment.num_replicas):
                for core in range(device_assignment.num_cores_per_replica):
                    with tf.device(device_assignment.host_device(
                            replica, core)):
                        outfeed_devices.append(
                            tpu_ops.outfeed_dequeue_tuple(
                                tensor_types,
                                tensor_shapes,
                                device_ordinal=device_assignment.tpu_ordinal(
                                    replica, core)))
            offset = i * num_devices
            output_arrays = list(input_arrays)
            # Each output_array holds a different per-example tensor. We get results
            # for each tensor from each TPU for each TpuTrainStep call.
            for j in range(len(output_arrays)):
                for k in range(len(outfeed_devices)):
                    output_arrays[j] = output_arrays[j].write(
                        offset + k, outfeed_devices[k][j])

            return tuple([i + 1] + output_arrays)
Example #2
0
 def create_dequeue_ops(host_id):
   """Create outfeed dequeue ops."""
   dequeue_ops = []
   tensor_dtypes = []
   tensor_shapes = []
   for v in self.outfeed_tensors:
     dequeue_ops.append([])
     tensor_dtypes.append(v.dtype)
     tensor_shapes.append(v.shape)
   for i in range(FLAGS.tpu_num_shards_per_host):
     with tf.device(
         low_level_utils.device_for_host(self._get_host(host_id))):
       outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
           dtypes=tensor_dtypes, shapes=tensor_shapes, device_ordinal=i)
       for j, item in enumerate(outfeed_tensors):
         dequeue_ops[j].append(item)
   for j in range(len(outfeed_tensors)):
     dequeue_ops[j] = tf.concat(dequeue_ops[j], axis=0)
   return dequeue_ops
Example #3
0
 def _OutfeedDequeue(self):
     """Collect outfeed dequeue from all devices."""
     num_outfeeds = len(self.metrics_nm.Flatten())
     outfeed_ops = [[]] * num_outfeeds
     device_assignment = py_utils.GetTpuDeviceAssignment()
     assert device_assignment
     for replica in range(device_assignment.num_replicas):
         num_cores_per_replica = 1 if self.spmd else (
             device_assignment.num_cores_per_replica)
         for core in range(num_cores_per_replica):
             with tf.device(device_assignment.host_device(replica, core)):
                 outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
                     dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
                     shapes=[x.shape for x in self.metrics_nm.Flatten()],
                     device_ordinal=device_assignment.tpu_ordinal(
                         replica, core))
                 for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
                     outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [
                         out_feed
                     ]
     return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
 def create_dequeue_ops(host_id):
     """Create outfeed dequeue ops."""
     dequeue_ops = []
     tensor_dtypes = []
     tensor_shapes = []
     for v in self.outfeed_tensors:
         tensor_dtypes.append(v.dtype)
         tensor_shapes.append(v.shape)
     with tf.device(utils.device_for_host(self._get_host(host_id))):
         for i in range(self.replicas_per_worker):
             if self.use_spatial_partition:
                 replica_id = self.device_assignment.lookup_replicas(
                     host_id, 0)[i]
                 ordinal = self.device_assignment.tpu_ordinal(
                     replica=replica_id, logical_core=0)
             else:
                 ordinal = i
             outfeed = tpu_ops.outfeed_dequeue_tuple(
                 dtypes=tensor_dtypes,
                 shapes=tensor_shapes,
                 device_ordinal=ordinal)
             if len(outfeed) == 2:
                 # 2 outfeed tensors
                 #   is_pad: [batch]
                 #   detections: [batch, 200, 7]
                 if outfeed[0].shape.ndims == 3:
                     detections, is_pad = outfeed
                 else:
                     is_pad, detections = outfeed
                 num_non_pad = tf.shape(is_pad)[0] - tf.reduce_sum(
                     tf.cast(is_pad, tf.int32))
                 dequeue_ops.append(
                     tf.slice(detections, [0, 0, 0],
                              [num_non_pad, -1, -1]))
             else:
                 # no padding, only detections are in the outfeed
                 dequeue_ops.append(outfeed)
         dequeue_ops = tf.concat(dequeue_ops, axis=0)
     return dequeue_ops
Example #5
0
  def _OutfeedDequeue(self):
    """Collect outfeed dequeue from all devices.

    Returns:
      A list of tensors corresponding to stacked decoded outputs. The decoder
      outputs are stacked on the first dimension (usually corresponds to
      batch size).
    """
    num_decode_tensors = len(self.decode_nm.Flatten())
    outfeed_ops = [[]] * num_decode_tensors
    device_assignment = py_utils.GetTpuDeviceAssignment()
    assert device_assignment
    num_cores_per_replica = (1 if self.spmd else
                             (device_assignment.num_cores_per_replica))
    for replica in range(device_assignment.num_replicas):
      for core in range(num_cores_per_replica):
        with tf.device(device_assignment.host_device(replica, core)):
          outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
              dtypes=[x.dtype for x in self.decode_nm.Flatten()],
              shapes=[x.shape for x in self.decode_nm.Flatten()],
              device_ordinal=device_assignment.tpu_ordinal(replica, core))
          for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
            outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed]
    return [tf.concat(per_outfeed, axis=0) for per_outfeed in outfeed_ops]
Example #6
0
  def _run_eval_phase():
    """The real function that runs the evaluation phase."""
    # Setup input pipeline.
    ds_creator = unet.get_dataset_creator('eval')
    mtf_shapes = unet.get_input_mtf_shapes('eval')

    model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn(
        'eval', mesh_context)

    if FLAGS.use_tpu:
      assert mesh_context.device_assignment
      assert mesh_context.num_cores
      simd_input_reader = input_reader.SimdMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True)
      eval_computation = tpu.replicate(
          computation=model_eval_fn,
          inputs=[[]] * mesh_context.num_cores,
          infeed_queue=simd_input_reader.infeed_queue,
          device_assignment=mesh_context.device_assignment)

      output_dtypes, output_shapes = output_dtypes_shapes.get()
      outfeed_dequeue_ops = []

      # Create outfeed_dequeue_ops.
      for host_id in range(mesh_context.num_hosts):
        # pylint: disable=protected-access
        with ops.device(input_reader._host_id_to_tf_device(
            host_id, external_worker=True)):
          for device_ordinal in range(mesh_context.num_cores_per_host):
            outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple(
                dtypes=output_dtypes,
                shapes=output_shapes,
                device_ordinal=device_ordinal)

            # We don't need output other than from core 0.
            if outfeed_dequeue_ops:
              outfeed_dequeue_ops.append(
                  [tf.reduce_mean(x) for x in outfeed_dequeue_op])
            else:
              outfeed_dequeue_ops.append(outfeed_dequeue_op)

    else:
      placement_input_reader = input_reader.PlacementMeshImplInputReader(
          mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False)
      eval_computation = placement_input_reader.gpu_placement(model_eval_fn)

    ###########################################################
    # Evaluation.
    master_to_slice_hook, _ = eval_hooks.get()
    ckpt_loader_hook = _CkptLoaderHook()
    all_hooks = [ckpt_loader_hook, master_to_slice_hook]

    if FLAGS.write_summary:
      flush_summary = contrib_summary.flush()

    with tf.train.MonitoredSession(
        session_creator=tf.train.ChiefSessionCreator(
            master=FLAGS.master,
            config=tf.ConfigProto(allow_soft_placement=True)),
        hooks=all_hooks) as sess:

      if FLAGS.write_summary:
        contrib_summary.initialize(session=sess)

      if FLAGS.use_tpu:
        simd_input_reader.start_infeed_thread(
            sess, FLAGS.num_eval_iterations_per_loop)
      else:
        placement_input_reader.initialize(sess)

      pprocessor = unet.PostProcessor()
      for step in range(FLAGS.num_eval_iterations_per_loop):
        # Only get results from the 0-th core.
        if FLAGS.use_tpu:
          sess.run(eval_computation)
          results = sess.run(outfeed_dequeue_ops)[0]
        else:
          results = sess.run(eval_computation)
        pprocessor.record(results, FLAGS.pred_output_dir)

        if FLAGS.write_summary:
          sess.run(flush_summary)
        tf.logging.info('eval steps: {}'.format(step))
      pprocessor.finish()
Example #7
0
    def init_graph(self, model_params):
        """Builds moe decode graph.

    Args:
      model_params: the hyperparams of the specified model.
    """
        assert self.graph
        self.model_params = model_params
        batch_size = model_params.task.batch_size
        if (hasattr(model_params.task.builder, 'device_mesh_shape')
                and model_params.task.builder.device_mesh_shape):
            num_partitions = np.prod(
                model_params.task.builder.device_mesh_shape)
        else:
            num_partitions = model_params.task.builder.num_devices

        device_order_mode = (model_params.task.train.tpu_device_order_mode
                             or tpu_device_assignment.DeviceOrderMode.AUTO)
        self._init_tpu(num_partitions, device_order_mode)
        assert self.cluster_params  # configured by init_tpu
        self.cluster = self.cluster_params.Instantiate()

        with self.graph.as_default(), self.cluster, tf.device(
                self.cluster.GetPlacer()):
            _ = py_utils.GetOrCreateGlobalStepVar()
            self.heartbeat = tf.constant(np.pi)

            device_assignment = py_utils.GetTpuDeviceAssignment()

            tf.logging.info('Instantiating model')
            model = model_params.Instantiate()
            xformer = model.GetTask()
            self.task = xformer

            self.init_vars_op = tf.global_variables_initializer()
            self.saver = tf.train.Saver(sharded=True,
                                        reshape=self._saver_reshape)

            infeed = self._config_infeed(num_partitions=num_partitions,
                                         device_assignment=device_assignment,
                                         batch_size=batch_size)

            self.outfeed = []

            def decode_fn(*infeed_batch):  # pylint: disable=missing-docstring
                # Length 6 is passed when there is no tgt_mask (e.g. decoding) and
                # length 7 is passed when there is a tgt_mask (e.g. fprop).

                self.outfeed = self._config_outfeed(xformer, infeed_batch)

                with tf.device(tf.tpu.core(0)):
                    outfeed_op = tpu_ops.outfeed_enqueue_tuple(
                        tf.nest.flatten(self.outfeed))

                return [outfeed_op]

            @tpu_function.on_device_training_loop
            def decode_loop_fn():
                if not self.num_batches:
                    infinite_repeat(decode_fn, infeed)
                else:
                    training_loop.repeat(self.num_batches,
                                         decode_fn,
                                         infeed_queue=infeed)

            self.compile_op, self.decode_loop = tpu_lib.split_compile_and_shard(
                decode_loop_fn,
                num_shards=1,
                device_assignment=device_assignment)

            assert self.outfeed
            with tf.device(device_assignment.tpu_device(0, 0)):
                self.outfeed_op = tpu_ops.outfeed_dequeue_tuple(
                    dtypes=[x.dtype for x in tf.nest.flatten(self.outfeed)],
                    shapes=[x.shape for x in tf.nest.flatten(self.outfeed)])
Example #8
0
    def initialize(self,
                   train_input_fn,
                   eval_input_fn,
                   model_fn,
                   train_batch_size,
                   eval_batch_size,
                   input_partition_dims=None,
                   init_fn=None,
                   train_has_labels=True,
                   eval_has_labels=True,
                   params=None,
                   num_partitions=None):
        """Build graphs for the TPU device and the input pipelines."""
        num_cores_per_replica = 1
        num_cores_per_replica = functools.reduce(
            operator.mul, input_partition_dims
        ) if input_partition_dims else num_partitions if num_partitions else 1

        self.device_assignment = device_assignment.device_assignment(
            topology=self.device_topology,
            computation_shape=_NUM_CORES_TO_COMPUTATION_SHAPE[
                num_cores_per_replica],
            num_replicas=self.num_replicas)
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.eval_has_labels = eval_has_labels
        self.model_fn = model_fn

        if params is None:
            params = {}
        params[
            "dataset_num_shards"] = self.num_replicas // FLAGS.replicas_per_host
        per_replica_train_batch_size = train_batch_size // self.num_replicas
        per_replica_eval_batch_size = eval_batch_size // self.num_replicas
        for i in range(self.num_replicas // FLAGS.replicas_per_host):
            params["dataset_index"] = i
            params["batch_size"] = per_replica_train_batch_size
            self.build_enqueue_ops(train_input_fn, True, input_partition_dims,
                                   params)
            if self.eval_steps > 0:
                params["batch_size"] = per_replica_eval_batch_size
                self.build_enqueue_ops(eval_input_fn, False,
                                       input_partition_dims, params)

        def train_step(_):
            """One train step."""
            inp = self.infeed_op[True].generate_dequeue_op()
            flatten_structure = tf.nest.flatten(self.feature_structure[True])
            inp = [
                tf.slice(i, [0] * i.shape.ndims, j.shape)
                for i, j in zip(inp, flatten_structure)
            ]
            if train_has_labels:
                features, labels = tf.nest.pack_sequence_as(
                    self.feature_structure[True], inp)
            else:
                features = tf.nest.pack_sequence_as(
                    self.feature_structure[True], inp)
                labels = None
            self.maybe_add_embedding_features(features, True)
            train_op, _ = model_fn(features, labels, True)
            embedding_train_op = self.maybe_get_embedding_train_op()
            with tf.device(device_for_tpu_core(self.get_host(0))):
                with tf.control_dependencies([train_op, embedding_train_op]):
                    return tf.constant(0)

        @tpu_function.on_device_training_loop
        def train_loop():
            return training_loop.repeat(self.iterations_per_loop, train_step,
                                        tf.constant(0))

        def train_eval_step():
            with tf.control_dependencies(train_loop()):
                if self.eval_steps > 0:
                    return self.eval_loop()
                else:
                    return tf.no_op()

        @on_device_train_and_eval_loops
        def train_eval_loop():
            return training_loop.repeat(self.max_train_iterations,
                                        train_eval_step)

        with self.graph.as_default():
            (self.train_eval_op, ) = tpu.shard(
                train_eval_loop,
                inputs=[],
                num_shards=self.num_replicas,
                outputs_from_all_shards=False,
                device_assignment=self.device_assignment)
            if FLAGS.model_dir:
                tf.io.write_graph(self.graph, FLAGS.model_dir, "graph.pbtxt")

        output_graph = tf.Graph()
        if self.eval_steps > 0:
            with output_graph.as_default():
                flatten_output = tf.nest.flatten(self.predict_output)
                self.dequeue_ops = [[] for _ in flatten_output]
                tensor_dtypes = [v.dtype for v in flatten_output]
                tensor_shapes = [v.shape for v in flatten_output]
                is_padded_index = flatten_output.index(
                    self.predict_output[_IS_PADDED]
                ) if _IS_PADDED in self.predict_output else -1
                for i in range(self.num_replicas // FLAGS.replicas_per_host):
                    with tf.device(device_for_host(self.get_host(i))):
                        host_dequeue_ops = [[] for _ in flatten_output]
                        for j in range(FLAGS.replicas_per_host):
                            replica_id = self.device_assignment.lookup_replicas(
                                i, 0)[j]
                            ordinal = self.device_assignment.tpu_ordinal(
                                replica=replica_id, logical_core=0)
                            dequeue_ops = tpu_ops.outfeed_dequeue_tuple(
                                dtypes=tensor_dtypes,
                                shapes=tensor_shapes,
                                device_ordinal=ordinal)
                            if is_padded_index >= 0:
                                num_non_pad = tf.shape(
                                    dequeue_ops[is_padded_index]
                                )[0] - tf.reduce_sum(
                                    tf.cast(dequeue_ops[is_padded_index],
                                            tf.int32))
                                dequeue_ops = [
                                    tf.slice(k, [0] * k.shape.ndims,
                                             [num_non_pad] + [-1] *
                                             (k.shape.ndims - 1))
                                    for k in dequeue_ops
                                ]
                            for k, item in enumerate(dequeue_ops):
                                host_dequeue_ops[k].append(item)
                        for k in range(len(self.predict_output)):
                            self.dequeue_ops[k].append(
                                tf.concat(host_dequeue_ops[k], axis=0))

        self.sess = tf.Session(self.master,
                               graph=self.graph,
                               config=self.config)
        for is_training in [True, False]:
            if is_training or self.eval_steps > 0:
                for i in range(self.num_input_graphs):
                    with self.input_graph[is_training][i].as_default():
                        self.input_sess[is_training][i] = tf.Session(
                            self.master,
                            graph=self.input_graph[is_training][i],
                            config=self.config)
                        self.input_sess[is_training][i].run(
                            self.dataset_initializer[is_training][i])
        self.output_sess = tf.Session(self.master,
                                      graph=output_graph,
                                      config=self.config)

        with self.graph.as_default():
            _ = tf.train.get_or_create_global_step()
            if init_fn:
                init_fn()
            checkpoint_path = tf.train.latest_checkpoint(
                FLAGS.model_dir) if FLAGS.model_dir else None
            if FLAGS.restore_checkpoint and checkpoint_path:
                tf.train.Saver().restore(self.sess, checkpoint_path)
            else:
                self.sess.run(tf.global_variables_initializer())
                self.sess.run(tf.local_variables_initializer())
            self.maybe_load_embedding_vars()
            self.global_step = self.sess.run(
                tf.train.get_global_step(self.graph))

        def train_eval_thread_fn(sess, train_eval_op):
            sess.run([train_eval_op])

        # Start the just in time compilation of the model function
        self.train_eval_thread = threading.Thread(target=train_eval_thread_fn,
                                                  args=(self.sess,
                                                        self.train_eval_op))
        self.train_eval_thread.start()

        # Sleep for JTC to finish
        time.sleep(FLAGS.sleep_after_init)