Example #1
0
 def __init__(self, params):
     super(SeqLayer, self).__init__(params)
     p = self.params
     assert p.name
     num_cells = len(p.cell_tpl)
     self._before_layers = []
     self._cells = []
     before_tpl_device = ''
     cell_devices = [''] * num_cells
     if py_utils.use_tpu():
         cluster = self.cluster
         before_tpl_device = cluster.WorkerDeviceInModelSplit(0)
         cell_devices = [
             cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells)
         ]
     for l in p.before_tpl:
         with tf.device(before_tpl_device):
             assert l.name
             self.CreateChild(l.name, l)
             self._before_layers.append((l.name, self.children[l.name]))
     for i, l in enumerate(p.cell_tpl):
         with tf.device(cell_devices[i]):
             assert l.name
             self.CreateChild(l.name, l)
             self._cells.append((l.name, self.children[l.name]))
Example #2
0
      def TrainAndDecodeEpoch(i, host_device):
        """Train and decode infeed for an epoch.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          Decode with control deps on train node.
        """
        train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i)
        decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i)
        tf.logging.info('self._train_steps_per_loop: %d',
                        self._train_steps_per_loop)
        tf.logging.info('self._decode_steps_per_loop: %d',
                        self._decode_steps_per_loop)
        train = wrap_computation_in_while_loop(train_infeed_fn,
                                               self._train_steps_per_loop,
                                               host_device)
        with tf.device(host_device):
          with tf.control_dependencies([train]):
            decode = wrap_computation_in_while_loop(decode_infeed_fn,
                                                    self._decode_steps_per_loop,
                                                    host_device)
        return decode
Example #3
0
    def LoopBody(i, *input_arrays):
      """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
      # Outfeed ops execute on each JF node, so they must be located on the
      # nodes.
      outfeed_devices = []
      device_assignment = py_utils.GetTpuDeviceAssignment()
      assert device_assignment
      for replica in range(device_assignment.num_replicas):
        for core in range(device_assignment.num_cores_per_replica):
          with tf.device(device_assignment.host_device(replica, core)):
            outfeed_devices.append(
                tpu_ops.outfeed_dequeue_tuple(
                    tensor_types,
                    tensor_shapes,
                    device_ordinal=device_assignment.tpu_ordinal(replica,
                                                                 core)))
      offset = i * num_devices
      output_arrays = list(input_arrays)
      # Each output_array holds a different per-example tensor. We get results
      # for each tensor from each TPU for each TpuTrainStep call.
      for j in range(len(output_arrays)):
        for k in range(len(outfeed_devices)):
          output_arrays[j] = output_arrays[j].write(offset + k,
                                                    outfeed_devices[k][j])

      return tuple([i + 1] + output_arrays)
Example #4
0
    def FProp(self, theta, *args):
        """Round-robin every children cells in cell_tpl among worker devices.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Input args

    Returns:
      A list contains one tensor of [batch_size, feature_height, feature_width,
        channel].
    """
        num_layers = len(self.params.cell_tpl)
        cluster = self.cluster

        for (name, l) in self._before_layers:
            l_theta = theta[name]
            args = _ToTuple(args)
            args = l.FProp(l_theta, *args)
        for i in range(num_layers):
            with tf.device(cluster.WorkerDeviceInModelSplit(i)):
                cell_name, cell = self._cells[i]
                args = _ToTuple(args)
                args = cell.FProp(theta[cell_name], *args)

        return args
Example #5
0
  def FProp(self, theta, *args):
    """FProp through multiple devices in the split.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: A tuple of Tensors (one or more). Every tensor's first dimension is
        the same (the batch dimension).

    Returns:
      The sub layer's output.
    """
    p = self.params
    with tf.name_scope(p.name):
      assert all(isinstance(x, tf.Tensor) for x in args)
      cluster = self.cluster
      num = cluster.num_devices_per_split
      if num == 1:
        return self.sub.FProp(theta.sub, *args)
      inps = py_utils.SplitRecursively(list(args), num, axis=0)
      outs = []
      for i, xs in enumerate(inps):
        device = cluster.WorkerDeviceInModelSplit(i)
        tf.logging.info('%d on device %s', i, device)
        with tf.device(device):
          ys = self.sub.FProp(theta.sub, *xs)
          if isinstance(ys, tuple):
            outs += [list(ys)]
          else:
            outs += [ys]  # ys is a single tensor
      ret = py_utils.ConcatRecursively(outs, axis=0)
      if isinstance(ret, list):
        return tuple(ret)
      else:
        return ret  # ys is a single tensor
Example #6
0
  def _OutfeedDequeue(self):
    """Collect outfeed dequeue from all devices."""
    num_outfeeds = len(self.metrics_nm.Flatten())
    outfeed_dicts = []
    concat_lists = {}
    # Hard-coding for Transformer/MLPerf.
    keys = ['target_ids', 'eval_weight', 'tlen', 'top_ids', 'top_lens']
    concat_dict = {}
    for key in keys:
      concat_lists[key] = []

    device_assignment = py_utils.GetTpuDeviceAssignment()
    assert device_assignment
    for replica in range(device_assignment.num_replicas):
      num_cores_per_replica = 1 if self.spmd else (
          device_assignment.num_cores_per_replica)
      for core in range(num_cores_per_replica):
        with tf.device(device_assignment.host_device(replica, core)):
          outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
              dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
              shapes=[x.shape for x in self.metrics_nm.Flatten()],
              device_ordinal=device_assignment.tpu_ordinal(replica, core))
          packed = tf.nest.pack_sequence_as(self.metrics_nm, outfeeds_per_core)
          outfeed_dict = self._decode_model_task.PostProcessDecodeHost(packed)
          for key in keys:
            concat_lists[key].append(outfeed_dict[key])

    for key in keys:
      concat_dict[key] = tf.concat(concat_lists[key], 0)
    return concat_dict
    def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
Example #8
0
 def _DecodeStep():
   """Decode call to be compiled for TPU."""
   input_batch = self._model_task.input_generator.TpuDequeueBatch()
   metrics_dict = self._model_task.Decode(input_batch)
   self.metrics_nm = py_utils.NestedMap(metrics_dict)
   device = tpu.core(0) if self.spmd else ''
   with tf.device(device):
     outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
         self.metrics_nm.Flatten())
     return [outfeed_enqueue]
Example #9
0
      def wrap_computation_in_while_loop(op_fn, n, host_device):
        """Wraps the ops generated by `op_fn` in tf.while_loop."""

        def computation(i):
          ops = op_fn()
          if not isinstance(ops, list):
            ops = [ops]
          with tf.control_dependencies(ops):
            return tf.Print(i + 1, [i], 'while_loop:')

        with tf.device(host_device):
          return tf.while_loop(
              lambda i: tf.less(i, n),
              computation, [tf.constant(0)],
              parallel_iterations=1)
    def TpuDequeueBatch(self):
        """Create TPU dequeue ops.

    This should only be called within a TPU context.

    Returns:
    - A NestedMap of the input batch.
    """
        assert self._tpu_queues, 'CreateTpuEnqueueOps must be called first.'
        with tf.device(tf.tpu.core(0)):
            # Note that the dequeue_tuple op on the TPU core
            # only cares about the shape/types being dequeued
            # which is why this is hard-coded to the first Queue.
            tensors = self._tpu_queues[0].generate_dequeue_op()
        return self._batch.Pack(tensors)
def CollectVarHistogram(vs_gs):
  """Adds histogram summaries for variables and gradients."""

  for name, (var, grad) in vs_gs.FlattenItems():
    name = py_utils.SanitizeScopeKey(name)
    with tf.device(var.device), tf.name_scope(name + '/summary'):
      if isinstance(grad, tf.IndexedSlices):
        var = tf.gather(var, grad.indices)
        grad = grad.values
      if var.dtype.is_complex:
        var = tf.abs(var)
        grad = tf.abs(grad)

    histogram('var_hist/' + name, var)
    histogram('grad_hist/' + name, grad)
Example #12
0
 def _DecodeStep():
   """Decode call to be compiled for TPU."""
   with py_utils.OpportunisticVariableReuseScope(True):
     with cluster_factory.SetEval(True):
       self._decode_model = self._decode_task_params.Instantiate()
       self._decode_model_task = self._decode_model.GetTask()
       self._decode_model_task.AddChild('input', self._decode_input)
       input_batch = self._decode_model_task.input_generator.TpuDequeueBatch(
       )
       metrics_dict = self._decode_model_task.Decode(input_batch)
       self.metrics_nm = py_utils.NestedMap(metrics_dict)
       device = tpu.core(0) if self.spmd else ''
       with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
Example #13
0
 def Send(self, tensor):
   """Sends a tensor through the channel."""
   assert tensor.dtype == self._dtype
   assert not self._send_called, ("Send called multiple times for %s" %
                                  self._name)
   self._send_called = True
   if self._send_tpu_core == -1:
     return tf.raw_ops.Send(
         tensor=tensor,
         tensor_name=self._name,
         send_device=self._send_device,
         send_device_incarnation=0,
         recv_device=self._recv_device)
   else:
     with tf.device(self._send_device):
       return xla.send(
           tensor, tensor_name=self._name, name="Send_" + self._name)
Example #14
0
 def _OutfeedDequeue(self):
   """Collect outfeed dequeue from all devices."""
   num_outfeeds = len(self.metrics_nm.Flatten())
   outfeed_ops = [[]] * num_outfeeds
   device_assignment = py_utils.GetTpuDeviceAssignment()
   assert device_assignment
   for replica in range(device_assignment.num_replicas):
     num_cores_per_replica = 1 if self.spmd else (
         device_assignment.num_cores_per_replica)
     for core in range(num_cores_per_replica):
       with tf.device(device_assignment.host_device(replica, core)):
         outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
             dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
             shapes=[x.shape for x in self.metrics_nm.Flatten()],
             device_ordinal=device_assignment.tpu_ordinal(replica, core))
         for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
           outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed]
   return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
Example #15
0
 def Recv(self):
   """Receives a tensor from the channel."""
   if self._send_tpu_core == -1:
     received = tf.raw_ops.Recv(
         tensor_type=self._dtype,
         tensor_name=self._name,
         send_device=self._send_device,
         send_device_incarnation=0,
         recv_device=self._recv_device)
     received.set_shape(self._shape)
     return received
   else:
     with tf.device(self._recv_device):
       return xla.recv(
           self._dtype,
           tensor_name=self._name,
           shape=self._shape,
           name="Recv_" + self._name)
Example #16
0
    def Export(cls,
               model_cfg,
               model_task_name=None,
               device_options=InferenceDeviceOptions(
                   device='',
                   retain_device_placement=False,
                   var_options=None,
                   gen_init_op=True,
                   dtype_override=None),
               freeze_checkpoint=None,
               freeze_defaults=False,
               export_path=None,
               subgraph_filter=None,
               random_seed=None,
               disable_packed_input=True):
        """Exports a InferenceGraph proto with piecewise subgraphs.

    Sets FLAGS.enable_asserts to False unless user explicitly sets it to True.

    Args:
      model_cfg: a Params instance as returned by
        model_registry.GetParams(modelname, 'Test') or model_params.Model().
      model_task_name: The task to generate an inference graph for. Should be
        None for single-task models.
      device_options: Device options for the accelerator used for serving.
      freeze_checkpoint: The checkpoint to load. Loads and freezes the model if
        given.
      freeze_defaults: Default initializes the graph and freeze. Useful for
        early testing of downstream tools without having a checkpoint.
      export_path: If not None, write the inference graph in ASCII to this path.
      subgraph_filter: A list of subgraph names. If not None or empty, export
        only this list of inference subgraphs.
      random_seed: Fixes the random seed in the exported inference graph.
      disable_packed_input: Disable packed input for inference writing purposes.

    Returns:
      InferenceGraph proto.

    Raises:
      ValueError: if the model does not support the listed subgraphs.
    """
        assert issubclass(model_cfg.cls, base_model.BaseModel)

        # Disable assertions unless user explicitly enables it.
        if FLAGS['enable_asserts'].using_default_value:
            FLAGS.enable_asserts = False

        # TODO(laurenzo): Work out how much we need to specify here in terms of
        # cluster configuration.
        cls._SetClusterParams(model_cfg.cluster, device_options)

        # Configure the model.
        model_cfg.random_seed = random_seed
        model_cfg.is_inference = True

        if disable_packed_input:

            def _DisablePackedInput(task):
                if (_ParamExists(task, 'encoder')
                        and _ParamExists(task.encoder, 'packed_input')):
                    task.encoder.packed_input = False
                if (_ParamExists(task, 'decoder')
                        and _ParamExists(task.decoder, 'packed_input')):
                    task.decoder.packed_input = False

            if issubclass(model_cfg.cls, base_model.MultiTaskModel):
                for _, task_param in model_cfg.task_params.IterParams():
                    _DisablePackedInput(task_param)
            else:
                _DisablePackedInput(model_cfg.task)

        tf.logging.info('Model %s params:', model_cfg.name)
        for line in model_cfg.ToText().split('\n'):
            tf.logging.info('%s', line)

        # Instantiate the graph.
        graph = tf.Graph()
        with graph.as_default():
            tf.random.set_seed(random_seed)
            cluster = model_cfg.cluster.Instantiate()
            device = cluster.GetPlacer()
            tpu_const_scope = _DummyScope()
            if (IsTpu(device_options)
                    and device_options.var_options == 'AS_CONSTANTS'):
                # Do not specify devices for variables if we are marking them as
                # constants.
                device = ''
                tpu_const_scope = ConstGuaranteeScope()

            with cluster, tf.device(device), tpu_const_scope:

                bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations(
                    device_options)

                if bfloat16_override:
                    py_utils.UpdateDtype(model_cfg, tf.bfloat16)
                    py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16)

                # Hard-code TPU-related flags prior to instantiating model.
                old_enable_asserts = FLAGS.enable_asserts
                old_xla_device = FLAGS.xla_device
                if IsTpu(device_options):
                    FLAGS.enable_asserts = False
                    FLAGS.xla_device = 'tpu'

                # Ensure the global_step variable is created.
                global_step_var = py_utils.GetOrCreateGlobalStepVar()
                global_step = tf.identity(global_step_var,
                                          name='global_step_tensor')

                with py_utils.GlobalStepContext(global_step):
                    try:
                        mdl = model_cfg.Instantiate()
                        variables_to_restore = (_MakeVariableDictionary(
                            tf.global_variables()) if not mdl.ema else
                                                mdl.ema.variables_to_restore(
                                                    mdl.variables_for_ema))

                        if bfloat16_override:
                            saver_var_spec = (
                                bfloat16_variables.
                                get_saver_spec_for_variables_with_bf16_overrides(
                                    variables_to_restore))
                        else:
                            saver_var_spec = variables_to_restore

                        saver = tf.train.Saver(saver_var_spec)
                        tf.variables_initializer(tf.global_variables(),
                                                 name='init_all_variables')
                        if IsTpu(
                                device_options) and device_options.gen_init_op:
                            tf.group(tf.tpu.initialize_system(),
                                     name='tpu_init_op')

                        model_task = mdl.GetTask(model_task_name)

                        inference_graph_proto = inference_graph_pb2.InferenceGraph(
                        )
                        subgraphs_proto = model_task.Inference()
                        if isinstance(subgraphs_proto, dict):
                            subgraphs_proto = ConvertSubgraphDictToProto(
                                subgraphs_proto)
                        for name, subgraph in subgraphs_proto.subgraphs.items(
                        ):
                            if not subgraph_filter or name in subgraph_filter:
                                inference_graph_proto.subgraphs[name].CopyFrom(
                                    subgraph)

                        # Add a table init op and global variable init op to the graph.
                        # Tables can be declared anywhere in the graph, so this op has to be
                        # added last.
                        tf.tables_initializer(name='init_all_tables')
                    finally:
                        # Reset TPU-related flags after model instantiation.
                        FLAGS.enable_asserts = old_enable_asserts
                        FLAGS.xla_device = old_xla_device

        tf.logging.info('Graph contains ops: %r',
                        [op.name for op in graph.get_operations()])

        inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def())

        # Freezing.
        if freeze_defaults or freeze_checkpoint:
            output_op_names = GetOutputOpNames(graph,
                                               inference_graph_proto,
                                               preserve_colocation_nodes=False)
            if cls._DeviceSupportsFreezing(device_options):
                raise ValueError(
                    'freeze_checkpoint cannot be used with device ' +
                    device_options.device)
            if freeze_checkpoint:
                tf.logging.info('Freezing graph from checkpoint: %s',
                                freeze_checkpoint)
                graph_def = _FreezeGraphFromCheckpoint(graph, saver,
                                                       freeze_checkpoint,
                                                       output_op_names)
            elif freeze_defaults:
                tf.logging.info('Default initializing graph and freezing.')
                graph_def = _FreezeDefaults(graph, output_op_names)
        else:
            output_op_names = GetOutputOpNames(graph, inference_graph_proto)

            # Prune the graph to just the parts we need.
            # To support restoring, we have to not prune out the restore node.
            output_op_names.append('init_all_tables')
            output_op_names.append('init_all_variables')
            output_op_names.append('save/control_dependency')
            output_op_names.append('save/restore_all')
            if IsTpu(device_options) and device_options.gen_init_op:
                output_op_names.append('tpu_init_op')
            graph_def = graph.as_graph_def()
            tf.logging.info('Pruning graph to output ops: %r', output_op_names)
            graph_def = tf.graph_util.extract_sub_graph(
                graph_def, output_op_names)

        if not device_options.retain_device_placement:
            # Clear the device so that the runtime can choose.
            tf.logging.info('Clearing device placement for: %s',
                            device_options.device)
            for node in graph_def.node:
                node.ClearField('device')
            for function in graph_def.library.function:
                for node_def in function.node_def:
                    node_def.ClearField('device')

        inference_graph_proto.graph_def.CopyFrom(graph_def)

        if export_path:
            with tf.io.gfile.GFile(export_path, 'w') as f:
                f.write(text_format.MessageToString(inference_graph_proto))
        return inference_graph_proto
Example #17
0
 def __init__(self, params):
     with tf.device(self.cluster.input_device):
         super(_UseInputDevice, self).__init__(params)
Example #18
0
 def SplitInputBatch(self, num_splits):
     with tf.device(self.cluster.input_device):
         return super(_UseInputDevice,
                      self).SplitInputBatch(num_splits)
Example #19
0
    def FProp(self, theta, *args):
        """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
        # TODO(huangyp): handle optional None inputs.
        p = self.params
        if self.do_eval:
            outputs = copy.copy(args)
            for (name, l) in self._before_layers + self._cells:
                outputs = _ToTuple(outputs)
                outputs = l.FProp(theta[name], *outputs)
            return outputs

        num_cells = len(p.cell_tpl)
        cluster = self.cluster

        # Compute shapes of input and output tensors.
        input_shapes = self._get_input_shapes(*args)
        state_dtype = self._get_state_dtype(*args)
        state_shapes = self._CalculateOutputShapes(input_shapes)
        tf.logging.info('state_shapes={}'.format(state_shapes))

        def GetCellFn(i):
            """Get the ith feature extraction layer."""
            def CellFn(theta, state0, inputs):
                """A cell fn is exectued inside of StackedRecurrent."""
                del state0

                def _FPropInputSetShape(name, t_shape):
                    if t_shape is None:
                        return None
                    inputs[name].set_shape(t_shape.ToTensorShape().as_list())
                    return inputs[name]

                if p.nested_map_fprop:
                    # pylint: disable=protected-access
                    fprop_inputs = state_shapes[i]._RecursiveMap(
                        _FPropInputSetShape)
                    # pylint: enable=protected-access
                else:
                    fprop_inputs = []
                    for input_idx, input_shape in enumerate(state_shapes[i]):
                        name = 's{}'.format(input_idx)
                        fprop_inputs.append(
                            _FPropInputSetShape(name, input_shape))

                with py_utils.RemoveAssertContext(remove=True):
                    with CellFnFPropOpReplacementWrapper():
                        tf.logging.info('cell {} input {}'.format(
                            i, fprop_inputs))
                        mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
                        SetOverWriteGlobalStep(mb_tensor)
                        _, cell = self._cells[i]
                        fprop_inputs = _ToTuple(fprop_inputs)
                        outputs = cell.FProp(theta, *fprop_inputs)

                if p.nested_map_fprop:
                    assert py_utils.IsCompatible(outputs, state_shapes[i + 1])
                    state1 = outputs.Filter(lambda x: x is not None)
                else:
                    state1 = py_utils.NestedMap()
                    outputs = _ToTuple(outputs)
                    assert len(outputs) == len(state_shapes[i + 1])
                    for output_idx in range(len(outputs)):
                        if outputs[output_idx] is not None:
                            name = 's{}'.format(output_idx)
                            state1[name] = outputs[output_idx]
                state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
                return state1, py_utils.NestedMap()

            return CellFn

        cell_fns = []
        accumulator_layers = []
        thetas = []
        init_states = []
        devices = []
        for cell_idx in range(num_cells):
            cell_name, cell = self._cells[cell_idx]
            accumulator_layers.append(cell)
            cell_fns.append(GetCellFn(cell_idx))
            thetas.append(theta[cell_name])

            def _TfZeros(t_shape):
                if t_shape is None:
                    return None
                return tf.zeros(t_shape.ToTensorShape().as_list(),
                                dtype=state_dtype)

            if p.nested_map_fprop:
                init_state = py_utils.Transform(_TfZeros,
                                                state_shapes[cell_idx + 1])
                init_state = init_state.Filter(lambda x: x is not None)
            else:
                init_state = py_utils.NestedMap()
                for output_idx, state in enumerate(state_shapes[cell_idx + 1]):
                    state = _TfZeros(state)
                    if state is not None:
                        name = 's{}'.format(output_idx)
                        init_state[name] = state
            init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
            init_states.append(init_state)

            devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

        cell_grads = [None] * num_cells
        cell_outs = [lambda x: x] * num_cells
        cell_out_grads = [lambda x: x] * num_cells

        with tf.device(devices[0]):
            previous = _ToTuple(args)
            for (name, l) in self._before_layers:
                previous = l.FProp(theta[name], *previous)
                previous = _ToTuple(previous)

            def _StackAndSplit(x):
                # Split tensors into microbatches.
                if x is None:
                    return None
                return tf.stack(
                    tf.split(x, p.num_micro_batches, axis=p.batch_dim))

            if p.nested_map_fprop:
                inputs = py_utils.Transform(_StackAndSplit, previous[0])
                inputs = inputs.Filter(lambda x: x is not None)
            else:
                inputs = py_utils.NestedMap()
                for output_idx, output_tensor in enumerate(previous):
                    output_tensor = _StackAndSplit(output_tensor)
                    if output_tensor is not None:
                        name = 's{}'.format(output_idx)
                        inputs[name] = output_tensor
            gs_tensor = py_utils.GetGlobalStep()
            inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
                tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
                for t in range(p.num_micro_batches)
            ])
        tf.logging.info('pipeline input = {}'.format(inputs))
        output_state, _ = recurrent.StackedRecurrent(
            devices=devices,
            cell_fns=cell_fns,
            cell_grads=cell_grads,
            cell_outs=cell_outs,
            cell_out_grads=cell_out_grads,
            thetas=thetas,
            init_states=init_states,
            inputs=inputs,
            accumulator_layers=accumulator_layers,
            unused_acc_state=True)

        with tf.device(devices[-1]):

            def _ReshapeRetVal(name, t_shape):
                """Restore shape for tensors in microbatches."""
                if t_shape is None:
                    return None
                output_tensor = output_state[name]
                if p.batch_dim != 0:
                    perm = list(range(1, p.batch_dim + 1)) + [0]
                    perm += list(range(p.batch_dim + 1, t_shape.rank + 1))
                    output_tensor = tf.transpose(output_tensor, perm=perm)
                output_shape = t_shape.ToTensorShape().as_list()
                output_shape[p.batch_dim] *= p.num_micro_batches
                output_tensor = tf.reshape(output_tensor, output_shape)
                return output_tensor

            # Construct the final return values from output_state.
            if p.nested_map_fprop:
                # pylint: disable=protected-access
                output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal)
                # pylint: enable=protected-access
            else:
                output_tensors = []
                for output_idx, state_shape in enumerate(state_shapes[-1]):
                    output_name = 's{}'.format(output_idx)
                    output_tensor = _ReshapeRetVal(output_name, state_shape)
                    output_tensors.append(output_tensor)
                if len(output_tensors) == 1:
                    output_tensors = output_tensors[0]
                else:
                    output_tensors = tuple(output_tensors)
            tf.logging.info('pipeline output = {}'.format(output_tensors))
            return output_tensors
Example #20
0
 def _DecoderDevice(self):
     """Returns the device to run the decoder computation."""
     return tf.device('')
Example #21
0
    def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir,
                 tf_master, **kwargs):
        """Construct an ExecutorTpu BaseRunner.

    Args:
      train_cfg: SingleTaskModelParams or MultiTaskModelParams
      ps_params_dict: A dict of top-level task name -> ProgramSchedule params,
          if train_cfg is a SingleTaskModelParams, we expect only one entry.
      model_task_name: An override for multi-task models, currently unused.
      logdir:  String path to the log directory to output to.
      tf_master: String path to the master job, e.g. 'local'.
      **kwargs: keyword args to pass through to BaseRunner.
    """
        super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir,
                                          tf_master, **kwargs)

        self._cluster_def = self._cluster.worker_cluster_def

        # There is a single Executor task
        assert self._cluster.num_replicas == 1
        data_parallelism = self._cluster.num_splits_per_client

        assert data_parallelism
        num_devices_per_split = self._cluster.num_devices_per_split
        tf.logging.info('data_parallelism: %d, num_devices_per_split: %d',
                        data_parallelism, num_devices_per_split)

        self.task_scheduler = None
        self._checkpoint_dir = os.path.join(logdir, 'train')

        self._variable_renaming_rules = []

        self._ml_perf = None

        # If this is a multi-task model, grab the params for the TaskScheduler.
        if issubclass(train_cfg.cls, base_model.SingleTaskModel):
            tf.logging.info('single_task_model')
            assert len(ps_params_dict) == 1
            self._model_task_name = list(ps_params_dict.keys())[0]
            self._single_task_mode = True
        elif issubclass(train_cfg.cls, base_model.MultiTaskModel):
            tf.logging.info('multi_task_model')

            if issubclass(train_cfg.cls,
                          multitask_model.RegExSharedVariableModel):
                self._variable_renaming_rules = train_cfg.variable_renaming_rules

            if train_cfg.task_schedule is None:
                task_schedule_params = task_scheduler.ConstantScheduler.Params(
                )
                task_schedule_params.task_probs = sorted(
                    list(train_cfg.task_probs.IterParams()))
            else:
                task_schedule_params = train_cfg.task_schedule
            self.task_scheduler = task_schedule_params.Instantiate()
            self._single_task_mode = False
        else:
            tf.logging.fatal(
                'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel',
                train_cfg.cls)

        tf.logging.info('train_cfg.cls: %s', train_cfg.cls)

        self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir,
                         'params.txt')
        self._program_schedule_dict = {}
        self._programs = []

        for task_string, program_schedule_params in ps_params_dict.items():
            program_schedule_params.logdir = logdir
            program_schedule_params.num_splits_per_client = data_parallelism
            program_schedule_params.task_name = task_string
            ps = program_schedule_params.Instantiate()
            self._program_schedule_dict[task_string] = ps
            tf.logging.info('program_schedule_params: %s',
                            program_schedule_params.ToText())
            self._programs += ps.Programs()
            if program_schedule_params.ml_perf.benchmark_name is not None:
                self._ml_perf = program_schedule_params.ml_perf

        tf.logging.info('num_programs: %d', len(self._programs))

        if self._ml_perf is not None:
            self._ml_perf_log = True
            mlp_log.mlperf_print(key='benchmark',
                                 value=self._ml_perf.benchmark_name)
        else:
            self._ml_perf_log = False

        # BaseRunner legacy
        self.enqueue_ops = None

        @py_utils.RetryOnTransientTfError()
        def _WaitTillInit():
            """Wait until the model is ready."""
            try:
                with self._graph.as_default(), self._GetSession(
                        cluster_def=self._cluster_def,
                        disable_meta_optimizer=FLAGS.
                        disable_meta_optimizer_in_executor) as sess:
                    topology = sess.run(
                        tf.tpu.initialize_system(embedding_config=None,
                                                 job=None))
                    device_assignment = device_assignment_lib.device_assignment(
                        topology,
                        computation_shape=py_utils.ComputationShape(
                            num_devices_per_split),
                        num_replicas=data_parallelism)
                    py_utils.SetTpuDeviceAssignment(device_assignment)
                    tf.logging.info('device_assignment.core_assignment: %s',
                                    str(device_assignment.core_assignment))
                    tf.logging.info(
                        'device_assignment.topology.device_coordinates: %s',
                        str(device_assignment.topology.device_coordinates))
            except py_utils.transient_tf_errors as e:
                tf.logging.info('TPU initialization failed: %s', e)
                raise

        if self._ml_perf_log:
            mlp_log.mlperf_print(key='cache_clear', value=True)
            mlp_log.mlperf_print(key='init_start', value=None)
        _WaitTillInit()

        with self._graph.as_default(), tf.container(self._container_id):
            tf.logging.info('self._cluster.job_spec.name: %s',
                            self._cluster.job_spec.name)
            with self._cluster, tf.device(
                    self._cluster.job_spec.name if not FLAGS.
                    cluster_placer_in_executor else self._cluster.GetPlacer()):
                with py_utils.VariableRenameScope(
                        self._variable_renaming_rules):
                    _ = py_utils.GetOrCreateGlobalStepVar()
                    for program in self._programs:
                        program.BuildTpuSubgraph()
                for program in self._programs:
                    program.SetStatusMessageFn(self._SetStatusMessage)
                    program.CreateCheckpointer()
                self._initialize_tables = tf.tables_initializer()
                self._initialize_local_vars = tf.local_variables_initializer()

                self.save_only_checkpointer = checkpointer.Checkpointer(
                    self._checkpoint_dir,
                    model=None,
                    train_params=train_cfg.train,
                    save_only=True)