Exemple #1
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings)
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.zeros_like(norm_variance)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_variance)),
            ]):
                bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                      norm_variance, beta,
                                                      gamma, self._epsilon)
            bn_output *= 1.0 - paddings
            return bn_output
Exemple #2
0
 def _ComputeBN(self, inputs, paddings, gamma, beta, norm_mean,
                norm_variance):
     p = self.params
     with tf.control_dependencies([
             py_utils.assert_greater_equal(norm_variance,
                                           tf.zeros_like(norm_variance)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_mean)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_variance)),
     ]):
         if p.use_fused_batch_norm_for_eval and (self.do_eval
                                                 or p.freeze_bn_stats):
             bn_output, _, _ = nn.fused_batch_norm(inputs,
                                                   gamma,
                                                   beta,
                                                   norm_mean,
                                                   norm_variance,
                                                   self._epsilon,
                                                   is_training=False)
         else:
             bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                   norm_variance, beta,
                                                   gamma, self._epsilon)
         if p.set_padded_output_to_zero:
             bn_output = py_utils.ApplyPadding(paddings, bn_output)
     return bn_output
Exemple #3
0
  def FProp(self, theta, input_batch):
    p = self.params
    src_segment_id = None
    with tf.name_scope(p.name):
      inputs = py_utils.with_dependencies([
          py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]),
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings))
      ], tf.transpose(input_batch.ids))
      paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
      xs = self.emb.EmbLookup(theta.emb, inputs)
      xs = self.ApplyClipping(theta, xs)
      summary_utils.histogram('input_emb', xs)
      xs = self.dropout.FProp(theta.dropout, xs)
      ps = paddings
      # Now the rnn layers.
      outputs_list = []
      for i in range(0, p.num_lstm_layers):
        layer = self.rnn[i]
        ys, _ = layer.FProp(theta.rnn[i], xs, ps)
        ys = self.dropout.FProp(theta.dropout, ys)
        if i >= p.residual_start:
          xs += ys  # Residual skip
          xs = self.ApplyClipping(theta, xs)
        else:
          xs = ys
        outputs_list.append(xs)
        summary_utils.histogram('layer_out_%s' % i, xs)

      if p.is_transparent:
        xs = self.transparent_merger.FProp(theta.transparent_merger,
                                           outputs_list)

      return py_utils.NestedMap(
          encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
Exemple #4
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply group normalization.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor with shape [batch_size, height, width, channel].
      paddings: The paddings tensor with shape [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      A single tensor as the output after applying group normalization, with
      the same shape as 'inputs'. Or a output, output_paddings pair if input
      paddings is not None.
    """
        p = self.params
        n, h, w, c = tf.unstack(tf.shape(inputs), axis=0, num=4)
        group_size = p.dim // p.num_groups
        num_groups = p.num_groups
        min_group_size = p.min_group_size if p.dim > p.min_group_size else p.dim
        if group_size <= min_group_size:
            group_size = min_group_size
            num_groups = p.dim // group_size

        with tf.name_scope(p.name):
            x = tf.reshape(inputs, [n, h, w, num_groups, group_size])
            if paddings is None:
                counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics(
                    x, axes=[1, 2, 4], keepdims=True)
                norm_mean, norm_variance = tf.nn.normalize_moments(
                    counts, means_ss, variance_ss, None)
            else:
                expanded_paddings = tf.reshape(paddings, [n, h, 1, 1, 1])
                norm_mean, norm_variance = ComputeMomentsWithPadding(
                    x, expanded_paddings, [1, 2, 4], keepdims=True)

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta = theta.beta
            gamma = theta.gamma

            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.cast(0., norm_variance.dtype)),
                    py_utils.assert_shape_match([n, 1, 1, num_groups, 1],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([n, 1, 1, num_groups, 1],
                                                tf.shape(norm_variance)),
            ]):
                x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon)
                x = tf.reshape(x, [n, h, w, c])
                gn_output = x * gamma + beta
                gn_output = tf.reshape(gn_output, [n, h, w, c])
                if paddings is None:
                    return gn_output
                else:
                    return gn_output, paddings
Exemple #5
0
    def FProp(self, theta, input_batch):
        """Encodes source as represented by `inputs` and `paddings`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:
        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:

      - encoded: The encoded features, a tensor of shape [time, batch, depth]
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
    """
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Now the rnn layers.
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            self._emb_out = xs
            ps = paddings
            # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell
            # with the same cc_schedule so that the RNN layer output is within
            # clipping range.
            xs = self.rnn[0].FProp(theta.rnn[0], xs, ps)
            xs = self.dropout.FProp(theta.dropout, xs)
            for i in range(1, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, _ = layer.FProp(theta.rnn[i], xs, ps)
                ys = self.dropout.FProp(theta.dropout, ys)
                if hasattr(layer.params, 'cell'):
                    layer_params = layer.params.cell
                else:
                    layer_params = layer.params
                if layer_params.num_input_nodes == layer_params.num_output_nodes:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    # When cc_schedule is specified, make sure lstm_tpl is
                    # QuantizedLSTMCell with the same cc_schedule so that the RNN layer
                    # output is within clipping range.
                    xs = ys
            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Exemple #6
0
    def FProp(self, theta, inputs, paddings):
        """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
                py_utils.assert_shape_match(
                    tf.shape(inputs),
                    tf.concat([
                        tf.shape(paddings),
                        [-1, symbolic.ToStatic(self.input_channels)]
                    ], 0))
            ], inputs)

            def _ApplyPadding(tensor_in, padding_in):
                padding_expanded = tf.expand_dims(
                    tf.expand_dims(padding_in, -1), -1)
                return tensor_in * (1.0 - padding_expanded)

            # Zeroing out padded inputs.
            inputs = _ApplyPadding(inputs, paddings)

            # Apply conv on 'inputs'.
            out = self._ApplyConv(theta, inputs)

            if p.partial_conv:
                out = self._RescaleBoundary(out, paddings)
            # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
            # But there's likely no real problems. Trying to set it gives an error:
            # pooling with SAME padding is not implemented for dilation_rate > 1.
            # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
            # implementation.  Consider updating it to be the actual shape.
            conv_padding = ComputeConvOutputPadding(paddings,
                                                    window=p.filter_stride[0],
                                                    stride=p.filter_stride[0])
            # Assuming padded nodes will be properly zero-ed out if necessary by
            # sub-sequent layers.
            # out = _ApplyPadding(out, conv_padding)
            out = py_utils.HasShape(
                out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
            return out, conv_padding
Exemple #7
0
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self._ComputeInputs(theta, inputs, input_batch)
            summary_utils.histogram('input_emb', xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Exemple #8
0
  def FProp(self, theta, inputs):
    """Applies batch normalization.

    Using the implementation in github.com/
    tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550

    Args:
      theta: A nested map object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
    p = self.params
    inputs_dtype = inputs.dtype
    inputs = tf.cast(inputs, p.dtype)
    inputs = py_utils.with_dependencies(
        [py_utils.assert_shape_match([tf.shape(inputs)[-1]], [p.dim])], inputs)
    with tf.name_scope(p.name) as scope:
      if p.is_eval:
        outputs = tf.nn.batch_normalization(inputs, theta.moving_mean,
                                            theta.moving_variance,
                                            theta.beta, theta.gamma, p.epsilon)
      else:
        mean, variance = self._Moments(inputs, p.bn_group_size)
        mean = py_utils.CheckNumerics(
            mean, 'mean of {} failed numeric check'.format(scope))
        variance = py_utils.CheckNumerics(
            variance, 'variance of {} failed numeric check'.format(scope))
        outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta,
                                            theta.gamma, p.epsilon)
      outputs.set_shape(inputs.get_shape())
      return tf.cast(outputs, inputs_dtype)
    def StreamStep(self, theta, inputs, paddings, state0):
        """Apply a singele step of convolution to input_tensor.

    Only supports 1d causal convolution. Doesn't support dilation.

    Args:
      theta: A NestedMap of layer params.
      inputs: A Tensor of shape [b, t=1, 1, c]
      paddings: A 0/1 valued tensor of shape [b, t=1].
      state0: A NestedMap of tensors of the same struct as returned by
        zero_state().

    Returns:
      outputs: A Tensor of shape [b, t=1, 1, c * channel_multiplier]
      padding: the same as input paddings.
      state1: A NestedMap of the same struct as input state
    """
        p = self.params
        assert p.filter_shape[1] == 1, (
            'StreamStep only supports 1d causal convolution.')
        assert p.filter_stride[0] == 1, (
            'StreamStep doesn\'t support striding')
        assert p.dilation_rate == (1,
                                   1), ('StreamStep doesn\'t support dilation')

        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(py_utils.GetShape(inputs),
                                            [-1, 1, 1, p.filter_shape[2]])
            ], inputs)
            b = py_utils.GetShape(inputs)[0]

            # next state.
            state1 = py_utils.NestedMap(context=tf.concat(
                [state0.context[:, 1:, :, :], inputs], axis=1))

            expanded_paddings = tf.reshape(paddings, [b, 1, 1, 1])
            # Not updating the states for padded examples.
            state1.context = (state0.context * expanded_paddings +
                              state1.context * (1. - expanded_paddings))

            outputs = tf.nn.depthwise_conv2d(state1.context,
                                             self._GetWeight(self.theta),
                                             strides=(1, 1, 1, 1),
                                             dilations=(1, 1),
                                             data_format='NHWC',
                                             padding='VALID')
            return outputs, paddings, state1
Exemple #10
0
    def ComputeAndUpdateMoments(self, theta, inputs, paddings=None):
        """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
        ], inputs)
        with tf.name_scope(p.name):
            if self.do_eval:
                # The mean and variance used for normalization.
                norm_mean, norm_variance = self._moving_mean, self._moving_variance
            else:
                mean, variance = self._Moments(
                    inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu)

                py_utils.UpdateBatchNormVars(self._moving_mean, mean,
                                             self._decay)
                py_utils.UpdateBatchNormVars(self._moving_variance, variance,
                                             self._decay)
                # Add some summaries for visualization.
                summary_utils.histogram('%s_mean' % p.name,
                                        tf.cast(mean, tf.float32))
                summary_utils.histogram('%s_variance' % p.name,
                                        tf.cast(variance, tf.float32))
                summary_utils.histogram('%s_moving_mean' % p.name,
                                        tf.cast(self._moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_moving_variance' % p.name,
                    tf.cast(self._moving_variance, tf.float32))
                summary_utils.histogram(
                    '%s_mean_diff' % p.name,
                    tf.cast(mean - self._moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_variance_diff' % p.name,
                    tf.cast(variance - self._moving_variance, tf.float32))
                if p.use_moving_avg_in_training:
                    # Use the global statistics for normalization.
                    # Control dependencies on mean and variance make sure
                    # moving_mean and variance will be updated for every training step.
                    norm_mean = py_utils.with_dependencies([mean],
                                                           self._moving_mean)
                    norm_variance = py_utils.with_dependencies(
                        [variance], self._moving_variance)
                else:
                    # Use the batch statistics for normalization.
                    norm_mean = mean
                    norm_variance = variance

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            if p.use_moving_avg_in_training:
                beta = 0.0
                gamma = 1.0
            else:
                beta = theta.beta
                gamma = theta.gamma
            return norm_mean, norm_variance, beta, gamma
Exemple #11
0
    def FProp(self, theta, inputs, query_vec=None):
        """Combines the list of input tensors into a single tensor.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      inputs: A list of tensors of shape [..., hidden_dim] or
          [..., [pre_proj_input_dims[i]]] if pre_proj_input_dims is specified.
      query_vec: A tensor of shape [..., hidden_dim].
    Returns:
      A tensor of the same shape with input tensors.

    Raises:
      ValueError: p.merger_op is not defined.
    """
        p = self.params
        n_sources = len(inputs)

        if p.pre_proj_input_dims and len(p.pre_proj_input_dims) != n_sources:
            raise ValueError(
                'pre_proj_input_dims must be specified for each input.')

        if n_sources == 1:
            return inputs[0]

        # Pre-projection operation.
        if p.pre_proj_input_dims:
            for i in range(n_sources):
                inputs[i] = self.pre_proj[i].FProp(theta.pre_proj[i],
                                                   inputs[i])

        tensor_pairs = list(zip(inputs[:-1], inputs[1:]))
        if p.merger_op == 'mean':
            # Simply take the mean, all dims must match.
            with tf.control_dependencies([
                    py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2))
                    for t1, t2 in tensor_pairs
            ]):
                output = tf.add_n(inputs) / n_sources

        elif p.merger_op == 'sum':
            # Sum up all sources, all dims must match.
            with tf.control_dependencies([
                    py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2))
                    for t1, t2 in tensor_pairs
            ]):
                output = tf.add_n(inputs)

        elif p.merger_op == 'weighted_sum':
            # Weighted sum of all sources, all dims must match.
            # For weighted_sum, assume input is a list of rank 3 tensors
            inputs = tf.stack(inputs)
            inputs = py_utils.HasRank(inputs, 4)

            with tf.control_dependencies([
                    py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2))
                    for t1, t2 in tensor_pairs
            ]):
                w = tf.expand_dims(
                    tf.expand_dims(tf.expand_dims(self._sum_weight, 1), 1), 1)
                w = tf.tile(w, [
                    1,
                    tf.shape(inputs)[1],
                    tf.shape(inputs)[2],
                    tf.shape(inputs)[3]
                ])
                output = tf.reduce_sum(inputs * w, axis=0)

        elif p.merger_op == 'atten':
            # Apply attention over the concatenated tensor, all dims must match.
            with tf.control_dependencies([
                    py_utils.assert_shape_match(tf.shape(t1), tf.shape(t2))
                    for t1, t2 in tensor_pairs
            ]):
                inputs = tf.stack(inputs, axis=0)
                batch_size = tf.shape(inputs)[1]
                paddings = tf.zeros([n_sources, batch_size],
                                    dtype=inputs.dtype)
                self.atten.InitForSourcePacked(theta.atten, inputs, inputs,
                                               paddings)
                output, _, _ = self.atten.ComputeContextVector(
                    theta.atten, tf.reshape(query_vec, [-1, p.query_dim]))

        elif p.merger_op == 'concat':
            # Concatenate over the last dim, all dims but last must match.
            with tf.control_dependencies([
                    py_utils.assert_equal(
                        tf.shape(t1)[:-1],
                        tf.shape(t2)[:-1]) for t1, t2 in tensor_pairs
            ]):
                output = tf.concat(inputs, axis=-1)

        elif p.merger_op == 'gated_avg':
            output = self.gated_average.FProp(theta.gated_average, inputs)

        else:
            raise ValueError('Unrecognized merge op!')

        return output
Exemple #12
0
    def FProp(self,
              theta,
              query_vec,
              source_paddings,
              source_vecs=None,
              query_segment_id=None,
              source_segment_id=None):
        """Transformer attention, residual and normalization layer.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      query_vec: [target_time, target_batch, dim]
      source_paddings: [source_time, source_batch]
      source_vecs: [source_time, source_batch, dim].
      query_segment_id: [target_time, target_batch]
      source_segment_id: [source_time, source_batch]
    Returns:
      (output, atten_probs). output is of shape [target_time, target_batch,
      source_dim], atten_probs is of shape [target_time, target_batch,
      source_time].
    """
        p = self.params
        unnormalized_query_vec = query_vec
        query_vec = self.layer_norm.FProp(theta.layer_norm, query_vec)

        if source_vecs is None:
            source_vecs = query_vec
            source_segment_id = query_segment_id

        if p.is_masked:
            assert source_vecs is not None
            query_vec = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(source_vecs),
                                            tf.shape(query_vec))
            ], query_vec)
            # Prepares mask for self-attention
            # [time, time]
            target_time = tf.shape(query_vec)[0]
            target_bs = tf.shape(query_vec)[1]
            triangle_padding = 1.0 - tf.matrix_band_part(
                tf.ones([target_time, target_time],
                        dtype=py_utils.FPropDtype(p)), -1, 0)
            # [time,  batch, time]
            causal_padding = tf.tile(tf.expand_dims(triangle_padding, 1),
                                     [1, target_bs, 1])

            causal_padding = tf.reshape(causal_padding, [-1, target_time])
        else:
            causal_padding = None

        query_dim = tf.shape(query_vec)[-1]
        packed_src = self.atten.PackSource(theta.atten, source_vecs,
                                           source_vecs, source_paddings,
                                           source_segment_id)

        if query_segment_id is not None:
            query_segment_id = tf.reshape(query_segment_id, [-1])
        ctx_vec, atten_prob, _ = self.atten.ComputeContextVectorWithSource(
            theta.atten,
            packed_src,
            tf.reshape(query_vec, [-1, query_dim]),
            per_step_source_padding=causal_padding,
            query_segment_id=query_segment_id)
        ctx_vec = self.residual_dropout.FProp(theta.residual_dropout, ctx_vec)
        input_to_add = (unnormalized_query_vec
                        if p.add_unnormalized_input else query_vec)
        h = input_to_add + tf.reshape(ctx_vec, tf.shape(query_vec))
        atten_prob = tf.reshape(atten_prob, [
            tf.shape(query_vec)[0],
            tf.shape(query_vec)[1],
            tf.shape(source_vecs)[0]
        ])
        return h, atten_prob
Exemple #13
0
  def FProp(self, theta, input_batch):
    """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

    p = self.params
    with tf.name_scope(p.name):
      src_segment_id = None
      src_segment_pos = None
      input_ids = py_utils.with_dependencies([
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings)),
          py_utils.assert_equal(tf.rank(input_batch.ids), 2)
      ], input_batch.ids)

      if (not py_utils.use_tpu() and
          FLAGS.transformer_encoder_truncates_inputs):
        max_seq_length = tf.cast(
            tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)),
            tf.int32)
        paddings = py_utils.with_dependencies([
            py_utils.assert_equal(
                tf.constant(True, tf.bool),
                tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5))
        ], input_batch.paddings)
        input_ids = input_ids[:, :max_seq_length]
        paddings = paddings[:, :max_seq_length]
        if p.packed_input:
          src_segment_id = input_batch.segment_ids[:, :max_seq_length]
          src_segment_pos = input_batch.segment_pos[:, :max_seq_length]
      else:
        paddings = input_batch.paddings
        if p.packed_input:
          src_segment_id = input_batch.segment_ids
          src_segment_pos = input_batch.segment_pos

      max_time = tf.shape(input_ids)[1]

      # Input token embeddings + positional embeddings
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                              tf.reshape(input_ids, [-1]))
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax,
                                            tf.reshape(input_ids, [-1]))

      input_embs = tf.reshape(input_embs,
                              [-1, max_time, p.token_emb.embedding_dim])
      # [time, batch, dim]
      orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

      if p.packed_input:
        position_embs = self.position_emb.FPropWithPosition(
            theta.position_emb, src_segment_pos)
      else:
        position_embs = self.position_emb.FProp(theta.position_emb, max_time)
        position_embs = tf.reshape(position_embs,
                                   [1, max_time, p.token_emb.embedding_dim])
      # Position embeddings are simply added to token embeddings.
      input_embs += position_embs

      if p.individually_tagged_input:
        assert not p.packed_input
        # Look up tag embeddings; this assumes that the tags arriving on
        # input_batch.segment_ids (originating as common.source_segment_id
        # in the input NMTExample) have been reserved in the WPM vocabulary
        # as context tags, e.g. the ids for <src_token> and <ctxt_token> in
        # wide source context experiments.
        input_tags = py_utils.with_dependencies([
            py_utils.assert_shape_match(
                tf.shape(input_batch.segment_ids), tf.shape(input_batch.ids)),
            py_utils.assert_equal(tf.rank(input_batch.segment_ids), 2)
        ], input_batch.segment_ids)
        tag_embeddings = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_tags, [-1]))
        tag_embeddings = tf.reshape(tag_embeddings,
                                    [-1, max_time, p.token_emb.embedding_dim])
        # Concatenate the tag embeddings to the input embeddings, and then
        # project back to the original embedding dimensionality.
        concat_embs = tf.concat([input_embs, tag_embeddings], -1)
        input_embs = self.concat_emb_and_tag_proj.FProp(
            theta.concat_emb_and_tag_proj, concat_embs)

      if p.ln_input:
        input_embs = self.layer_norm_input.FProp(theta.layer_norm_input,
                                                 input_embs)

      if p.task_emb:
        input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                              input_batch.task_ids)

      summary_utils.histogram('input_embs', input_embs)
      if p.model_dim != p.token_emb.embedding_dim:
        input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)
        summary_utils.histogram('emb_proj', input_embs)

      paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
      if p.packed_input:
        src_segment_id = tf.transpose(src_segment_id)
      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)

      # [time, batch, dim]
      transformer_input = tf.transpose(input_embs, [1, 0, 2])

    if not self.do_eval and p.apply_source_mask:
      # Augment padding for masked source word positions.
      dtype = paddings.dtype
      source_mask = tf.where(
          tf.equal(input_ids, p.source_mask_id),
          tf.ones_like(input_ids, dtype=dtype),
          tf.zeros_like(input_ids, dtype=dtype))
      # Make sure padding is between 0 and 1.
      paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0,
                                  1.0)

    encoded, padding, segment_id = self.transformer_stack.FProp(
        theta.transformer_stack, transformer_input, paddings, src_segment_id)
    return py_utils.NestedMap(
        encoded=encoded,
        padding=padding,
        segment_id=segment_id,
        embedded_inputs=orig_input_embs)
Exemple #14
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply group normalization.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor with shape [batch_size, height, width, channel].
      paddings: The paddings tensor with shape [batch_size, height]. Intended to
        be used for sequence processing where `height` is `time`.

    Returns:
      A single tensor as the output after applying group normalization, with
      the same shape as 'inputs'. Or a output, output_paddings pair if input
      paddings is not None.
    """
        p = self.params
        inputs = py_utils.with_dependencies([
            py_utils.assert_greater_equal(py_utils.GetRank(inputs),
                                          p.input_rank)
        ], inputs)

        min_group_size = min(p.min_group_size, p.dim)
        group_size = max(p.dim // p.num_groups, min_group_size)
        num_groups = p.dim // group_size

        input_shape = py_utils.GetShape(inputs)
        with tf.name_scope(p.name):
            x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size])
            expanded_rank = p.input_rank + 1
            all_dims = list(range(expanded_rank))
            if paddings is None:
                # Skip d0, d[-2]
                axes = all_dims[1:-2] + all_dims[-1:]
                counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics(
                    x, axes=axes, keepdims=True)
                norm_mean, norm_variance = tf.nn.normalize_moments(
                    counts, means_ss, variance_ss, None)
            else:
                expanded_paddings = tf.reshape(
                    paddings, input_shape[:2] + [1] * (expanded_rank - 2))
                # skip the batching and group dim
                if p.cumulative:
                    # Skip d0, d1 and d[-2]
                    reduce_over_dims = all_dims[2:-2] + all_dims[-1:]
                    norm_mean, norm_variance = ComputeMomentsWithPadding(
                        x,
                        expanded_paddings,
                        reduce_over_dims=reduce_over_dims,
                        cumulative_axis=1,
                        keepdims=True)
                else:
                    # Skip d0, d[-2]
                    reduce_over_dims = all_dims[1:-2] + all_dims[-1:]
                    norm_mean, norm_variance = ComputeMomentsWithPadding(
                        x, expanded_paddings, reduce_over_dims, keepdims=True)

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            beta = theta.beta
            gamma = theta.gamma
            n = input_shape[0]
            t = input_shape[1] if p.cumulative else 1
            norm_shape = [n, t, 1, num_groups, 1
                          ] if p.input_rank == 4 else [n, t, num_groups, 1]
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.cast(0., norm_variance.dtype)),
                    py_utils.assert_shape_match(norm_shape,
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match(norm_shape,
                                                tf.shape(norm_variance)),
            ]):
                x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon)
                x = tf.reshape(x, input_shape)
                gn_output = x * gamma + beta
                gn_output = tf.reshape(gn_output, input_shape)
                if paddings is None:
                    return gn_output
                else:
                    return gn_output, paddings
Exemple #15
0
def AssertIdShape(expected_ids_shape_pattern, ids_shape, *args):
  dependencies = [
      py_utils.assert_shape_match(ids_shape, expected_ids_shape_pattern)
  ] + [py_utils.assert_shape_match(ids_shape, x_shape) for x_shape in args]
  return py_utils.with_dependencies(dependencies, ids_shape)
Exemple #16
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:
        - encoded: The encoded features, either a tensor of shape [time, batch,
            depth], or a list of tensors if is_transparent is set in
            transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
            (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
            positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_ids, [-1]))
            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.transpose(paddings)
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Exemple #17
0
  def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None):
    # pyformat: disable
    """Interpolates source ids in input_batch and interpolation_batch.

    Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060.
    It is a standard Transformer Encoder if interpolation_batch != None.

    Args:
      theta: A `.NestedMap` object containing weights values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
      interpolation_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
        - embs: Embeddings of ids.
      lambdas: A pair of tensors to combine embeddings of ids in input_batch and
        interpolation_batch.

    Returns:
      A NestedMap of

        - encoded: The encoded features, either a tensor of shape
          [time, batch, depth], or a list of tensors if is_transparent is set in
          transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
          (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
          positional encodings.
    """
    # pyformat: enable

    p = self.params
    with tf.name_scope(p.name):
      src_segment_id = None
      src_segment_pos = None
      input_ids = py_utils.with_dependencies([
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings)),
          py_utils.assert_equal(tf.rank(input_batch.ids), 2)
      ], input_batch.ids)

      max_seq_length = None
      if (not py_utils.use_tpu() and
          FLAGS.transformer_encoder_truncates_inputs):
        max_seq_length = tf.cast(
            tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)),
            tf.int32)
        paddings = py_utils.with_dependencies([
            py_utils.assert_equal(
                tf.constant(True, tf.bool),
                tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5))
        ], input_batch.paddings)
        input_ids = input_ids[:, :max_seq_length]
        paddings = paddings[:, :max_seq_length]
        if p.packed_input:
          src_segment_id = input_batch.segment_ids[:, :max_seq_length]
          src_segment_pos = input_batch.segment_pos[:, :max_seq_length]
      else:
        paddings = input_batch.paddings
        if p.packed_input:
          src_segment_id = input_batch.segment_ids
          src_segment_pos = input_batch.segment_pos

      max_time = tf.shape(input_ids)[1]

      # Input token embeddings + positional embeddings
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                              tf.reshape(input_ids, [-1]))
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax,
                                            tf.reshape(input_ids, [-1]))

      if interpolation_batch is not None:
        other_input_ids = interpolation_batch.ids
        if not p.shared_emb:
          other_input_embs = self.token_emb.EmbLookup(
              theta.token_emb, tf.reshape(other_input_ids, [-1]))
        else:
          other_input_embs = self.softmax.EmbLookup(
              theta.softmax, tf.reshape(other_input_ids, [-1]))
        lambdas = [tf.expand_dims(a, -1) for a in lambdas]
        if 'embs' in input_batch and input_batch.embs is not None:
          input_embs = input_batch.embs
        if 'embs' in interpolation_batch and interpolation_batch.embs is not None:
          other_input_embs = interpolation_batch.embs
        else:
          input_embs = tf.reshape(
              input_embs,
              [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim])
          other_input_embs = tf.reshape(
              other_input_embs,
              [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim])
        input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs
        paddings = paddings + interpolation_batch.paddings - 1.0
        paddings = tf.clip_by_value(paddings, 0.0, 1.0)

      input_embs = tf.reshape(input_embs,
                              [-1, max_time, p.token_emb.embedding_dim])

      orig_input_embs = input_embs
      if p.task_emb:
        if interpolation_batch is None:
          input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                input_batch.task_ids)
        else:
          task_embs = self.task_emb.EmbLookup(theta.task_emb,
                                              input_batch.task_ids)
          other_task_embs = self.task_emb.EmbLookup(
              theta.task_emb, interpolation_batch.task_ids)
          task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs
          input_embs += task_embs

      if p.packed_input:
        position_embs = self.position_emb.FPropWithPosition(
            theta.position_emb, src_segment_pos)
      else:
        position_embs = self.position_emb.FProp(theta.position_emb, max_time)
        position_embs = tf.reshape(position_embs,
                                   [1, max_time, p.token_emb.embedding_dim])
      input_embs += position_embs

      if p.model_dim != p.token_emb.embedding_dim:
        input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

      paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
      if p.packed_input:
        src_segment_id = tf.transpose(src_segment_id)

      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)

      # [time, batch, dim]
      transformer_input = tf.transpose(input_embs, [1, 0, 2])

    if not self.do_eval and p.apply_source_mask:
      # Augment padding for masked source word positions.
      dtype = paddings.dtype
      source_mask = tf.where(
          tf.equal(input_ids, p.source_mask_id),
          tf.ones_like(input_ids, dtype=dtype),
          tf.zeros_like(input_ids, dtype=dtype))
      # Make sure padding is between 0 and 1.
      paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0,
                                  1.0)

    encoded, padding, segment_id = self.transformer_stack.FProp(
        theta.transformer_stack, transformer_input, paddings, src_segment_id)

    return py_utils.NestedMap(
        encoded=encoded,
        padding=padding,
        segment_id=segment_id,
        embedded_inputs=orig_input_embs)