コード例 #1
0
ファイル: modeling.py プロジェクト: nikitakit/sabertooth
    def __call__(
        self,
        input_ids: jnp.ndarray,
        input_mask: jnp.ndarray,
        type_ids: jnp.ndarray,
        *,
        deterministic: bool = False,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Applies BERT model on the inputs."""

        word_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(input_ids)
        type_embeddings = self.type_embeddings(type_ids)

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = self.embeddings_layer_norm(embeddings)
        embeddings = self.embeddings_dropout(embeddings,
                                             deterministic=deterministic)

        hidden_states = embeddings

        mask = input_mask.astype(jnp.int32)
        for transformer_block in self.encoder_layers:
            hidden_states = transformer_block(hidden_states,
                                              mask,
                                              deterministic=deterministic)
        pooled_output = self.pooler(hidden_states[:, 0])
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output
コード例 #2
0
def mean_absolute_percentage_error(y_true: jnp.ndarray,
                                   y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean absolute percentage error (MAPE) between labels and predictions.

    After computing the absolute distance between the true value and the prediction value
    and divide by the true value, the mean value over the last dimension is returned.

    Usage:

    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_absolute_percentage_error(y_true, y_pred)

    assert loss.shape == (2,)

    assert jnp.array_equal(loss, 100. * jnp.mean(jnp.abs((y_pred - y_true) / jnp.clip(y_true, types.EPSILON, None))))
    ```

    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.

    Returns:
        Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`.
    """
    y_true = y_true.astype(y_pred.dtype)
    diff = jnp.abs(
        (y_pred - y_true) / jnp.maximum(jnp.abs(y_true), types.EPSILON))
    return 100.0 * jnp.mean(diff, axis=-1)
コード例 #3
0
ファイル: mnist_conv.py プロジェクト: poets-ai/elegy
    def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray:
        # Normalize the input
        x = x.astype(jnp.float32) / 255.0

        # Block 1
        x = linen.Conv(32, [3, 3], strides=[2, 2])(x)
        x = linen.Dropout(0.05, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 2
        x = linen.Conv(64, [3, 3], strides=[2, 2])(x)
        x = linen.BatchNorm(use_running_average=not training)(x)
        x = linen.Dropout(0.1, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 3
        x = linen.Conv(128, [3, 3], strides=[2, 2])(x)

        # Global average pooling
        x = x.mean(axis=(1, 2))

        # Classification layer
        x = linen.Dense(10)(x)

        return x
コード例 #4
0
        def call(self, image: jnp.ndarray, training: bool):
            @elegy.to_module
            def ConvBlock(x, units, kernel, stride=1):
                x = elegy.nn.Conv2D(units,
                                    kernel,
                                    stride=stride,
                                    padding="same")(x)
                x = elegy.nn.BatchNormalization()(x, training)
                x = elegy.nn.Dropout(0.2)(x, training)
                return jax.nn.relu(x)

            x: np.ndarray = image.astype(jnp.float32) / 255.0

            # base
            x = ConvBlock()(x, 32, [3, 3])
            x = ConvBlock()(x, 64, [3, 3], stride=2)
            x = ConvBlock()(x, 64, [3, 3], stride=2)
            x = ConvBlock()(x, 128, [3, 3], stride=2)

            # GlobalAveragePooling2D
            x = jnp.mean(x, axis=[1, 2])

            # 1x1 Conv
            x = elegy.nn.Linear(10)(x)

            return x
コード例 #5
0
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # Normalize the input
        x = x.astype(jnp.float32) / 255.0

        # Block 1
        x = eg.Conv(32, [3, 3], strides=[2, 2])(x)
        x = eg.Dropout(0.05)(x)
        x = jax.nn.relu(x)

        # Block 2
        x = eg.Conv(64, [3, 3], strides=[2, 2])(x)
        x = eg.BatchNorm()(x)
        x = eg.Dropout(0.1)(x)
        x = jax.nn.relu(x)

        # Block 3
        x = eg.Conv(128, [3, 3], strides=[2, 2])(x)

        # Global average pooling
        x = x.mean(axis=(1, 2))

        # Classification layer
        x = eg.Linear(10)(x)

        return x
コード例 #6
0
def precision(
    y_true: jnp.ndarray,
    y_pred: jnp.ndarray,
    threshold: jnp.ndarray,
    class_id: jnp.ndarray,
    sample_weight: jnp.ndarray,
    true_positives: ReduceConfusionMatrix,
    false_positives: ReduceConfusionMatrix,
) -> jnp.ndarray:

    # TODO: class_id behavior
    y_pred = (y_pred > threshold).astype(jnp.float32)

    if y_true.dtype != y_pred.dtype:
        y_pred = y_pred.astype(y_true.dtype)

    true_positives = true_positives(y_true=y_true,
                                    y_pred=y_pred,
                                    sample_weight=sample_weight)
    false_positives = false_positives(y_true=y_true,
                                      y_pred=y_pred,
                                      sample_weight=sample_weight)

    return jnp.nan_to_num(
        jnp.divide(true_positives, true_positives + false_positives))
コード例 #7
0
ファイル: swarm_layer.py プロジェクト: kingoflolz/swarm-jax
def quantize(x: jnp.ndarray, to_type: str):
    assert to_type in ["float16", "float32", "uint16", "uint8"]

    if "int" in to_type:
        max_int = 2 ** 8 - 1 if to_type == "uint8" else 2 ** 16 - 1
        return to_type, int_quantize_jit(x, max_int, to_type)
    else:
        return to_type, x.astype(to_type)
コード例 #8
0
    def fit(
        self,
        df: pd.DataFrame,
        sampler: str = "NUTS",
        rng_key: np.ndarray = None,
        sampler_kwargs: typing.Dict[str, typing.Any] = None,
        **mcmc_kwargs,
    ):
        """Fit the model to a DataFrame.
        
        Parameters
        ----------
        df : pd.DataFrame
            Source dataframe.
        sampler : str
            Numpyro sampler name. Default NUTS
        rng_key : two-element ndarray.
            Optional rng key, will be randomly splitted if not provided.
        sampler_kwargs :
            Passed to the numpyro sampler selected.
        **mcmc_kwargs :
            Passed to numpyro.infer.MCMC

        Returns
        -------
        Model
            The fitted model.

        """
        if self.fitted:
            raise exceptions.AlreadyFittedError(self)

        if sampler.upper() not in ("NUTS", "HMC"):
            raise ValueError("Invalid sampler, try NUTS or HMC.")

        sampler = getattr(infer, sampler.upper())

        # store fit df
        self.df = df

        # set up mcmc
        _mcmc_kwargs = dict(num_warmup=500, num_samples=1000)
        _mcmc_kwargs.update(mcmc_kwargs)
        _sampler_kwargs = dict(model=self.model)
        _sampler_kwargs.update(sampler_kwargs or {})
        mcmc = infer.MCMC(sampler(**_sampler_kwargs), **_mcmc_kwargs)

        # do it
        rng_key_ = (self.split_rand_key()
                    if rng_key is None else rng_key.astype("uint32"))
        mcmc.run(rng_key_, df=df)

        # store results
        self.samples = mcmc.get_samples(group_by_chain=True)
        self.fitted = True

        return self
コード例 #9
0
ファイル: mnist.py プロジェクト: poets-ai/elegy
 def __call__(self, x: jnp.ndarray):
     x = x.astype(jnp.float32) / 255.0
     x = einops.rearrange(x, "batch ... -> batch (...)")
     x = eg.nn.Linear(self.n1)(x)
     x = jax.nn.relu(x)
     x = eg.nn.Linear(self.n2)(x)
     x = jax.nn.relu(x)
     x = eg.nn.Linear(10)(x)
     return x
コード例 #10
0
    def sample_posterior_predictive(
        self,
        df: pd.DataFrame,
        hdpi: bool = False,
        hdpi_interval: float = 0.9,
        rng_key: np.ndarray = None,
    ) -> typing.Union[pd.Series, pd.DataFrame]:
        """Obtain samples from the posterior predictive.
        
        Parameters
        ----------
        df : pd.DataFrame
            Source dataframe.
        hdpi : bool
            Option to include lower/upper bound of the highest posterior density 
            interval. Returns a dataframe if true, a series if false. Default False.
        hdpi_interval : float
            HDPI width. Default 0.9.
        rng_key : two-element ndarray.
            Optional rng key, will be randomly splitted if not provided.

        Returns
        -------
        pd.Series or pd.DataFrame
            Forecasts. Will be a series with the name of the dv if no HDPI. Will be a
            dataframe if HDPI is included.

        """
        # get rng key
        rng_key_ = (self.split_rand_key()
                    if rng_key is None else rng_key.astype("uint32"))

        # check for nulls
        null_cols = columns_with_null_data(self.transform(df))
        if null_cols:
            raise exceptions.NullDataFound(*null_cols)

        #  do it
        predictions = infer.Predictive(self.model,
                                       self.samples_flat)(rng_key_,
                                                          df=df)[self.dv]

        if not hdpi:
            return pd.Series(predictions.mean(axis=0),
                             index=df.index,
                             name=self.dv)

        hdpi = diagnostics.hpdi(predictions, hdpi_interval)

        return pd.DataFrame(
            {
                self.dv: predictions.mean(axis=0),
                "hdpi_lower": hdpi[0, :],
                "hdpi_upper": hdpi[1, :],
            },
            index=df.index,
        )
コード例 #11
0
ファイル: accuracy.py プロジェクト: stjordanis/elegy
def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
    # [y_pred, y_true], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values(
    #     [y_pred, y_true]
    # )
    # y_pred.shape.assert_is_compatible_with(y_true.shape)

    if y_true.dtype != y_pred.dtype:
        y_pred = y_pred.astype(y_true.dtype)

    return (y_true == y_pred).astype(jnp.float32)
コード例 #12
0
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape)
        logits = Decoder(self._hidden_size, self._output_shape)(z)

        p = jax.nn.sigmoid(logits)
        image = jax.random.bernoulli(hk.next_rng_key(), p)

        return VAEOutput(image, mean, stddev, logits)
コード例 #13
0
ファイル: mnist_vae.py プロジェクト: poets-ai/elegy
    def __call__(self, x: jnp.ndarray) -> dict:
        next_key = eg.KeySeq()
        x = x.astype(jnp.float32)

        z = Encoder(self.hidden_size, self.latent_size)(x)
        logits = Decoder(self.hidden_size, self.output_shape)(z)

        p = jax.nn.sigmoid(logits)
        image = jax.random.bernoulli(next_key(), p)

        return dict(image=image, logits=logits, det_image=p)
コード例 #14
0
ファイル: mnist_dataloader.py プロジェクト: poets-ai/elegy
        def __call__(self, x: jnp.ndarray):
            x = x.astype(jnp.float32) / 255.0

            x = eg.Flatten()(x)
            x = eg.Linear(self.n1)(x)
            x = jax.nn.relu(x)
            x = eg.Linear(self.n2)(x)
            x = jax.nn.relu(x)
            x = eg.Linear(10)(x)

            return x
コード例 #15
0
ファイル: model_test.py プロジェクト: Dave0995/elegy
    def call(self, image: jnp.ndarray, training: bool):
        x = image.astype(jnp.float32) / 255.0

        x = jnp.reshape(x, [x.shape[0], -1])
        x = elegy.nn.Linear(self.n1)(x)
        x = elegy.nn.BatchNormalization()(x)
        x = jax.nn.relu(x)

        x = elegy.nn.Linear(self.n2)(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(10)(x)

        return x
コード例 #16
0
ファイル: mnist.py プロジェクト: stjordanis/elegy
        def call(self, image: jnp.ndarray):

            image = image.astype(jnp.float32) / 255.0

            mlp = hk.Sequential([
                hk.Flatten(),
                hk.Linear(self.n1),
                jax.nn.relu,
                hk.Linear(self.n2),
                jax.nn.relu,
                hk.Linear(10),
            ])
            return dict(outputs=mlp(image))
コード例 #17
0
ファイル: mnist_test.py プロジェクト: anvelezec/elegy
        def call(self, image: jnp.ndarray):
            image = image.astype(jnp.float32) / 255.0

            mlp = elegy.nn.sequential(
                elegy.nn.Flatten(),
                elegy.nn.Linear(self.n1),
                jax.nn.relu,
                elegy.nn.Linear(self.n2),
                jax.nn.relu,
                elegy.nn.Linear(10),
            )

            return mlp(image)
コード例 #18
0
    def __call__(self, image: jnp.ndarray):
        x = image.astype(jnp.float32) / 255.0
        x = eg.Flatten()(x)
        x = eg.Linear(self.n1)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(self.n2)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(self.n1)(x)
        x = jax.nn.relu(x)
        x = eg.Linear(np.prod(image.shape[-2:]))(x)
        x = jax.nn.sigmoid(x) * 255
        x = x.reshape(image.shape)

        return x
コード例 #19
0
ファイル: mnist_autoencoder.py プロジェクト: anvelezec/elegy
 def call(self, image: jnp.ndarray):
     image = image.astype(jnp.float32) / 255.0
     x = elegy.nn.Flatten()(image)
     x = elegy.nn.sequential(
         elegy.nn.Linear(self.n1),
         jax.nn.relu,
         elegy.nn.Linear(self.n2),
         jax.nn.relu,
         elegy.nn.Linear(self.n1),
         jax.nn.relu,
         elegy.nn.Linear(x.shape[-1]),
         jax.nn.sigmoid,
     )(x)
     return x.reshape(image.shape) * 255
コード例 #20
0
    def call(self, x: jnp.ndarray):
        batch_size = x.shape[0]

        # normalize data
        x = x.astype(jnp.float32) / 255.0

        # make patch embeddings
        x = einops.rearrange(x,
                             "batch (h1 h2) (w1 w2) -> batch (h1 w1) (h2 w2)",
                             h2=7,
                             w2=7)
        x = elegy.nn.Linear(self.size)(x)

        # add predict token
        predict_token = jnp.zeros(shape=[batch_size, 1, self.size])
        x = jnp.concatenate([predict_token, x], axis=1)

        # create positional embeddings
        positional_embeddings = self.add_parameter(
            "positional_embeddings",
            lambda: elegy.initializers.TruncatedNormal()
            (x.shape[-2:], jnp.float32),
        )
        positional_embeddings = einops.repeat(
            positional_embeddings,
            "... -> batch ...",
            batch=batch_size,
        )

        # add positional embeddings
        x = x + positional_embeddings

        # apply N transformers encoder layers
        x = elegy.nn.transformers.TransformerEncoder(
            lambda: elegy.nn.transformers.TransformerEncoderLayer(
                head_size=self.size,
                num_heads=self.num_heads,
                dropout=self.dropout,
            ),
            num_layers=self.num_layers,
        )(x)

        # get predict output token
        x = x[:, 0]

        # apply predict head
        logits = elegy.nn.Linear(10)(x)

        return logits
コード例 #21
0
    def __call__(self, state: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        kernel_initializer = jax.nn.initializers.glorot_uniform()

        # Preprocess inputs
        a = action.reshape(-1)  # flatten
        x = state.astype(jnp.float32)
        x = x.reshape(-1)  # flatten
        x = jnp.concatenate((x, a))

        for _ in range(self.num_layers):
            x = nn.Dense(features=self.hidden_units,
                         kernel_init=kernel_initializer)(x)
            x = nn.relu(x)

        return nn.Dense(features=1, kernel_init=kernel_initializer)(x)
コード例 #22
0
def _samplewise_log_loss(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
    """Based on: https://github.com/scikit-learn/scikit-learn/blob/ffbb1b4a0bbb58fdca34a30856c6f7faace87c67/sklearn
    /metrics/_classification.py#L2123"""
    if y_true.ndim == 0:  # If no dimension binary classification problem
        y_true = y_true.reshape(1)[:, jnp.newaxis]
        y_pred = y_pred.reshape(1)[:, jnp.newaxis]
    if y_true.shape[0] == 1:  # Reshuffle data to compute log loss correctly
        y_true = jnp.append(1 - y_true, y_true)
        y_pred = jnp.append(1 - y_pred, y_pred)

    # Clipping
    eps = 1e-15
    y_pred = y_pred.astype(jnp.float32).clip(eps, 1 - eps)

    loss = (y_true * -jnp.log(y_pred)).sum()
    return loss
コード例 #23
0
    def actor(self, state: jnp.ndarray, key: jnp.ndarray) -> SacActorOutput:
        """Calls the SAC actor network.

    This can be called using network_def.apply(..., method=network_def.actor).

    Args:
      state: An input state.
      key: A PRNGKey to use to sample an action from the actor's output
        distribution.

    Returns:
      A named tuple containing a sampled action, the mean action, and the
        likelihood of the sampled action.
    """
        # Preprocess inputs
        x = state.astype(jnp.float32)
        x = x.reshape(-1)  # flatten

        for layer in self._actor_layers:
            x = layer(x)
            x = nn.relu(x)

        # Note we are only producing a diagonal covariance matrix, not a full
        # covariance matrix as it is difficult to ensure that it would be PSD.
        loc_and_scale_diag = self._actor_final_layer(x)
        loc, scale_diag = jnp.split(loc_and_scale_diag, 2)
        # Exponentiate to only get positive terms.
        scale_diag = jnp.exp(scale_diag)
        dist = tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
        if self.action_limits is None:
            mode = dist.mode()
        else:
            lower_action_limit = jnp.asarray(self.action_limits[0],
                                             dtype=jnp.float32)
            upper_action_limit = jnp.asarray(self.action_limits[1],
                                             dtype=jnp.float32)
            mean = (lower_action_limit + upper_action_limit) / 2.0
            magnitude = (upper_action_limit - lower_action_limit) / 2.0
            mode = magnitude * jnp.tanh(dist.mode()) + mean
            dist = _transform_distribution(dist, mean, magnitude)
        sampled_action = dist.sample(seed=key)
        action_probability = dist.log_prob(sampled_action)

        mode = jnp.reshape(mode, self.action_shape)
        sampled_action = jnp.reshape(sampled_action, self.action_shape)

        return SacActorOutput(mode, sampled_action, action_probability)
コード例 #24
0
    def __call__(self, x: jnp.ndarray):
        # normalize
        x = x.astype(jnp.float32) / 255.0

        # base
        x = ConvBlock()(x, 32, (3, 3))
        x = ConvBlock()(x, 64, (3, 3), stride=2)
        x = ConvBlock()(x, 64, (3, 3), stride=2)
        x = ConvBlock()(x, 128, (3, 3), stride=2)

        # GlobalAveragePooling2D
        x = jnp.mean(x, axis=(1, 2))

        # 1x1 Conv
        x = eg.Linear(10)(x)

        return x
コード例 #25
0
ファイル: schedules.py プロジェクト: zyc00/deepmind-research
def learning_schedule(global_step: jnp.ndarray,
                      batch_size: int,
                      base_learning_rate: float,
                      total_steps: int,
                      warmup_steps: int) -> float:
  """Cosine learning rate scheduler."""
  # Compute LR & Scaled LR
  scaled_lr = base_learning_rate * batch_size / 256.
  learning_rate = (
      global_step.astype(jnp.float32) / int(warmup_steps) *
      scaled_lr if warmup_steps > 0 else scaled_lr)

  # Cosine schedule after warmup.
  return jnp.where(
      global_step < warmup_steps, learning_rate,
      _cosine_decay(global_step - warmup_steps, total_steps - warmup_steps,
                    scaled_lr))
コード例 #26
0
    def _encode_bow(self, bow: jnp.ndarray) -> jnp.ndarray:
        """Encode the bag-of-words into tensors that can be used by the transormer.

    Args:
      bow: a [batch_size, bow_vocab_size] tensor, each row is a bow vector.

    Returns:
      embeddings: [batch_size, bow_n_tokens, bow_embedding_dim] tensor.
    """
        batch_size = bow.shape[0]
        bow = bow.astype(jnp.float32)

        # [B, D * n]
        embeddings = hk.Linear(self._bow_embedding_dim *
                               self._bow_n_tokens)(bow)
        embeddings = transformer_block.layer_norm(jax.nn.gelu(embeddings))
        return jnp.reshape(
            embeddings,
            [batch_size, self._bow_n_tokens, self._bow_embedding_dim])
コード例 #27
0
ファイル: vae.py プロジェクト: stjordanis/distrax
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)

        # q(z|x) = N(mean(x), covariance(x))
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        variational_distrib = distrax.MultivariateNormalDiag(loc=mean,
                                                             scale_diag=stddev)
        z = variational_distrib.sample(seed=hk.next_rng_key())

        # p(x|z) = \Prod Bernoulli(logits(z))
        logits = Decoder(self._hidden_size, self._output_shape)(z)
        likelihood_distrib = distrax.Independent(
            distrax.Bernoulli(logits=logits),
            reinterpreted_batch_ndims=len(
                self._output_shape))  # 3 non-batch dims

        # Generate images from the likelihood
        image = likelihood_distrib.sample(seed=hk.next_rng_key())

        return VAEOutput(variational_distrib, likelihood_distrib, image)
コード例 #28
0
def learning_schedule(
    global_step: jnp.ndarray,
    base_learning_rate: float,
    total_steps: int,
    warmup_steps: int,
    use_schedule: bool,
) -> float:
    """Cosine learning rate scheduler."""
    # Compute LR & Scaled LR
    if not use_schedule:
        return base_learning_rate
    warmup_learning_rate = (global_step.astype(jnp.float32) /
                            int(warmup_steps) * base_learning_rate
                            if warmup_steps > 0 else base_learning_rate)

    # Cosine schedule after warmup.
    decay_learning_rate = _cosine_decay(global_step - warmup_steps,
                                        total_steps - warmup_steps,
                                        base_learning_rate)
    return jnp.where(global_step < warmup_steps, warmup_learning_rate,
                     decay_learning_rate)
コード例 #29
0
def mean_squared_logarithmic_error(y_true: jnp.ndarray,
                                   y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean squared logarithmic error between labels and predictions.

    ```python
    loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)
    ```

    Usage:

    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_squared_logarithmic_error(y_true, y_pred)

    assert loss.shape == (2,)

    first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0)
    second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0)
    assert jnp.array_equal(loss, jnp.mean(jnp.square(first_log - second_log), axis=-1))
    ```

    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.

    Returns:
        Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`.
    """

    y_true = y_true.astype(y_pred.dtype)
    first_log = jnp.log(jnp.maximum(y_true, types.EPSILON) + 1.0)
    second_log = jnp.log(jnp.maximum(y_pred, types.EPSILON) + 1.0)

    return jnp.mean(jnp.square(first_log - second_log), axis=-1)
コード例 #30
0
def mean_squared_error(y_true: jnp.ndarray,
                       y_pred: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the mean squared error between labels and predictions.
    
    After computing the squared distance between the inputs, the mean value over
    the last dimension is returned.
    
    ```python
    loss = mean(square(y_true - y_pred), axis=-1)
    ```

    Usage:
    
    ```python
    rng = jax.random.PRNGKey(42)

    y_true = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    y_pred = jax.random.uniform(rng, shape=(2, 3))

    loss = elegy.losses.mean_squared_error(y_true, y_pred)

    assert loss.shape == (2,)

    assert jnp.array_equal(loss, jnp.mean(jnp.square(y_true - y_pred), axis=-1))
    ```
    
    Arguments:
        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
    
    Returns:
        Mean squared error values. shape = `[batch_size, d0, .. dN-1]`.
    """

    y_true = y_true.astype(y_pred.dtype)

    return jnp.mean(jnp.square(y_pred - y_true), axis=-1)