def test_between(self, value, min_value, max_value, left_strict, right_strict): asserts.between(value, min_value, max_value, left_strict=left_strict, right_strict=right_strict, value_str=f'value={value}')
def compute_moments( inputs: JTensor, padding: JTensor, reduce_over_dims: List[int], cumulative_axis: Optional[int] = None, enable_cross_replica_sum_on_tpu: bool = False, keepdims: bool = False, ) -> Tuple[JTensor, JTensor]: """Computes mean and variance over the valid data points in inputs. Args: inputs: The inputs JTensor. padding: The paddings JTensor. reduce_over_dims: A sequence of ints for dimensions to reduce `inputs` over. cumulative_axis: An optional int for axis to compute a cumulative sum. If none, there will be no cumulative sum applied. enable_cross_replica_sum_on_tpu: A boolean indicating whether to use an all-reduce sum over the 'batch' axis. keepdims: A boolean indicating whether summations reduction axes should be left in the result as dimensions with size one. Returns: Tuple of (mean, variance). """ asserts.eq(inputs.ndim, padding.ndim) rank = inputs.ndim for dim in reduce_over_dims: asserts.between(dim, 0, rank, left_strict=False, right_strict=True) mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=keepdims) count_v = jnp.sum(jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_v = jnp.cumsum(sum_v, axis=cumulative_axis) count_v = jnp.cumsum(count_v, axis=cumulative_axis) if enable_cross_replica_sum_on_tpu: # TODO(shafey, yonghui): Fetch axis_name from globals. sum_v = jax.lax.psum(sum_v, axis_name='batch') count_v = jax.lax.psum(count_v, axis_name='batch') count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = jnp.sum((inputs - mean) * (inputs - mean) * mask, axis=reduce_over_dims, keepdims=keepdims) if cumulative_axis is not None: sum_vv = jnp.cumsum(sum_vv, axis=cumulative_axis) if enable_cross_replica_sum_on_tpu: # TODO(shafey, yonghui): Fetch axis_name from globals. sum_vv = jax.lax.psum(sum_vv, axis_name='batch') variance = sum_vv / count_v return mean, variance
def test_between_raises(self, value, min_value, max_value, left_strict, right_strict): with self.assertRaisesRegex( ValueError, f'`value={value}` must be in the range `.*`.$'): asserts.between(value, min_value, max_value, left_strict=left_strict, right_strict=right_strict, value_str=f'value={value}')
def test_multiline_invalid_label(self): value = 1 min_value = 0 max_value = 2 left_strict = True right_strict = True asserts.between(value, min_value, max_value, left_strict=left_strict, right_strict=right_strict)
def test_multiline_invalid_label_raises(self): value = 10 min_value = 0 max_value = 2 left_strict = True right_strict = True with self.assertRaisesRegex( ValueError, f'`.*={value}` must be in the range `.*`.$'): asserts.between(value, min_value, max_value, left_strict=left_strict, right_strict=right_strict)