Exemple #1
0
def _cross_replica_scatter_add(source: Array, indices: Array, updates: Array,
                               axis_name):
    """tf.scatter_add, but with JAX, cross replica, and without state.

  Args:
    source: An array of shape [O].
    indices: An array indicating which index each update is for.
    updates: The updates to apply to `source`. Of same shape as indices.
    axis_name: What axis to aggregate over, if str. If passed an iterable,
      aggregates over multiple axes. Defaults to no aggregation, i.e. None.

  Returns:
    An array of shape [O], which is source + the scattered updates from all
    replicas.
  """
    assert updates.shape == indices.shape
    assert jnp.issubdtype(indices.dtype, jnp.integer)
    assert source.ndim == 1
    # Flatten indices, updates.
    num_classes = source.shape[0]
    indices = jnp.reshape(indices, [-1])
    updates = jnp.reshape(updates, [-1])
    # Scatter updates according to value of indices.
    updates_at_idxs = updates[..., None] * base.one_hot(indices, num_classes)
    # Aggregate locally first, then across replicas.
    total_updates = jnp.sum(updates_at_idxs, axis=0)
    if axis_name is not None:
        axis_names = (axis_name, ) if isinstance(axis_name, str) else axis_name
        for a_name in axis_names:
            total_updates = jax.lax.psum(total_updates, axis_name=a_name)
    return source + total_updates
Exemple #2
0
def transform_to_2hot(scalar: Array, min_value: float, max_value: float,
                      num_bins: int) -> Array:
    """Transforms a scalar tensor to a 2 hot representation."""
    scalar = jnp.clip(scalar, min_value, max_value)
    scalar_bin = (scalar - min_value) / (max_value - min_value) * (num_bins -
                                                                   1)
    lower, upper = jnp.floor(scalar_bin), jnp.ceil(scalar_bin)
    lower_value = (lower /
                   (num_bins - 1.0)) * (max_value - min_value) + min_value
    upper_value = (upper /
                   (num_bins - 1.0)) * (max_value - min_value) + min_value
    p_lower = (upper_value - scalar) / (upper_value - lower_value + 1e-5)
    p_upper = 1 - p_lower
    lower_one_hot = base.one_hot(lower, num_bins) * jnp.expand_dims(
        p_lower, -1)
    upper_one_hot = base.one_hot(upper, num_bins) * jnp.expand_dims(
        p_upper, -1)
    return lower_one_hot + upper_one_hot
Exemple #3
0
 def test_one_hot(self):
   num_classes = 3
   indices = jnp.array(
       [[[1., 2., 3.], [1., 2., 2.]]])
   expected_result = jnp.array([
       [[[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]],
        [[0., 1., 0.], [0., 0., 1.], [0., 0., 1.]]]])
   result = base.one_hot(indices, num_classes)
   np.testing.assert_array_almost_equal(result, expected_result)