Ejemplo n.º 1
0
    def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
        """Computes mean and variance over the valid data points in inputs."""
        inputs = py_utils.with_dependencies([
            py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
            py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
        ], inputs)
        rank = tf.rank(mask)
        reduce_over_dims = tf.range(0, rank - 1)
        sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                              reduce_over_dims)
        count_v = tf.reduce_sum(mask, reduce_over_dims)
        # Input shape is guaranteed to be a multiple of mask shape because the
        # inputs * mask op above was successfully broadcasted.
        mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
        count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_v = tf.tpu.cross_replica_sum(sum_v)
            count_v = tf.tpu.cross_replica_sum(count_v)

        count_v = tf.maximum(count_v, 1.0)
        mean = sum_v / count_v
        sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                               reduce_over_dims)

        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_vv = tf.tpu.cross_replica_sum(sum_vv)

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
        ], sum_vv / count_v)
        return mean, variance
Ejemplo n.º 2
0
def SplitTensors(xs, num_splits):
  """Splits tensors in `xs` evenly into num_splits along the 1st dimenion.

  Args:
    xs: A tuple of tensors. Each tensor's 1st dimension is the same size.
    num_splits: A python integer.

  Returns:
    A tuple of lists of tensors, num elements in the tuple = len(xs).

    i-th element in each list corresponds to i-th split of each tensor in xs
    along the first dimension of each tensor.
  """
  # assert first dim of all tensors in xs is equal
  batch_dims = [tf.shape(x)[0] for x in xs]
  all_batch_dims = tf.stack(batch_dims)

  all_batch_dims = py_utils.with_dependencies([
      py_utils.assert_equal(
          all_batch_dims,
          tf.shape(xs[0])[0],
          message='first dim of tensors in xs must match'),
      py_utils.assert_greater_equal(
          tf.shape(xs[0])[0],
          num_splits,
          message='first dim of tensors in xs must be greater than num_splits')
  ], all_batch_dims)

  splits = ComputeSplits(tf.shape(xs[0])[0], num_splits)
  # add the above assertion into the compute graph
  splits = py_utils.with_dependencies([all_batch_dims], splits)
  split_xs = [tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs]

  return split_xs
Ejemplo n.º 3
0
 def _TruncateTargetSequence(self, targets):
   """Truncate padded time steps from all sequences."""
   # The following tensors are all in the [batch, time] shape.
   p = self.params
   # Let's make a copy of targets.
   targets = targets.Pack(targets.Flatten())
   target_ids = targets.ids
   target_labels = targets.labels
   target_weights = targets.weights
   target_paddings = targets.paddings
   max_seq_length = tf.to_int32(
       tf.reduce_max(tf.reduce_sum(1.0 - target_paddings, 1)))
   summary_utils.scalar('max_seq_length', max_seq_length)
   # Assert to make sure after max_seq_length, all are padded steps for all
   # sequences.
   target_paddings = py_utils.with_dependencies([
       py_utils.assert_equal(
           tf.constant(True, tf.bool),
           tf.reduce_all(target_paddings[:, max_seq_length:] > 0.5))
   ], target_paddings)
   target_ids = py_utils.with_dependencies([
       AssertIdShape(
           py_utils.GetShape(target_ids), py_utils.GetShape(target_labels),
           py_utils.GetShape(target_paddings),
           py_utils.GetShape(target_weights))
   ], target_ids)
   targets.ids = target_ids[:, :max_seq_length]
   targets.labels = target_labels[:, :max_seq_length]
   targets.weights = target_weights[:, :max_seq_length]
   targets.paddings = target_paddings[:, :max_seq_length]
   return targets
Ejemplo n.º 4
0
def ComputeMoments(inputs,
                   padding,
                   reduce_over_dims,
                   cumulative_axis=None,
                   enable_cross_replica_sum_on_tpu=False,
                   keepdims=False):
    """Computes mean and variance over the valid data points in inputs."""
    mask = 1.0 - padding
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims,
                          keepdims=keepdims)
    count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims)

    if cumulative_axis is not None:
        sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis)
        count_v = tf.math.cumsum(count_v, axis=cumulative_axis)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    input_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(inputs), reduce_over_dims))
    mask_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(mask), reduce_over_dims))
    mask_multiplier = tf.math.truediv(input_size_on_reduced_dims,
                                      mask_size_on_reduced_dims)
    count_v *= tf.cast(mask_multiplier, count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_v = tf.tpu.cross_replica_sum(sum_v)
        count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims,
                           keepdims=keepdims)
    if cumulative_axis is not None:
        sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
Ejemplo n.º 5
0
def MakeCausalPadding(seq_len,
                      block_size,
                      left_context,
                      right_context,
                      dtype=tf.float32):
    """Makes the causal padding tensor for a full sequence.

  Args:
    seq_len: int or scalar int tensor. Sequence length.
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.
    dtype: tf.dtype, default is tf.float32.

  Returns:
    A tensor of [num_blocks, block_size, context_size] taking values in {0, 1},
    where context_size = block_size + (left_context - 1) + right_context.
    Element b, i, j is zero if in the b-th block, the i-th frame can access
    the j-th frame in the context.
  """
    seq_len = py_utils.with_dependencies([
        py_utils.assert_greater_equal(
            seq_len, 1, message='seq_len must be at least 1')
    ], seq_len)

    num_blocks = (seq_len + block_size - 1) // block_size
    context_size = block_size + (left_context - 1) + right_context

    # [num_blocks, block_size]: source positions in the original sequence.
    src_positions = tf.reshape(tf.range(num_blocks * block_size),
                               [num_blocks, block_size])
    # [num_blocks,]: source positions at the start of each block.
    block_start_positions = tf.range(0, num_blocks * block_size, block_size)
    # [context_size]:  positions relative to the block start.
    relative_context_positions = tf.range(context_size) - (left_context - 1)

    # [num_blocks, context_size]: target positions in the original sequence.
    tgt_positions = (block_start_positions[:, tf.newaxis] +
                     relative_context_positions[tf.newaxis, :])
    # [num_blocks, block_size, context_size]: position differences between source-
    # target pairs.
    position_diff = src_positions[:, :,
                                  tf.newaxis] - tgt_positions[:, tf.newaxis, :]
    # [num_blocks, block_size, context_size]: if attention is allowed between
    # source-target pairs.
    valid_atten = tf.math.logical_and(-right_context <= position_diff,
                                      position_diff < left_context)

    # [num_blocks, block_size]: if the source position is valid, not padded.
    valid_src = src_positions < seq_len
    # [num_blocks, context_size]: if the target position is valid, not padded.
    valid_tgt = tf.math.logical_and(0 <= tgt_positions,
                                    tgt_positions < seq_len)

    valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis],
                                       valid_tgt[:, tf.newaxis, :])

    padding = 1.0 - tf.cast(valid_atten, dtype=dtype)

    return padding
Ejemplo n.º 6
0
  def FProp(self, theta, inputs):
    """Applies batch normalization.

    Using the implementation in github.com/
    tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550

    Args:
      theta: A nested map object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
    p = self.params
    inputs_dtype = inputs.dtype
    inputs = tf.cast(inputs, p.dtype)
    inputs = py_utils.with_dependencies(
        [py_utils.assert_shape_match([tf.shape(inputs)[-1]], [p.dim])], inputs)
    with tf.name_scope(p.name) as scope:
      if p.is_eval:
        outputs = tf.nn.batch_normalization(inputs, theta.moving_mean,
                                            theta.moving_variance,
                                            theta.beta, theta.gamma, p.epsilon)
      else:
        mean, variance = self._Moments(inputs, p.bn_group_size)
        mean = py_utils.CheckNumerics(
            mean, 'mean of {} failed numeric check'.format(scope))
        variance = py_utils.CheckNumerics(
            variance, 'variance of {} failed numeric check'.format(scope))
        outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta,
                                            theta.gamma, p.epsilon)
      outputs.set_shape(inputs.get_shape())
      return tf.cast(outputs, inputs_dtype)
Ejemplo n.º 7
0
    def FProp(self, theta, inputs, paddings, class_emb):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [batch, ..., dim].
      paddings: The paddings tensor.  Shaped [batch, ..., 1], with the same rank
        as the input tensor.
      class_emb: The conditioning inputs, Shaped [batch, emb_dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        p = self.params
        batch = py_utils.GetShape(inputs)[0]
        class_emb = py_utils.HasShape(class_emb, [batch, p.class_emb_dim])
        if not py_utils.use_tpu():
            class_emb = py_utils.with_dependencies([
                py_utils.assert_less_equal(
                    tf.cast(class_emb, tf.int32), 1, name='one_hot_assert1'),
                py_utils.assert_greater_equal(
                    tf.cast(class_emb, tf.int32), 0, name='one_hot_assert2'),
                py_utils.assert_equal(tf.ones([batch], tf.int32),
                                      tf.cast(tf.reduce_sum(class_emb, -1),
                                              tf.int32),
                                      name='one_hot_assert3'),
            ], class_emb)

        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings=paddings, class_emb=class_emb)
            return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean,
                                   norm_variance)
Ejemplo n.º 8
0
    def _Normalize(self, theta, grouped_inputs, group_mean, group_variance):
        p = self.params
        group_mean = py_utils.CheckNumerics(
            group_mean, f'mean of {p.name} failed numeric check.')
        group_variance = py_utils.CheckNumerics(
            group_variance, f'variance of {p.name} failed numeric check.')

        input_shape = py_utils.GetShape(grouped_inputs)
        moment_shape = list(input_shape)
        if p.input_rank == 4:
            moment_shape[2] = 1
            moment_shape[-1] = 1
        else:
            moment_shape[-1] = 1
        if not p.cumulative:
            # If not cumulative, the seqlen dimension is also reduced.
            moment_shape[1] = 1

        group_mean = py_utils.HasShape(group_mean, moment_shape)
        group_variance = py_utils.HasShape(group_variance, moment_shape)
        group_variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(group_variance,
                                          tf.cast(0, group_variance.dtype))
        ], group_variance)

        grouped_inputs = (grouped_inputs - group_mean
                          ) * tf.math.rsqrt(group_variance + self._epsilon)
        # Merges the last two dims.
        grouped_inputs = tf.reshape(grouped_inputs, input_shape[:-2] + [-1])

        # Note, The real gamma to use is 1 + gamma.
        outputs = grouped_inputs * (theta.gamma + 1) + theta.beta
        return outputs
Ejemplo n.º 9
0
  def FProp(self, theta, input_batch):
    p = self.params
    src_segment_id = None
    with tf.name_scope(p.name):
      inputs = py_utils.with_dependencies([
          py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]),
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings))
      ], tf.transpose(input_batch.ids))
      paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
      xs = self.emb.EmbLookup(theta.emb, inputs)
      xs = self.ApplyClipping(theta, xs)
      summary_utils.histogram('input_emb', xs)
      xs = self.dropout.FProp(theta.dropout, xs)
      ps = paddings
      # Now the rnn layers.
      outputs_list = []
      for i in range(0, p.num_lstm_layers):
        layer = self.rnn[i]
        ys, _ = layer.FProp(theta.rnn[i], xs, ps)
        ys = self.dropout.FProp(theta.dropout, ys)
        if i >= p.residual_start:
          xs += ys  # Residual skip
          xs = self.ApplyClipping(theta, xs)
        else:
          xs = ys
        outputs_list.append(xs)
        summary_utils.histogram('layer_out_%s' % i, xs)

      if p.is_transparent:
        xs = self.transparent_merger.FProp(theta.transparent_merger,
                                           outputs_list)

      return py_utils.NestedMap(
          encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
Ejemplo n.º 10
0
    def FProp(self, theta, input_batch):
        """Encodes source as represented by `inputs` and `paddings`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:
        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:

      - encoded: The encoded features, a tensor of shape [time, batch, depth]
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
    """
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Now the rnn layers.
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            self._emb_out = xs
            ps = paddings
            # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell
            # with the same cc_schedule so that the RNN layer output is within
            # clipping range.
            xs = self.rnn[0].FProp(theta.rnn[0], xs, ps)
            xs = self.dropout.FProp(theta.dropout, xs)
            for i in range(1, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, _ = layer.FProp(theta.rnn[i], xs, ps)
                ys = self.dropout.FProp(theta.dropout, ys)
                if hasattr(layer.params, 'cell'):
                    layer_params = layer.params.cell
                else:
                    layer_params = layer.params
                if layer_params.num_input_nodes == layer_params.num_output_nodes:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    # When cc_schedule is specified, make sure lstm_tpl is
                    # QuantizedLSTMCell with the same cc_schedule so that the RNN layer
                    # output is within clipping range.
                    xs = ys
            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Ejemplo n.º 11
0
 def IdsToStrings(self, ids, lens):
     """Takes integer matrices and returns vectors of strings."""
     ids = py_utils.with_dependencies(
         [py_utils.assert_same_dim0([ids, lens])], ids)
     return tf.map_fn(
         lambda inputs: self._wpm_encoder.Decode(inputs[0][:inputs[1]]),
         (ids, lens),
         dtype=tf.string,
         parallel_iterations=30,
         back_prop=False)
Ejemplo n.º 12
0
    def FProp(self, theta, inputs, paddings):
        """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
                py_utils.assert_shape_match(
                    tf.shape(inputs),
                    tf.concat([
                        tf.shape(paddings),
                        [-1, symbolic.ToStatic(self.input_channels)]
                    ], 0))
            ], inputs)

            def _ApplyPadding(tensor_in, padding_in):
                padding_expanded = tf.expand_dims(
                    tf.expand_dims(padding_in, -1), -1)
                return tensor_in * (1.0 - padding_expanded)

            # Zeroing out padded inputs.
            inputs = _ApplyPadding(inputs, paddings)

            # Apply conv on 'inputs'.
            out = self._ApplyConv(theta, inputs)

            if p.partial_conv:
                out = self._RescaleBoundary(out, paddings)
            # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
            # But there's likely no real problems. Trying to set it gives an error:
            # pooling with SAME padding is not implemented for dilation_rate > 1.
            # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
            # implementation.  Consider updating it to be the actual shape.
            conv_padding = ComputeConvOutputPadding(paddings,
                                                    window=p.filter_stride[0],
                                                    stride=p.filter_stride[0])
            # Assuming padded nodes will be properly zero-ed out if necessary by
            # sub-sequent layers.
            # out = _ApplyPadding(out, conv_padding)
            out = py_utils.HasShape(
                out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
            return out, conv_padding
Ejemplo n.º 13
0
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent
Ejemplo n.º 14
0
    def _MaybeExpandPaddings(self, inputs, paddings):
        # rank difference is at most one.
        rank_diff = tf.rank(inputs) - tf.rank(paddings)
        paddings = py_utils.with_dependencies([
            py_utils.assert_less_equal(rank_diff, 1),
            py_utils.assert_greater_equal(rank_diff, 0)
        ], paddings)

        # Pads [1] to the end of paddings.
        paddings = tf.reshape(
            paddings,
            tf.concat([tf.shape(paddings),
                       tf.tile([1], [rank_diff])], axis=0))
        return paddings
Ejemplo n.º 15
0
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self._ComputeInputs(theta, inputs, input_batch)
            summary_utils.histogram('input_emb', xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Ejemplo n.º 16
0
    def IdsToStrings(self, ids, lens):
        """Takes int32 token ids and returns approximate detokenized strings."""
        ids = py_utils.with_dependencies(
            [py_utils.assert_same_dim0([ids, lens])], ids)

        def _ProcessRow(inputs):
            length = inputs[1]
            ids = tf.reshape(inputs[0][:length], [1, -1])
            tokens = self._tokenizer.detokenize(ids)
            return tf.strings.reduce_join(tokens.flat_values, separator=' ')

        return tf.map_fn(_ProcessRow, (ids, lens),
                         dtype=tf.string,
                         parallel_iterations=30,
                         back_prop=False)
    def StreamStep(self, theta, inputs, paddings, state0):
        """Apply a singele step of convolution to input_tensor.

    Only supports 1d causal convolution. Doesn't support dilation.

    Args:
      theta: A NestedMap of layer params.
      inputs: A Tensor of shape [b, t=1, 1, c]
      paddings: A 0/1 valued tensor of shape [b, t=1].
      state0: A NestedMap of tensors of the same struct as returned by
        zero_state().

    Returns:
      outputs: A Tensor of shape [b, t=1, 1, c * channel_multiplier]
      padding: the same as input paddings.
      state1: A NestedMap of the same struct as input state
    """
        p = self.params
        assert p.filter_shape[1] == 1, (
            'StreamStep only supports 1d causal convolution.')
        assert p.filter_stride[0] == 1, (
            'StreamStep doesn\'t support striding')
        assert p.dilation_rate == (1,
                                   1), ('StreamStep doesn\'t support dilation')

        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(py_utils.GetShape(inputs),
                                            [-1, 1, 1, p.filter_shape[2]])
            ], inputs)
            b = py_utils.GetShape(inputs)[0]

            # next state.
            state1 = py_utils.NestedMap(context=tf.concat(
                [state0.context[:, 1:, :, :], inputs], axis=1))

            expanded_paddings = tf.reshape(paddings, [b, 1, 1, 1])
            # Not updating the states for padded examples.
            state1.context = (state0.context * expanded_paddings +
                              state1.context * (1. - expanded_paddings))

            outputs = tf.nn.depthwise_conv2d(state1.context,
                                             self._GetWeight(self.theta),
                                             strides=(1, 1, 1, 1),
                                             dilations=(1, 1),
                                             data_format='NHWC',
                                             padding='VALID')
            return outputs, paddings, state1
Ejemplo n.º 18
0
                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                                  axis=1))
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                                                                  1)))
                    consistent = tf.math.logical_and(states.consistent,
                                                     local_consistence)

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,
                                                  py_utils.FPropDtype(p))

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                             vocab_size,
                                             dtype=py_utils.FPropDtype(
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent
Ejemplo n.º 19
0
  def _check_paddings(self, paddings):
    with tf.name_scope('check_paddings'):
      unpacked_paddings = tf.unstack(paddings)

      non_decr = []
      for t in unpacked_paddings:
        non_d = tf.is_non_decreasing(t)
        non_decr.append(non_d)
      all_non_decr = tf.stack(non_decr)

      paddings = py_utils.with_dependencies([
          tf.assert_equal(
              tf.reduce_any(tf.equal(paddings, 0.0)),
              True,
              message='must have at least one zero value.'),
          tf.assert_equal(all_non_decr, True, message='must be non-decreasing')
      ], paddings)
      return paddings
Ejemplo n.º 20
0
    def _InferenceSubgraph_Default(self):
        with tf.name_scope('inference'):
            src_strings = tf.placeholder(tf.string, shape=[None])
            _, src_ids, src_paddings = self.input_generator.StringsToIds(
                src_strings, is_source=True)

            # Truncate paddings at the end.
            max_seq_length = tf.to_int32(
                tf.reduce_max(tf.reduce_sum(1.0 - src_paddings, 1)))
            src_paddings = py_utils.with_dependencies([
                py_utils.assert_equal(
                    tf.constant(True, tf.bool),
                    tf.reduce_all(src_paddings[:, max_seq_length:] > 0.5))
            ], src_paddings)
            src_ids = src_ids[:, :max_seq_length]
            src_paddings = src_paddings[:, :max_seq_length]

            src_input_map = py_utils.NestedMap(ids=src_ids,
                                               paddings=src_paddings)
            encoder_outputs = self.enc.FPropDefaultTheta(src_input_map)
            decoder_outs = self.dec.BeamSearchDecode(encoder_outputs)

            topk_hyps = decoder_outs.topk_hyps
            topk_ids = decoder_outs.topk_ids
            topk_lens = decoder_outs.topk_lens

            topk_decoded = self.input_generator.IdsToStrings(
                topk_ids, topk_lens - 1)
            topk_decoded = tf.reshape(topk_decoded, tf.shape(topk_hyps))

            feeds = py_utils.NestedMap({'src_strings': src_strings})
            fetches = py_utils.NestedMap({
                'src_ids': src_ids,
                'topk_decoded': topk_decoded,
                'topk_scores': decoder_outs.topk_scores,
                'topk_hyps': topk_hyps,
            })

            return fetches, feeds
Ejemplo n.º 21
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:
        - encoded: The encoded features, either a tensor of shape [time, batch,
            depth], or a list of tensors if is_transparent is set in
            transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
            (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
            positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_ids, [-1]))
            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.transpose(paddings)
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Ejemplo n.º 22
0
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
  """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
  source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
  value_dict = {}
  for output in beam_search_outputs:
    hyps_per_beam = py_utils.with_dependencies([
        py_utils.assert_equal(source_batch,
                              tf.shape(output.topk_hyps)[0]),
    ],
                                               tf.shape(output.topk_hyps)[1])
    for k, v in six.iteritems(output._asdict()):
      if v is None:
        continue
      if k == 'done_hyps':
        v = tf.transpose(v)
      if k not in value_dict:
        value_dict[k] = []
      value_dict[k].append(tf.reshape(v, [source_batch, hyps_per_beam, -1]))

  # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
  concatenated = {}
  for k, values in six.iteritems(value_dict):
    if len(values) != len(beam_search_outputs):
      raise ValueError('Incomplete values for %s: %s' %
                       (k, beam_search_outputs))
    concatenated[k] = tf.concat(values, axis=1)

  scores = concatenated['topk_scores']
  scores = tf.where(
      tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6),
      scores)
  scores = tf.squeeze(scores, -1)

  # Select top max_hyps_per_beam indices per beam.
  _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
  batch_ids = tf.tile(
      tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam])
  # [source_batch, max_hyps_per_beam, 2]
  gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

  # Gather the merged top hyps according to 'gather_indices'.
  top = beam_search_outputs[0]._asdict()
  total_hyps = source_batch * max_hyps_per_beam
  for k, v in six.iteritems(concatenated):
    v = tf.gather_nd(v, gather_indices)
    if k == 'done_hyps':
      v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
    elif k == 'topk_hyps':
      v = tf.reshape(v, [source_batch, max_hyps_per_beam])
    elif k == 'topk_ids':
      v = tf.reshape(v, [total_hyps, -1])
    elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
      v = tf.reshape(v, [total_hyps])
    else:
      raise ValueError('Unexpected field: %s' % k)
    top[k] = v
  return BeamSearchDecodeOutput(**top)
Ejemplo n.º 23
0
def FarthestPointSampler(points,
                         padding,
                         num_sampled_points,
                         precomputed_squared_distance=None,
                         num_seeded_points=0,
                         random_seed=None):
    """Samples num_sampled_points from points using farthest point sampling.

  Algorithm:
  1. Start by selecting a random point and adding to a selected set.
  2. For all remaining points, find the furthest point from those selected.
  3. Add furthest point to selected.
  4. Repeat 2-3 until num_sampled_points are selected.

  More details at https://en.wikipedia.org/wiki/Farthest-first_traversal

  This output of this function can be used with tf.batch_gather to extract the
  desired points, for example: tf.batch_gather(points, sampled_idx)

  Args:
    points: floating point tf.Tensor of shape [N, P1, dims]
    padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is
      real, and 1 otherwise.
    num_sampled_points: integer number of points to sample.
    precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of
      distances between each point. if None, distances will be computed on the
      fly.
    num_seeded_points: If num_seeded_points > 0, then the first
      num_seeded_points in points are considered to be seeded in the FPS
      sampling. Note that we assume that these points are *not* padded, and do
      not check padding when seeding them.
    random_seed: optional integer random seed to use with all the random ops.

  Returns:
    A tuple of tf.Tensors (sampled_idx, closest_idx) of types
    (tf.int32, tf.int32).

    sampled_idx is of shape [N, num_sampled_points] representing the indices
    selected using the sampler. This will have range of [0, P1].

    closest_idx is of shape [N, P1] representing the indices of the closest
    sampled points for each input point. closest_idx is used in PCNN as part of
    the pooling operation: each point is assigned to the closest sampled point
    and a max is taken over them. This will have a range of [0, P2] with the
    index of the closest sampled point that remains.
  """
    points = py_utils.HasRank(points, 3)
    batch_size, num_points, dims = py_utils.GetShape(points, 3)

    points = py_utils.with_dependencies(
        [py_utils.assert_greater_equal(num_points, num_sampled_points)],
        points)

    # Add a tiny bit of noise to the distance matrix or points so all
    # points are unique. This will also ensure true repeated points
    # like padded points are only selected after all valid points are selected.
    if precomputed_squared_distance is not None:
        precomputed_squared_distance = py_utils.HasShape(
            precomputed_squared_distance, [batch_size, num_points, num_points])
        precomputed_squared_distance += tf.random.uniform(
            (batch_size, num_points, 1),
            minval=1e-6,
            maxval=1e-5,
            dtype=tf.float32,
            seed=random_seed)
    else:
        points += tf.random.uniform((batch_size, num_points, dims),
                                    minval=1e-6,
                                    maxval=1e-5,
                                    dtype=tf.float32,
                                    seed=random_seed)

    # TensorArray to store the sampled indices in the loop.
    sampled_idx = tf.TensorArray(tf.int32, num_sampled_points)

    # Initialize distance_to_selected to inf for all points.
    distance_to_selected = float('inf') * tf.ones((batch_size, num_points))

    # For tracking the index to the closest selected point.
    closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32)

    # Current loop index counter.
    curr_idx = tf.constant(0, dtype=tf.int32)

    # Get number of valid points (1 is padded, so num_points - num_padded).
    num_valid_points = tf.cast(tf.cast(num_points, dtype=tf.float32) -
                               tf.reduce_sum(padding, axis=1),
                               dtype=tf.int32)

    def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx):
        """Loop body for farthest point sampler."""
        def _GetRandomRealPoint():
            """Select the first point.

      For the first point, we want any random real (non padded) point, so we
      create a random values per point, and then set all padded ones to
      some large value (more than the maxval). We then take the min per batch
      element to get the first points.

      Returns:
        Tensor containing the index of a random point selected for each example
        in the batch.
      """
            random_values = tf.random.uniform((batch_size, num_points),
                                              minval=0,
                                              maxval=1,
                                              dtype=tf.float32,
                                              seed=random_seed)
            random_values = tf.where(tf.equal(padding, 0.0), random_values,
                                     padding * 10)
            return tf.argmin(random_values, axis=1, output_type=tf.int32)

        def _GetFurthestPoint():
            """Get point that is furthest from those already selected.

      We also bias the sampling towards real points by setting the distance
      to padded points negative until we are out of real points.

      Returns:
        Tensor containing the index of the next farthest point selected for each
        example in the batch.
      """
            # Set padded points distance to negative so they aren't selected.
            padding_masked_distance_to_selected = tf.where(
                tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones(
                    (batch_size, num_points), dtype=tf.float32))
            # But only do this when we still have valid points left.
            padding_masked_distance_to_selected = tf.where(
                tf.less(curr_idx, num_valid_points),
                padding_masked_distance_to_selected, distance_to_selected)
            return tf.argmax(padding_masked_distance_to_selected,
                             axis=-1,
                             output_type=tf.int32)

        def _GetSeededPoint():
            """Select a seeded point.

      Seeded points are assumed to be at the beginning of the original points.

      Returns:
        Tensor containing the index of the next seeded point to select for each
        example in the batch.
      """
            return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx

        # Select indices for this loop iteration.
        def _Seeded():
            return tf.cond(tf.less(curr_idx, num_seeded_points),
                           _GetSeededPoint, _GetFurthestPoint)

        def _Real():
            return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint,
                           _GetFurthestPoint)

        new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded,
                               _Real)
        sampled_idx = sampled_idx.write(curr_idx, new_selected)

        # Extract the distance to the latest point selected to update
        # distance_to_selected.
        new_selected_gather_idx = tf.stack(
            [tf.range(batch_size), new_selected], axis=1)
        if precomputed_squared_distance is not None:
            new_distance = tf.gather_nd(precomputed_squared_distance,
                                        new_selected_gather_idx)
        else:
            new_points = tf.reshape(
                tf.gather_nd(points, new_selected_gather_idx),
                [batch_size, 1, dims])
            new_distance = tf.reshape(
                SquaredDistanceMatrix(points, new_points),
                [batch_size, num_points])

        is_newly_closest = tf.less(new_distance, distance_to_selected)
        distance_to_selected = tf.minimum(distance_to_selected, new_distance)

        # Track the index to the closest selected point.
        new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points])
        closest_idx = tf.cond(
            tf.equal(curr_idx, 0),
            # At the first loop iteration, the init points are the closest.
            lambda: new_selected_tiled,
            # Otherwise, update with the new points based on the distances.
            lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx)
        )
        return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx

    _, _, sampled_idx, closest_idx = tf.while_loop(
        lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points),
        _BodyFn,
        loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx),
        back_prop=False,
        maximum_iterations=num_sampled_points)

    sampled_idx = sampled_idx.stack()  # num_sampled_points x n
    sampled_idx = tf.transpose(sampled_idx, [1, 0])

    if isinstance(batch_size, int) and isinstance(num_sampled_points, int):
        sampled_idx.set_shape((batch_size, num_sampled_points))

    return sampled_idx, closest_idx
Ejemplo n.º 24
0
    def ComputeAndUpdateMoments(self, theta, inputs, paddings=None):
        """Computes moments and updates state.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        inputs = py_utils.with_dependencies([
            py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]),
        ], inputs)
        with tf.name_scope(p.name):
            if self.do_eval:
                # The mean and variance used for normalization.
                norm_mean, norm_variance = self._moving_mean, self._moving_variance
            else:
                mean, variance = self._Moments(
                    inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu)

                py_utils.UpdateBatchNormVars(self._moving_mean, mean,
                                             self._decay)
                py_utils.UpdateBatchNormVars(self._moving_variance, variance,
                                             self._decay)
                # Add some summaries for visualization.
                summary_utils.histogram('%s_mean' % p.name,
                                        tf.cast(mean, tf.float32))
                summary_utils.histogram('%s_variance' % p.name,
                                        tf.cast(variance, tf.float32))
                summary_utils.histogram('%s_moving_mean' % p.name,
                                        tf.cast(self._moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_moving_variance' % p.name,
                    tf.cast(self._moving_variance, tf.float32))
                summary_utils.histogram(
                    '%s_mean_diff' % p.name,
                    tf.cast(mean - self._moving_mean, tf.float32))
                summary_utils.histogram(
                    '%s_variance_diff' % p.name,
                    tf.cast(variance - self._moving_variance, tf.float32))
                if p.use_moving_avg_in_training:
                    # Use the global statistics for normalization.
                    # Control dependencies on mean and variance make sure
                    # moving_mean and variance will be updated for every training step.
                    norm_mean = py_utils.with_dependencies([mean],
                                                           self._moving_mean)
                    norm_variance = py_utils.with_dependencies(
                        [variance], self._moving_variance)
                else:
                    # Use the batch statistics for normalization.
                    norm_mean = mean
                    norm_variance = variance

            norm_mean = py_utils.CheckNumerics(
                norm_mean, 'mean of %s failed numeric check' % p.name)
            norm_variance = py_utils.CheckNumerics(
                norm_variance, 'variance of %s failed numeric check' % p.name)

            if p.use_moving_avg_in_training:
                beta = 0.0
                gamma = 1.0
            else:
                beta = theta.beta
                gamma = theta.gamma
            return norm_mean, norm_variance, beta, gamma
Ejemplo n.º 25
0
    def FProp(self,
              theta,
              query_vec,
              source_paddings,
              source_vecs=None,
              query_segment_id=None,
              source_segment_id=None):
        """Transformer attention, residual and normalization layer.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      query_vec: [target_time, target_batch, dim]
      source_paddings: [source_time, source_batch]
      source_vecs: [source_time, source_batch, dim].
      query_segment_id: [target_time, target_batch]
      source_segment_id: [source_time, source_batch]
    Returns:
      (output, atten_probs). output is of shape [target_time, target_batch,
      source_dim], atten_probs is of shape [target_time, target_batch,
      source_time].
    """
        p = self.params
        unnormalized_query_vec = query_vec
        query_vec = self.layer_norm.FProp(theta.layer_norm, query_vec)

        if source_vecs is None:
            source_vecs = query_vec
            source_segment_id = query_segment_id

        if p.is_masked:
            assert source_vecs is not None
            query_vec = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(source_vecs),
                                            tf.shape(query_vec))
            ], query_vec)
            # Prepares mask for self-attention
            # [time, time]
            target_time = tf.shape(query_vec)[0]
            target_bs = tf.shape(query_vec)[1]
            triangle_padding = 1.0 - tf.matrix_band_part(
                tf.ones([target_time, target_time],
                        dtype=py_utils.FPropDtype(p)), -1, 0)
            # [time,  batch, time]
            causal_padding = tf.tile(tf.expand_dims(triangle_padding, 1),
                                     [1, target_bs, 1])

            causal_padding = tf.reshape(causal_padding, [-1, target_time])
        else:
            causal_padding = None

        query_dim = tf.shape(query_vec)[-1]
        packed_src = self.atten.PackSource(theta.atten, source_vecs,
                                           source_vecs, source_paddings,
                                           source_segment_id)

        if query_segment_id is not None:
            query_segment_id = tf.reshape(query_segment_id, [-1])
        ctx_vec, atten_prob, _ = self.atten.ComputeContextVectorWithSource(
            theta.atten,
            packed_src,
            tf.reshape(query_vec, [-1, query_dim]),
            per_step_source_padding=causal_padding,
            query_segment_id=query_segment_id)
        ctx_vec = self.residual_dropout.FProp(theta.residual_dropout, ctx_vec)
        input_to_add = (unnormalized_query_vec
                        if p.add_unnormalized_input else query_vec)
        h = input_to_add + tf.reshape(ctx_vec, tf.shape(query_vec))
        atten_prob = tf.reshape(atten_prob, [
            tf.shape(query_vec)[0],
            tf.shape(query_vec)[1],
            tf.shape(source_vecs)[0]
        ])
        return h, atten_prob
Ejemplo n.º 26
0
  def FProp(self, theta, x, paddings=None, update=False):
    """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
    p = self.params
    if paddings is None:
      paddings = tf.zeros_like(x[:, :, 0, 0])
    # Shape [B, L, 1, 1]
    paddings_4d = paddings[:, :, None, None]

    if p.apply_layer_norm:
      x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

    # 'x' is normalized (but theta.means is not), we use negative dot product to
    # approximate the Euclidean distance here.
    dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means)

    # For padded positions we update the distances to very large numbers.
    very_large_dists = tf.ones_like(dists) * tf.constant(
        0.1, dtype=dists.dtype) * dists.dtype.max
    paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
    dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

    # Shape [B, L, N, K], the same as 'dists' above.
    nearest_one_hot = tf.one_hot(
        tf.math.argmin(dists, axis=-1),
        p.num_clusters,
        dtype=py_utils.FPropDtype(p))
    # Same shape as the input 'x'.
    nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                 theta.means)
    diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid))
    diff = py_utils.ApplyPadding(paddings_4d, diff)
    diff = tf.math.reduce_mean(diff, axis=2)

    # The commitment loss which when back proped against encourages the 'x'
    # values to commit to their chosen centroids.
    k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings)
    summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

    # TODO(zhouwk): investigate normalizing theta.means after each update.
    means_norm = tf.norm(theta.means)
    summary_utils.scalar('k_means/centroid_l2_norm/min',
                         tf.math.reduce_min(means_norm))
    summary_utils.scalar('k_means/centroid_l2_norm/mean',
                         tf.math.reduce_mean(means_norm))

    if not update:
      return dists, k_means_loss

    # To update the centroids (self.vars.means), we apply gradient descent on
    # the mini-batch of input 'x', which yields the following:
    #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
    # where x_mean is the average over all the input vectors closest to this
    # centroid.
    #
    # Note that this approach is equivalent with backprop via
    #    loss = tf.math.reduce_mean(
    #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
    # , except that here the learning rate is independently set via 'decay'.

    # Ensure that the padded positions are not used to update the centroids.
    nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

    # Sum away batch and sequence length dimensions to get per cluster count.
    # Shape: [N, K]
    per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
    summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count)

    # Sum of the input 'x' per each closest centroid.
    sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

    if py_utils.use_tpu():
      per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
      sum_x = tf.tpu.cross_replica_sum(sum_x)

    # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
    # cluster's position will always be 0, hence 'sum_x' in that dimension will
    # be 0.
    new_means = sum_x / tf.maximum(
        tf.constant(1.0, dtype=per_cluster_count.dtype),
        tf.expand_dims(per_cluster_count, axis=-1))

    # We use exponential moving average. TODO(zhouwk): investigate smooth this
    # over an exponentially moving averaged per cluster count.
    #
    # Note that we intentionally do not normalize the means after this update
    # as empirically this works better.
    update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means),
                                self.vars.means.dtype)
    return py_utils.with_dependencies(
        [tf.assign_add(self.vars.means, update_means_diff)],
        dists), k_means_loss
Ejemplo n.º 27
0
    def _StreamMoments(self, inputs, paddings, cached_sum, cached_count,
                       cached_var):
        """Computes mean and variance over the valid data points in inputs.

    Args:
      inputs: [B, T, F, N, G] or [B, T, N, G]
      paddings: [B, T, 1, 1, 1] or [B, T, 1, 1]
      cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1]
      cached_count: same shape as cached_sum.
      cached_var: same shape as cached_sum.

    Returns:
      mean: [B, T, 1, N, 1] or [B, T, N, 1]
      variance: same shape as mean.
      new_cached_sum: same shape as cached_sum.
      new_cached_count: same shape as cached_count.
    """
        tf.logging.vlog(1, 'inputs: %r', inputs)
        tf.logging.vlog(1, 'paddings: %r', paddings)
        tf.logging.vlog(1, 'cached_sum: %r', cached_sum)
        tf.logging.vlog(1, 'cached_count: %r', cached_count)

        inputs = py_utils.ApplyPadding(paddings, inputs, use_select=False)

        input_rank = py_utils.GetRank(inputs)
        assert input_rank is not None, (f'inputs rank must be staic for '
                                        f'{repr(inputs)}')
        reduce_over_dims = list(range(input_rank))
        # Skip B, T, and N. Reduce {F,G} or just G.
        reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:]
        tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims)

        # [B, T, 1, N, 1] or [B, T, N, 1]
        sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True)
        sum_v = tf.math.cumsum(sum_v, axis=1)
        sum_v += cached_sum

        # [B, T, 1, 1, 1] or [B, T, 1, 1]
        mask = tf.cast(1.0 - paddings, inputs.dtype)
        count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True)
        count_v = tf.math.cumsum(count_v, axis=1)
        input_shape = py_utils.GetShape(inputs)
        if input_rank == 4:
            # F * G
            multiplier = input_shape[-1] * input_shape[-3]
        else:
            # G
            multiplier = input_shape[-1]
        count_v *= multiplier
        count_v += cached_count

        tf.logging.vlog(1, 'sum_v: %r', sum_v)
        tf.logging.vlog(1, 'count_v: %r', count_v)

        mean = sum_v / tf.maximum(count_v, 1.0)

        sum_vv = tf.reduce_sum(py_utils.ApplyPadding(
            paddings,
            tf.math.squared_difference(inputs, mean),
            use_select=False),
                               reduce_over_dims,
                               keepdims=True)
        sum_vv = tf.math.cumsum(sum_vv, axis=1)
        sum_vv += cached_var

        cached_sum = sum_v[:, -1:]
        cached_count = count_v[:, -1:]
        cached_var = sum_vv[:, -1:]

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)),
        ], sum_vv / tf.maximum(count_v, 1.0))
        return mean, variance, cached_sum, cached_count, cached_var
Ejemplo n.º 28
0
    def _StreamMoments(self, inputs, paddings, cached_sum, cached_count,
                       cached_var):
        """Computes mean and variance over the valid data points in inputs.

    Args:
      inputs: [B, T, F, N, G] or [B, T, N, G]
      paddings: [B, T, 1, 1, 1] or [B, T, 1, 1]
      cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1]
      cached_count: same shape as cached_sum.
      cached_var: same shape as cached_sum.

    Returns:
      mean: [B, T, 1, N, 1] or [B, T, N, 1]
      variance: same shape as mean.
      new_cached_sum: same shape as cached_sum.
      new_cached_count: same shape as cached_count.
    """
        tf.logging.vlog(1, 'inputs: %r', inputs)
        tf.logging.vlog(1, 'paddings: %r', paddings)
        tf.logging.vlog(1, 'cached_sum: %r', cached_sum)
        tf.logging.vlog(1, 'cached_count: %r', cached_count)

        mask = tf.cast(1.0 - paddings, inputs.dtype)
        inputs *= tf.cast(mask, inputs.dtype)

        input_rank = py_utils.GetRank(inputs)
        assert input_rank is not None, (f'inputs rank must be staic for '
                                        f'{repr(inputs)}')
        reduce_over_dims = list(range(input_rank))
        # Skip B, T, and N. Reduce {F,G} or just G.
        reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:]
        tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims)

        # [B, T, 1, N, 1] or [B, T, N, 1]
        sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True)
        sum_v = tf.math.cumsum(sum_v, axis=1)
        sum_v += cached_sum

        # [B, T, 1, 1, 1] or [B, T, 1, 1]
        count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True)
        count_v = tf.math.cumsum(count_v, axis=1)
        input_shape = py_utils.GetShape(inputs)
        if input_rank == 4:
            # F * G
            multiplier = input_shape[-1] * input_shape[-3]
        else:
            # G
            multiplier = input_shape[-1]
        count_v *= multiplier
        count_v += cached_count
        count_v = tf.maximum(count_v, 1.0)

        tf.logging.vlog(1, 'sum_v: %r', sum_v)
        tf.logging.vlog(1, 'count_v: %r', count_v)

        mean = sum_v / count_v
        if py_utils.FLAGS.tflite_compatible:
            # TfLite doesn't support broadcasting with 5D tensors.
            inputs_shape = py_utils.GetShape(inputs)
            if len(inputs_shape) == 4:
                tiled_mean = tf.tile(mean, [1, 1, 1, inputs_shape[3]])
            else:
                tiled_mean = tf.tile(
                    mean, [1, 1, inputs_shape[2], 1, inputs_shape[4]])
            sum_vv = tf.reduce_sum(tf.math.square(inputs - tiled_mean) * mask,
                                   reduce_over_dims,
                                   keepdims=True)
        else:
            sum_vv = tf.reduce_sum((inputs - mean)**2 * mask,
                                   reduce_over_dims,
                                   keepdims=True)
        sum_vv = tf.math.cumsum(sum_vv, axis=1)
        sum_vv += cached_var

        cached_sum = sum_v[:, -1:]
        cached_count = count_v[:, -1:]
        cached_var = sum_vv[:, -1:]

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)),
        ], sum_vv / count_v)
        return mean, variance, cached_sum, cached_count, cached_var
Ejemplo n.º 29
0
    def LocalizationResiduals(self, anchor_bboxes, assigned_gt_bboxes):
        """Computes the anchor residuals for every bbox.

    For a given bbox, compute residuals in the following way:

      Let ``anchor_bbox = (x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a)``
      and ``assigned_gt_bbox = (x_gt, y_gt, z_gt, dx_gt, dy_gt, dz_gt, phi_gt)``

      Define ``diagonal_xy = sqrt(dx_a^2 + dy_a^2)``

      Then the corresponding residuals are given by::

        x_residual = (x_gt - x_a) / (diagonal_xy)
        y_residual = (y_gt - y_a) / (diagonal_xy)
        z_residual = (z_gt - z_a) / (dz_a)

        dx_residual = log(dx_gt / dx_a)
        dy_residual = log(dy_gt / dy_a)
        dz_residual = log(dz_gt / dz_a)

        phi_residual = phi_gt - phi_a

      The normalization for x and y residuals by the diagonal was first
      proposed by [1]. Intuitively, this reflects that objects can usually
      move freely in the x-y plane, including diagonally. On the other hand,
      moving in the z-axis (up and down) can be considered orthogonal to x-y.

      For phi_residual, one way to frame the loss is with
      SmoothL1(sine(phi_residual - phi_predicted)).
      The use of sine to wrap the phi residual was proposed by [2]. This
      stems from the observation that bboxes at phi and phi + pi are the same
      bbox, fully overlapping in 3D space, except that the direction is
      different. Note that the use of sine makes this residual invariant to
      direction when a symmetric loss like SmoothL1 is used. In
      ResidualsToBBoxes, we ensure that the phi predicted is between [0, pi).

    The Huber (SmoothL1) loss can then be applied to the delta between these
    target residuals and the model predicted residuals.

    [1] VoxelNet: End-to-End Learning for Point Cloud Based 3D Object Detection
        https://arxiv.org/abs/1711.06396

    [2] SECOND: Sparsely Embedded Convolutional Detection
        https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf

    Args:
      anchor_bboxes: tf.float32. where [..., :7] contains (x, y, z, dx, dy, dz,
        phi), corresponding to each anchor bbox parameters.
      assigned_gt_bboxes: tf.float32 of the same shape as anchor_bboxes
        containing the corresponding assigned ground-truth bboxes.

    Returns:
      A tf.float32 tensor of the same shape as anchor_bboxes with target
      residuals for every corresponding bbox.
    """
        anchor_bboxes_shape = py_utils.GetShape(anchor_bboxes)
        anchor_bboxes = py_utils.with_dependencies(
            [py_utils.assert_equal(anchor_bboxes_shape[-1], 7)], anchor_bboxes)
        assigned_gt_bboxes = py_utils.HasShape(assigned_gt_bboxes,
                                               anchor_bboxes_shape)

        x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a = tf.unstack(anchor_bboxes,
                                                            num=7,
                                                            axis=-1)
        x_gt, y_gt, z_gt, dx_gt, dy_gt, dz_gt, phi_gt = tf.unstack(
            assigned_gt_bboxes, num=7, axis=-1)

        diagonal_xy = tf.sqrt(tf.square(dx_a) + tf.square(dy_a))

        # The anchor dimensions is usually a hard-coded param given to the input
        # generator and should not be 0. We use CheckNumerics to ensure that is the
        # case.
        x_residual = py_utils.CheckNumerics((x_gt - x_a) / diagonal_xy)
        y_residual = py_utils.CheckNumerics((y_gt - y_a) / diagonal_xy)
        z_residual = py_utils.CheckNumerics((z_gt - z_a) / dz_a)

        dx_residual = py_utils.CheckNumerics(tf.log(dx_gt / dx_a))
        dy_residual = py_utils.CheckNumerics(tf.log(dy_gt / dy_a))
        dz_residual = py_utils.CheckNumerics(tf.log(dz_gt / dz_a))

        phi_residual = phi_gt - phi_a

        return tf.stack([
            x_residual,
            y_residual,
            z_residual,
            dx_residual,
            dy_residual,
            dz_residual,
            phi_residual,
        ],
                        axis=-1)  # pyformat: disable
Ejemplo n.º 30
0
    def ResidualsToBBoxes(self,
                          anchor_bboxes,
                          residuals,
                          min_angle_rad=-np.pi,
                          max_angle_rad=np.pi):
        r"""Converts anchor_boxes and residuals to predicted bboxes.

    This converts predicted residuals into bboxes using the following formulae::

      x_predicted = x_a + x_residual * diagonal_xy
      y_predicted = y_a + y_residual * diagonal_xy
      z_predicted = z_a + z_residual * dz_a

      dx_predicted = dx_a * exp(dx_residual)
      dy_predicted = dy_a * exp(dy_residual)
      dz_predicted = dz_a * exp(dz_residual)

      # Adding the residual, and bounding it between
      # [min_angle_rad, max_angle_rad]
      phi_predicted = NormalizeAngleRad(phi_a + phi_residual,
                                        min_angle_rad, max_angle_rad)

    These equations follow from those in LocalizationResiduals, where we solve
    for the \*_gt variables.

    Args:
      anchor_bboxes: tf.float32. where [..., :7] contains (x, y, z, dx, dy, dz,
        phi), corresponding to each anchor bbox parameters.
      residuals: tf.float32 of the same shape as anchor_bboxes containing
        predicted residuals at each anchor location.
      min_angle_rad: Scalar with the minimum angle allowed (before wrapping)
        in radians.
      max_angle_rad: Scalar with the maximum angle allowed (before wrapping)
        in radians. This value usually should be pi.

    Returns:
      A tf.float32 tensor of the same shape as anchor_bboxes with predicted
      bboxes.
    """
        anchor_bboxes_shape = py_utils.GetShape(anchor_bboxes)
        anchor_bboxes = py_utils.with_dependencies(
            [py_utils.assert_equal(anchor_bboxes_shape[-1], 7)], anchor_bboxes)
        residuals = py_utils.HasShape(residuals, anchor_bboxes_shape)

        x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a = tf.unstack(anchor_bboxes,
                                                            num=7,
                                                            axis=-1)
        (x_residual, y_residual, z_residual, dx_residual, dy_residual,
         dz_residual, phi_residual) = tf.unstack(residuals, num=7, axis=-1)

        diagonal_xy = tf.sqrt(tf.square(dx_a) + tf.square(dy_a))

        x_predicted = x_a + x_residual * diagonal_xy
        y_predicted = y_a + y_residual * diagonal_xy
        z_predicted = z_a + z_residual * dz_a

        dx_predicted = dx_a * tf.exp(dx_residual)
        dy_predicted = dy_a * tf.exp(dy_residual)
        dz_predicted = dz_a * tf.exp(dz_residual)

        # We bound the angle between [min_angle_rad, max_angle_rad], which should
        # be passed in depending on the heading handling in the calling model.
        # If the model uses a sine(delta_phi) transformation in the loss, then it
        # cannot distinguish direction and a [0, np.pi]
        # [min_angle_rad, max_angle_rad] should be used.
        # If there is a heading encoding that is directional, most likely you
        # should use a [-np.pi, np.pi] [min_angle_rad, max_angle_rad].
        phi_predicted = phi_a + phi_residual
        phi_predicted = geometry.WrapAngleRad(phi_predicted, min_angle_rad,
                                              max_angle_rad)

        return tf.stack([
            x_predicted,
            y_predicted,
            z_predicted,
            dx_predicted,
            dy_predicted,
            dz_predicted,
            phi_predicted,
        ],
                        axis=-1)  # pyformat: disable