Example #1
0
    def __call__(
        self,
        input_ids: jnp.ndarray,
        input_mask: jnp.ndarray,
        type_ids: jnp.ndarray,
        *,
        deterministic: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Applies BERT model on the inputs."""

        word_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(input_ids)
        type_embeddings = self.type_embeddings(type_ids)

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = self.embeddings_layer_norm(embeddings)
        embeddings = self.embeddings_dropout(embeddings,
                                             deterministic=deterministic)

        hidden_states = embeddings

        mask = input_mask.astype(jnp.int32)
        for transformer_block in self.encoder_layers:
            hidden_states = transformer_block(hidden_states,
                                              mask,
                                              deterministic=deterministic)
        pooled_output = self.pooler(hidden_states[:, 0])
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output
def mean_absolute_percentage_error(y_true: jnp.ndarray,
                                   y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean absolute percentage error (MAPE) between labels and predictions.

    After computing the absolute distance between the true value and the prediction value
    and divide by the true value, the mean value over the last dimension is returned.

    Usage:

    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_absolute_percentage_error(y_true, y_pred)

    assert loss.shape == (2,)

    assert jnp.array_equal(loss, 100. * jnp.mean(jnp.abs((y_pred - y_true) / jnp.clip(y_true, types.EPSILON, None))))
    ```

    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.

    Returns:
        Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`.
    """
    y_true = y_true.astype(y_pred.dtype)
    diff = jnp.abs(
        (y_pred - y_true) / jnp.maximum(jnp.abs(y_true), types.EPSILON))
    return 100.0 * jnp.mean(diff, axis=-1)
Example #3
0
    def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray:
        # Normalize the input
        x = x.astype(jnp.float32) / 255.0

        # Block 1
        x = linen.Conv(32, [3, 3], strides=[2, 2])(x)
        x = linen.Dropout(0.05, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 2
        x = linen.Conv(64, [3, 3], strides=[2, 2])(x)
        x = linen.BatchNorm(use_running_average=not training)(x)
        x = linen.Dropout(0.1, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 3
        x = linen.Conv(128, [3, 3], strides=[2, 2])(x)

        # Global average pooling
        x = x.mean(axis=(1, 2))

        # Classification layer
        x = linen.Dense(10)(x)

        return x
Example #4
0
        def call(self, image: jnp.ndarray, training: bool):
            @elegy.to_module
            def ConvBlock(x, units, kernel, stride=1):
                x = elegy.nn.Conv2D(units,
                                    kernel,
                                    stride=stride,
                                    padding="same")(x)
                x = elegy.nn.BatchNormalization()(x, training)
                x = elegy.nn.Dropout(0.2)(x, training)
                return jax.nn.relu(x)

            x: np.ndarray = image.astype(jnp.float32) / 255.0

            # base
            x = ConvBlock()(x, 32, [3, 3])
            x = ConvBlock()(x, 64, [3, 3], stride=2)
            x = ConvBlock()(x, 64, [3, 3], stride=2)
            x = ConvBlock()(x, 128, [3, 3], stride=2)

            # GlobalAveragePooling2D
            x = jnp.mean(x, axis=[1, 2])

            # 1x1 Conv
            x = elegy.nn.Linear(10)(x)

            return x
Example #5
0
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # Normalize the input
        x = x.astype(jnp.float32) / 255.0

        # Block 1
        x = eg.Conv(32, [3, 3], strides=[2, 2])(x)
        x = eg.Dropout(0.05)(x)
        x = jax.nn.relu(x)

        # Block 2
        x = eg.Conv(64, [3, 3], strides=[2, 2])(x)
        x = eg.BatchNorm()(x)
        x = eg.Dropout(0.1)(x)
        x = jax.nn.relu(x)

        # Block 3
        x = eg.Conv(128, [3, 3], strides=[2, 2])(x)

        # Global average pooling
        x = x.mean(axis=(1, 2))

        # Classification layer
        x = eg.Linear(10)(x)

        return x
Example #6
0
def precision(
    y_true: jnp.ndarray,
    y_pred: jnp.ndarray,
    threshold: jnp.ndarray,
    class_id: jnp.ndarray,
    sample_weight: jnp.ndarray,
    true_positives: ReduceConfusionMatrix,
    false_positives: ReduceConfusionMatrix,
) -> jnp.ndarray:

    # TODO: class_id behavior
    y_pred = (y_pred > threshold).astype(jnp.float32)

    if y_true.dtype != y_pred.dtype:
        y_pred = y_pred.astype(y_true.dtype)

    true_positives = true_positives(y_true=y_true,
                                    y_pred=y_pred,
                                    sample_weight=sample_weight)
    false_positives = false_positives(y_true=y_true,
                                      y_pred=y_pred,
                                      sample_weight=sample_weight)

    return jnp.nan_to_num(
        jnp.divide(true_positives, true_positives + false_positives))
Example #7
0
def quantize(x: jnp.ndarray, to_type: str):
    assert to_type in ["float16", "float32", "uint16", "uint8"]

    if "int" in to_type:
        max_int = 2 ** 8 - 1 if to_type == "uint8" else 2 ** 16 - 1
        return to_type, int_quantize_jit(x, max_int, to_type)
    else:
        return to_type, x.astype(to_type)
Example #8
0
    def fit(
        self,
        df: pd.DataFrame,
        sampler: str = "NUTS",
        rng_key: np.ndarray = None,
        sampler_kwargs: typing.Dict[str, typing.Any] = None,
        **mcmc_kwargs,
    ):
        """Fit the model to a DataFrame.
        
        Parameters
        ----------
        df : pd.DataFrame
            Source dataframe.
        sampler : str
            Numpyro sampler name. Default NUTS
        rng_key : two-element ndarray.
            Optional rng key, will be randomly splitted if not provided.
        sampler_kwargs :
            Passed to the numpyro sampler selected.
        **mcmc_kwargs :
            Passed to numpyro.infer.MCMC

        Returns
        -------
        Model
            The fitted model.

        """
        if self.fitted:
            raise exceptions.AlreadyFittedError(self)

        if sampler.upper() not in ("NUTS", "HMC"):
            raise ValueError("Invalid sampler, try NUTS or HMC.")

        sampler = getattr(infer, sampler.upper())

        # store fit df
        self.df = df

        # set up mcmc
        _mcmc_kwargs = dict(num_warmup=500, num_samples=1000)
        _mcmc_kwargs.update(mcmc_kwargs)
        _sampler_kwargs = dict(model=self.model)
        _sampler_kwargs.update(sampler_kwargs or {})
        mcmc = infer.MCMC(sampler(**_sampler_kwargs), **_mcmc_kwargs)

        # do it
        rng_key_ = (self.split_rand_key()
                    if rng_key is None else rng_key.astype("uint32"))
        mcmc.run(rng_key_, df=df)

        # store results
        self.samples = mcmc.get_samples(group_by_chain=True)
        self.fitted = True

        return self
Example #9
0
 def __call__(self, x: jnp.ndarray):
     x = x.astype(jnp.float32) / 255.0
     x = einops.rearrange(x, "batch ... -> batch (...)")
     x = eg.nn.Linear(self.n1)(x)
     x = jax.nn.relu(x)
     x = eg.nn.Linear(self.n2)(x)
     x = jax.nn.relu(x)
     x = eg.nn.Linear(10)(x)
     return x
Example #10
0
    def sample_posterior_predictive(
        self,
        df: pd.DataFrame,
        hdpi: bool = False,
        hdpi_interval: float = 0.9,
        rng_key: np.ndarray = None,
    ) -> typing.Union[pd.Series, pd.DataFrame]:
        """Obtain samples from the posterior predictive.
        
        Parameters
        ----------
        df : pd.DataFrame
            Source dataframe.
        hdpi : bool
            Option to include lower/upper bound of the highest posterior density 
            interval. Returns a dataframe if true, a series if false. Default False.
        hdpi_interval : float
            HDPI width. Default 0.9.
        rng_key : two-element ndarray.
            Optional rng key, will be randomly splitted if not provided.

        Returns
        -------
        pd.Series or pd.DataFrame
            Forecasts. Will be a series with the name of the dv if no HDPI. Will be a
            dataframe if HDPI is included.

        """
        # get rng key
        rng_key_ = (self.split_rand_key()
                    if rng_key is None else rng_key.astype("uint32"))

        # check for nulls
        null_cols = columns_with_null_data(self.transform(df))
        if null_cols:
            raise exceptions.NullDataFound(*null_cols)

        #  do it
        predictions = infer.Predictive(self.model,
                                       self.samples_flat)(rng_key_,
                                                          df=df)[self.dv]

        if not hdpi:
            return pd.Series(predictions.mean(axis=0),
                             index=df.index,
                             name=self.dv)

        hdpi = diagnostics.hpdi(predictions, hdpi_interval)

        return pd.DataFrame(
            {
                self.dv: predictions.mean(axis=0),
                "hdpi_lower": hdpi[0, :],
                "hdpi_upper": hdpi[1, :],
            },
            index=df.index,
        )
Example #11
0
def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
    # [y_pred, y_true], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values(
    #     [y_pred, y_true]
    # )
    # y_pred.shape.assert_is_compatible_with(y_true.shape)

    if y_true.dtype != y_pred.dtype:
        y_pred = y_pred.astype(y_true.dtype)

    return (y_true == y_pred).astype(jnp.float32)
Example #12
0
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape)
        logits = Decoder(self._hidden_size, self._output_shape)(z)

        p = jax.nn.sigmoid(logits)
        image = jax.random.bernoulli(hk.next_rng_key(), p)

        return VAEOutput(image, mean, stddev, logits)
Example #13
0
    def __call__(self, x: jnp.ndarray) -> dict:
        next_key = eg.KeySeq()
        x = x.astype(jnp.float32)

        z = Encoder(self.hidden_size, self.latent_size)(x)
        logits = Decoder(self.hidden_size, self.output_shape)(z)

        p = jax.nn.sigmoid(logits)
        image = jax.random.bernoulli(next_key(), p)

        return dict(image=image, logits=logits, det_image=p)
Example #14
0
        def __call__(self, x: jnp.ndarray):
            x = x.astype(jnp.float32) / 255.0

            x = eg.Flatten()(x)
            x = eg.Linear(self.n1)(x)
            x = jax.nn.relu(x)
            x = eg.Linear(self.n2)(x)
            x = jax.nn.relu(x)
            x = eg.Linear(10)(x)

            return x
Example #15
0
    def call(self, image: jnp.ndarray, training: bool):
        x = image.astype(jnp.float32) / 255.0

        x = jnp.reshape(x, [x.shape[0], -1])
        x = elegy.nn.Linear(self.n1)(x)
        x = elegy.nn.BatchNormalization()(x)
        x = jax.nn.relu(x)

        x = elegy.nn.Linear(self.n2)(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(10)(x)

        return x
Example #16
0
        def call(self, image: jnp.ndarray):

            image = image.astype(jnp.float32) / 255.0

            mlp = hk.Sequential([
                hk.Flatten(),
                hk.Linear(self.n1),
                jax.nn.relu,
                hk.Linear(self.n2),
                jax.nn.relu,
                hk.Linear(10),
            ])
            return dict(outputs=mlp(image))
Example #17
0
        def call(self, image: jnp.ndarray):
            image = image.astype(jnp.float32) / 255.0

            mlp = elegy.nn.sequential(
                elegy.nn.Flatten(),
                elegy.nn.Linear(self.n1),
                jax.nn.relu,
                elegy.nn.Linear(self.n2),
                jax.nn.relu,
                elegy.nn.Linear(10),
            )

            return mlp(image)
Example #18
0
    def __call__(self, image: jnp.ndarray):
        x = image.astype(jnp.float32) / 255.0
        x = eg.Flatten()(x)
        x = eg.Linear(self.n1)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(self.n2)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(self.n1)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(np.prod(image.shape[-2:]))(x)
        x = jax.nn.sigmoid(x) * 255
        x = x.reshape(image.shape)

        return x
Example #19
0
 def call(self, image: jnp.ndarray):
     image = image.astype(jnp.float32) / 255.0
     x = elegy.nn.Flatten()(image)
     x = elegy.nn.sequential(
         elegy.nn.Linear(self.n1),
         jax.nn.relu,
         elegy.nn.Linear(self.n2),
         jax.nn.relu,
         elegy.nn.Linear(self.n1),
         jax.nn.relu,
         elegy.nn.Linear(x.shape[-1]),
         jax.nn.sigmoid,
     )(x)
     return x.reshape(image.shape) * 255
Example #20
0
    def call(self, x: jnp.ndarray):
        batch_size = x.shape[0]

        # normalize data
        x = x.astype(jnp.float32) / 255.0

        # make patch embeddings
        x = einops.rearrange(x,
                             "batch (h1 h2) (w1 w2) -> batch (h1 w1) (h2 w2)",
                             h2=7,
                             w2=7)
        x = elegy.nn.Linear(self.size)(x)

        # add predict token
        predict_token = jnp.zeros(shape=[batch_size, 1, self.size])
        x = jnp.concatenate([predict_token, x], axis=1)

        # create positional embeddings
        positional_embeddings = self.add_parameter(
            "positional_embeddings",
            lambda: elegy.initializers.TruncatedNormal()
            (x.shape[-2:], jnp.float32),
        )
        positional_embeddings = einops.repeat(
            positional_embeddings,
            "... -> batch ...",
            batch=batch_size,
        )

        # add positional embeddings
        x = x + positional_embeddings

        # apply N transformers encoder layers
        x = elegy.nn.transformers.TransformerEncoder(
            lambda: elegy.nn.transformers.TransformerEncoderLayer(
                head_size=self.size,
                num_heads=self.num_heads,
                dropout=self.dropout,
            ),
            num_layers=self.num_layers,
        )(x)

        # get predict output token
        x = x[:, 0]

        # apply predict head
        logits = elegy.nn.Linear(10)(x)

        return logits
Example #21
0
    def __call__(self, state: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        kernel_initializer = jax.nn.initializers.glorot_uniform()

        # Preprocess inputs
        a = action.reshape(-1)  # flatten
        x = state.astype(jnp.float32)
        x = x.reshape(-1)  # flatten
        x = jnp.concatenate((x, a))

        for _ in range(self.num_layers):
            x = nn.Dense(features=self.hidden_units,
                         kernel_init=kernel_initializer)(x)
            x = nn.relu(x)

        return nn.Dense(features=1, kernel_init=kernel_initializer)(x)
Example #22
0
def _samplewise_log_loss(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
    """Based on: https://github.com/scikit-learn/scikit-learn/blob/ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn
    /metrics/_classification.py#L2123"""
    if y_true.ndim == 0:  # If no dimension binary classification problem
        y_true = y_true.reshape(1)[:, jnp.newaxis]
        y_pred = y_pred.reshape(1)[:, jnp.newaxis]
    if y_true.shape[0] == 1:  # Reshuffle data to compute log loss correctly
        y_true = jnp.append(1 - y_true, y_true)
        y_pred = jnp.append(1 - y_pred, y_pred)

    # Clipping
    eps = 1e-15
    y_pred = y_pred.astype(jnp.float32).clip(eps, 1 - eps)

    loss = (y_true * -jnp.log(y_pred)).sum()
    return loss
Example #23
0
    def actor(self, state: jnp.ndarray, key: jnp.ndarray) -> SacActorOutput:
        """Calls the SAC actor network.

    This can be called using network_def.apply(..., method=network_def.actor).

    Args:
      state: An input state.
      key: A PRNGKey to use to sample an action from the actor's output
        distribution.

    Returns:
      A named tuple containing a sampled action, the mean action, and the
        likelihood of the sampled action.
    """
        # Preprocess inputs
        x = state.astype(jnp.float32)
        x = x.reshape(-1)  # flatten

        for layer in self._actor_layers:
            x = layer(x)
            x = nn.relu(x)

        # Note we are only producing a diagonal covariance matrix, not a full
        # covariance matrix as it is difficult to ensure that it would be PSD.
        loc_and_scale_diag = self._actor_final_layer(x)
        loc, scale_diag = jnp.split(loc_and_scale_diag, 2)
        # Exponentiate to only get positive terms.
        scale_diag = jnp.exp(scale_diag)
        dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
        if self.action_limits is None:
            mode = dist.mode()
        else:
            lower_action_limit = jnp.asarray(self.action_limits[0],
                                             dtype=jnp.float32)
            upper_action_limit = jnp.asarray(self.action_limits[1],
                                             dtype=jnp.float32)
            mean = (lower_action_limit + upper_action_limit) / 2.0
            magnitude = (upper_action_limit - lower_action_limit) / 2.0
            mode = magnitude * jnp.tanh(dist.mode()) + mean
            dist = _transform_distribution(dist, mean, magnitude)
        sampled_action = dist.sample(seed=key)
        action_probability = dist.log_prob(sampled_action)

        mode = jnp.reshape(mode, self.action_shape)
        sampled_action = jnp.reshape(sampled_action, self.action_shape)

        return SacActorOutput(mode, sampled_action, action_probability)
Example #24
0
    def __call__(self, x: jnp.ndarray):
        # normalize
        x = x.astype(jnp.float32) / 255.0

        # base
        x = ConvBlock()(x, 32, (3, 3))
        x = ConvBlock()(x, 64, (3, 3), stride=2)
        x = ConvBlock()(x, 64, (3, 3), stride=2)
        x = ConvBlock()(x, 128, (3, 3), stride=2)

        # GlobalAveragePooling2D
        x = jnp.mean(x, axis=(1, 2))

        # 1x1 Conv
        x = eg.Linear(10)(x)

        return x
Example #25
0
def learning_schedule(global_step: jnp.ndarray,
                      batch_size: int,
                      base_learning_rate: float,
                      total_steps: int,
                      warmup_steps: int) -> float:
  """Cosine learning rate scheduler."""
  # Compute LR & Scaled LR
  scaled_lr = base_learning_rate * batch_size / 256.
  learning_rate = (
      global_step.astype(jnp.float32) / int(warmup_steps) *
      scaled_lr if warmup_steps > 0 else scaled_lr)

  # Cosine schedule after warmup.
  return jnp.where(
      global_step < warmup_steps, learning_rate,
      _cosine_decay(global_step - warmup_steps, total_steps - warmup_steps,
                    scaled_lr))
Example #26
0
    def _encode_bow(self, bow: jnp.ndarray) -> jnp.ndarray:
        """Encode the bag-of-words into tensors that can be used by the transormer.

    Args:
      bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector.

    Returns:
      embeddings: [batch_size, bow_n_tokens, bow_embedding_dim] tensor.
    """
        batch_size = bow.shape[0]
        bow = bow.astype(jnp.float32)

        # [B, D * n]
        embeddings = hk.Linear(self._bow_embedding_dim *
                               self._bow_n_tokens)(bow)
        embeddings = transformer_block.layer_norm(jax.nn.gelu(embeddings))
        return jnp.reshape(
            embeddings,
            [batch_size, self._bow_n_tokens, self._bow_embedding_dim])
Example #27
0
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)

        # q(z|x) = N(mean(x), covariance(x))
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        variational_distrib = distrax.MultivariateNormalDiag(loc=mean,
                                                             scale_diag=stddev)
        z = variational_distrib.sample(seed=hk.next_rng_key())

        # p(x|z) = \Prod Bernoulli(logits(z))
        logits = Decoder(self._hidden_size, self._output_shape)(z)
        likelihood_distrib = distrax.Independent(
            distrax.Bernoulli(logits=logits),
            reinterpreted_batch_ndims=len(
                self._output_shape))  # 3 non-batch dims

        # Generate images from the likelihood
        image = likelihood_distrib.sample(seed=hk.next_rng_key())

        return VAEOutput(variational_distrib, likelihood_distrib, image)
Example #28
0
def learning_schedule(
    global_step: jnp.ndarray,
    base_learning_rate: float,
    total_steps: int,
    warmup_steps: int,
    use_schedule: bool,
) -> float:
    """Cosine learning rate scheduler."""
    # Compute LR & Scaled LR
    if not use_schedule:
        return base_learning_rate
    warmup_learning_rate = (global_step.astype(jnp.float32) /
                            int(warmup_steps) * base_learning_rate
                            if warmup_steps > 0 else base_learning_rate)

    # Cosine schedule after warmup.
    decay_learning_rate = _cosine_decay(global_step - warmup_steps,
                                        total_steps - warmup_steps,
                                        base_learning_rate)
    return jnp.where(global_step < warmup_steps, warmup_learning_rate,
                     decay_learning_rate)
def mean_squared_logarithmic_error(y_true: jnp.ndarray,
                                   y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean squared logarithmic error between labels and predictions.

    ```python
    loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)
    ```

    Usage:

    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_squared_logarithmic_error(y_true, y_pred)

    assert loss.shape == (2,)

    first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0)
    second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0)
    assert jnp.array_equal(loss, jnp.mean(jnp.square(first_log - second_log), axis=-1))
    ```

    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.

    Returns:
        Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`.
    """

    y_true = y_true.astype(y_pred.dtype)
    first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0)
    second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0)

    return jnp.mean(jnp.square(first_log - second_log), axis=-1)
Example #30
0
def mean_squared_error(y_true: jnp.ndarray,
                       y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean squared error between labels and predictions.
    
    After computing the squared distance between the inputs, the mean value over
    the last dimension is returned.
    
    ```python
    loss = mean(square(y_true - y_pred), axis=-1)
    ```

    Usage:
    
    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_squared_error(y_true, y_pred)

    assert loss.shape == (2,)

    assert jnp.array_equal(loss, jnp.mean(jnp.square(y_true - y_pred), axis=-1))
    ```
    
    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
    
    Returns:
        Mean squared error values. shape = `[batch_size, d0, .. dN-1]`.
    """

    y_true = y_true.astype(y_pred.dtype)

    return jnp.mean(jnp.square(y_pred - y_true), axis=-1)