Ejemplo n.º 1
0
def remove_pad(x, pad_remover, mode):
    """Remove padding by concatenating all dimension into one.

  Args:
    x (tf.Tensor): input of shape [batch_size, length, depth]
    pad_remover (obj): a PadRemover object
    mode (ModeKeys): infer, train or eval. If inference, the padding remover is
      not applied

  Returns:
    tf.Tensor of shape [1,length_nonpad,depth] where
      length_nonpad <= batch_size*length
  """
    # Concatenate all tokens (without padding)
    x = expert_utils.flatten_all_but_last(x)

    # Remove padding for training and eval
    if mode != ModeKeys.PREDICT:
        # This is a hack to allows inference when the <go> token
        # is detected as padding and removed. This works for now because there is
        # no padding at inference.
        x = pad_remover.remove(x)

    x = tf.expand_dims(x, axis=0)  # Now batch_size=1
    return x
Ejemplo n.º 2
0
def remove_pad(x, pad_remover, mode):
  """Remove padding by concatenating all dimension into one.

  Args:
    x (tf.Tensor): input of shape [batch_size, length, depth]
    pad_remover (obj): a PadRemover object
    mode (ModeKeys): infer, train or eval. If inference, the padding remover is
      not applied

  Returns:
    tf.Tensor of shape [1,length_nonpad,depth] where
      length_nonpad <= batch_size*length
  """
  # Concatenate all tokens (without padding)
  x = expert_utils.flatten_all_but_last(x)

  # Remove padding for training and eval
  if mode != ModeKeys.PREDICT:
    # This is a hack to allows inference when the <go> token
    # is detected as padding and removed. This works for now because there is
    # no padding at inference.
    x = pad_remover.remove(x)

  x = tf.expand_dims(x, axis=0)  # Now batch_size=1
  return x