def val_to_typecode(val): dtype = dtypes.result_type(val) weak_type = dtypes.is_weakly_typed(val) typecode = dtype_to_typecode[dtype] if weak_type: typecode = typecode[:-1] + '*' return typecode
def _todense_abstract_eval(*bufs, tree): arr = tree_util.tree_unflatten(tree, bufs) if isinstance(arr, core.ShapedArray): return arr return core.ShapedArray(arr.shape, arr.dtype, weak_type=dtypes.is_weakly_typed(arr.data))
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, normalize_indices): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) if dtype != dtypes.result_type(x, y): # TODO(jakevdp): change this to an error after the deprecation period. warnings.warn( "scatter inputs have incompatible types: cannot safely cast " f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. " "In future JAX releases this will result in an error.", FutureWarning) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return x x, y = jnp._promote_dtypes(x, y) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) # Collapse any `None`/`jnp.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map) out = scatter_op(x, indexer.gather_indices, y, dnums, indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted, unique_indices=indexer.unique_indices or unique_indices, mode=mode) return lax_internal._convert_element_type(out, dtype, weak_type)
def testBinaryNonPromotion(self, dtype, weak_type, promotion): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) if promotion == 'standard' or not weak_type or dtype == dtypes.bool_: expected_dtype = dtype elif dtypes.issubdtype(dtype, np.complexfloating): expected_dtype = dtypes.complex_ elif dtypes.issubdtype(dtype, np.floating): expected_dtype = dtypes.float_ else: expected_dtype = dtypes.int_ # No boolean weak types. expected_weak_type = weak_type and dtype != bool expected_dtype = dtypes.canonicalize_dtype(expected_dtype) self.assertEqual(y.dtype, expected_dtype) self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type)
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, normalize_indices): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return x x, y = jnp._promote_dtypes(x, y) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) # Collapse any `None`/`jnp.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map) out = scatter_op(x, indexer.gather_indices, y, dnums, indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted, unique_indices=indexer.unique_indices or unique_indices, mode=mode) return lax._convert_element_type(out, dtype, weak_type)
def testBinaryNonPromotion(self, dtype, weak_type): # Regression test for https://github.com/google/jax/issues/6051 x = lax._convert_element_type(0, dtype, weak_type=weak_type) y = (x + x) assert x.dtype == y.dtype assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)
def test_gmres_weak_types(self): x, _ = jax.scipy.sparse.linalg.gmres(lambda x: x, 1.0) self.assertTrue(dtypes.is_weakly_typed(x))