def call(self, values: jnp.ndarray, sample_weight: tp.Optional[jnp.ndarray] = None) -> jnp.ndarray: """ Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` is specified as [1, 1, 0, 0] then value of `result()` would be 2. Arguments: values: Per-example value. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Array with the cummulative reduce. """ total = self.add_parameter( "total", shape=[], dtype=self.dtype, initializer=initializers.Constant(0), trainable=False, ) if self._reduction in ( Reduction.SUM_OVER_BATCH_SIZE, Reduction.WEIGHTED_MEAN, ): count = self.add_parameter( "count", shape=[], dtype=jnp.int32, initializer=initializers.Constant(0), trainable=False, ) else: count = None value, total, count = reduce( total=total, count=count, values=values, reduction=self._reduction, sample_weight=sample_weight, dtype=self.dtype, ) self.update_parameter("total", total) if count is not None: self.update_parameter("count", count) return value
def __init__( self, head_size: int, num_heads: int, output_size: tp.Optional[int] = None, dropout: float = 0.0, use_projection_bias: bool = True, return_attn_coef: bool = False, kernel_initializer: types.Initializer = initializers.VarianceScaling( scale=2.0), bias_initializer: types.Initializer = initializers.Constant(0.0), # kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform", # kernel_regularizer: typing.Union[str, typing.Callable] = None, # kernel_constraint: typing.Union[str, typing.Callable] = None, # bias_regularizer: typing.Union[str, typing.Callable] = None, # bias_constraint: typing.Union[str, typing.Callable] = None, **kwargs): super().__init__(**kwargs) if output_size is not None and output_size < 1: raise ValueError("output_size must be a positive number") self.head_size = head_size self.num_heads = num_heads self.output_size = output_size self.use_projection_bias = use_projection_bias self.return_attn_coef = return_attn_coef self.droput_rate = dropout self.kernel_initializer = kernel_initializer self.bias_initializer = bias_initializer
def call(self, value, update_stats=True): """Updates the EMA and returns the new value. Args: value: The array-like object for which you would like to perform an exponential decay on. update_stats: A Boolean, whether to update the internal state of this object to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. Returns: The exponentially weighted average of the input value. """ if not isinstance(value, jnp.ndarray): value = jnp.asarray(value) counter = self.add_parameter( "counter", (), jnp.int32, initializer=initializers.Constant(-self._warmup_length), trainable=False, ) counter += 1 decay = jax.lax.convert_element_type(self._decay, value.dtype) if self._warmup_length > 0: decay = self._cond(counter <= 0, 0.0, decay, value.dtype) one = jnp.ones([], value.dtype) hidden = self.add_parameter("hidden", value.shape, value.dtype, initializer=jnp.zeros, trainable=False) hidden = hidden * decay + value * (one - decay) average = hidden if self._zero_debias: average /= one - jnp.power(decay, counter) self.add_parameter("average", initializer=average, trainable=False) if update_stats: self.update_parameter("counter", counter) self.update_parameter("hidden", hidden) self.update_parameter("average", average) return average
def call( self, y_true: jnp.ndarray, y_pred: jnp.ndarray, sample_weight: tp.Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """ Accumulates confusion matrix metrics for computing the reduction metric. Arguments: y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Array with the cummulative reduce metric. """ cm_metric = self.add_parameter( "cm_metric", shape=[], dtype=jnp.int32, initializer=initializers.Constant(0), trainable=False, ) cm_metric = reduce( cm_metric=cm_metric, y_true=y_true, y_pred=y_pred, reduction=self._reduction, sample_weight=sample_weight, dtype=self.dtype, ) self.update_parameter("cm_metric", cm_metric) return cm_metric