Example #1
0
def autoregressive_sample(model,
                          prefix=None,
                          inputs=None,
                          batch_size=1,
                          temperature=1.0,
                          start_id=0,
                          eos_id=1,
                          max_length=100,
                          accelerate=True):
    """Perform aturegressive sampling from the provided model.

  Args:
    model: instance of trax.Layer, the model to sample from (at mode='predict')
    prefix: optional tensor [batch_size, L]: prefix for decoding
    inputs: optional tensor [batch_size, M]: inputs to provide to the model
    batch_size: how many batches to sample (default: 1)
    temperature: sampling temperature (default: 1.0)
    start_id: int, id for the start symbol fed at the beginning (default: 1)
    eos_id: int, id of the end-of-sequence symbol used to stop (default: 1)
    max_length: maximum length to sample (default: 100)
    accelerate: whether to accelerate the model before decoding (default: True)

  Returns:
    a tensor of ints of shape [batch_size, N] with N <= max_length containing
    the autoregressively sampled output from the model
  """
    if prefix is not None and prefix.shape[0] != batch_size:
        raise ValueError(
            f'Prefix batch size {prefix.shape[0]} != {batch_size}.')
    if inputs is not None and inputs.shape[0] != batch_size:
        raise ValueError(
            f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
    fast_model = tl.Accelerate(model) if accelerate else model
    cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
    result = []
    for i in range(max_length):
        model_input = cur_symbol if inputs is None else (inputs, cur_symbol)
        logits = fast_model(model_input)
        if inputs is not None:
            logits = logits[
                0]  # Pick first element from model output (a pair here)
        if prefix is not None and i < prefix.shape[1]:  # Read from prefix.
            cur_prefix_symbol = prefix[:, i]
            sample = cur_prefix_symbol[:, None]
        else:
            sample = tl.gumbel_sample(logits, temperature=temperature)
        result.append(sample)
        # Note: we're using 'predict' mode autoregressive models here, so history
        # is caches in the model state and we are only feeding one symbol next.
        cur_symbol = sample
        # TODO(lukaszkaiser): extend stopping below to batch_sizes > 1.
        if batch_size == 1 and int(sample[0, 0]) == eos_id:
            break
    return np.concatenate(result, axis=1)
Example #2
0
  def F(x):
    # TODO(afrozm): What to do in this case?
    if mode == 'predict':
      raise ValueError('MaskOfRightShiftedArray not implemented for predict.')

    mask = x != 0

    if n_shifts == 0:
      return mask

    # Need to set (B, n_shifts, ...) section to True.
    trues_shape = (x.shape[0], n_shifts) + mask.shape[2:]
    trues = jnp.full(trues_shape, True)
    return jnp.concatenate([trues, mask[:, n_shifts:, ...]], axis=1)
Example #3
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if fastmath.is_backend(fastmath.Backend.JAX):
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
Example #4
0
    def forward(self, inputs):
        rng, state = self.rng, self.state
        embs = []
        for ax_emb in self.weights:
            ax_emb = jnp.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                      self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = jnp.concatenate(embs, -1)
            emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            emb = jax.lax.dynamic_slice_in_dim(emb,
                                               state,
                                               inputs.shape[1],
                                               axis=1)
            self.state = state + inputs.shape[1]
            return inputs + emb
        elif self._dropout == 0:
            # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled)
            # leads to memory blow-up on TPU.
            # emb = jnp.concatenate(embs, -1)
            # return inputs + jnp.reshape(emb, inputs.shape), state
            return inputs + jnp.concatenate([
                jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], ))
                for emb in embs
            ], -1)
        else:
            emb = jnp.concatenate(embs, -1)
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if fastmath.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, jnp.full((), keep_prob, dtype=inputs.dtype))
            keep = fastmath.random.bernoulli(rng, keep_prob,
                                             tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob
            return inputs + jnp.reshape(emb * multiplier, inputs.shape)