Example #1
0
 def from_model_output(cls,
                       values: jnp.array,
                       mask: Optional[jnp.array] = None,
                       **_) -> Metric:
     if values.ndim == 0:
         values = values[None]
     if mask is None:
         mask = jnp.ones_like(values)
     # Leading dimensions of mask and values must match.
     if mask.shape[0] != values.shape[0]:
         raise ValueError(
             f"Argument `mask` must have the same leading dimension as `values`. "
             f"Received mask of dimension {mask.shape} "
             f"and values of dimension {values.shape}.")
     # Broadcast mask to the same number of dimensions as values.
     if mask.ndim < values.ndim:
         mask = jnp.expand_dims(mask,
                                axis=tuple(
                                    jnp.arange(mask.ndim, values.ndim)))
     mask = mask.astype(bool)
     utils.check_param(mask, dtype=bool, ndim=values.ndim)
     return cls(
         total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
         count=jnp.where(mask, jnp.ones_like(values),
                         jnp.zeros_like(values)).sum(),
     )
Example #2
0
 def test_check_param(self):
     a = jnp.array(0.)
     with self.assertRaisesRegex(ValueError, r"^Expected jnp.array"):
         utils.check_param(None, ndim=1)
     with self.assertRaisesRegex(ValueError, r"^Expected ndim"):
         utils.check_param(a, ndim=1)
     with self.assertRaisesRegex(ValueError, r"^Expected dtype"):
         utils.check_param(a, ndim=0, dtype=jnp.int32)
     utils.check_param(a, ndim=0)  # should work
     utils.check_param(a, ndim=0, dtype=jnp.float32)  # should also work
Example #3
0
 def from_model_output(cls,
                       values: jnp.array,
                       mask: Optional[jnp.array] = None,
                       **_) -> Metric:
     if values.ndim == 0:
         values = values[None]
     utils.check_param(values, ndim=1)
     if mask is None:
         mask = jnp.ones(values.shape[0])
     return cls(
         total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
         count=mask.sum(),
     )