def _cross_replica_scatter_add(source: Array, indices: Array, updates: Array, axis_name): """tf.scatter_add, but with JAX, cross replica, and without state. Args: source: An array of shape [O]. indices: An array indicating which index each update is for. updates: The updates to apply to `source`. Of same shape as indices. axis_name: What axis to aggregate over, if str. If passed an iterable, aggregates over multiple axes. Defaults to no aggregation, i.e. None. Returns: An array of shape [O], which is source + the scattered updates from all replicas. """ assert updates.shape == indices.shape assert jnp.issubdtype(indices.dtype, jnp.integer) assert source.ndim == 1 # Flatten indices, updates. num_classes = source.shape[0] indices = jnp.reshape(indices, [-1]) updates = jnp.reshape(updates, [-1]) # Scatter updates according to value of indices. updates_at_idxs = updates[..., None] * base.one_hot(indices, num_classes) # Aggregate locally first, then across replicas. total_updates = jnp.sum(updates_at_idxs, axis=0) if axis_name is not None: axis_names = (axis_name, ) if isinstance(axis_name, str) else axis_name for a_name in axis_names: total_updates = jax.lax.psum(total_updates, axis_name=a_name) return source + total_updates
def transform_to_2hot(scalar: Array, min_value: float, max_value: float, num_bins: int) -> Array: """Transforms a scalar tensor to a 2 hot representation.""" scalar = jnp.clip(scalar, min_value, max_value) scalar_bin = (scalar - min_value) / (max_value - min_value) * (num_bins - 1) lower, upper = jnp.floor(scalar_bin), jnp.ceil(scalar_bin) lower_value = (lower / (num_bins - 1.0)) * (max_value - min_value) + min_value upper_value = (upper / (num_bins - 1.0)) * (max_value - min_value) + min_value p_lower = (upper_value - scalar) / (upper_value - lower_value + 1e-5) p_upper = 1 - p_lower lower_one_hot = base.one_hot(lower, num_bins) * jnp.expand_dims( p_lower, -1) upper_one_hot = base.one_hot(upper, num_bins) * jnp.expand_dims( p_upper, -1) return lower_one_hot + upper_one_hot
def test_one_hot(self): num_classes = 3 indices = jnp.array( [[[1., 2., 3.], [1., 2., 2.]]]) expected_result = jnp.array([ [[[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]], [[0., 1., 0.], [0., 0., 1.], [0., 0., 1.]]]]) result = base.one_hot(indices, num_classes) np.testing.assert_array_almost_equal(result, expected_result)