Ejemplo n.º 1
0
Archivo: conv.py Proyecto: chjort/elegy
    def call(self, inputs: np.ndarray) -> np.ndarray:
        """
        Connects ``ConvND`` layer.

        Args:
            inputs: A rank-N+2 array with shape ``[N, spatial_dims, C]``.

        Returns:
            A rank-N+2 array with shape ``[N, spatial_dims, output_channels]``.
        """
        required_rank = self.num_spatial_dims + 2
        if inputs.ndim != required_rank:
            raise ValueError(
                f"Input to ConvND needs to have rank {required_rank}, "
                f"but input has shape {inputs.shape}."
            )

        w_shape = self.kernel_shape + (
            inputs.shape[self.channel_index],
            self.output_channels,
        )

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError(
                "Mask needs to have the same shape as weights. "
                f"Shapes are: {self.mask.shape}, {w_shape}"
            )

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = np.prod(w_shape[:-1])
            stddev = 1.0 / np.sqrt(fan_in_shape)
            w_init = initializers.TruncatedNormal(stddev=stddev)
        w = hooks.get_parameter("w", w_shape, inputs.dtype, initializer=w_init)

        if self.mask is not None:
            w *= self.mask

        out = lax.conv_general_dilated(
            inputs,
            w,
            window_strides=self.stride,
            padding=self.padding,
            lhs_dilation=self.lhs_dilation,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=self.dimension_numbers,
        )

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels,)
            else:
                bias_shape = (self.output_channels,) + (1,) * self.num_spatial_dims
            b = hooks.get_parameter(
                "b", bias_shape, inputs.dtype, initializer=self.b_init
            )
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out
Ejemplo n.º 2
0
    def call(self, inputs: np.ndarray) -> np.ndarray:
        """"""
        if not inputs.shape:
            raise ValueError("Input must not be scalar.")

        input_size = self.input_size = inputs.shape[-1]
        output_size = self.output_size
        dtype = inputs.dtype

        w_init = self.w_init

        if w_init is None:
            stddev = 1.0 / np.sqrt(self.input_size)
            w_init = TruncatedNormal(stddev=stddev)

        w = hooks.get_parameter("w", [input_size, output_size],
                                dtype,
                                initializer=w_init)

        out = jnp.dot(inputs, w)

        if self.with_bias:
            b = hooks.get_parameter("b", [self.output_size],
                                    dtype,
                                    initializer=self.b_init)
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out
Ejemplo n.º 3
0
    def call(
        self,
        inputs: jnp.ndarray,
        scale: Optional[jnp.ndarray] = None,
        offset: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Connects the layer norm.

        Args:
          inputs: An array, where the data format is ``[N, ..., C]``.
          scale: An array up to n-D. The shape of this tensor must be broadcastable
            to the shape of ``inputs``. This is the scale applied to the normalized
            inputs. This cannot be passed in if the module was constructed with
            ``create_scale=True``.
          offset: An array up to n-D. The shape of this tensor must be broadcastable
            to the shape of ``inputs``. This is the offset applied to the normalized
            inputs. This cannot be passed in if the module was constructed with
            ``create_offset=True``.

        Returns:
          The array, normalized.
        """
        if self.create_scale and scale is not None:
            raise ValueError("Cannot pass `scale` at call time if `create_scale=True`.")
        if self.create_offset and offset is not None:
            raise ValueError(
                "Cannot pass `offset` at call time if `create_offset=True`."
            )

        axis = self.axis
        if isinstance(axis, slice):
            axis = tuple(range(inputs.ndim)[axis])

        mean = jnp.mean(inputs, axis=axis, keepdims=True)
        variance = jnp.var(inputs, axis=axis, keepdims=True)

        param_shape = inputs.shape[-1:]
        if self.create_scale:
            scale = hooks.get_parameter(
                "scale", param_shape, jnp.float32, initializer=self.scale_init
            )
        elif scale is None:
            scale = np.array(1.0, dtype=inputs.dtype)

        if self.create_offset:
            offset = hooks.get_parameter(
                "offset", param_shape, jnp.float32, initializer=self.offset_init
            )
        elif offset is None:
            offset = np.array(0.0, dtype=inputs.dtype)

        scale = jnp.broadcast_to(scale, inputs.shape)
        offset = jnp.broadcast_to(offset, inputs.shape)
        mean = jnp.broadcast_to(mean, inputs.shape)

        inv = scale * jax.lax.rsqrt(variance + self.eps)
        return inv * (inputs - mean) + offset
Ejemplo n.º 4
0
    def call(
        self,
        inputs: jnp.ndarray,
        training: tp.Optional[bool] = None,
        test_local_stats: bool = False,
        scale: Optional[jnp.ndarray] = None,
        offset: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Computes the normalized version of the input.

        Args:
            inputs: An array, where the data format is ``[..., C]``.
            training: Whether training is currently happening.
            test_local_stats: Whether local stats are used when training=False.
            scale: An array up to n-D. The shape of this tensor must be broadcastable
                to the shape of ``inputs``. This is the scale applied to the normalized
                inputs. This cannot be passed in if the module was constructed with
                ``create_scale=True``.
            offset: An array up to n-D. The shape of this tensor must be broadcastable
                to the shape of ``inputs``. This is the offset applied to the normalized
                inputs. This cannot be passed in if the module was constructed with
                ``create_offset=True``.

        Returns:
            The array, normalized across all but the last dimension.
        """
        if training is None:
            training = hooks.is_training()

        if self.create_scale and scale is not None:
            raise ValueError("Cannot pass `scale` at call time if `create_scale=True`.")
        if self.create_offset and offset is not None:
            raise ValueError(
                "Cannot pass `offset` at call time if `create_offset=True`."
            )

        channel_index = self.channel_index
        if channel_index < 0:
            channel_index += inputs.ndim

        if self.axis is not None:
            axis = self.axis
        else:
            axis = [i for i in range(inputs.ndim) if i != channel_index]

        if training or test_local_stats:
            cross_replica_axis = self.cross_replica_axis
            if self.cross_replica_axis:
                mean = jnp.mean(inputs, axis, keepdims=True)
                mean = jax.lax.pmean(mean, cross_replica_axis)
                mean_of_squares = jnp.mean(inputs ** 2, axis, keepdims=True)
                mean_of_squares = jax.lax.pmean(mean_of_squares, cross_replica_axis)
                var = mean_of_squares - mean ** 2
            else:
                mean = jnp.mean(inputs, axis, keepdims=True)
                # This uses E[(X - E[X])^2].
                # TODO(tycai): Consider the faster, but possibly less stable
                # E[X^2] - E[X]^2 method.
                var = jnp.var(inputs, axis, keepdims=True)
        else:
            mean = self.mean_ema.average
            var = self.var_ema.average

        if training:
            self.mean_ema(mean)
            self.var_ema(var)

        w_shape = [1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)]
        w_dtype = inputs.dtype

        if self.create_scale:
            scale = hooks.get_parameter("scale", w_shape, w_dtype, self.scale_init)
        elif scale is None:
            scale = np.ones([], dtype=w_dtype)

        if self.create_offset:
            offset = hooks.get_parameter("offset", w_shape, w_dtype, self.offset_init)
        elif offset is None:
            offset = np.zeros([], dtype=w_dtype)

        inv = scale * jax.lax.rsqrt(var + self.eps)
        return (inputs - mean) * inv + offset
Ejemplo n.º 5
0
    def __init__(
        self,
        vocab_size: Optional[int] = None,
        embed_dim: Optional[int] = None,
        embedding_matrix: Optional[jnp.ndarray] = None,
        w_init: Optional[initializers.Initializer] = None,
        lookup_style: Union[str, EmbedLookupStyle] = "ARRAY_INDEX",
        name: Optional[str] = None,
    ):
        """
        Constructs an Embed module.

        Args:
        vocab_size: The number of unique tokens to embed. If not provided, an
            existing vocabulary matrix from which ``vocab_size`` can be inferred
            must be provided as ``existing_vocab``.
        embed_dim: Number of dimensions to assign to each embedding. If an
            existing vocabulary matrix initializes the module, this should not be
            provided as it will be inferred.
        embedding_matrix: A matrix-like object equivalent in size to
            ``[vocab_size, embed_dim]``. If given, it is used as the initial value
            for the embedding matrix and neither ``vocab_size`` or ``embed_dim``
            need be given. If they are given, their values are checked to be
            consistent with the dimensions of ``embedding_matrix``.
        w_init: An initializer for the embeddings matrix. As a default,
            embeddings are initialized via a truncated normal distribution.
        lookup_style: One of the enum values of :class:`EmbedLookupStyle`
            determining how to access the value of the embbeddings given an ID.
            Regardless the input should be a dense array of integer values
            representing ids. This setting changes how internally this module maps
            those ides to embeddings. The result is the same, but the speed and
            memory tradeoffs are different. It default to using numpy-style array
            indexing. This value is only the default for the module, and at any
            given invocation can be overriden in :meth:`__call__`.
        name: Optional name for this module.

        Raises:
        ValueError: If none of ``embed_dim``, ``embedding_matrix`` and
            ``vocab_size`` are supplied, or if ``embedding_matrix`` is supplied
            and ``embed_dim`` or ``vocab_size`` is not consistent with the
            supplied matrix.
        """
        super().__init__(name=name)
        if embedding_matrix is None and not (vocab_size and embed_dim):
            raise ValueError(
                "Embedding must be supplied either with an initial `embedding_matrix` "
                "or with `embed_dim` and `vocab_size`.")
        if embedding_matrix is not None:
            embedding_matrix = jnp.asarray(embedding_matrix)
            if vocab_size and embedding_matrix.shape[0] != vocab_size:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `vocab_size` of "
                    f"{vocab_size} was not consistent with its shape "
                    f"{embedding_matrix.shape}.")
            if embed_dim and embedding_matrix.shape[1] != embed_dim:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `embed_dim` of "
                    f"{embed_dim} was not consistent with its shape "
                    f"{embedding_matrix.shape}.")
            self.embeddings = hooks.get_parameter(
                "embeddings",
                embedding_matrix.shape,
                initializer=lambda _, __: embedding_matrix,
            )
        else:
            assert embed_dim is not None
            assert vocab_size is not None

            w_init = w_init or initializers.TruncatedNormal()
            self.embeddings = hooks.get_parameter("embeddings",
                                                  [vocab_size, embed_dim],
                                                  initializer=w_init)

        self.vocab_size = vocab_size or embedding_matrix.shape[0]
        self.embed_dim = embed_dim or embedding_matrix.shape[1]
        self.lookup_style = lookup_style
Ejemplo n.º 6
0
    def call(
        self,
        query: jnp.ndarray,
        key: tp.Optional[jnp.ndarray] = None,
        value: tp.Optional[jnp.ndarray] = None,
        mask=None,
        training=None,
    ):
        """
        Arguments:
            inputs:  List of `[query, key, value]` where
                * `query`: np.ndarray of shape `(..., query_elements, query_depth)`
                * `key`: `np.ndarray of shape '(..., key_elements, key_depth)`
                * `value`: np.ndarray of shape `(..., key_elements, value_depth)`, optional, if not given `key` will be used.
            mask: a binary np.ndarray of shape `[batch_size?, num_heads?, query_elements, key_elements]`
                which specifies which query elements can attendo to which key elements,
                `1` indicates attention and `0` indicates no attention.
        Output shape:
            * `(..., query_elements, output_size)` if `output_size` is given, else
            * `(..., query_elements, value_depth)` if `value` is given, else
            * `(..., query_elements, key_depth)`
        """

        # einsum nomenclature
        # ------------------------
        # N = query elements
        # M = key/value elements
        # H = heads
        # I = input features
        # O = output features

        if key is None:
            key = query

        if value is None:
            value = key

        output_size = (self.output_size
                       if self.output_size is not None else value.shape[-1])

        # verify shapes
        if key.shape[-2] != value.shape[-2]:
            raise ValueError(
                "the number of elements in 'key' must be equal to the same as the number of elements in 'value'"
            )

        if mask is not None:
            if len(mask.shape) < 2:
                raise ValueError("'mask' must have atleast 2 dimensions")
            if query.shape[-2] != mask.shape[-2]:
                raise ValueError(
                    "mask's second to last dimension must be equal to the number of elements in 'query'"
                )
            if key.shape[-2] != mask.shape[-1]:
                raise ValueError(
                    "mask's last dimension must be equal to the number of elements in 'key'"
                )

        # get weights
        query_kernel = hooks.get_parameter(
            "query_kernel",
            [self.num_heads, query.shape[-1], self.head_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )
        key_kernel = hooks.get_parameter(
            "key_kernel",
            [self.num_heads, key.shape[-1], self.head_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )
        value_kernel = hooks.get_parameter(
            "value_kernel",
            [self.num_heads, value.shape[-1], self.head_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )
        projection_kernel = hooks.get_parameter(
            "projection_kernel",
            [self.num_heads, self.head_size, output_size],
            jnp.float32,
            initializer=self.kernel_initializer,
        )

        # Linear transformations
        query = jnp.einsum("...NI , HIO -> ...NHO", query, query_kernel)
        key = jnp.einsum("...MI , HIO -> ...MHO", key, key_kernel)
        value = jnp.einsum("...MI , HIO -> ...MHO", value, value_kernel)

        # Scale dot-product, doing the division to either query or key
        # instead of their product saves some computation
        query /= jnp.sqrt(self.head_size)

        # Calculate dot product attention
        logits = jnp.einsum("...NHO,...MHO->...HNM", query, key)

        # apply mask
        if mask is not None:
            mask = mask.astype(jnp.float32)

            # possibly expand on the head dimension so broadcasting works
            if len(mask.shape) != len(logits.shape):
                mask = jnp.expand_dims(mask, -3)

            logits += -10e9 * (1.0 - mask)

        attn_coef = jax.nn.softmax(logits)

        # attention dropout
        attn_coef_dropout = Dropout(self.droput_rate)(attn_coef,
                                                      training=training)

        # attention * value
        multihead_output = jnp.einsum("...HNM,...MHI->...NHI",
                                      attn_coef_dropout, value)

        # Run the outputs through another linear projection layer. Recombining heads
        # is automatically done.
        output = jnp.einsum("...NHI,HIO->...NO", multihead_output,
                            projection_kernel)

        if self.use_projection_bias:
            output += hooks.get_parameter(
                "projection_bias",
                [output_size],
                jnp.float32,
                initializer=self.bias_initializer,
            )

        if self.return_attn_coef:
            return output, attn_coef
        else:
            return output