def initialize(self, value): """If uninitialized sets the average to ``zeros_like`` the given value.""" hooks.get_state("hidden", value.shape, value.dtype, initializer=jnp.zeros) hooks.get_state("average", value.shape, value.dtype, initializer=jnp.zeros)
def call(self, logs): count = hooks.get_state("count", initializer=0) total = hooks.get_state("total", initializer=jax.tree_map(lambda x: 0.0, logs)) count += 1 total = jax.tree_multimap(lambda a, b: a + b, total, logs) hooks.set_state("count", count) hooks.set_state("total", total) logs = jax.tree_map(lambda total: total / count, total) return logs
def call( self, values: jnp.ndarray, sample_weight: tp.Optional[jnp.ndarray] = None ) -> jnp.ndarray: """ Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` is specified as [1, 1, 0, 0] then value of `result()` would be 2. Arguments: values: Per-example value. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Array with the cummulative reduce. """ total = hooks.get_state( "total", shape=[], dtype=self.dtype, initializer=initializers.Constant(0) ) if self._reduction in ( Reduction.SUM_OVER_BATCH_SIZE, Reduction.WEIGHTED_MEAN, ): count = hooks.get_state( "count", shape=[], dtype=jnp.int32, initializer=initializers.Constant(0), ) else: count = None value, total, count = reduce( total=total, count=count, values=values, reduction=self._reduction, sample_weight=sample_weight, dtype=self.dtype, ) hooks.set_state("total", total) if count is not None: hooks.set_state("count", count) return value
def call(self, value, update_stats=True): """Updates the EMA and returns the new value. Args: value: The array-like object for which you would like to perform an exponential decay on. update_stats: A Boolean, whether to update the internal state of this object to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. Returns: The exponentially weighted average of the input value. """ if not isinstance(value, jnp.ndarray): value = jnp.asarray(value) counter = hooks.get_state( "counter", (), jnp.int32, initializer=initializers.Constant(-self._warmup_length), ) counter += 1 decay = jax.lax.convert_element_type(self._decay, value.dtype) if self._warmup_length > 0: decay = self._cond(counter <= 0, 0.0, decay, value.dtype) one = jnp.ones([], value.dtype) hidden = hooks.get_state("hidden", value.shape, value.dtype, initializer=jnp.zeros) hidden = hidden * decay + value * (one - decay) average = hidden if self._zero_debias: average /= one - jnp.power(decay, counter) if update_stats: hooks.set_state("counter", counter) hooks.set_state("hidden", hidden) hooks.set_state("average", average) return average
def call( self, y_true: jnp.ndarray, y_pred: jnp.ndarray, sample_weight: tp.Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """ Accumulates confusion matrix metrics for computing the reduction metric. Arguments: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Array with the cummulative reduce metric. """ cm_metric = hooks.get_state( "cm_metric", shape=[], dtype=jnp.int32, initializer=initializers.Constant(0), ) cm_metric = reduce( cm_metric=cm_metric, y_true=y_true, y_pred=y_pred, reduction=self._reduction, sample_weight=sample_weight, dtype=self.dtype, ) hooks.set_state("cm_metric", cm_metric) return cm_metric