Ejemplo n.º 1
0
 def _DecodeStep():
     """Decode call to be compiled for TPU."""
     input_batch = self._task.input.TpuDequeueBatch()
     metrics_dict = self._task.Decode(input_batch)
     self.metrics_nm = py_utils.NestedMap(metrics_dict)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
Ejemplo n.º 2
0
    def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table):
        p = self.params

        load_op_list = []
        retrieve_op_list = []

        num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts
        table_name = tpu_embedding_table.table_name
        slot_var_collections = [
            tpu_embedding_table.__class__.__name__ + '_vars'
        ]

        for host_id, table_var in zip(range(num_tpu_hosts), table_vars):
            # The slot vars should be on the same device as the table var.
            device_name = tpu_embedding_table.GetDeviceName(host_id)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                w_ada = py_utils.WeightParams(
                    shape=table_var.shape.as_list(),
                    init=py_utils.WeightInit.Constant(p.initial_accumulator),
                    dtype=p.dtype,
                    collections=slot_var_collections)
                var_name = tpu_embedding_table.GetVariableName(host_id)
                tpu_embedding_table.CreateVariable('%s/Adagrad' % var_name,
                                                   w_ada,
                                                   trainable=False)
                accumulator_var = tpu_embedding_table.vars['%s/Adagrad' %
                                                           var_name]

                # Only the Trainer needs these ops.
                if py_utils.use_tpu():
                    # TPU Embedding load/retrieve ops need to be in the outer graph scope.
                    with tf.init_scope():
                        tf.logging.info('creating load and retrieve ops.')
                        load_parameters_op = (
                            tpu_embedding_lib.tpu_ops.
                            load_tpu_embedding_adagrad_parameters(
                                parameters=table_var,
                                accumulators=accumulator_var,
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        load_op_list.append(load_parameters_op)

                        retrieved_table, retrieved_accumulator = (
                            tpu_embedding_lib.tpu_ops.
                            retrieve_tpu_embedding_adagrad_parameters(
                                table_name=table_name,
                                num_shards=num_tpu_hosts,
                                shard_id=host_id))
                        retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group(
                            tf.assign(table_var, retrieved_table),
                            tf.assign(accumulator_var, retrieved_accumulator))
                        retrieve_op_list.append(retrieve_parameters_op)

        return load_op_list, retrieve_op_list
Ejemplo n.º 3
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master="",
                 session_config=None,
                 clear_device_placement=False):
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, six.string_types):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph,
                                                 clear_device_placement)
        self._inference_graph = inference_graph
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master
        self._session_config = session_config

        self._graph = tf.Graph()
        with self._graph.as_default():
            tf.logging.info(
                "Loading inference graph for prediction subgraph_name={}.".
                format(subgraph_name))
            self._saver = tf.train.Saver(saver_def=inference_graph.saver_def)
            with tf.device("/%s:0" %
                           "cpu" if device_type == "tpu" else device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
            if device_type == "tpu":
                # If no tpu init op exists, create it here.
                try:
                    self._graph.get_operation_by_name("tpu_init_op")
                except KeyError:
                    tf.group(tf.tpu.initialize_system(), name="tpu_init_op")

            self._graph.finalize()

        if inference_graph.subgraphs:
            if subgraph_name not in inference_graph.subgraphs:
                raise ValueError(
                    "Subgraph %s not defined. Valid subgraphs: %s" %
                    (subgraph_name, list(inference_graph.subgraphs.keys())))
            subgraph = inference_graph.subgraphs[subgraph_name]
            self._fetches = subgraph.fetches
            self._feeds = subgraph.feeds
        else:
            self._fetches = inference_graph.fetches
            self._feeds = inference_graph.feeds

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._CreateNewSession()
Ejemplo n.º 4
0
            def decode_fn(*infeed_batch):  # pylint: disable=missing-docstring
                # Length 6 is passed when there is no tgt_mask (e.g. decoding) and
                # length 7 is passed when there is a tgt_mask (e.g. fprop).

                self.outfeed = self._config_outfeed(xformer, infeed_batch)

                with tf.device(tf.tpu.core(0)):
                    outfeed_op = tpu_ops.outfeed_enqueue_tuple(
                        tf.nest.flatten(self.outfeed))

                return [outfeed_op]
Ejemplo n.º 5
0
 def _DecodeStep():
     """Decode call to be compiled for TPU."""
     with py_utils.OpportunisticVariableReuseScope(True):
         self._model.InstantiateVariables()
         input_batch = self._task.input.TpuDequeueBatch()
         metrics_dict = self._task.Decode(input_batch)
     self.metrics_nm = py_utils.NestedMap(metrics_dict)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
Ejemplo n.º 6
0
def ReplicatedGenericInput(processor, num_replicas, replica_device_fn,
                           **kwargs):
    """Builds a replicated input pipeline.

  This is similar to GenericInput, except that the input processing can be
  distributed across devices and then concatenated at the current device.

  Args:
    processor: see comments for GenericInput.
    num_replicas: the number of input processing replicas. Usually set to number
      of infeed hosts.
    replica_device_fn: a int -> string function that takes the replica index in
      range [0, num_replicas) and returns a TF device string, e.g.,
      lambda i: '/task:{}/device:CPU:0'.format(i)
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):

    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension. The batch size will be
      (num_replicas * bucket_batch_limit[...]), i.e.,
      kwargs['bucket_batch_limit'] specifies the per-replica batch size.
    - bucket_keys: a tf.int32 vector.

  Raises:
    RuntimeError: If called in pure Eager/tf.function mode without
      `generic_input_v2_key` defined.
  """
    if num_replicas > 1 and 'bucket_batch_limit' in kwargs:
        assert all(b == max(kwargs['bucket_batch_limit'])
                   for b in kwargs['bucket_batch_limit'])
    replica_outputs = []
    if py_utils.IsEagerMode():
        current_key = kwargs.pop('generic_input_v2_key', None)
        if current_key is None:
            raise RuntimeError(_MISSING_KEY_ERR)

    for replica_i in range(num_replicas):
        # Blend `replica_i` into the key for _GENERIC_CACHE_V2 to distinguish
        # different GenericInputV2 ops in the same Datasource object.
        if py_utils.IsEagerMode():
            kwargs['generic_input_v2_key'] = (current_key, replica_i)
        replica_device = replica_device_fn(replica_i)
        with tf.device(replica_device):
            replica_outputs.append(GenericInput(processor, **kwargs))

    output_nmaps, output_bucket_keys = zip(*replica_outputs)
    concat_nmap = tf.nest.map_structure(lambda *t: tf.concat(t, axis=0),
                                        *output_nmaps)
    concat_bucket_keys = tf.concat(output_bucket_keys, axis=0)
    return concat_nmap, concat_bucket_keys
Ejemplo n.º 7
0
    def testBasic(self):
        devices = _ListDevices(_Target())
        print("\n".join(devices))
        sender, recver = devices[0], devices[-1]
        shape = []

        for dtype in tf.float32, tf.complex64:
            to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype)
            g = tf.Graph()
            with g.as_default():
                ch = sendrecv.Channel(dtype, shape, sender, recver, "test")
                with tf.device(sender):
                    src_val = tf.constant(to_send)
                    send_op = ch.Send(src_val)
                with tf.device(recver):
                    recv_val = ch.Recv()

            with tf.Session(_Target(), graph=g):
                _, val = self.evaluate([send_op, recv_val])

            self.assertAllClose(to_send, val)
Ejemplo n.º 8
0
        def SendRecv(graph, dtype):
            to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype)
            with graph.as_default():
                ch = sendrecv.Channel(dtype, shape, sender, recver, "test")
                with tf.device(sender):

                    @tf.Defun()
                    def Send():
                        src_val = tf.constant(to_send)
                        ch.Send(src_val)
                        return 1.0

                    send_op = Send()

                with tf.device(recver):

                    @tf.Defun()
                    def Recv():
                        return ch.Recv()

                    recv_val = Recv()
            return send_op, recv_val, to_send
Ejemplo n.º 9
0
  def _CreateChildrenVariables(self):
    p = self.params

    num_cells = len(p.cell_tpl)
    before_tpl_device = ''
    cell_devices = [''] * num_cells
    if py_utils.use_tpu():
      cluster = self.cluster
      before_tpl_device = cluster.WorkerDeviceInModelSplit(0)
      cell_devices = [
          cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells)
      ]

    for unused_name, l in self._before_layers:
      with tf.device(before_tpl_device):
        l.InstantiateVariables()

    for i, (unused_name, l) in enumerate(self._cells):
      with tf.device(cell_devices[i]):
        l.InstantiateVariables()

    super()._CreateChildrenVariables()
Ejemplo n.º 10
0
        def SendRecv(graph, dtype):
            to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype)
            with graph.as_default():
                ch = sendrecv.Channel(dtype, shape, sender, recver, "test")
                with tf.device(sender):

                    # py_utils.CallDefun requires non-empty inputs. Same below.
                    def Send(_):
                        src_val = tf.constant(to_send)
                        ch.Send(src_val)
                        return tf.convert_to_tensor(1.0)

                    send_op = py_utils.CallDefun(Send, tf.convert_to_tensor(0))

                with tf.device(recver):

                    def Recv(_):
                        return ch.Recv()

                    recv_val = py_utils.CallDefun(Recv,
                                                  tf.convert_to_tensor(0))
            return send_op, recv_val, to_send
Ejemplo n.º 11
0
  def _CreateLayerVariables(self):
    p = self.params

    # Reuse the singleton table variables if they were created before.
    all_table_vars = self._tpu_embedding_collection.table_variables
    if self.table_name in all_table_vars:
      embedding_table_vars = all_table_vars[self.table_name]
    else:
      w_pc = py_utils.WeightParams(
          shape=[self._ids_per_shard, p.embedding_dim],
          init=p.params_init,
          dtype=p.dtype,
          collections=[self.__class__.__name__ + '_vars'])

      embedding_table_vars = []
      for i in range(p.num_tpu_hosts):
        device_name = self.GetDeviceName(i)
        with tf.device(device_name), py_utils.outside_all_rewrites():
          var_name = self.GetVariableName(i)
          self.CreateVariable(var_name, w_pc)
          embedding_var = self.vars[var_name]
          embedding_table_vars.append(embedding_var)
          # Remove from _private_vars / _private_thetas to be added later as wm.
          _RemovePrivateVar(self, var_name)

      self._tpu_embedding_collection.AddTableVariables(self.table_name,
                                                       embedding_table_vars)

    if not _ShouldUseTpu(p):
      # We don't want to add this for TrainerTpu, otherwise the identity
      # reference leads to copying the embedding to the TPU for no reason.
      # However, this is needed for CPU (eval/decode/controller).
      self._private_vars['wm'] = embedding_table_vars
      self._private_theta['wm'] = [tf.identity(v) for v in embedding_table_vars]

    # If slot variables and load/retrieve ops were created before, maybe by a
    # different program or task, don't create it again.
    # Note that there should be only one copy of slot variables and
    # load/retrieve ops in the graph and they're shared by different
    # tasks/programs.
    all_load_ops = self._tpu_embedding_collection.load_ops
    if self.table_name not in all_load_ops:
      assert self.table_name not in self._tpu_embedding_collection.retrieve_ops
      # Only trainer and controller (for checkpointing) need slot variables.
      # Only trainer needs load/retrieve ops.
      if not self.do_eval and not p.is_inference:
        load_ops, retrieve_ops = self.optimizer.CreateSlotVariablesAndOps(
            embedding_table_vars, self)
        self._tpu_embedding_collection.AddLoadRetrieveOps(
            self.table_name, load_ops, retrieve_ops)
Ejemplo n.º 12
0
    def CreateVariable(self, name: str, var_params: hyperparams.Params,
                       **kwargs) -> None:
        """Create a variable of this layer according to the parameter `var_params`.

    E.g.::

        def __init__(self, ...):    # A layer's constructor
          self.CreateVariable(
              'weight', py_utils.WeightParams(shape=[100, 100]))

    Args:
      name: Variable name which is used as the key into vars/theta.
      var_params: `Params` used to create the variable.
      **kwargs: Keyword args passed to `.py_utils.CreateVariable`.
    """
        kwargs.setdefault('default_seed', self.params.random_seed)
        if self.params.device_mesh is not None:
            if (len([dim for dim in var_params.shape if dim > 1]) > 1
                    and var_params.tensor_split_dims_mapping is None):
                tf.logging.warning(
                    'tensor_split_dims_mapping missing for %s.%s: shape=%s',
                    self.path, name, var_params.shape)
        self._CheckName(name)
        if (self.params.skip_lp_regularization and
                py_utils.SKIP_LP_REGULARIZATION not in var_params.collections):
            var_params = py_utils.WeightParams(
                shape=var_params.shape,
                dtype=var_params.dtype,
                init=var_params.init,
                collections=(var_params.collections +
                             [py_utils.SKIP_LP_REGULARIZATION]))
        self._var_symbolic_shape_map[name] = var_params.shape

        var = py_utils.CreateVariable(name, var_params, **kwargs)
        self._private_vars[name] = var

        if py_utils.IsEagerMode():
            # With eager trainer, always use the variable directly.
            value = var
        else:
            if self.cluster.params.worker.gpus_per_replica > 0:
                # On GPU (which always trains a single step per session.run()),
                # reference a tensor in FProp to cache it on device and avoid extraneous
                # sends from reading variables from ps multiple times.
                with tf.device(var.device):
                    value = tf.identity(var, name=name)
            else:
                value = var

        self._private_theta[name] = value
Ejemplo n.º 13
0
def CollectVarHistogram(vs_gs):
    """Adds histogram summaries for variables and gradients."""

    for name, (var, grad) in vs_gs.FlattenItems():
        with tf.device(var.device), tf.name_scope(name + '/summary'):
            if isinstance(grad, tf.IndexedSlices):
                var = tf.gather(var, grad.indices)
                grad = grad.values
            if var.dtype.is_complex:
                var = tf.abs(var)
                grad = tf.abs(grad)

        histogram('var_hist/' + name, var)
        histogram('grad_hist/' + name, grad)
Ejemplo n.º 14
0
 def Recv(self):
     """Receives a tensor from the channel."""
     if self._send_tpu_core == -1:
         return tf.raw_ops.Recv(tensor_type=self._dtype,
                                tensor_name=self._name,
                                send_device=self._send_device,
                                send_device_incarnation=0,
                                recv_device=self._recv_device)
     else:
         with tf.device(self._recv_device):
             return xla.recv(self._dtype,
                             tensor_name=self._name,
                             shape=self._shape,
                             name="Recv_" + self._name)
Ejemplo n.º 15
0
    def TpuDequeueBatch(self):
        """Create TPU dequeue ops.

    This should only be called within a TPU context.

    Returns:
    - A NestedMap of the input batch.
    """
        assert self._tpu_queues, 'CreateTpuEnqueueOps must be called first.'
        with tf.device(tf.tpu.core(0)):
            # Note that the dequeue_tuple op on the TPU core
            # only cares about the shape/types being dequeued
            # which is why this is hard-coded to the first Queue.
            tensors = self._tpu_queues[0].generate_dequeue_op()
        return self._batch.Pack(tensors)
Ejemplo n.º 16
0
    def __init__(self, params):
        super().__init__(params)
        p = self.params
        self._before_layers = []
        self._cells = []

        num_cells = len(p.cell_tpl)
        before_tpl_device = ''
        cell_devices = [''] * num_cells
        if py_utils.use_tpu():
            cluster = self.cluster
            before_tpl_device = cluster.WorkerDeviceInModelSplit(0)
            cell_devices = [
                cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells)
            ]

        for l in p.before_tpl:
            with tf.device(before_tpl_device):
                self.CreateChild(l.name, l)
            self._before_layers.append((l.name, self.children[l.name]))
        for i, l in enumerate(p.cell_tpl):
            with tf.device(cell_devices[i]):
                self.CreateChild(l.name, l)
            self._cells.append((l.name, self.children[l.name]))
Ejemplo n.º 17
0
    def _load_graph_from_inference_graph(self, inference_graph):
        """Returns a tf.Graph() constructed from `inference_graph`.

    Args:
      inference_graph: An InferenceGraph proto from which a graph_def is loaded
        from.

    Returns:
      A loaded tf.Graph().
    """
        graph = tf.Graph()
        with graph.as_default():
            with tf.device("/%s:0" % "cpu" if self._device_type ==
                           "tpu" else self._device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
        return graph
Ejemplo n.º 18
0
 def Send(self, tensor):
   """Sends a tensor through the channel."""
   assert tensor.dtype == self._dtype
   assert not self._send_called, ("Send called multiple times for %s" %
                                  self._name)
   self._send_called = True
   if self._send_tpu_core == -1:
     return tf.raw_ops.Send(
         tensor=tensor,
         tensor_name=self._name,
         send_device=self._send_device,
         send_device_incarnation=0,
         recv_device=self._recv_device)
   else:
     with tf.device(self._send_device):
       return xla.send(
           tensor, tensor_name=self._name, name="Send_" + self._name)
Ejemplo n.º 19
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._job_name = 'trainer'
        with self._graph.as_default(), tf.container(self._container_id):
            try:
                self._task_probs_summary_writers = []
                for task in self._model.task_schedule.tasks:
                    path = os.path.join(os.path.join(self._train_dir, task))
                    tf.io.gfile.makedirs(path)
                    self._task_probs_summary_writers.append(
                        self._CreateSummaryWriter(path))
            except AttributeError:
                tf.logging.info(
                    'AttributeError. Expected for single task models.')
                self._task_probs_summary_writers = []

            if self.params.cluster.task == 0:
                self._summary_writer = self._CreateSummaryWriter(
                    self._train_dir)
                self._CreateTF2SummaryWriter(self._train_dir)
            else:
                self._summary_writer = None

            with self._cluster, tf.device(
                    self._cluster.GetPlacer()), self._TF2SummaryContext():
                self._model = self.params.Instantiate()
                self._params = self._model.params
                self._model.ConstructFPropBPropGraph()
            self._CreateTF2SummaryOps()
            self._initialize_tables = tf.tables_initializer()
            self._initialize_local_vars = tf.local_variables_initializer()
            self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS)
            tf.logging.info('Trainer number of enqueue ops: %d',
                            len(self.enqueue_ops))

        self._step_rate_tracker = summary_utils.StepRateTracker()

        # Saves the graph def.
        if self.params.cluster.task == 0:
            self._WriteToLog(self.params.ToText(), self._train_dir,
                             'trainer_params.txt')
            tf.io.write_graph(self._graph.as_graph_def(), self._train_dir,
                              'train.pbtxt')
        worker_id = self.params.cluster.task
        self._start_up_delay_steps = (((worker_id + 1) * worker_id / 2) *
                                      self.params.train.start_up_delay_steps)
Ejemplo n.º 20
0
def CollectVarHistogram(vs_gs):
    """Adds histogram summaries for variables and gradients."""

    for name, (var, grad) in vs_gs.FlattenItems():
        name = py_utils.SanitizeScopeKey(name)
        with tf.device(var.device), tf.name_scope(name + '/summary'):
            if isinstance(grad, tf.IndexedSlices):
                var = tf.gather(var, grad.indices)
                grad = grad.values
            if var.dtype.is_complex:
                var = tf.abs(var)
                grad = tf.abs(grad)

        if py_utils.IsEagerMode():
            histogram_v2(f'var_hist/{name}', var)
            histogram_v2(f'grad_hist/{name}', grad)
        else:
            histogram(f'var_hist/{name}', var)
            histogram(f'grad_hist/{name}', grad)
Ejemplo n.º 21
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master=""):
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, six.string_types):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph)
        self._inference_graph = inference_graph
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master

        self._graph = tf.Graph()
        with self._graph.as_default():
            tf.logging.info("Loading inference graph for prediction.")
            self._saver = tf.train.Saver(saver_def=inference_graph.saver_def)
            with tf.device("/%s:0" %
                           "cpu" if device_type == "tpu" else device_type):
                tf.import_graph_def(inference_graph.graph_def, name="")
            self._graph.finalize()

        if inference_graph.subgraphs:
            if subgraph_name not in inference_graph.subgraphs:
                raise ValueError(
                    "Subgraph %s not defined. Valid subgraphs: %s" %
                    (subgraph_name, list(inference_graph.subgraphs.keys())))
            subgraph = inference_graph.subgraphs[subgraph_name]
            self._fetches = subgraph.fetches
            self._feeds = subgraph.feeds
        else:
            self._fetches = inference_graph.fetches
            self._feeds = inference_graph.feeds

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._CreateNewSession()
Ejemplo n.º 22
0
  def _FPropSplitInputBatch(self, theta, input_batch):
    """Splits the input batch on the input device."""
    if py_utils.use_tpu():
      return self._FPropTpu(theta, input_batch)

    cluster = self.cluster
    num_splits = cluster.num_splits_per_client

    if not isinstance(input_batch, list):
      input_batch = [input_batch]

    assert len(input_batch) == num_splits, (len(input_batch), num_splits)

    # dev_list_per_replica[i][j] is the i-th worker's j-th device.
    dev_list_per_replica = cluster.available_devices.tolist()

    # Asserts invariant of the total number of splits w.r.t.,
    # splits per worker.
    splits_per_replica = cluster.num_splits_per_replica
    assert num_splits == splits_per_replica * len(dev_list_per_replica), (
        num_splits, splits_per_replica, len(dev_list_per_replica))

    all_metrics = []
    all_per_example_tensors = []
    with cluster:
      for w_id, w_devs in enumerate(dev_list_per_replica):
        # Make local copy of the vars, shard on devices for this worker.
        theta_local = py_utils.CreateLocalTheta(
            theta, w_devs, label='worker %d' % w_id)

        for s_id in range(splits_per_replica):
          # s_id-th split for the w_id-th worker.
          split_id = splits_per_replica * w_id + s_id
          with cluster_factory.SetModelSplit(split_id) as c:
            with tf.device(c.WorkerDeviceInModelSplit(0)):
              with tf.name_scope('tower_%d_%d' % (w_id, s_id)):
                batch = input_batch[split_id]
                metrics, per_example = self.FPropTower(theta_local, batch)
          all_metrics.append(metrics)
          all_per_example_tensors.append(per_example)
    return py_utils.WeightedAvgOfMetrics(
        all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)
Ejemplo n.º 23
0
 def testPSRandomSize(self):
   p = cluster_factory.Cluster.Params()
   p.worker.name = '/job:trainer'
   p.ps.name = '/job:ps'
   p.ps.replicas = 10
   c = cluster_factory.Cluster(p)
   g = tf.Graph()
   vs = []
   np.random.seed(301)
   with g.as_default():
     with tf.device(c.GetPlacer()):
       # Creates 200 variables with different sizes.
       for i in range(200):
         if i % 13:
           size = np.random.randint(10000)
         elif i % 7:
           size = np.random.randint(100)
         else:
           size = np.random.randint(10)
         vs.append(tf.get_variable('x%d' % i, shape=(size)))
       sum_all = tf.add_n([tf.reduce_sum(x) for x in vs])
   # Computes the total size of variables placed on each device.
   total_size = {}  # device name -> size
   for v in vs:
     size = tf.TensorShape(v.op.get_attr('shape')).num_elements()
     if v.device in total_size:
       total_size[v.device] += size
     else:
       total_size[v.device] = size
   for (device, allocated) in zip(
       sorted(total_size),
       [91701, 91361, 90346, 88738, 87240, 89265, 91944, 92472, 88051, 95053]):
     self.assertEqual(total_size[device], allocated)
   self.assertEqual(
       sum_all.device,
       cluster.MakeDeviceString(
           job_name='/job:trainer',
           replica_id=0,
           task_id=0,
           device_name='CPU',
           device_id=0))
Ejemplo n.º 24
0
    def _CreateVariableInternal(self, name, meta):
        """Immediately creates the variable described by `meta`.

    DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use
    self.CreateVariable() to create variables.

    Args:
      name: The variable name.
      meta: A CreateVariableMeta describing the variable to be created.
    """
        meta.kwargs.setdefault('default_seed', self.params.random_seed)
        var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs)
        self._private_vars[name] = var
        if resource_variable_ops.is_resource_variable(var):
            value = var
        else:
            with tf.device(var.device):
                value = tf.identity(var)
        if meta.theta_fn is not None:
            value = meta.theta_fn(value)
        self._private_theta[name] = value
Ejemplo n.º 25
0
 def _OutfeedDequeue(self):
     """Collect outfeed dequeue from all devices."""
     num_outfeeds = len(self.metrics_nm.Flatten())
     outfeed_ops = [[]] * num_outfeeds
     device_assignment = py_utils.GetTpuDeviceAssignment()
     assert device_assignment
     for replica in range(device_assignment.num_replicas):
         num_cores_per_replica = 1 if self.spmd else (
             device_assignment.num_cores_per_replica)
         for core in range(num_cores_per_replica):
             with tf.device(device_assignment.host_device(replica, core)):
                 outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple(
                     dtypes=[x.dtype for x in self.metrics_nm.Flatten()],
                     shapes=[x.shape for x in self.metrics_nm.Flatten()],
                     device_ordinal=device_assignment.tpu_ordinal(
                         replica, core))
                 for idx_outfeed, out_feed in enumerate(outfeeds_per_core):
                     outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [
                         out_feed
                     ]
     return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
Ejemplo n.º 26
0
  def _CreateVariableInternal(self, name: str,
                              meta: CreateVariableMeta) -> None:
    """Immediately creates the variable described by `meta`.

    DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use
    self.CreateVariable() to create variables.

    Args:
      name: The variable name.
      meta: A CreateVariableMeta describing the variable to be created.
    """
    meta.kwargs.setdefault('default_seed', self.params.random_seed)
    var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs)
    self._private_vars[name] = var
    if self.cluster.params.worker.gpus_per_replica > 0:
      # On GPU (which always trains a single step per session.run()), reference
      # a tensor in FProp to cache it on device and avoid extraneous sends from
      # reading variables from ps multiple times.
      with tf.device(var.device):
        value = tf.identity(var)
    else:
      # Pass the resource variable directly into the training loop.
      value = var

    # Due to b/174956514, we have to annotate the use of the variable once,
    # otherwise, the sharding annotation on the var will be ignored.
    # TODO(yonghui): Get rid of this once b/174956514 is fixed.
    if (meta.var_params.device_mesh is not None and
        var.shape.rank == len(meta.var_params.tensor_split_dims_mapping)):
      value = gshard_utils.MeshSplit(
          value,
          meta.var_params.device_mesh,
          meta.var_params.tensor_split_dims_mapping,
          use_sharding_op=True)

    if meta.theta_fn is not None:
      self._private_theta_fn[name] = meta.theta_fn

    self._private_theta[name] = value
Ejemplo n.º 27
0
def ReplicatedGenericInput(processor, num_replicas, replica_device_fn,
                           **kwargs):
    """Builds a replicated input pipeline.

  This is similar to GenericInput, except that the input processing can be
  distributed across devices and then concatenated at the current device.

  Args:
    processor: see comments for GenericInput.
    num_replicas: the number of input processing replicas. Usually set to number
      of infeed hosts.
    replica_device_fn: a int -> string function that takes the replica index in
      range [0, num_replicas) and returns a TF device string, e.g.,
      lambda i: '/task:{}/device:CPU:0'.format(i)
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):

    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension. The batch size will be
      (num_replicas * bucket_batch_limit[...]), i.e.,
      kwargs['bucket_batch_limit'] specifies the per-replica batch size.
    - bucket_keys: a tf.int32 vector.
  """
    if num_replicas > 1 and 'bucket_batch_limit' in kwargs:
        assert all(b == max(kwargs['bucket_batch_limit'])
                   for b in kwargs['bucket_batch_limit'])
    replica_outputs = []
    for replica_i in range(num_replicas):
        replica_device = replica_device_fn(replica_i)
        with tf.device(replica_device):
            replica_outputs.append(GenericInput(processor, **kwargs))
    output_nmaps, output_bucket_keys = zip(*replica_outputs)
    concat_nmap = tf.nest.map_structure(lambda *t: tf.concat(t, axis=0),
                                        *output_nmaps)
    concat_bucket_keys = tf.concat(output_bucket_keys, axis=0)
    return concat_nmap, concat_bucket_keys
Ejemplo n.º 28
0
def _WrapNonLingvoVars(dest_layer: base_layer.BaseLayer,
                       variables: Collection[tf.Variable],
                       trainable_variables: Collection[tf.Variable] = ()):
    """Adds variables to the given lingvo layer and appropriate graph collections.

  This function helps wrap variables created outside of lingvo so they are
  correctly handled by lingvo's trainer and checkpointer. It does the following:

    - makes all `variables` trackable through `dest_layer.vars`;
    - ensures `variables` are in the `tf.global_variables()` graph collection so
      the trainer can initialize them;
    - adds the `trainable_variables` subset to the `tf.trainable_variables()`
      graph collection, so they are visible to the learner (i.e. can be
      trained).

  Args:
    dest_layer: Lingvo layer to add the `variables` to.
    variables: The non-lingvo variables to wrap.
    trainable_variables: The subset of `variables` to ensure are trainable.
  """

    global_collection = set(tf.global_variables())
    for v in variables:
        assert v in global_collection
        name = v.name.split(':')[0]
        # pylint: disable=protected-access
        dest_layer._private_vars[name] = v
        with tf.device(v.device):
            dest_layer._private_theta[name] = tf.identity(v)
        # pylint: enable=protected-access

    trainable_collection = set(tf.trainable_variables())
    for v in trainable_variables:
        if v not in trainable_collection:
            tf.logging.warning(
                'Wrapped var %s not in trainable collection; adding it.',
                v.name)
            tf.compat.v1.add_to_collection(
                tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, v)
Ejemplo n.º 29
0
    def _CreateLayerVariables(self):
        p = self.params
        w_pc = py_utils.WeightParams(
            shape=[self._ids_per_shard, p.embedding_dim],
            init=p.params_init,
            dtype=p.dtype,
            collections=[self.__class__.__name__ + '_vars'])

        embedding_table_vars = []
        for i in range(p.num_tpu_hosts):
            device_name = self.GetDeviceName(i)
            with tf.device(device_name), py_utils.outside_all_rewrites():
                var_name = self.GetVariableName(i)
                self.CreateVariable(var_name, w_pc)
                embedding_var = self.vars[var_name]
                embedding_table_vars.append(embedding_var)
                # Remove from _private_vars / _private_thetas to be added later as wm.
                del self._private_vars[var_name]
                del self._private_theta[var_name]

        self._tpu_embedding_collection.AddTableVariables(
            self.table_name, embedding_table_vars)

        if not py_utils.use_tpu():
            # We don't want to add this for TrainerTpu, otherwise the identity
            # reference leads to copying the embedding to the TPU for no reason.
            # However, this is needed for CPU (eval/decode/controller).
            self._private_vars['wm'] = embedding_table_vars
            self._private_theta['wm'] = [
                tf.identity(v) for v in embedding_table_vars
            ]

        # Only trainer and controller need slot variables and load/retrieve ops.
        if not self.do_eval:
            self._load_op_list, self._retrieve_op_list = (
                self.optimizer.CreateSlotVariablesAndOps(
                    embedding_table_vars, self))
Ejemplo n.º 30
0
        def LoopBody(i, *input_arrays):
            """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
            # Outfeed ops execute on each JF node, so they must be located on the
            # nodes.
            outfeed_devices = []
            device_assignment = py_utils.GetTpuDeviceAssignment()
            assert device_assignment
            for replica in range(device_assignment.num_replicas):
                num_cores_per_replica = 1 if self.spmd else (
                    device_assignment.num_cores_per_replica)
                for core in range(num_cores_per_replica):
                    with tf.device(device_assignment.host_device(
                            replica, core)):
                        outfeed_devices.append(
                            tpu_ops.outfeed_dequeue_tuple(
                                tensor_types,
                                tensor_shapes,
                                device_ordinal=device_assignment.tpu_ordinal(
                                    replica, core)))
            offset = i * num_devices
            output_arrays = list(input_arrays)
            # Each output_array holds a different per-example tensor. We get results
            # for each tensor from each TPU for each TpuTrainStep call.
            for j in range(len(output_arrays)):
                for k in range(len(outfeed_devices)):
                    output_arrays[j] = output_arrays[j].write(
                        offset + k, outfeed_devices[k][j])

            return tuple([i + 1] + output_arrays)