def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) bn_output *= 1.0 - paddings return bn_output
def _ComputeBN(self, inputs, paddings, gamma, beta, norm_mean, norm_variance): p = self.params with tf.control_dependencies([ py_utils.assert_greater_equal(norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and (self.do_eval or p.freeze_bn_stats): bn_output, _, _ = nn.fused_batch_norm(inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output = py_utils.ApplyPadding(paddings, bn_output) return bn_output
def FProp(self, theta, input_batch): p = self.params src_segment_id = None with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match( tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys, _ = layer.FProp(theta.rnn[i], xs, ps) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) return py_utils.NestedMap( encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, inputs, paddings=None): """Apply group normalization. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor with shape [batch_size, height, width, channel]. paddings: The paddings tensor with shape [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: A single tensor as the output after applying group normalization, with the same shape as 'inputs'. Or a output, output_paddings pair if input paddings is not None. """ p = self.params n, h, w, c = tf.unstack(tf.shape(inputs), axis=0, num=4) group_size = p.dim // p.num_groups num_groups = p.num_groups min_group_size = p.min_group_size if p.dim > p.min_group_size else p.dim if group_size <= min_group_size: group_size = min_group_size num_groups = p.dim // group_size with tf.name_scope(p.name): x = tf.reshape(inputs, [n, h, w, num_groups, group_size]) if paddings is None: counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics( x, axes=[1, 2, 4], keepdims=True) norm_mean, norm_variance = tf.nn.normalize_moments( counts, means_ss, variance_ss, None) else: expanded_paddings = tf.reshape(paddings, [n, h, 1, 1, 1]) norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, [1, 2, 4], keepdims=True) norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta = theta.beta gamma = theta.gamma with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.cast(0., norm_variance.dtype)), py_utils.assert_shape_match([n, 1, 1, num_groups, 1], tf.shape(norm_mean)), py_utils.assert_shape_match([n, 1, 1, num_groups, 1], tf.shape(norm_variance)), ]): x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon) x = tf.reshape(x, [n, h, w, c]) gn_output = x * gamma + beta gn_output = tf.reshape(gn_output, [n, h, w, c]) if paddings is None: return gn_output else: return gn_output, paddings
def FProp(self, theta, input_batch): """Encodes source as represented by `inputs` and `paddings`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, a tensor of shape [time, batch, depth] - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. """ p = self.params src_segment_id = None with tf.name_scope(p.name): # Now the rnn layers. inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) self._emb_out = xs ps = paddings # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell # with the same cc_schedule so that the RNN layer output is within # clipping range. xs = self.rnn[0].FProp(theta.rnn[0], xs, ps) xs = self.dropout.FProp(theta.dropout, xs) for i in range(1, p.num_lstm_layers): layer = self.rnn[i] ys, _ = layer.FProp(theta.rnn[i], xs, ps) ys = self.dropout.FProp(theta.dropout, ys) if hasattr(layer.params, 'cell'): layer_params = layer.params.cell else: layer_params = layer.params if layer_params.num_input_nodes == layer_params.num_output_nodes: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: # When cc_schedule is specified, make sure lstm_tpl is # QuantizedLSTMCell with the same cc_schedule so that the RNN layer # output is within clipping range. xs = ys return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, inputs, paddings): """Apply convolution to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor, expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. """ p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]), py_utils.assert_shape_match( tf.shape(inputs), tf.concat([ tf.shape(paddings), [-1, symbolic.ToStatic(self.input_channels)] ], 0)) ], inputs) def _ApplyPadding(tensor_in, padding_in): padding_expanded = tf.expand_dims( tf.expand_dims(padding_in, -1), -1) return tensor_in * (1.0 - padding_expanded) # Zeroing out padded inputs. inputs = _ApplyPadding(inputs, paddings) # Apply conv on 'inputs'. out = self._ApplyConv(theta, inputs) if p.partial_conv: out = self._RescaleBoundary(out, paddings) # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1. # But there's likely no real problems. Trying to set it gives an error: # pooling with SAME padding is not implemented for dilation_rate > 1. # NOTE: we use window=p.filter_stride[0] to be compatible with legacy # implementation. Consider updating it to be the actual shape. conv_padding = ComputeConvOutputPadding(paddings, window=p.filter_stride[0], stride=p.filter_stride[0]) # Assuming padded nodes will be properly zero-ed out if necessary by # sub-sequent layers. # out = _ApplyPadding(out, conv_padding) out = py_utils.HasShape( out, symbolic.ToStatic(self.OutShape(tf.shape(inputs)))) return out, conv_padding
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self._ComputeInputs(theta, inputs, input_batch) summary_utils.histogram('input_emb', xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, inputs): """Applies batch normalization. Using the implementation in github.com/ tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550 Args: theta: A nested map object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params inputs_dtype = inputs.dtype inputs = tf.cast(inputs, p.dtype) inputs = py_utils.with_dependencies( [py_utils.assert_shape_match([tf.shape(inputs)[-1]], [p.dim])], inputs) with tf.name_scope(p.name) as scope: if p.is_eval: outputs = tf.nn.batch_normalization(inputs, theta.moving_mean, theta.moving_variance, theta.beta, theta.gamma, p.epsilon) else: mean, variance = self._Moments(inputs, p.bn_group_size) mean = py_utils.CheckNumerics( mean, 'mean of {} failed numeric check'.format(scope)) variance = py_utils.CheckNumerics( variance, 'variance of {} failed numeric check'.format(scope)) outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta, theta.gamma, p.epsilon) outputs.set_shape(inputs.get_shape()) return tf.cast(outputs, inputs_dtype)
def StreamStep(self, theta, inputs, paddings, state0): """Apply a singele step of convolution to input_tensor. Only supports 1d causal convolution. Doesn't support dilation. Args: theta: A NestedMap of layer params. inputs: A Tensor of shape [b, t=1, 1, c] paddings: A 0/1 valued tensor of shape [b, t=1]. state0: A NestedMap of tensors of the same struct as returned by zero_state(). Returns: outputs: A Tensor of shape [b, t=1, 1, c * channel_multiplier] padding: the same as input paddings. state1: A NestedMap of the same struct as input state """ p = self.params assert p.filter_shape[1] == 1, ( 'StreamStep only supports 1d causal convolution.') assert p.filter_stride[0] == 1, ( 'StreamStep doesn\'t support striding') assert p.dilation_rate == (1, 1), ('StreamStep doesn\'t support dilation') with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(py_utils.GetShape(inputs), [-1, 1, 1, p.filter_shape[2]]) ], inputs) b = py_utils.GetShape(inputs)[0] # next state. state1 = py_utils.NestedMap(context=tf.concat( [state0.context[:, 1:, :, :], inputs], axis=1)) expanded_paddings = tf.reshape(paddings, [b, 1, 1, 1]) # Not updating the states for padded examples. state1.context = (state0.context * expanded_paddings + state1.context * (1. - expanded_paddings)) outputs = tf.nn.depthwise_conv2d(state1.context, self._GetWeight(self.theta), strides=(1, 1, 1, 1), dilations=(1, 1), data_format='NHWC', padding='VALID') return outputs, paddings, state1
def ComputeAndUpdateMoments(self, theta, inputs, paddings=None): """Computes moments and updates state. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Tuple of (mean, variance, beta, gamma). """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]), ], inputs) with tf.name_scope(p.name): if self.do_eval: # The mean and variance used for normalization. norm_mean, norm_variance = self._moving_mean, self._moving_variance else: mean, variance = self._Moments( inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu) py_utils.UpdateBatchNormVars(self._moving_mean, mean, self._decay) py_utils.UpdateBatchNormVars(self._moving_variance, variance, self._decay) # Add some summaries for visualization. summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32)) summary_utils.histogram('%s_variance' % p.name, tf.cast(variance, tf.float32)) summary_utils.histogram('%s_moving_mean' % p.name, tf.cast(self._moving_mean, tf.float32)) summary_utils.histogram( '%s_moving_variance' % p.name, tf.cast(self._moving_variance, tf.float32)) summary_utils.histogram( '%s_mean_diff' % p.name, tf.cast(mean - self._moving_mean, tf.float32)) summary_utils.histogram( '%s_variance_diff' % p.name, tf.cast(variance - self._moving_variance, tf.float32)) if p.use_moving_avg_in_training: # Use the global statistics for normalization. # Control dependencies on mean and variance make sure # moving_mean and variance will be updated for every training step. norm_mean = py_utils.with_dependencies([mean], self._moving_mean) norm_variance = py_utils.with_dependencies( [variance], self._moving_variance) else: # Use the batch statistics for normalization. norm_mean = mean norm_variance = variance norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) if p.use_moving_avg_in_training: beta = 0.0 gamma = 1.0 else: beta = theta.beta gamma = theta.gamma return norm_mean, norm_variance, beta, gamma
def FProp(self, theta, inputs, query_vec=None): """Combines the list of input tensors into a single tensor. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: A list of tensors of shape [..., hidden_dim] or [..., [pre_proj_input_dims[i]]] if pre_proj_input_dims is specified. query_vec: A tensor of shape [..., hidden_dim]. Returns: A tensor of the same shape with input tensors. Raises: ValueError: p.merger_op is not defined. """ p = self.params n_sources = len(inputs) if p.pre_proj_input_dims and len(p.pre_proj_input_dims) != n_sources: raise ValueError( 'pre_proj_input_dims must be specified for each input.') if n_sources == 1: return inputs[0] # Pre-projection operation. if p.pre_proj_input_dims: for i in range(n_sources): inputs[i] = self.pre_proj[i].FProp(theta.pre_proj[i], inputs[i]) tensor_pairs = list(zip(inputs[:-1], inputs[1:])) if p.merger_op == 'mean': # Simply take the mean, all dims must match. with tf.control_dependencies([ py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2)) for t1, t2 in tensor_pairs ]): output = tf.add_n(inputs) / n_sources elif p.merger_op == 'sum': # Sum up all sources, all dims must match. with tf.control_dependencies([ py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2)) for t1, t2 in tensor_pairs ]): output = tf.add_n(inputs) elif p.merger_op == 'weighted_sum': # Weighted sum of all sources, all dims must match. # For weighted_sum, assume input is a list of rank 3 tensors inputs = tf.stack(inputs) inputs = py_utils.HasRank(inputs, 4) with tf.control_dependencies([ py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2)) for t1, t2 in tensor_pairs ]): w = tf.expand_dims( tf.expand_dims(tf.expand_dims(self._sum_weight, 1), 1), 1) w = tf.tile(w, [ 1, tf.shape(inputs)[1], tf.shape(inputs)[2], tf.shape(inputs)[3] ]) output = tf.reduce_sum(inputs * w, axis=0) elif p.merger_op == 'atten': # Apply attention over the concatenated tensor, all dims must match. with tf.control_dependencies([ py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2)) for t1, t2 in tensor_pairs ]): inputs = tf.stack(inputs, axis=0) batch_size = tf.shape(inputs)[1] paddings = tf.zeros([n_sources, batch_size], dtype=inputs.dtype) self.atten.InitForSourcePacked(theta.atten, inputs, inputs, paddings) output, _, _ = self.atten.ComputeContextVector( theta.atten, tf.reshape(query_vec, [-1, p.query_dim])) elif p.merger_op == 'concat': # Concatenate over the last dim, all dims but last must match. with tf.control_dependencies([ py_utils.assert_equal( tf.shape(t1)[:-1], tf.shape(t2)[:-1]) for t1, t2 in tensor_pairs ]): output = tf.concat(inputs, axis=-1) elif p.merger_op == 'gated_avg': output = self.gated_average.FProp(theta.gated_average, inputs) else: raise ValueError('Unrecognized merge op!') return output
def FProp(self, theta, query_vec, source_paddings, source_vecs=None, query_segment_id=None, source_segment_id=None): """Transformer attention, residual and normalization layer. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. query_vec: [target_time, target_batch, dim] source_paddings: [source_time, source_batch] source_vecs: [source_time, source_batch, dim]. query_segment_id: [target_time, target_batch] source_segment_id: [source_time, source_batch] Returns: (output, atten_probs). output is of shape [target_time, target_batch, source_dim], atten_probs is of shape [target_time, target_batch, source_time]. """ p = self.params unnormalized_query_vec = query_vec query_vec = self.layer_norm.FProp(theta.layer_norm, query_vec) if source_vecs is None: source_vecs = query_vec source_segment_id = query_segment_id if p.is_masked: assert source_vecs is not None query_vec = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(source_vecs), tf.shape(query_vec)) ], query_vec) # Prepares mask for self-attention # [time, time] target_time = tf.shape(query_vec)[0] target_bs = tf.shape(query_vec)[1] triangle_padding = 1.0 - tf.matrix_band_part( tf.ones([target_time, target_time], dtype=py_utils.FPropDtype(p)), -1, 0) # [time, batch, time] causal_padding = tf.tile(tf.expand_dims(triangle_padding, 1), [1, target_bs, 1]) causal_padding = tf.reshape(causal_padding, [-1, target_time]) else: causal_padding = None query_dim = tf.shape(query_vec)[-1] packed_src = self.atten.PackSource(theta.atten, source_vecs, source_vecs, source_paddings, source_segment_id) if query_segment_id is not None: query_segment_id = tf.reshape(query_segment_id, [-1]) ctx_vec, atten_prob, _ = self.atten.ComputeContextVectorWithSource( theta.atten, packed_src, tf.reshape(query_vec, [-1, query_dim]), per_step_source_padding=causal_padding, query_segment_id=query_segment_id) ctx_vec = self.residual_dropout.FProp(theta.residual_dropout, ctx_vec) input_to_add = (unnormalized_query_vec if p.add_unnormalized_input else query_vec) h = input_to_add + tf.reshape(ctx_vec, tf.shape(query_vec)) atten_prob = tf.reshape(atten_prob, [ tf.shape(query_vec)[0], tf.shape(query_vec)[1], tf.shape(source_vecs)[0] ]) return h, atten_prob
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, :max_seq_length] src_segment_pos = input_batch.segment_pos[:, :max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup(theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp(theta.position_emb, max_time) position_embs = tf.reshape(position_embs, [1, max_time, p.token_emb.embedding_dim]) # Position embeddings are simply added to token embeddings. input_embs += position_embs if p.individually_tagged_input: assert not p.packed_input # Look up tag embeddings; this assumes that the tags arriving on # input_batch.segment_ids (originating as common.source_segment_id # in the input NMTExample) have been reserved in the WPM vocabulary # as context tags, e.g. the ids for <src_token> and <ctxt_token> in # wide source context experiments. input_tags = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.segment_ids), tf.shape(input_batch.ids)), py_utils.assert_equal(tf.rank(input_batch.segment_ids), 2) ], input_batch.segment_ids) tag_embeddings = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_tags, [-1])) tag_embeddings = tf.reshape(tag_embeddings, [-1, max_time, p.token_emb.embedding_dim]) # Concatenate the tag embeddings to the input embeddings, and then # project back to the original embedding dimensionality. concat_embs = tf.concat([input_embs, tag_embeddings], -1) input_embs = self.concat_emb_and_tag_proj.FProp( theta.concat_emb_and_tag_proj, concat_embs) if p.ln_input: input_embs = self.layer_norm_input.FProp(theta.layer_norm_input, input_embs) if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) summary_utils.histogram('input_embs', input_embs) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) summary_utils.histogram('emb_proj', input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where( tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap( encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def FProp(self, theta, inputs, paddings=None): """Apply group normalization. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor with shape [batch_size, height, width, channel]. paddings: The paddings tensor with shape [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: A single tensor as the output after applying group normalization, with the same shape as 'inputs'. Or a output, output_paddings pair if input paddings is not None. """ p = self.params inputs = py_utils.with_dependencies([ py_utils.assert_greater_equal(py_utils.GetRank(inputs), p.input_rank) ], inputs) min_group_size = min(p.min_group_size, p.dim) group_size = max(p.dim // p.num_groups, min_group_size) num_groups = p.dim // group_size input_shape = py_utils.GetShape(inputs) with tf.name_scope(p.name): x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size]) expanded_rank = p.input_rank + 1 all_dims = list(range(expanded_rank)) if paddings is None: # Skip d0, d[-2] axes = all_dims[1:-2] + all_dims[-1:] counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics( x, axes=axes, keepdims=True) norm_mean, norm_variance = tf.nn.normalize_moments( counts, means_ss, variance_ss, None) else: expanded_paddings = tf.reshape( paddings, input_shape[:2] + [1] * (expanded_rank - 2)) # skip the batching and group dim if p.cumulative: # Skip d0, d1 and d[-2] reduce_over_dims = all_dims[2:-2] + all_dims[-1:] norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, reduce_over_dims=reduce_over_dims, cumulative_axis=1, keepdims=True) else: # Skip d0, d[-2] reduce_over_dims = all_dims[1:-2] + all_dims[-1:] norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, reduce_over_dims, keepdims=True) norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta = theta.beta gamma = theta.gamma n = input_shape[0] t = input_shape[1] if p.cumulative else 1 norm_shape = [n, t, 1, num_groups, 1 ] if p.input_rank == 4 else [n, t, num_groups, 1] with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.cast(0., norm_variance.dtype)), py_utils.assert_shape_match(norm_shape, tf.shape(norm_mean)), py_utils.assert_shape_match(norm_shape, tf.shape(norm_variance)), ]): x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon) x = tf.reshape(x, input_shape) gn_output = x * gamma + beta gn_output = tf.reshape(gn_output, input_shape) if paddings is None: return gn_output else: return gn_output, paddings
def AssertIdShape(expected_ids_shape_pattern, ids_shape, *args): dependencies = [ py_utils.assert_shape_match(ids_shape, expected_ids_shape_pattern) ] + [py_utils.assert_shape_match(ids_shape, x_shape) for x_shape in args] return py_utils.with_dependencies(dependencies, ids_shape)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.transpose(paddings) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None): # pyformat: disable """Interpolates source ids in input_batch and interpolation_batch. Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060. It is a standard Transformer Encoder if interpolation_batch != None. Args: theta: A `.NestedMap` object containing weights values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. interpolation_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. - embs: Embeddings of ids. lambdas: A pair of tensors to combine embeddings of ids in input_batch and interpolation_batch. Returns: A NestedMap of - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ # pyformat: enable p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) max_seq_length = None if (not py_utils.use_tpu() and FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, :max_seq_length] src_segment_pos = input_batch.segment_pos[:, :max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup(theta.softmax, tf.reshape(input_ids, [-1])) if interpolation_batch is not None: other_input_ids = interpolation_batch.ids if not p.shared_emb: other_input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(other_input_ids, [-1])) else: other_input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(other_input_ids, [-1])) lambdas = [tf.expand_dims(a, -1) for a in lambdas] if 'embs' in input_batch and input_batch.embs is not None: input_embs = input_batch.embs if 'embs' in interpolation_batch and interpolation_batch.embs is not None: other_input_embs = interpolation_batch.embs else: input_embs = tf.reshape( input_embs, [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim]) other_input_embs = tf.reshape( other_input_embs, [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim]) input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs paddings = paddings + interpolation_batch.paddings - 1.0 paddings = tf.clip_by_value(paddings, 0.0, 1.0) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) orig_input_embs = input_embs if p.task_emb: if interpolation_batch is None: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) else: task_embs = self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) other_task_embs = self.task_emb.EmbLookup( theta.task_emb, interpolation_batch.task_ids) task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs input_embs += task_embs if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp(theta.position_emb, max_time) position_embs = tf.reshape(position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where( tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap( encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)