Ejemplo n.º 1
0
Archivo: conv.py Proyecto: chjort/elegy
    def call(self, inputs: np.ndarray) -> np.ndarray:
        """
        Connects ``ConvND`` layer.

        Args:
            inputs: A rank-N+2 array with shape ``[N, spatial_dims, C]``.

        Returns:
            A rank-N+2 array with shape ``[N, spatial_dims, output_channels]``.
        """
        required_rank = self.num_spatial_dims + 2
        if inputs.ndim != required_rank:
            raise ValueError(
                f"Input to ConvND needs to have rank {required_rank}, "
                f"but input has shape {inputs.shape}."
            )

        w_shape = self.kernel_shape + (
            inputs.shape[self.channel_index],
            self.output_channels,
        )

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError(
                "Mask needs to have the same shape as weights. "
                f"Shapes are: {self.mask.shape}, {w_shape}"
            )

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = np.prod(w_shape[:-1])
            stddev = 1.0 / np.sqrt(fan_in_shape)
            w_init = initializers.TruncatedNormal(stddev=stddev)
        w = hooks.get_parameter("w", w_shape, inputs.dtype, initializer=w_init)

        if self.mask is not None:
            w *= self.mask

        out = lax.conv_general_dilated(
            inputs,
            w,
            window_strides=self.stride,
            padding=self.padding,
            lhs_dilation=self.lhs_dilation,
            rhs_dilation=self.kernel_dilation,
            dimension_numbers=self.dimension_numbers,
        )

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels,)
            else:
                bias_shape = (self.output_channels,) + (1,) * self.num_spatial_dims
            b = hooks.get_parameter(
                "b", bias_shape, inputs.dtype, initializer=self.b_init
            )
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out
Ejemplo n.º 2
0
    def call(self, inputs: np.ndarray) -> np.ndarray:
        """
        Computes the transposed convolution of the input.

        Args:
            inputs: A rank-N+2 array with shape ``[N, spatial_dims, C]``.

        Returns:
            A rank-N+2 array with shape ``[N, spatial_dims, output_channels]``.
        """
        required_rank = self.num_spatial_dims + 2
        if inputs.ndim != required_rank:
            raise ValueError(
                f"Input to ConvND needs to have rank {required_rank}, "
                f"but input has shape {inputs.shape}.")

        input_channels = inputs.shape[self.channel_index]
        w_shape = self.kernel_shape + (self.output_channels, input_channels)

        if self.mask is not None and self.mask.shape != w_shape:
            raise ValueError("Mask needs to have the same shape as weights. "
                             f"Shapes are: {self.mask.shape}, {w_shape}")

        w_init = self.w_init
        if w_init is None:
            fan_in_shape = self.kernel_shape + (input_channels, )
            stddev = 1.0 / np.sqrt(np.prod(fan_in_shape))
            w_init = initializers.TruncatedNormal(stddev=stddev)
        w = self.add_parameter("w", lambda: w_init(w_shape, inputs.dtype))

        if self.mask is not None:
            w = w * self.mask

        out = lax.conv_transpose(
            inputs,
            w,
            strides=self.stride,
            padding=self.padding,
            dimension_numbers=self.dimension_numbers,
        )

        if self.with_bias:
            if self.channel_index == -1:
                bias_shape = (self.output_channels, )
            else:
                bias_shape = (
                    self.output_channels, ) + (1, ) * self.num_spatial_dims
            b = self.add_parameter(
                "b", lambda: self.b_init(bias_shape, inputs.dtype))
            b = jnp.broadcast_to(b, out.shape)
            out = out + b

        return out
Ejemplo n.º 3
0
    def __init__(
        self,
        vocab_size: Optional[int] = None,
        embed_dim: Optional[int] = None,
        embedding_matrix: Optional[jnp.ndarray] = None,
        w_init: Optional[initializers.Initializer] = None,
        lookup_style: Union[str, EmbedLookupStyle] = "ARRAY_INDEX",
        name: Optional[str] = None,
    ):
        """
        Constructs an Embed module.

        Args:
        vocab_size: The number of unique tokens to embed. If not provided, an
            existing vocabulary matrix from which ``vocab_size`` can be inferred
            must be provided as ``existing_vocab``.
        embed_dim: Number of dimensions to assign to each embedding. If an
            existing vocabulary matrix initializes the module, this should not be
            provided as it will be inferred.
        embedding_matrix: A matrix-like object equivalent in size to
            ``[vocab_size, embed_dim]``. If given, it is used as the initial value
            for the embedding matrix and neither ``vocab_size`` or ``embed_dim``
            need be given. If they are given, their values are checked to be
            consistent with the dimensions of ``embedding_matrix``.
        w_init: An initializer for the embeddings matrix. As a default,
            embeddings are initialized via a truncated normal distribution.
        lookup_style: One of the enum values of :class:`EmbedLookupStyle`
            determining how to access the value of the embbeddings given an ID.
            Regardless the input should be a dense array of integer values
            representing ids. This setting changes how internally this module maps
            those ides to embeddings. The result is the same, but the speed and
            memory tradeoffs are different. It default to using numpy-style array
            indexing. This value is only the default for the module, and at any
            given invocation can be overriden in :meth:`__call__`.
        name: Optional name for this module.

        Raises:
        ValueError: If none of ``embed_dim``, ``embedding_matrix`` and
            ``vocab_size`` are supplied, or if ``embedding_matrix`` is supplied
            and ``embed_dim`` or ``vocab_size`` is not consistent with the
            supplied matrix.
        """
        super().__init__(name=name)
        if embedding_matrix is None and not (vocab_size and embed_dim):
            raise ValueError(
                "Embedding must be supplied either with an initial `embedding_matrix` "
                "or with `embed_dim` and `vocab_size`.")
        if embedding_matrix is not None:
            embedding_matrix = jnp.asarray(embedding_matrix)
            if vocab_size and embedding_matrix.shape[0] != vocab_size:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `vocab_size` of "
                    f"{vocab_size} was not consistent with its shape "
                    f"{embedding_matrix.shape}.")
            if embed_dim and embedding_matrix.shape[1] != embed_dim:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `embed_dim` of "
                    f"{embed_dim} was not consistent with its shape "
                    f"{embedding_matrix.shape}.")
            self.embeddings = hooks.get_parameter(
                "embeddings",
                embedding_matrix.shape,
                initializer=lambda _, __: embedding_matrix,
            )
        else:
            assert embed_dim is not None
            assert vocab_size is not None

            w_init = w_init or initializers.TruncatedNormal()
            self.embeddings = hooks.get_parameter("embeddings",
                                                  [vocab_size, embed_dim],
                                                  initializer=w_init)

        self.vocab_size = vocab_size or embedding_matrix.shape[0]
        self.embed_dim = embed_dim or embedding_matrix.shape[1]
        self.lookup_style = lookup_style