Ejemplo n.º 1
0
 def initialize(self, value):
     """If uninitialized sets the average to ``zeros_like`` the given value."""
     hooks.get_state("hidden",
                     value.shape,
                     value.dtype,
                     initializer=jnp.zeros)
     hooks.get_state("average",
                     value.shape,
                     value.dtype,
                     initializer=jnp.zeros)
Ejemplo n.º 2
0
    def call(self, logs):

        count = hooks.get_state("count", initializer=0)
        total = hooks.get_state("total", initializer=jax.tree_map(lambda x: 0.0, logs))

        count += 1
        total = jax.tree_multimap(lambda a, b: a + b, total, logs)

        hooks.set_state("count", count)
        hooks.set_state("total", total)

        logs = jax.tree_map(lambda total: total / count, total)

        return logs
Ejemplo n.º 3
0
    def call(
        self, values: jnp.ndarray, sample_weight: tp.Optional[jnp.ndarray] = None
    ) -> jnp.ndarray:
        """
        Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7]
        and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight`
        is specified as [1, 1, 0, 0] then value of `result()` would be 2.

        Arguments:
            values: Per-example value.
            sample_weight: Optional weighting of each example. Defaults to 1.

        Returns:
            Array with the cummulative reduce.
        """
        total = hooks.get_state(
            "total", shape=[], dtype=self.dtype, initializer=initializers.Constant(0)
        )

        if self._reduction in (
            Reduction.SUM_OVER_BATCH_SIZE,
            Reduction.WEIGHTED_MEAN,
        ):
            count = hooks.get_state(
                "count",
                shape=[],
                dtype=jnp.int32,
                initializer=initializers.Constant(0),
            )
        else:
            count = None

        value, total, count = reduce(
            total=total,
            count=count,
            values=values,
            reduction=self._reduction,
            sample_weight=sample_weight,
            dtype=self.dtype,
        )

        hooks.set_state("total", total)

        if count is not None:
            hooks.set_state("count", count)

        return value
Ejemplo n.º 4
0
    def call(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

        Args:
          value: The array-like object for which you would like to perform an
            exponential decay on.
          update_stats: A Boolean, whether to update the internal state
            of this object to reflect the input value. When `update_stats` is False
            the internal stats will remain unchanged.

        Returns:
          The exponentially weighted average of the input value.
        """
        if not isinstance(value, jnp.ndarray):
            value = jnp.asarray(value)

        counter = hooks.get_state(
            "counter",
            (),
            jnp.int32,
            initializer=initializers.Constant(-self._warmup_length),
        )
        counter += 1

        decay = jax.lax.convert_element_type(self._decay, value.dtype)
        if self._warmup_length > 0:
            decay = self._cond(counter <= 0, 0.0, decay, value.dtype)

        one = jnp.ones([], value.dtype)
        hidden = hooks.get_state("hidden",
                                 value.shape,
                                 value.dtype,
                                 initializer=jnp.zeros)
        hidden = hidden * decay + value * (one - decay)

        average = hidden
        if self._zero_debias:
            average /= one - jnp.power(decay, counter)

        if update_stats:
            hooks.set_state("counter", counter)
            hooks.set_state("hidden", hidden)
            hooks.set_state("average", average)

        return average
Ejemplo n.º 5
0
    def call(
        self,
        y_true: jnp.ndarray,
        y_pred: jnp.ndarray,
        sample_weight: tp.Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """
        Accumulates confusion matrix metrics for computing the reduction metric.

        Arguments:
            y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.

            y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.

            sample_weight: Optional weighting of each example. Defaults to 1.

        Returns:
            Array with the cummulative reduce metric.
        """

        cm_metric = hooks.get_state(
            "cm_metric",
            shape=[],
            dtype=jnp.int32,
            initializer=initializers.Constant(0),
        )

        cm_metric = reduce(
            cm_metric=cm_metric,
            y_true=y_true,
            y_pred=y_pred,
            reduction=self._reduction,
            sample_weight=sample_weight,
            dtype=self.dtype,
        )

        hooks.set_state("cm_metric", cm_metric)

        return cm_metric