コード例 #1
0
    def FProp(self, theta, inputs, paddings):
        """Builds FProp graph.

    Args:
      theta: A NestedMap of Tensors, see base class.
      inputs: A Tensor of shape [batch, seqlen, dim0].
      paddings: A Tensor of shape [batch, seqlen].

    Returns:
      output: A Tensor of shape [batch, seqlen, dim0].
      out_paddings: A Tensor of shape [batch, seqlen].
    """

        p = self.params
        with tf.name_scope(p.name):
            unnormalized_inputs = inputs

            inputs = self.ln.FProp(theta.ln, inputs)
            if p.split_act_gated_linear_start:
                act_inputs = self.linear_start_act.FProp(
                    theta.linear_start_act, inputs)
                gated_inputs = self.linear_start_gated.FProp(
                    theta.linear_start_gated, inputs)
            else:
                inputs = self.linear_start.FProp(theta.linear_start, inputs)
                gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
            inputs = self._GLU(gated_inputs, act_inputs)

            # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs.
            # [b, t, d] --> [b, t, 1, d]
            inputs = tf.expand_dims(inputs, 2)
            adapted_blf_dims_mapping = None
            if p.activation_split_dims_mapping.blf is not None:
                adapted_blf_dims_mapping = p.activation_split_dims_mapping.blf.copy(
                )
                adapted_blf_dims_mapping.insert(2, -1)
            inputs = gshard_utils.MeshSplit(inputs, p.device_mesh,
                                            adapted_blf_dims_mapping)
            theta.depthwise_conv1d.w = gshard_utils.MeshSplit(
                theta.depthwise_conv1d.w, p.device_mesh,
                p.weight_split_dims_mapping.hwim)
            inputs, paddings = self.depthwise_conv1d.FProp(
                theta.depthwise_conv1d, inputs, paddings)

            inputs = gshard_utils.MeshSplit(inputs, p.device_mesh,
                                            adapted_blf_dims_mapping)
            inputs = self._Normalize(theta, inputs, paddings)
            inputs = gshard_utils.MeshSplit(
                inputs, p.device_mesh, p.activation_split_dims_mapping.blf)

            inputs = self._ApplyActivation(inputs, p.conv_activation)

            inputs = self.linear_end.FProp(theta.linear_end, inputs)
            inputs = self.dropout.FProp(theta.dropout, inputs)

            output = inputs + unnormalized_inputs
            return output, paddings
コード例 #2
0
  def _CreateVariableInternal(self, name: str,
                              meta: CreateVariableMeta) -> None:
    """Immediately creates the variable described by `meta`.

    DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use
    self.CreateVariable() to create variables.

    Args:
      name: The variable name.
      meta: A CreateVariableMeta describing the variable to be created.
    """
    meta.kwargs.setdefault('default_seed', self.params.random_seed)
    var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs)
    self._private_vars[name] = var
    if self.cluster.params.worker.gpus_per_replica > 0:
      # On GPU (which always trains a single step per session.run()), reference
      # a tensor in FProp to cache it on device and avoid extraneous sends from
      # reading variables from ps multiple times.
      with tf.device(var.device):
        value = tf.identity(var)
    else:
      # Pass the resource variable directly into the training loop.
      value = var

    # Due to b/174956514, we have to annotate the use of the variable once,
    # otherwise, the sharding annotation on the var will be ignored.
    # TODO(yonghui): Get rid of this once b/174956514 is fixed.
    if (meta.var_params.device_mesh is not None and
        var.shape.rank == len(meta.var_params.tensor_split_dims_mapping)):
      value = gshard_utils.MeshSplit(
          value,
          meta.var_params.device_mesh,
          meta.var_params.tensor_split_dims_mapping,
          use_sharding_op=True)

    if meta.theta_fn is not None:
      self._private_theta_fn[name] = meta.theta_fn

    self._private_theta[name] = value