def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False): """Computes mean and variance over the valid data points in inputs.""" inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) rank = tf.rank(mask) reduce_over_dims = tf.range(0, rank - 1) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims) count_v = tf.reduce_sum(mask, reduce_over_dims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1] count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def _MaybeExpandPaddings(self, inputs, paddings): # rank difference is at most one. rank_diff = tf.rank(inputs) - tf.rank(paddings) paddings = py_utils.with_dependencies([ py_utils.assert_less_equal(rank_diff, 1), py_utils.assert_greater_equal(rank_diff, 0) ], paddings) # Pads [1] to the end of paddings. paddings = tf.reshape( paddings, tf.concat([tf.shape(paddings), tf.tile([1], [rank_diff])], axis=0)) return paddings
def ComputeMoments(inputs, padding, reduce_over_dims, cumulative_axis=None, enable_cross_replica_sum_on_tpu=False, keepdims=False): """Computes mean and variance over the valid data points in inputs.""" mask = 1.0 - padding inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims, keepdims=keepdims) count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis) count_v = tf.math.cumsum(count_v, axis=cumulative_axis) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. input_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(inputs), reduce_over_dims)) mask_size_on_reduced_dims = tf.reduce_prod( tf.gather(tf.shape(mask), reduce_over_dims)) mask_multiplier = tf.math.truediv(input_size_on_reduced_dims, mask_size_on_reduced_dims) count_v *= tf.cast(mask_multiplier, count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def _Slice(tensor): """Return a slice of this tensor at time=state0.t.""" shape = py_utils.GetShape(tensor) # All zeros except for t in the time dimension. # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...] begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t) # Same as shape, but with a 1 in the time dimension. # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...] size = tf.concat([ shape[0:self.params.axis], tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:] ], axis=0) # Make a slice where the time dimension is fixed at state0.t. time_slice = tf.slice(tensor, begin, size) # Remove the time dimension. return tf.squeeze(time_slice, axis=self.params.axis)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def ComputeAndUpdateMoments(self, theta, inputs, paddings=None, **kwargs): """Computes moments and updates state. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. **kwargs: Additional inputs. Returns: Tuple of (mean, variance, beta, gamma). """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]), ], inputs) with tf.name_scope(p.name): if self.do_eval or p.freeze_bn_stats: # The mean and variance used for normalization. norm_mean, norm_variance = (self.vars.moving_mean, self.vars.moving_variance) else: rank = tf.rank(paddings) reduce_over_dims = tf.range(0, rank - 1) mean, variance = ComputeMoments( inputs, paddings, reduce_over_dims, None, p.enable_cross_replica_sum_on_tpu) py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean, self._decay) py_utils.UpdateBatchNormVars(self.vars.moving_variance, variance, self._decay) # Add some summaries for visualization. summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32)) summary_utils.histogram('%s_variance' % p.name, tf.cast(variance, tf.float32)) summary_utils.histogram( '%s_moving_mean' % p.name, tf.cast(self.vars.moving_mean, tf.float32)) summary_utils.histogram( '%s_moving_variance' % p.name, tf.cast(self.vars.moving_variance, tf.float32)) summary_utils.histogram( '%s_mean_diff' % p.name, tf.cast( tf.cast(mean, self.vars.moving_mean.dtype.base_dtype) - self.vars.moving_mean, tf.float32)) summary_utils.histogram( '%s_variance_diff' % p.name, tf.cast( tf.cast(variance, self.vars.moving_variance.dtype.base_dtype) - self.vars.moving_variance, tf.float32)) if p.use_moving_avg_in_training: # Use the global statistics for normalization. # Control dependencies on mean and variance make sure # moving_mean and variance will be updated for every training step. norm_mean = py_utils.with_dependencies( [mean], self.vars.moving_mean) norm_variance = py_utils.with_dependencies( [variance], self.vars.moving_variance) else: # Use the batch statistics for normalization. norm_mean = mean norm_variance = variance norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta, gamma = self._GetBetaGamma(theta, inputs, **kwargs) return norm_mean, norm_variance, beta, gamma
def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None): # pyformat: disable """Interpolates source ids in input_batch and interpolation_batch. Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060. It is a standard Transformer Encoder if interpolation_batch != None. Args: theta: A `.NestedMap` object containing weights values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. interpolation_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. - embs: Embeddings of ids. lambdas: A pair of tensors to combine embeddings of ids in input_batch and interpolation_batch. Returns: A NestedMap of - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ # pyformat: enable p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) max_seq_length = None if (not py_utils.use_tpu() and FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, :max_seq_length] src_segment_pos = input_batch.segment_pos[:, :max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup(theta.softmax, tf.reshape(input_ids, [-1])) if interpolation_batch is not None: other_input_ids = interpolation_batch.ids if not p.shared_emb: other_input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(other_input_ids, [-1])) else: other_input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(other_input_ids, [-1])) lambdas = [tf.expand_dims(a, -1) for a in lambdas] if 'embs' in input_batch and input_batch.embs is not None: input_embs = input_batch.embs if 'embs' in interpolation_batch and interpolation_batch.embs is not None: other_input_embs = interpolation_batch.embs else: input_embs = tf.reshape( input_embs, [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim]) other_input_embs = tf.reshape( other_input_embs, [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim]) input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs paddings = paddings + interpolation_batch.paddings - 1.0 paddings = tf.clip_by_value(paddings, 0.0, 1.0) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) orig_input_embs = input_embs if p.task_emb: if interpolation_batch is None: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) else: task_embs = self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) other_task_embs = self.task_emb.EmbLookup( theta.task_emb, interpolation_batch.task_ids) task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs input_embs += task_embs if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp(theta.position_emb, max_time) position_embs = tf.reshape(position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where( tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap( encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)