コード例 #1
0
 def test_between(self, value, min_value, max_value, left_strict,
                  right_strict):
     asserts.between(value,
                     min_value,
                     max_value,
                     left_strict=left_strict,
                     right_strict=right_strict,
                     value_str=f'value={value}')
コード例 #2
0
def compute_moments(
    inputs: JTensor,
    padding: JTensor,
    reduce_over_dims: List[int],
    cumulative_axis: Optional[int] = None,
    enable_cross_replica_sum_on_tpu: bool = False,
    keepdims: bool = False,
) -> Tuple[JTensor, JTensor]:
    """Computes mean and variance over the valid data points in inputs.

  Args:
    inputs: The inputs JTensor.
    padding: The paddings JTensor.
    reduce_over_dims: A sequence of ints for dimensions to reduce `inputs` over.
    cumulative_axis: An optional int for axis to compute a cumulative sum. If
      none, there will be no cumulative sum applied.
    enable_cross_replica_sum_on_tpu: A boolean indicating whether to use an
      all-reduce sum over the 'batch' axis.
    keepdims: A boolean indicating whether summations reduction axes should be
      left in the result as dimensions with size one.

  Returns:
    Tuple of (mean, variance).
  """
    asserts.eq(inputs.ndim, padding.ndim)
    rank = inputs.ndim
    for dim in reduce_over_dims:
        asserts.between(dim, 0, rank, left_strict=False, right_strict=True)
    mask = 1.0 - padding
    sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=keepdims)
    count_v = jnp.sum(jnp.ones_like(inputs) * mask,
                      axis=reduce_over_dims,
                      keepdims=keepdims)
    if cumulative_axis is not None:
        sum_v = jnp.cumsum(sum_v, axis=cumulative_axis)
        count_v = jnp.cumsum(count_v, axis=cumulative_axis)

    if enable_cross_replica_sum_on_tpu:
        # TODO(shafey, yonghui): Fetch axis_name from globals.
        sum_v = jax.lax.psum(sum_v, axis_name='batch')
        count_v = jax.lax.psum(count_v, axis_name='batch')

    count_v = jnp.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = jnp.sum((inputs - mean) * (inputs - mean) * mask,
                     axis=reduce_over_dims,
                     keepdims=keepdims)
    if cumulative_axis is not None:
        sum_vv = jnp.cumsum(sum_vv, axis=cumulative_axis)

    if enable_cross_replica_sum_on_tpu:
        # TODO(shafey, yonghui): Fetch axis_name from globals.
        sum_vv = jax.lax.psum(sum_vv, axis_name='batch')

    variance = sum_vv / count_v
    return mean, variance
コード例 #3
0
 def test_between_raises(self, value, min_value, max_value, left_strict,
                         right_strict):
     with self.assertRaisesRegex(
             ValueError, f'`value={value}` must be in the range `.*`.$'):
         asserts.between(value,
                         min_value,
                         max_value,
                         left_strict=left_strict,
                         right_strict=right_strict,
                         value_str=f'value={value}')
コード例 #4
0
 def test_multiline_invalid_label(self):
     value = 1
     min_value = 0
     max_value = 2
     left_strict = True
     right_strict = True
     asserts.between(value,
                     min_value,
                     max_value,
                     left_strict=left_strict,
                     right_strict=right_strict)
コード例 #5
0
 def test_multiline_invalid_label_raises(self):
     value = 10
     min_value = 0
     max_value = 2
     left_strict = True
     right_strict = True
     with self.assertRaisesRegex(
             ValueError, f'`.*={value}` must be in the range `.*`.$'):
         asserts.between(value,
                         min_value,
                         max_value,
                         left_strict=left_strict,
                         right_strict=right_strict)