示例#1
0
文件: reduce.py 项目: anvelezec/elegy
    def call(self,
             values: jnp.ndarray,
             sample_weight: tp.Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """
        Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7]
        and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight`
        is specified as [1, 1, 0, 0] then value of `result()` would be 2.

        Arguments:
            values: Per-example value.
            sample_weight: Optional weighting of each example. Defaults to 1.

        Returns:
            Array with the cummulative reduce.
        """
        total = self.add_parameter(
            "total",
            shape=[],
            dtype=self.dtype,
            initializer=initializers.Constant(0),
            trainable=False,
        )

        if self._reduction in (
                Reduction.SUM_OVER_BATCH_SIZE,
                Reduction.WEIGHTED_MEAN,
        ):
            count = self.add_parameter(
                "count",
                shape=[],
                dtype=jnp.int32,
                initializer=initializers.Constant(0),
                trainable=False,
            )
        else:
            count = None

        value, total, count = reduce(
            total=total,
            count=count,
            values=values,
            reduction=self._reduction,
            sample_weight=sample_weight,
            dtype=self.dtype,
        )

        self.update_parameter("total", total)

        if count is not None:
            self.update_parameter("count", count)

        return value
示例#2
0
    def __init__(
        self,
        head_size: int,
        num_heads: int,
        output_size: tp.Optional[int] = None,
        dropout: float = 0.0,
        use_projection_bias: bool = True,
        return_attn_coef: bool = False,
        kernel_initializer: types.Initializer = initializers.VarianceScaling(
            scale=2.0),
        bias_initializer: types.Initializer = initializers.Constant(0.0),
        # kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform",
        # kernel_regularizer: typing.Union[str, typing.Callable] = None,
        # kernel_constraint: typing.Union[str, typing.Callable] = None,
        # bias_regularizer: typing.Union[str, typing.Callable] = None,
        # bias_constraint: typing.Union[str, typing.Callable] = None,
        **kwargs):
        super().__init__(**kwargs)

        if output_size is not None and output_size < 1:
            raise ValueError("output_size must be a positive number")

        self.head_size = head_size
        self.num_heads = num_heads
        self.output_size = output_size
        self.use_projection_bias = use_projection_bias
        self.return_attn_coef = return_attn_coef
        self.droput_rate = dropout

        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
示例#3
0
    def call(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

        Args:
          value: The array-like object for which you would like to perform an
            exponential decay on.
          update_stats: A Boolean, whether to update the internal state
            of this object to reflect the input value. When `update_stats` is False
            the internal stats will remain unchanged.

        Returns:
          The exponentially weighted average of the input value.
        """
        if not isinstance(value, jnp.ndarray):
            value = jnp.asarray(value)

        counter = self.add_parameter(
            "counter",
            (),
            jnp.int32,
            initializer=initializers.Constant(-self._warmup_length),
            trainable=False,
        )
        counter += 1

        decay = jax.lax.convert_element_type(self._decay, value.dtype)
        if self._warmup_length > 0:
            decay = self._cond(counter <= 0, 0.0, decay, value.dtype)

        one = jnp.ones([], value.dtype)
        hidden = self.add_parameter("hidden",
                                    value.shape,
                                    value.dtype,
                                    initializer=jnp.zeros,
                                    trainable=False)
        hidden = hidden * decay + value * (one - decay)

        average = hidden
        if self._zero_debias:
            average /= one - jnp.power(decay, counter)

        self.add_parameter("average", initializer=average, trainable=False)

        if update_stats:
            self.update_parameter("counter", counter)
            self.update_parameter("hidden", hidden)
            self.update_parameter("average", average)

        return average
示例#4
0
    def call(
        self,
        y_true: jnp.ndarray,
        y_pred: jnp.ndarray,
        sample_weight: tp.Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """
        Accumulates confusion matrix metrics for computing the reduction metric.

        Arguments:
            y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.

            y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.

            sample_weight: Optional weighting of each example. Defaults to 1.

        Returns:
            Array with the cummulative reduce metric.
        """

        cm_metric = self.add_parameter(
            "cm_metric",
            shape=[],
            dtype=jnp.int32,
            initializer=initializers.Constant(0),
            trainable=False,
        )

        cm_metric = reduce(
            cm_metric=cm_metric,
            y_true=y_true,
            y_pred=y_pred,
            reduction=self._reduction,
            sample_weight=sample_weight,
            dtype=self.dtype,
        )

        self.update_parameter("cm_metric", cm_metric)

        return cm_metric