def from_model_output(cls, values: jnp.array, mask: Optional[jnp.array] = None, **_) -> Metric: if values.ndim == 0: values = values[None] if mask is None: mask = jnp.ones_like(values) # Leading dimensions of mask and values must match. if mask.shape[0] != values.shape[0]: raise ValueError( f"Argument `mask` must have the same leading dimension as `values`. " f"Received mask of dimension {mask.shape} " f"and values of dimension {values.shape}.") # Broadcast mask to the same number of dimensions as values. if mask.ndim < values.ndim: mask = jnp.expand_dims(mask, axis=tuple( jnp.arange(mask.ndim, values.ndim))) mask = mask.astype(bool) utils.check_param(mask, dtype=bool, ndim=values.ndim) return cls( total=jnp.where(mask, values, jnp.zeros_like(values)).sum(), count=jnp.where(mask, jnp.ones_like(values), jnp.zeros_like(values)).sum(), )
def test_check_param(self): a = jnp.array(0.) with self.assertRaisesRegex(ValueError, r"^Expected jnp.array"): utils.check_param(None, ndim=1) with self.assertRaisesRegex(ValueError, r"^Expected ndim"): utils.check_param(a, ndim=1) with self.assertRaisesRegex(ValueError, r"^Expected dtype"): utils.check_param(a, ndim=0, dtype=jnp.int32) utils.check_param(a, ndim=0) # should work utils.check_param(a, ndim=0, dtype=jnp.float32) # should also work
def from_model_output(cls, values: jnp.array, mask: Optional[jnp.array] = None, **_) -> Metric: if values.ndim == 0: values = values[None] utils.check_param(values, ndim=1) if mask is None: mask = jnp.ones(values.shape[0]) return cls( total=jnp.where(mask, values, jnp.zeros_like(values)).sum(), count=mask.sum(), )