Ejemplo n.º 1
  def forward(self, inputs):
    x, lstm_state = inputs

    # LSTM state consists of c and h.
    c, h = jnp.split(lstm_state, 2, axis=-1)

    # Dense layer on the concatenation of x and h.
    w, b = self.weights
    y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = jnp.split(y, 4, axis=-1)

    new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j)
    new_h = jnp.tanh(new_c) * fastmath.sigmoid(o)
    return new_h, jnp.concatenate([new_c, new_h], axis=-1)
Ejemplo n.º 2
 def _shard_fn(x):
     axis = _axis_to_shard_heuristic(x.shape)
     if int(x.shape[axis]) % n_shards != 0:
         raise ValueError(
             f'Cannot split x with shape {x.shape} into {n_shards}.')
     split_x = jnp.split(x, n_shards, axis=axis)
     split_x = [split_x[i % n_shards] for i in indices]
     return np.stack(split_x, axis=0)
Ejemplo n.º 3
  def forward(self, inputs):
    x, gru_state = inputs

    # Dense layer on the concatenation of x and h.
    w1, b1, w2, b2 = self.weights
    y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

    # Update and reset gates.
    u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1)

    # Candidate.
    c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

    new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
    return new_gru_state, new_gru_state
Ejemplo n.º 4
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `filters` value, and the second to last dimension is
      shrinked if 'VALID' padding is used with kernel_size bigger than one.
        if self._use_bias:
            if not isinstance(self.weights, (tuple, list)):
                raise ValueError(f'Weights should be a (w, b) tuple or list; '
                                 f'instead got: {self.weights}')
            w, b = self.weights
            w = self.weights

        linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w)
        # TODO(jaszczur): this could be run after padding for better efficiency

        if self._kernel_size == 1:
            # With kernel size 1 we don't have to split or shift anything.
            linear_result = jnp.squeeze(linear_results_before_shifting,
            # We computed a result for every "pixel", but each direction from the
            # receptive field (there are 'self._kernel_size' such directions) must be
            # shifted by a different amount. The easiest way to do it is to split
            # the tensor to 'self._kernel_size' smaller tensors, shift each one
            # appropriately, and then sum them together.
            split_shifting_linear_results = jnp.split(
                linear_results_before_shifting, self._kernel_size, axis=-2)

            for i in range(self._kernel_size):
                # Each tensor has to be shifted a different amount.
                if self._padding == 'WRAP':
                    # We can shift by padding and cutting. With 'wrap' padding we
                    # essentially have a torus.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding, mode='wrap')
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                elif self._padding == 'SAME':
                    # We can shift by padding and cutting.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding)
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                    # TODO(jaszczur): improve efficiency by not padding things to cut
                elif self._padding == 'VALID':
                    # We don't need to shift - just cut the leftmost and rightmost values.
                    cut_left = (self._kernel_size - 1) - i
                    cut_right = split_shifting_linear_results[i].shape[-3] - i
                        i] = split_shifting_linear_results[i][
                            ..., cut_left:cut_right, :, :]
                    raise ValueError(f'Invalid padding {self._padding}')
            # After shifting.
            shifted_linear_results = jnp.concatenate(
                split_shifting_linear_results, axis=-2)
            linear_result = jnp.sum(shifted_linear_results, axis=-2)

        if self._use_bias:
            return linear_result + b
            return linear_result
Ejemplo n.º 5
 def forward(self, inputs):
     return tuple(jnp.split(inputs, self._n_items, self._axis))
Ejemplo n.º 6
 def forward(self, inputs):
     """Executes this layer as part of a forward pass through the model."""
     return tuple(jnp.split(inputs, self._n_items, self._axis))
Ejemplo n.º 7
 def _unshard_fn(x):
     y = jax.lax.all_gather(x, 'batch', axis_index_groups=groups)
     split_y = jnp.split(y, n_shards, axis=0)
     split_y = [jnp.squeeze(sy, axis=0) for sy in split_y]
     axis = _axis_to_shard_heuristic(split_y[0].shape)
     return jnp.concatenate(split_y, axis=axis)
Ejemplo n.º 8
 def _f(x, axis=-1):  # pylint: disable=invalid-name
     size = x.shape[axis]
     assert size % 2 == 0, f'axis {axis} of size {size} is not be divisible by 2'
     a, b = jnp.split(x, 2, axis)
     return a * fastmath.expit(b)