Пример #1
0
 def val_to_typecode(val):
     dtype = dtypes.result_type(val)
     weak_type = dtypes.is_weakly_typed(val)
     typecode = dtype_to_typecode[dtype]
     if weak_type:
         typecode = typecode[:-1] + '*'
     return typecode
Пример #2
0
def _todense_abstract_eval(*bufs, tree):
    arr = tree_util.tree_unflatten(tree, bufs)
    if isinstance(arr, core.ShapedArray):
        return arr
    return core.ShapedArray(arr.shape,
                            arr.dtype,
                            weak_type=dtypes.is_weakly_typed(arr.data))
Пример #3
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    if dtype != dtypes.result_type(x, y):
        # TODO(jakevdp): change this to an error after the deprecation period.
        warnings.warn(
            "scatter inputs have incompatible types: cannot safely cast "
            f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
            "In future JAX releases this will result in an error.",
            FutureWarning)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
Пример #4
0
    def testBinaryNonPromotion(self, dtype, weak_type, promotion):
        # Regression test for https://github.com/google/jax/issues/6051
        x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
        with jax.numpy_dtype_promotion(promotion):
            y = (x + x)

        if promotion == 'standard' or not weak_type or dtype == dtypes.bool_:
            expected_dtype = dtype
        elif dtypes.issubdtype(dtype, np.complexfloating):
            expected_dtype = dtypes.complex_
        elif dtypes.issubdtype(dtype, np.floating):
            expected_dtype = dtypes.float_
        else:
            expected_dtype = dtypes.int_

        # No boolean weak types.
        expected_weak_type = weak_type and dtype != bool
        expected_dtype = dtypes.canonicalize_dtype(expected_dtype)

        self.assertEqual(y.dtype, expected_dtype)
        self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type)
Пример #5
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax._convert_element_type(out, dtype, weak_type)
Пример #6
0
 def testBinaryNonPromotion(self, dtype, weak_type):
     # Regression test for https://github.com/google/jax/issues/6051
     x = lax._convert_element_type(0, dtype, weak_type=weak_type)
     y = (x + x)
     assert x.dtype == y.dtype
     assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)
Пример #7
0
 def test_gmres_weak_types(self):
     x, _ = jax.scipy.sparse.linalg.gmres(lambda x: x, 1.0)
     self.assertTrue(dtypes.is_weakly_typed(x))