def call(self, inputs: np.ndarray) -> np.ndarray: """ Connects ``ConvND`` layer. Args: inputs: A rank-N+2 array with shape ``[N, spatial_dims, C]``. Returns: A rank-N+2 array with shape ``[N, spatial_dims, output_channels]``. """ required_rank = self.num_spatial_dims + 2 if inputs.ndim != required_rank: raise ValueError( f"Input to ConvND needs to have rank {required_rank}, " f"but input has shape {inputs.shape}." ) w_shape = self.kernel_shape + ( inputs.shape[self.channel_index], self.output_channels, ) if self.mask is not None and self.mask.shape != w_shape: raise ValueError( "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {w_shape}" ) w_init = self.w_init if w_init is None: fan_in_shape = np.prod(w_shape[:-1]) stddev = 1.0 / np.sqrt(fan_in_shape) w_init = initializers.TruncatedNormal(stddev=stddev) w = hooks.get_parameter("w", w_shape, inputs.dtype, initializer=w_init) if self.mask is not None: w *= self.mask out = lax.conv_general_dilated( inputs, w, window_strides=self.stride, padding=self.padding, lhs_dilation=self.lhs_dilation, rhs_dilation=self.kernel_dilation, dimension_numbers=self.dimension_numbers, ) if self.with_bias: if self.channel_index == -1: bias_shape = (self.output_channels,) else: bias_shape = (self.output_channels,) + (1,) * self.num_spatial_dims b = hooks.get_parameter( "b", bias_shape, inputs.dtype, initializer=self.b_init ) b = jnp.broadcast_to(b, out.shape) out = out + b return out
def call(self, inputs: np.ndarray) -> np.ndarray: """ Computes the transposed convolution of the input. Args: inputs: A rank-N+2 array with shape ``[N, spatial_dims, C]``. Returns: A rank-N+2 array with shape ``[N, spatial_dims, output_channels]``. """ required_rank = self.num_spatial_dims + 2 if inputs.ndim != required_rank: raise ValueError( f"Input to ConvND needs to have rank {required_rank}, " f"but input has shape {inputs.shape}.") input_channels = inputs.shape[self.channel_index] w_shape = self.kernel_shape + (self.output_channels, input_channels) if self.mask is not None and self.mask.shape != w_shape: raise ValueError("Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {w_shape}") w_init = self.w_init if w_init is None: fan_in_shape = self.kernel_shape + (input_channels, ) stddev = 1.0 / np.sqrt(np.prod(fan_in_shape)) w_init = initializers.TruncatedNormal(stddev=stddev) w = self.add_parameter("w", lambda: w_init(w_shape, inputs.dtype)) if self.mask is not None: w = w * self.mask out = lax.conv_transpose( inputs, w, strides=self.stride, padding=self.padding, dimension_numbers=self.dimension_numbers, ) if self.with_bias: if self.channel_index == -1: bias_shape = (self.output_channels, ) else: bias_shape = ( self.output_channels, ) + (1, ) * self.num_spatial_dims b = self.add_parameter( "b", lambda: self.b_init(bias_shape, inputs.dtype)) b = jnp.broadcast_to(b, out.shape) out = out + b return out
def __init__( self, vocab_size: Optional[int] = None, embed_dim: Optional[int] = None, embedding_matrix: Optional[jnp.ndarray] = None, w_init: Optional[initializers.Initializer] = None, lookup_style: Union[str, EmbedLookupStyle] = "ARRAY_INDEX", name: Optional[str] = None, ): """ Constructs an Embed module. Args: vocab_size: The number of unique tokens to embed. If not provided, an existing vocabulary matrix from which ``vocab_size`` can be inferred must be provided as ``existing_vocab``. embed_dim: Number of dimensions to assign to each embedding. If an existing vocabulary matrix initializes the module, this should not be provided as it will be inferred. embedding_matrix: A matrix-like object equivalent in size to ``[vocab_size, embed_dim]``. If given, it is used as the initial value for the embedding matrix and neither ``vocab_size`` or ``embed_dim`` need be given. If they are given, their values are checked to be consistent with the dimensions of ``embedding_matrix``. w_init: An initializer for the embeddings matrix. As a default, embeddings are initialized via a truncated normal distribution. lookup_style: One of the enum values of :class:`EmbedLookupStyle` determining how to access the value of the embbeddings given an ID. Regardless the input should be a dense array of integer values representing ids. This setting changes how internally this module maps those ides to embeddings. The result is the same, but the speed and memory tradeoffs are different. It default to using numpy-style array indexing. This value is only the default for the module, and at any given invocation can be overriden in :meth:`__call__`. name: Optional name for this module. Raises: ValueError: If none of ``embed_dim``, ``embedding_matrix`` and ``vocab_size`` are supplied, or if ``embedding_matrix`` is supplied and ``embed_dim`` or ``vocab_size`` is not consistent with the supplied matrix. """ super().__init__(name=name) if embedding_matrix is None and not (vocab_size and embed_dim): raise ValueError( "Embedding must be supplied either with an initial `embedding_matrix` " "or with `embed_dim` and `vocab_size`.") if embedding_matrix is not None: embedding_matrix = jnp.asarray(embedding_matrix) if vocab_size and embedding_matrix.shape[0] != vocab_size: raise ValueError( "An `embedding_matrix` was supplied but the `vocab_size` of " f"{vocab_size} was not consistent with its shape " f"{embedding_matrix.shape}.") if embed_dim and embedding_matrix.shape[1] != embed_dim: raise ValueError( "An `embedding_matrix` was supplied but the `embed_dim` of " f"{embed_dim} was not consistent with its shape " f"{embedding_matrix.shape}.") self.embeddings = hooks.get_parameter( "embeddings", embedding_matrix.shape, initializer=lambda _, __: embedding_matrix, ) else: assert embed_dim is not None assert vocab_size is not None w_init = w_init or initializers.TruncatedNormal() self.embeddings = hooks.get_parameter("embeddings", [vocab_size, embed_dim], initializer=w_init) self.vocab_size = vocab_size or embedding_matrix.shape[0] self.embed_dim = embed_dim or embedding_matrix.shape[1] self.lookup_style = lookup_style