def forward(self, inputs): x, lstm_state = inputs # LSTM state consists of c and h. c, h = jnp.split(lstm_state, 2, axis=-1) # Dense layer on the concatenation of x and h. w, b = self.weights y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = jnp.split(y, 4, axis=-1) new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j) new_h = jnp.tanh(new_c) * fastmath.sigmoid(o) return new_h, jnp.concatenate([new_c, new_h], axis=-1)
def _shard_fn(x): axis = _axis_to_shard_heuristic(x.shape) if int(x.shape[axis]) % n_shards != 0: raise ValueError( f'Cannot split x with shape {x.shape} into {n_shards}.') split_x = jnp.split(x, n_shards, axis=axis) split_x = [split_x[i % n_shards] for i in indices] return np.stack(split_x, axis=0)
def forward(self, inputs): x, gru_state = inputs # Dense layer on the concatenation of x and h. w1, b1, w2, b2 = self.weights y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 # Update and reset gates. u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1) # Candidate. c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) return new_gru_state, new_gru_state
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer's `filters` value, and the second to last dimension is shrinked if 'VALID' padding is used with kernel_size bigger than one. """ if self._use_bias: if not isinstance(self.weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {self.weights}') w, b = self.weights else: w = self.weights linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w) # TODO(jaszczur): this could be run after padding for better efficiency if self._kernel_size == 1: # With kernel size 1 we don't have to split or shift anything. linear_result = jnp.squeeze(linear_results_before_shifting, axis=-2) else: # We computed a result for every "pixel", but each direction from the # receptive field (there are 'self._kernel_size' such directions) must be # shifted by a different amount. The easiest way to do it is to split # the tensor to 'self._kernel_size' smaller tensors, shift each one # appropriately, and then sum them together. split_shifting_linear_results = jnp.split( linear_results_before_shifting, self._kernel_size, axis=-2) for i in range(self._kernel_size): # Each tensor has to be shifted a different amount. if self._padding == 'WRAP': # We can shift by padding and cutting. With 'wrap' padding we # essentially have a torus. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding, mode='wrap') split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] elif self._padding == 'SAME': # We can shift by padding and cutting. padding = [(0, 0) for i in split_shifting_linear_results[i].shape] padding[-3] = ((self._kernel_size - 1) - i, i) split_shifting_linear_results[i] = jnp.pad( split_shifting_linear_results[i], padding) split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., (self._kernel_size - 1) // 2:-(self._kernel_size - 1) // 2, :, :] # TODO(jaszczur): improve efficiency by not padding things to cut elif self._padding == 'VALID': # We don't need to shift - just cut the leftmost and rightmost values. cut_left = (self._kernel_size - 1) - i cut_right = split_shifting_linear_results[i].shape[-3] - i split_shifting_linear_results[ i] = split_shifting_linear_results[i][ ..., cut_left:cut_right, :, :] else: raise ValueError(f'Invalid padding {self._padding}') # After shifting. shifted_linear_results = jnp.concatenate( split_shifting_linear_results, axis=-2) linear_result = jnp.sum(shifted_linear_results, axis=-2) if self._use_bias: return linear_result + b else: return linear_result
def forward(self, inputs): return tuple(jnp.split(inputs, self._n_items, self._axis))
def forward(self, inputs): """Executes this layer as part of a forward pass through the model.""" return tuple(jnp.split(inputs, self._n_items, self._axis))
def _unshard_fn(x): y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups) split_y = jnp.split(y, n_shards, axis=0) split_y = [jnp.squeeze(sy, axis=0) for sy in split_y] axis = _axis_to_shard_heuristic(split_y[0].shape) return jnp.concatenate(split_y, axis=axis)
def _f(x, axis=-1): # pylint: disable=invalid-name size = x.shape[axis] assert size % 2 == 0, f'axis {axis} of size {size} is not be divisible by 2' a, b = jnp.split(x, 2, axis) return a * fastmath.expit(b)