def LocallyConvDense(n_modules, n_units, kernel_size=1, length_kernel_size=1): """Layer using local convolutions for approximation of Dense layer. The layer splits the last axis of a tensor into `n_modules`, then runs a convolution on all those modules, and concatenates their results. It is similar to LocallyConnectedDense above, but shares weights. Args: n_modules: Indicates how many modules (pixels) should be input and output split into for processing. n_units: how many outputs (filters) should each module generate. kernel_size: The size of the kernel to be used. length_kernel_size: If > 1, also do causal convolution on the previous axis, which is often the sentence length in sequence models. Returns: LocallyConvDense base.Layer. """ if n_modules == 1: return tl.Dense(n_units) if kernel_size % 2 != 1: raise ValueError('Currently we only handle odd kernel sizes.') half = (kernel_size - 1) // 2 pad_widths = [[0, 0], [length_kernel_size - 1, 0], [half, half], [0, 0]] return tl.Serial( tl.SplitLastAxis(n_modules), tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths)), tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)), tl.MergeLastTwoAxes())
def pad(z): pad_widths = [(0, 0)] * len(z.shape) pad_widths[0] = (0, self._n_devices - remainder) return jnp.pad(z, pad_widths, mode='constant', constant_values=z.dtype.type(0))
def beam_init(batch_size, beam_size, max_decode_len, cache, start_tokens=None): """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF if start_tokens is None: live_seqs0 = jnp.zeros( (batch_size, beam_size, max_decode_len), jnp.int32) else: live_seqs0 = add_beam_dim( np.pad(start_tokens[:, None], ((0, 0), (0, max_decode_len - 1)), mode='constant'), beam_size) finished_seqs0 = jnp.zeros( (batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState(cur_index=cur_index0, live_logprobs=live_logprobs0, finished_scores=finished_scores0, live_seqs=live_seqs0, finished_seqs=finished_seqs0, finished_flags=finished_flags0, cache=beam_cache0)
def shift_right(x): pad_widths = [(0, 0)] * len(x.shape) pad_widths[1] = (1, 0) padded = jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(cls_id)) return padded[:, :-1]
def _zero_pad(x, pad, axis): """Helper for jnp.pad with 0s for single-axis case.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = pad # Padding on axis. return jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
def forward(self, x): assert self._padding == 'VALID' # Left pad with 0s. Applying an unmasked valid convolution on top of this # yields a causal convolution. # TODO(ddohan): Support strided and dilated convolutions. rate = 1 effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1) pad = effective_kernel_size - 1 x_leftpad = ( jnp.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant')) return super().forward(x_leftpad)
def f(x): # pylint: disable=invalid-name # x : [batch, 1, length, depth] x = jnp.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) depth = x.shape[-1] // 3 assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, x.shape) xs = [ x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], x[:, :, 2:, 2 * depth:3 * depth] ] return jnp.concatenate(xs, axis=3)
def pure_fn(self, x, weights, state, rng, use_cache=False): """Calls self.sublayer.pure_fn in an accelerated way.""" # Check if we can divide x evenly across devices. remainder = x.shape[0] % self._n_devices if remainder == 0: # If yes, run the accelerated sublayer.pure_fn. return self._jit_pure_fn(x, weights, state, rng) # If not, pad first. pad_widths = [(0, 0)] * len(x.shape) pad_widths[0] = (0, self._n_devices - remainder) padded_x = jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) # Run and un-pad. padded_y, state = self._jit_pure_fn(padded_x, weights, state, rng) return padded_y[:x.shape[0]], state
def _get_initial_state(self, inputs, targets_prefix, batch_size): """Get initial state for beam search.""" if targets_prefix is None: prompt = np.zeros((batch_size, 1), dtype=np.int32) else: prompt = np.pad( targets_prefix[:, :-1], ((0, 0), (1, 0)), mode='constant') # Get state prior to running the encoder or incorporating targets_prefix if inputs is None: signature = ShapeDtype((batch_size, 1), prompt.dtype) else: signature = (ShapeDtype(inputs.shape, inputs.dtype), ShapeDtype((batch_size, 1), prompt.dtype)) # Trax's model.init is stateful as opposed to functional. Calling it on an # already-existing model instance doesn't work. # TODO(lukaszkaiser): add purely functional init to Trax. _, initial_state = self.model(mode='predict').init(signature) # Incorporate encoder and prompt into state _, prompted_state = self.model_infer.pure_fn( prompt if inputs is None else (inputs, prompt), self.model_weights, initial_state, jax.random.PRNGKey(0)) state_structure = jax.tree_structure(prompted_state) if targets_prefix is not None: initial_state = prompted_state elif self.encoder_idx is not None: initial_state = (tuple(prompted_state[:self.encoder_idx]) + tuple(initial_state[self.encoder_idx:])) # Fix tree structure of the state (there's a tuple vs. list mismatch) initial_state = jax.tree_unflatten( state_structure, trax.fastmath.tree_leaves(initial_state)) return initial_state
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer's `filters` value, and the second to last dimension is shrinked if 'VALID' padding is used with kernel_size bigger than one. """ if self._use_bias: if not isinstance(self.weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {self.weights}') w, b = self.weights else: w = self.weights linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w) # TODO(jaszczur): this could be run after padding for better efficiency if self._kernel_size == 1: # With kernel size 1 we don't have to split or shift anything. linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) else: # We computed a result for every "pixel", but each direction from the # receptive field (there are 'self._kernel_size' such directions) must be # shifted by a different amount. The easiest way to do it is to split # the tensor to 'self._kernel_size' smaller tensors, shift each one # appropriately, and then sum them together. split_shifting_linear_results = jnp.split( linear_results_before_shifting, self._kernel_size, axis=-2) for i in range(self._kernel_size): # Each tensor has to be shifted a different amount. if self._padding == 'WRAP': # We can shift by padding and cutting. With 'wrap' padding we # essentially have a torus. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding, mode='wrap') split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] elif self._padding == 'SAME': # We can shift by padding and cutting. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding) split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] # TODO(jaszczur): improve efficiency by not padding things to cut elif self._padding == 'VALID': # We don't need to shift - just cut the leftmost and rightmost values. cut_left = (self._kernel_size - 1) - i cut_right = split_shifting_linear_results[i].shape[-3] - i split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., cut_left:cut_right, :, :] else: raise ValueError(f'Invalid padding {self._padding}') # After shifting. shifted_linear_results = jnp.concatenate( split_shifting_linear_results, axis=-2) linear_result = jnp.sum(shifted_linear_results, axis=-2) if self._use_bias: return linear_result + b else: return linear_result
def _zero_pad(x, pad, axis): # pylint: disable = invalid-name """Helper for jnp.pad with 0s for single-axis case.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = pad # Padding on axis. return jnp.pad(x, pad_widths, mode='constant')
def pad_right(x): pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2) return jnp.pad( x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
def pad_to_chunk_len(v): width = [(0, 0)] * v.ndim width[2] = (0, chunk_len - v.shape[2]) return jnp.pad(v, width, mode='constant', constant_values=0.0)