Beispiel #1
0
  def forward(self, inputs):
    x, lstm_state = inputs

    # LSTM state consists of c and h.
    c, h = jnp.split(lstm_state, 2, axis=-1)

    # Dense layer on the concatenation of x and h.
    w, b = self.weights
    y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = jnp.split(y, 4, axis=-1)

    new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j)
    new_h = jnp.tanh(new_c) * fastmath.sigmoid(o)
    return new_h, jnp.concatenate([new_c, new_h], axis=-1)
Beispiel #2
0
 def _shard_fn(x):
     axis = _axis_to_shard_heuristic(x.shape)
     if int(x.shape[axis]) % n_shards != 0:
         raise ValueError(
             f'Cannot split x with shape {x.shape} into {n_shards}.')
     split_x = jnp.split(x, n_shards, axis=axis)
     split_x = [split_x[i % n_shards] for i in indices]
     return np.stack(split_x, axis=0)
Beispiel #3
0
  def forward(self, inputs):
    x, gru_state = inputs

    # Dense layer on the concatenation of x and h.
    w1, b1, w2, b2 = self.weights
    y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

    # Update and reset gates.
    u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1)

    # Candidate.
    c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

    new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
    return new_gru_state, new_gru_state
Beispiel #4
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `filters` value, and the second to last dimension is
      shrinked if 'VALID' padding is used with kernel_size bigger than one.
    """
        if self._use_bias:
            if not isinstance(self.weights, (tuple, list)):
                raise ValueError(f'Weights should be a (w, b) tuple or list; '
                                 f'instead got: {self.weights}')
            w, b = self.weights
        else:
            w = self.weights

        linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w)
        # TODO(jaszczur): this could be run after padding for better efficiency

        if self._kernel_size == 1:
            # With kernel size 1 we don't have to split or shift anything.
            linear_result = jnp.squeeze(linear_results_before_shifting,
                                        axis=-2)
        else:
            # We computed a result for every "pixel", but each direction from the
            # receptive field (there are 'self._kernel_size' such directions) must be
            # shifted by a different amount. The easiest way to do it is to split
            # the tensor to 'self._kernel_size' smaller tensors, shift each one
            # appropriately, and then sum them together.
            split_shifting_linear_results = jnp.split(
                linear_results_before_shifting, self._kernel_size, axis=-2)

            for i in range(self._kernel_size):
                # Each tensor has to be shifted a different amount.
                if self._padding == 'WRAP':
                    # We can shift by padding and cutting. With 'wrap' padding we
                    # essentially have a torus.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding, mode='wrap')
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                elif self._padding == 'SAME':
                    # We can shift by padding and cutting.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding)
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                    # TODO(jaszczur): improve efficiency by not padding things to cut
                elif self._padding == 'VALID':
                    # We don't need to shift - just cut the leftmost and rightmost values.
                    cut_left = (self._kernel_size - 1) - i
                    cut_right = split_shifting_linear_results[i].shape[-3] - i
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., cut_left:cut_right, :, :]
                else:
                    raise ValueError(f'Invalid padding {self._padding}')
            # After shifting.
            shifted_linear_results = jnp.concatenate(
                split_shifting_linear_results, axis=-2)
            linear_result = jnp.sum(shifted_linear_results, axis=-2)

        if self._use_bias:
            return linear_result + b
        else:
            return linear_result
Beispiel #5
0
 def forward(self, inputs):
     return tuple(jnp.split(inputs, self._n_items, self._axis))
Beispiel #6
0
 def forward(self, inputs):
     """Executes this layer as part of a forward pass through the model."""
     return tuple(jnp.split(inputs, self._n_items, self._axis))
Beispiel #7
0
 def _unshard_fn(x):
     y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups)
     split_y = jnp.split(y, n_shards, axis=0)
     split_y = [jnp.squeeze(sy, axis=0) for sy in split_y]
     axis = _axis_to_shard_heuristic(split_y[0].shape)
     return jnp.concatenate(split_y, axis=axis)
Beispiel #8
0
 def _f(x, axis=-1):  # pylint: disable=invalid-name
     size = x.shape[axis]
     assert size % 2 == 0, f'axis {axis} of size {size} is not be divisible by 2'
     a, b = jnp.split(x, 2, axis)
     return a * fastmath.expit(b)