def autoregressive_sample(model, prefix=None, inputs=None, batch_size=1, temperature=1.0, start_id=0, eos_id=1, max_length=100, accelerate=True): """Perform aturegressive sampling from the provided model. Args: model: instance of trax.Layer, the model to sample from (at mode='predict') prefix: optional tensor [batch_size, L]: prefix for decoding inputs: optional tensor [batch_size, M]: inputs to provide to the model batch_size: how many batches to sample (default: 1) temperature: sampling temperature (default: 1.0) start_id: int, id for the start symbol fed at the beginning (default: 1) eos_id: int, id of the end-of-sequence symbol used to stop (default: 1) max_length: maximum length to sample (default: 100) accelerate: whether to accelerate the model before decoding (default: True) Returns: a tensor of ints of shape [batch_size, N] with N <= max_length containing the autoregressively sampled output from the model """ if prefix is not None and prefix.shape[0] != batch_size: raise ValueError( f'Prefix batch size {prefix.shape[0]} != {batch_size}.') if inputs is not None and inputs.shape[0] != batch_size: raise ValueError( f'Inputs batch size {inputs.shape[0]} != {batch_size}.') fast_model = tl.Accelerate(model) if accelerate else model cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) result = [] for i in range(max_length): model_input = cur_symbol if inputs is None else (inputs, cur_symbol) logits = fast_model(model_input) if inputs is not None: logits = logits[ 0] # Pick first element from model output (a pair here) if prefix is not None and i < prefix.shape[1]: # Read from prefix. cur_prefix_symbol = prefix[:, i] sample = cur_prefix_symbol[:, None] else: sample = tl.gumbel_sample(logits, temperature=temperature) result.append(sample) # Note: we're using 'predict' mode autoregressive models here, so history # is caches in the model state and we are only feeding one symbol next. cur_symbol = sample # TODO(lukaszkaiser): extend stopping below to batch_sizes > 1. if batch_size == 1 and int(sample[0, 0]) == eos_id: break return np.concatenate(result, axis=1)
def F(x): # TODO(afrozm): What to do in this case? if mode == 'predict': raise ValueError('MaskOfRightShiftedArray not implemented for predict.') mask = x != 0 if n_shifts == 0: return mask # Need to set (B, n_shifts, ...) section to True. trues_shape = (x.shape[0], n_shifts) + mask.shape[2:] trues = jnp.full(trues_shape, True) return jnp.concatenate([trues, mask[:, n_shifts:, ...]], axis=1)
def forward(self, inputs): """Returns the input activations, with added positional information.""" if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] px = self.weights[:, :symbol_size, :] if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.is_backend(fastmath.Backend.JAX): keep_prob = jax.lax.tie_in( x, jnp.full((), keep_prob, dtype=x.dtype)) keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( jax.lax.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] return inputs + jnp.stack(emb, 0)
def forward(self, inputs): rng, state = self.rng, self.state embs = [] for ax_emb in self.weights: ax_emb = jnp.broadcast_to(ax_emb, (inputs.shape[0], ) + self._shape + (ax_emb.shape[-1], )) embs.append(ax_emb) if self._mode == 'predict': assert self._dropout == 0.0 emb = jnp.concatenate(embs, -1) emb = jnp.reshape(emb, (inputs.shape[0], -1, emb.shape[-1])) emb = jax.lax.dynamic_slice_in_dim(emb, state, inputs.shape[1], axis=1) self.state = state + inputs.shape[1] return inputs + emb elif self._dropout == 0: # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled) # leads to memory blow-up on TPU. # emb = jnp.concatenate(embs, -1) # return inputs + jnp.reshape(emb, inputs.shape), state return inputs + jnp.concatenate([ jnp.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], )) for emb in embs ], -1) else: emb = jnp.concatenate(embs, -1) noise_shape = list(emb.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout if fastmath.backend_name() == 'jax': keep_prob = jax.lax.tie_in( inputs, jnp.full((), keep_prob, dtype=inputs.dtype)) keep = fastmath.random.bernoulli(rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(inputs.dtype) / keep_prob return inputs + jnp.reshape(emb * multiplier, inputs.shape)