예제 #1
0
def predict_step(params,
                 inputs,
                 outputs,
                 cache,
                 eos_token,
                 max_decode_len,
                 beam_size,
                 config):
  """Predict translation with fast decoding beam search on a batch."""
  # Prepare transformer fast-decoder call for beam search: for beam search, we
  # need to set up our decoder model to handle a batch size equal to
  # batch_size * beam_size, where each batch item's data is expanded in-place
  # rather than tiled.
  flat_encoded = decode.flat_batch_beam_expand(
      models.ProgramTransformer(config).apply(
          {'params': params},
          inputs,
          outputs,
          method=models.ProgramTransformer.encode),
      beam_size)

  encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
  flat_encoded_padding_mask = decode.flat_batch_beam_expand(
      encoded_padding_mask, beam_size)

  def tokens_ids_to_logits(flat_ids, flat_cache):
    """Token slice to logits from decoder model."""
    # --> [batch * beam, 1, vocab]
    flat_logits, new_vars = models.ProgramTransformer(config).apply(
        {'params': params, 'cache': flat_cache},
        flat_ids,
        flat_encoded,
        flat_encoded_padding_mask,
        mutable=['cache'],
        method=models.ProgramTransformer.decode)
    new_flat_cache = new_vars['cache']
    # Remove singleton sequence-length dimension:
    # [batch * beam, 1, vocab] --> [batch * beam, vocab]
    flat_logits = flat_logits.squeeze(axis=1)
    return flat_logits, new_flat_cache

  # Using the above-defined single-step decoder function, run a
  # beam search over possible sequences given input encoding.
  beam_seqs, _ = decode.beam_search(
      inputs,
      cache,
      tokens_ids_to_logits,
      beam_size=beam_size,
      alpha=0.6,
      bos_token=config.bos_token,
      eos_token=eos_token,
      max_decode_len=max_decode_len)

  # Beam search returns [n_batch, n_beam, n_length] with beam dimension
  # sorted in increasing order of log-probability.
  return beam_seqs
예제 #2
0
def predict_step(state, inputs, outputs, cache, lp_cache, beam_size, bos_token,
                 eos_token, max_decode_len, config, lp_config):
    """Predict translation with fast decoding beam search on a batch."""
    params = state.optimizer.target
    lp_params = state.lp_optimizer.target

    # Split beam over latent sequences and programs.
    per_latent_beam_size = beam_size // FLAGS.latent_beam_size
    beam_size = FLAGS.latent_beam_size * per_latent_beam_size

    flat_lp_encoded = decode.flat_batch_beam_expand(
        models.ProgramTransformer(lp_config).apply(
            {'params': lp_params},
            inputs,
            outputs,
            method=models.ProgramTransformer.encode), FLAGS.latent_beam_size)

    encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, FLAGS.latent_beam_size)

    def tokens_ids_to_latent_logits(flat_ids, flat_lp_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.ProgramTransformer(lp_config).apply(
            {
                'params': lp_params,
                'cache': flat_lp_cache
            },
            flat_ids,
            flat_lp_encoded,
            flat_encoded_padding_mask,
            mutable=['cache'],
            method=models.ProgramTransformer.decode)
        new_flat_lp_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_lp_cache

    # Step 1: Beam-search over latent tokens.
    latent_beam_seqs, _ = decode.beam_search(
        inputs,
        lp_cache,
        tokens_ids_to_latent_logits,
        beam_size=FLAGS.latent_beam_size,
        alpha=0.6,
        bos_token=bos_token,
        eos_token=eos_token,
        max_decode_len=np.ceil(max_decode_len / 2**FLAGS.c).astype(np.int32))

    flat_latent_seqs = decode.flat_batch_beam_expand(
        decode.flatten_beam_dim(latent_beam_seqs), per_latent_beam_size)
    # Quantize the predicted latent codes.
    flat_latents = models.LatentProgramTransformer(config).apply(
        {
            'params': params,
            'vqvae': state.model_state
        },
        flat_latent_seqs,
        mutable=False,
        method=models.LatentProgramTransformer.quantize)

    flat_encoded = decode.flat_batch_beam_expand(
        models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            inputs,
            outputs,
            mutable=False,
            method=models.LatentProgramTransformer.encode), beam_size)

    # Padding masks.
    flat_latents_mask = jnp.where(
        jnp.logical_and(flat_latent_seqs > 0, flat_latent_seqs != eos_token),
        1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, beam_size)

    def tokens_ids_to_logits(flat_ids, flat_cache):
        """Token slice to logits from decoder model."""
        # --> [batch * beam, 1, vocab]
        flat_logits, new_vars = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state,
                'cache': flat_cache
            },
            flat_ids,
            flat_latents,
            flat_encoded,
            flat_latents_mask,
            flat_encoded_padding_mask,
            mutable=['cache'],
            method=models.LatentProgramTransformer.decode)
        new_flat_cache = new_vars['cache']
        # Remove singleton sequence-length dimension:
        # [batch * beam, 1, vocab] --> [batch * beam, vocab]
        flat_logits = flat_logits.squeeze(axis=1)
        return flat_logits, new_flat_cache

    # Step 2: Beam-search over program tokens.
    per_latent_inputs = decode.flat_batch_beam_expand(inputs,
                                                      FLAGS.latent_beam_size)
    per_latent_cache = jax.tree_map(
        lambda x: decode.flat_batch_beam_expand(x, FLAGS.latent_beam_size),
        cache)
    beam_seqs, _ = decode.beam_search(per_latent_inputs,
                                      per_latent_cache,
                                      tokens_ids_to_logits,
                                      beam_size=per_latent_beam_size,
                                      alpha=0.6,
                                      bos_token=bos_token,
                                      eos_token=eos_token,
                                      max_decode_len=max_decode_len)
    # Collapse both beam dimensions into one.
    beam_seqs = beam_seqs.reshape((inputs.shape[0], beam_size) +
                                  beam_seqs.shape[2:])
    latent_beam_seqs = jnp.repeat(latent_beam_seqs,
                                  per_latent_beam_size,
                                  axis=1)

    # Beam search returns [n_batch, n_beam, n_length] with beam dimension
    # sorted in increasing order of log-probability.
    return beam_seqs, latent_beam_seqs
예제 #3
0
def predict_step(params,
                 inputs,
                 outputs,
                 cache,
                 beam_size,
                 num_partial_programs,
                 max_decode_len,
                 eos_token,
                 config,
                 use_expanding_layer,
                 slow_decode=False,
                 use_split_encoding=False,
                 split_params=None,
                 split_outputs=None):
  """Predict translation with fast decoding beam search on a batch."""
  per_partial_beam_size = max(beam_size // num_partial_programs, 1)

  m = models.DecomposeExpandingLayerTransformer(
      config=config, num_partial_programs=num_partial_programs,
      use_expanding_layer=use_expanding_layer)
  # Prepare transformer fast-decoder call for beam search: for beam search, we
  # need to set up our decoder model to handle a batch size equal to
  # batch_size * beam_size, where each batch item's data is expanded in-place
  # rather than tiled.
  if use_split_encoding:
    # Use pretrained split model to get encodings
    assert (split_params is not None) and (split_outputs is not None)

    split_inputs = models.add_and_tile_dim(inputs, num_partial_programs, axis=1)
    # split_outputs shape == [batch_size, num_partial, num_io, length]
    split_outputs = jnp.swapaxes(split_outputs, 1, 2)
    encoded = models.DecomposeExpandingLayerTransformer(
        config=config, num_partial_programs=1,
        use_expanding_layer=False).apply(
            {'params': split_params},
            split_inputs,
            split_outputs,
            method=models.DecomposeExpandingLayerTransformer.encode)
    flat_encoded = decode.flat_batch_beam_expand(encoded, per_partial_beam_size)

    encoded_padding_mask = jnp.where(
        split_outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, per_partial_beam_size)
  else:
    flat_encoded = decode.flat_batch_beam_expand(
        m.apply(
            {'params': params},
            inputs,
            outputs,
            method=models.DecomposeExpandingLayerTransformer.encode),
        per_partial_beam_size)
    flat_encoded = m.apply(
        {'params': params},
        flat_encoded,
        method=models.DecomposeExpandingLayerTransformer.decompose)

    encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, per_partial_beam_size)
    flat_encoded_padding_mask = models.add_and_tile_dim(
        flat_encoded_padding_mask, num_partial_programs, axis=1)

  if slow_decode:
    def tokens_ids_to_logits(flat_ids, i):
      """Token slice to logits from decoder model."""
      # --> [batch * beam, 1, vocab]
      flat_logits = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=use_expanding_layer).apply(
              {'params': params},
              flat_ids,
              flat_encoded[:, i],
              flat_encoded_padding_mask[:, i],
              method=models.DecomposeExpandingLayerTransformer.decode)
      return flat_logits
  else:
    def tokens_ids_to_logits(flat_ids, flat_cache, i):
      """Token slice to logits from decoder model."""
      # --> [batch * beam, 1, vocab]
      flat_logits, new_vars = models.DecomposeExpandingLayerTransformer(
          config=config, num_partial_programs=num_partial_programs,
          use_expanding_layer=use_expanding_layer).apply(
              {'params': params, 'cache': flat_cache},
              flat_ids,
              flat_encoded[:, i],
              flat_encoded_padding_mask[:, i],
              mutable=['cache'],
              method=models.DecomposeExpandingLayerTransformer.decode)
      new_flat_cache = new_vars['cache']
      # Remove singleton sequence-length dimension:
      # [batch * beam, 1, vocab] --> [batch * beam, vocab]
      flat_logits = flat_logits.squeeze(axis=1)
      return flat_logits, new_flat_cache

  # Perform beam search independently for each partial program.
  all_beam_seqs = []
  all_beam_log_probs = []
  for i in range(num_partial_programs):
    beam_seqs, beam_log_probs = decode.beam_search(
        inputs,
        cache,
        functools.partial(tokens_ids_to_logits, i=i),
        beam_size=per_partial_beam_size,
        alpha=0.6,
        bos_token=config.bos_token,
        eos_token=eos_token,
        max_decode_len=max_decode_len,
        slow_decode=slow_decode)
    all_beam_seqs.append(beam_seqs)
    all_beam_log_probs.append(beam_log_probs)

  all_beam_seqs = jnp.stack(all_beam_seqs, axis=1)
  all_beam_log_probs = jnp.stack(all_beam_log_probs, axis=1)

  # all_beam_seqs shape == [batch, n_partial, n_beam_per_partial, length]
  assert len(all_beam_seqs.shape) == 4
  # all_beam_log_probs shape == [batch, n_partial, n_beam_per_partial]
  assert len(all_beam_log_probs.shape) == 3

  # Sort beams in order of decreasing probability.
  order = jnp.argsort(all_beam_log_probs, axis=2)[:, :, ::-1]
  all_beam_log_probs = jnp.take_along_axis(all_beam_log_probs, order, axis=2)
  all_beam_seqs = jnp.take_along_axis(all_beam_seqs, order[Ellipsis, jnp.newaxis],
                                      axis=2)

  return all_beam_seqs, all_beam_log_probs
예제 #4
0
def predict_step(params,
                 inputs,
                 outputs,
                 cache,
                 beam_size,
                 eos_token,
                 max_decode_len,
                 config,
                 dropout_rng=None,
                 slow_decode=True):
    """Predict translation with fast decoding beam search on a batch."""
    # This code is necessary to experiment with using dropout during prediction,
    # but we don't normally use dropout here.
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    # Prepare transformer fast-decoder call for beam search: for beam search, we
    # need to set up our decoder model to handle a batch size equal to
    # batch_size * beam_size, where each batch item's data is expanded in-place
    # rather than tiled.
    flat_encoded = decode.flat_batch_beam_expand(
        models.DecomposeAttentionTransformer(config).apply(
            {'params': params},
            inputs,
            outputs,
            method=models.DecomposeAttentionTransformer.encode,
            rngs={'dropout': dropout_rng}), beam_size)

    encoded_padding_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    flat_encoded_padding_mask = decode.flat_batch_beam_expand(
        encoded_padding_mask, beam_size)

    if slow_decode:

        def tokens_ids_to_logits(flat_ids):
            """Token slice to logits from decoder model."""
            # --> [batch * beam, 1, vocab]
            flat_logits = models.DecomposeAttentionTransformer(
                config=config).apply(
                    {'params': params},
                    flat_ids,
                    flat_encoded,
                    flat_encoded_padding_mask,
                    method=models.DecomposeAttentionTransformer.decode,
                    rngs={'dropout': dropout_rng})
            return flat_logits
    else:

        def tokens_ids_to_logits(flat_ids, flat_cache):
            """Token slice to logits from decoder model."""
            # --> [batch * beam, 1, vocab]
            flat_logits, new_vars = models.DecomposeAttentionTransformer(
                config=config).apply(
                    {
                        'params': params,
                        'cache': flat_cache
                    },
                    flat_ids,
                    flat_encoded,
                    flat_encoded_padding_mask,
                    mutable=['cache'],
                    method=models.DecomposeAttentionTransformer.decode,
                    rngs={'dropout': dropout_rng})
            new_flat_cache = new_vars['cache']
            # Remove singleton sequence-length dimension:
            # [batch * beam, 1, vocab] --> [batch * beam, vocab]
            flat_logits = flat_logits.squeeze(axis=1)
            return flat_logits, new_flat_cache

    # Using the above-defined single-step decoder function, run a
    # beam search over possible sequences given input encoding.
    beam_seqs, _ = decode.beam_search(inputs,
                                      cache,
                                      tokens_ids_to_logits,
                                      beam_size=beam_size,
                                      alpha=0.6,
                                      bos_token=config.base_config.bos_token,
                                      eos_token=eos_token,
                                      max_decode_len=max_decode_len,
                                      slow_decode=slow_decode)

    # Beam search returns [n_batch, n_beam, n_length] with beam dimension
    # sorted in increasing order of log-probability.
    return beam_seqs, new_dropout_rng