def remove_pad(x, pad_remover, mode): """Remove padding by concatenating all dimension into one. Args: x (tf.Tensor): input of shape [batch_size, length, depth] pad_remover (obj): a PadRemover object mode (ModeKeys): infer, train or eval. If inference, the padding remover is not applied Returns: tf.Tensor of shape [1,length_nonpad,depth] where length_nonpad <= batch_size*length """ # Concatenate all tokens (without padding) x = expert_utils.flatten_all_but_last(x) # Remove padding for training and eval if mode != ModeKeys.PREDICT: # This is a hack to allows inference when the <go> token # is detected as padding and removed. This works for now because there is # no padding at inference. x = pad_remover.remove(x) x = tf.expand_dims(x, axis=0) # Now batch_size=1 return x