Exemple #1
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, normalize_indices):
    dtype = lax.dtype(x)
    x, y = jnp._promote_dtypes(x, y)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indices_are_sorted,
                     unique_indices=unique_indices)
    return lax.convert_element_type(out, dtype)
Exemple #2
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    if dtype != dtypes.result_type(x, y):
        # TODO(jakevdp): change this to an error after the deprecation period.
        warnings.warn(
            "scatter inputs have incompatible types: cannot safely cast "
            f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
            "In future JAX releases this will result in an error.",
            FutureWarning)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
Exemple #3
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax._convert_element_type(out, dtype, weak_type)