def FProp(self, theta, inputs, paddings): """Builds FProp graph. Args: theta: A NestedMap of Tensors, see base class. inputs: A Tensor of shape [batch, seqlen, dim0]. paddings: A Tensor of shape [batch, seqlen]. Returns: output: A Tensor of shape [batch, seqlen, dim0]. out_paddings: A Tensor of shape [batch, seqlen]. """ p = self.params with tf.name_scope(p.name): unnormalized_inputs = inputs inputs = self.ln.FProp(theta.ln, inputs) if p.split_act_gated_linear_start: act_inputs = self.linear_start_act.FProp( theta.linear_start_act, inputs) gated_inputs = self.linear_start_gated.FProp( theta.linear_start_gated, inputs) else: inputs = self.linear_start.FProp(theta.linear_start, inputs) gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) inputs = self._GLU(gated_inputs, act_inputs) # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs. # [b, t, d] --> [b, t, 1, d] inputs = tf.expand_dims(inputs, 2) adapted_blf_dims_mapping = None if p.activation_split_dims_mapping.blf is not None: adapted_blf_dims_mapping = p.activation_split_dims_mapping.blf.copy( ) adapted_blf_dims_mapping.insert(2, -1) inputs = gshard_utils.MeshSplit(inputs, p.device_mesh, adapted_blf_dims_mapping) theta.depthwise_conv1d.w = gshard_utils.MeshSplit( theta.depthwise_conv1d.w, p.device_mesh, p.weight_split_dims_mapping.hwim) inputs, paddings = self.depthwise_conv1d.FProp( theta.depthwise_conv1d, inputs, paddings) inputs = gshard_utils.MeshSplit(inputs, p.device_mesh, adapted_blf_dims_mapping) inputs = self._Normalize(theta, inputs, paddings) inputs = gshard_utils.MeshSplit( inputs, p.device_mesh, p.activation_split_dims_mapping.blf) inputs = self._ApplyActivation(inputs, p.conv_activation) inputs = self.linear_end.FProp(theta.linear_end, inputs) inputs = self.dropout.FProp(theta.dropout, inputs) output = inputs + unnormalized_inputs return output, paddings
def _CreateVariableInternal(self, name: str, meta: CreateVariableMeta) -> None: """Immediately creates the variable described by `meta`. DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use self.CreateVariable() to create variables. Args: name: The variable name. meta: A CreateVariableMeta describing the variable to be created. """ meta.kwargs.setdefault('default_seed', self.params.random_seed) var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs) self._private_vars[name] = var if self.cluster.params.worker.gpus_per_replica > 0: # On GPU (which always trains a single step per session.run()), reference # a tensor in FProp to cache it on device and avoid extraneous sends from # reading variables from ps multiple times. with tf.device(var.device): value = tf.identity(var) else: # Pass the resource variable directly into the training loop. value = var # Due to b/174956514, we have to annotate the use of the variable once, # otherwise, the sharding annotation on the var will be ignored. # TODO(yonghui): Get rid of this once b/174956514 is fixed. if (meta.var_params.device_mesh is not None and var.shape.rank == len(meta.var_params.tensor_split_dims_mapping)): value = gshard_utils.MeshSplit( value, meta.var_params.device_mesh, meta.var_params.tensor_split_dims_mapping, use_sharding_op=True) if meta.theta_fn is not None: self._private_theta_fn[name] = meta.theta_fn self._private_theta[name] = value