def beam_init(batch_size, beam_size, max_decode_len, cache, start_tokens=None): """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF if start_tokens is None: live_seqs0 = jnp.zeros( (batch_size, beam_size, max_decode_len), jnp.int32) else: live_seqs0 = add_beam_dim( np.pad(start_tokens[:, None], ((0, 0), (0, max_decode_len - 1)), mode='constant'), beam_size) finished_seqs0 = jnp.zeros( (batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState(cur_index=cur_index0, live_logprobs=live_logprobs0, finished_scores=finished_scores0, live_seqs=live_seqs0, finished_seqs=finished_seqs0, finished_flags=finished_flags0, cache=beam_cache0)
def zero_pad(x, pad, axis): """Helper for np.pad with 0s for single-axis case.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = pad # Padding on axis. return np.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
def forward(self, x, weights): assert self._padding == 'VALID' # Left pad with 0s. Applying an unmasked valid convolution on top of this # yields a causal convolution. # TODO(ddohan): Support strided and dilated convolutions. rate = 1 effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1) pad = effective_kernel_size - 1 x_leftpad = np.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant') return super(CausalConv, self).forward(x_leftpad, weights)
def ShiftRight(x, n_shifts=1, mode='train', **unused_kwargs): """Layer to shift the tensor to the right by padding on axis 1.""" if mode == 'predict': # Do nothing in predict mode, as then the sequence length is 1. return x pad_widths = [(0, 0)] * len(x.shape) pad_widths[1] = (n_shifts, 0) # Padding on axis=1 padded = np.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) return padded[:, :-n_shifts]
def DiagonalGate(x): """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.""" # x : [batch, 1, length, depth] x = np.pad( x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) depth = x.shape[-1] // 3 assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, x.shape) xs = [ x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], x[:, :, 2:, 2 * depth:3 * depth] ] return np.concatenate(xs, axis=3)
def f(x): # pylint: disable=invalid-name # x : [batch, 1, length, depth] x = np.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) depth = x.shape[-1] // 3 assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, x.shape) xs = [ x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], x[:, :, 2:, 2 * depth:3 * depth] ] return np.concatenate(xs, axis=3)
def serialize_observations_and_actions( # pylint: disable=invalid-name observations, actions, observation_serializer, action_serializer, representation_length, ): """Serializes observations and actions into a discrete sequence. Args: observations: Array (B, T + 1, ...), of observations, where B is the batch size and T is the number of timesteps excluding the last observation. actions: Array (B, T, ...) of actions. observation_serializer: SpaceSerializer for observations. action_serializer: SpaceSerializer for actions. representation_length: Number of symbols in the serialized sequence. The sequence is padded up to this number. Returns: Serialized sequence of shape (B, R) where R = representation_length. """ (batch_size, n_timesteps) = actions.shape[:2] assert observations.shape[:2] == (batch_size, n_timesteps + 1) serialization = tl.Serial([ tl.Parallel( Serialize(serializer=observation_serializer), # pylint: disable=no-value-for-parameter Serialize(serializer=action_serializer), # pylint: disable=no-value-for-parameter ), Interleave(), # pylint: disable=no-value-for-parameter ]) serialization.init(shapes.signature((observations, actions))) reprs = serialization((observations, actions)) assert reprs.shape[1] <= representation_length return np.pad( reprs, pad_width=((0, 0), (0, representation_length - reprs.shape[1])), mode='constant', )
def _get_initial_state(self, inputs, targets_prefix, batch_size): """Get initial state for beam search.""" if targets_prefix is None: prompt = np.zeros((batch_size, 1), dtype=np.int32) else: prompt = np.pad( targets_prefix[:, :-1], ((0, 0), (1, 0)), mode='constant') # Get state prior to running the encoder or incorporating targets_prefix if inputs is None: signature = ShapeDtype((batch_size, 1), prompt.dtype) else: signature = (ShapeDtype(inputs.shape, inputs.dtype), ShapeDtype((batch_size, 1), prompt.dtype)) # Trax's model.init is stateful as opposed to functional. Calling it on an # already-existing model instance doesn't work. # TODO(lukaszkaiser): add purely functional init to Trax. _, initial_state = self.model(mode='predict').init(signature) # Incorporate encoder and prompt into state _, prompted_state = self.model_infer.pure_fn( prompt if inputs is None else (inputs, prompt), self.model_weights, initial_state, jax.random.PRNGKey(0)) state_structure = jax.tree_structure(prompted_state) if targets_prefix is not None: initial_state = prompted_state elif self.encoder_idx is not None: initial_state = (tuple(prompted_state[:self.encoder_idx]) + tuple(initial_state[self.encoder_idx:])) # Fix tree structure of the state (there's a tuple vs. list mismatch) initial_state = jax.tree_unflatten( state_structure, jax.tree_leaves(initial_state)) return initial_state
def PadRight(x, n_to_pad, **unused_kwargs): pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2) return np.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
def pad_right(x): pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2) return jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0))