Exemple #1
0
  def test_repeat_with_one_to_one_alignment(self):
    batch_size = 2
    labels = np.array([
        [1, 2, 2, 3],
        [2, 3, 4, 4],
    ])
    label_lens = np.array([4, 4])
    label_paddings = lengths_to_paddings(label_lens, 4)
    logits = np.random.randn(batch_size, 5, 5)
    logprobs = jax.nn.log_softmax(logits)
    logprob_paddings = np.zeros(logprobs.shape[:2])

    jax_per_seq, unused_aux_vars = ctc_objectives.ctc_loss(
        logprobs, logprob_paddings, labels, label_paddings)

    expected_alignment = [
        [1, 2, 0, 2, 3],
        [2, 3, 4, 0, 4],
    ]

    for n in range(batch_size):
      expected_loss = -sum(logprobs[n, t, k]
                           for t, k in enumerate(expected_alignment[n]))
      self.assertAllClose(
          jnp.array(expected_loss), jax_per_seq[n], rtol=0.01, atol=0.05)
Exemple #2
0
  def test_against_tf_ctc_loss(self):
    batchsize = 8
    timesteps = 150
    labelsteps = 25
    nclasses = 400
    logits = np.random.randn(batchsize, timesteps, nclasses)
    logprobs = jax.nn.log_softmax(logits)
    logprob_paddings = np.zeros((batchsize, timesteps))
    labels = np.random.randint(
        1, nclasses, size=(batchsize, labelsteps)).astype(np.int32)
    label_paddings = np.zeros((batchsize, labelsteps))

    inputs = [logprobs, logprob_paddings, labels, label_paddings]

    jax_per_seq, unused_aux_vars = ctc_objectives.ctc_loss(*inputs)
    tf_per_seq = tf_ctc_loss(*inputs)
    self.assertAllClose(jax_per_seq.squeeze(), tf_per_seq.squeeze())

    average_tf_ctc_loss = lambda *args: jnp.average(tf_ctc_loss(*args))
    jax_dloss = jax.grad(average_ctc_loss)
    tf_dloss = jax.grad(average_tf_ctc_loss)

    jax_dlogits = jax_dloss(*inputs)
    tf_dlogits = tf_dloss(*inputs)
    # Relative error check is disabled as numerical errors explodes when a
    # probability computed from the input logits is close to zero.
    self.assertAllClose(jax_dlogits, tf_dlogits, rtol=0.0, atol=1e-4)
Exemple #3
0
  def test_with_one_to_one_alignment(self):
    # when inputsteps and outputsteps are equal, no phi will be allowed
    batchsize = 8
    steps = 50
    nclasses = 40
    logprobs = np.random.randn(batchsize, steps, nclasses)
    logprobs = jax.nn.log_softmax(logprobs)
    labels = np.random.uniform(
        1, nclasses, size=(batchsize, steps)).astype(np.int32)
    # This case only check the cases without same label repetition.
    # `test_repeat_with_one_to_one_alignment` below complements those cases.
    # Redraw samples for satisfying the constraint.
    for n in range(labels.shape[0]):
      for t in range(1, labels.shape[1]):
        while labels[n, t] == labels[n, t - 1]:
          labels[n, t] = np.random.uniform(1, nclasses)

    per_seq_loss, aux_vars = ctc_objectives.ctc_loss(
        logprobs, np.zeros(logprobs.shape[:2]), labels, np.zeros(labels.shape))

    for b in range(batchsize):
      p = 0.0
      for t in range(steps):
        p += logprobs[b, t, labels[b, t]]
      self.assertAllClose(jnp.array(-p), per_seq_loss[b])

      # Check logalpha interim variables
      # 1. All-phi path
      self.assertAllClose(aux_vars['logalpha_phi'][-1, b, 0],
                          jnp.sum(logprobs[b, :, 0]))
      # 2. After emitting all the labels
      self.assertAllClose(aux_vars['logalpha_emit'][-1, b, steps - 1],
                          -per_seq_loss[b])
      self.assertAllClose(aux_vars['logalpha_phi'][-1, b, -1], -per_seq_loss[b])
Exemple #4
0
  def test_against_tf_ctc_loss_with_paddings(self):
    batchsize = 8
    timesteps = 150
    labelsteps = 25
    nclasses = 400

    logits = np.random.randn(batchsize, timesteps, nclasses)
    logprobs = jax.nn.log_softmax(logits)
    logprob_lens = np.random.randint(25, timesteps - 3, size=(batchsize,))
    logprob_paddings = lengths_to_paddings(logprob_lens, timesteps)

    labels = np.random.randint(
        1, nclasses, size=(batchsize, labelsteps)).astype(np.int32)
    label_lens = np.random.randint(10, labelsteps, size=(batchsize,))
    label_paddings = lengths_to_paddings(label_lens, labelsteps)

    inputs = [logprobs, logprob_paddings, labels, label_paddings]

    jax_per_seq, _ = ctc_objectives.ctc_loss(*inputs)
    tf_per_seq = tf_ctc_loss(*inputs)
    self.assertAllClose(jax_per_seq.squeeze(), tf_per_seq.squeeze())
Exemple #5
0
def average_ctc_loss(logprobs: JTensor, logprob_paddings: JTensor,
                     labels: JTensor, label_paddings: JTensor) -> JTensor:
  return jnp.average(
      ctc_objectives.ctc_loss(logprobs, logprob_paddings, labels,
                              label_paddings)[0])