def test_from_jax(): jax_array_1d = jax.numpy.arange(10) jax_array_2d = jax.numpy.array([[1.1, 2.2], [3.3, 4.4], [5.5, 6.6], [7.7, 8.8]]) ak_jax_array_1d = ak.from_jax(jax_array_1d) ak_jax_array_2d = ak.from_jax(jax_array_2d) for i in range(10): assert ak_jax_array_1d[i] == jax_array_1d[i] for i in range(4): for j in range(2): assert ak_jax_array_2d[i][j] == jax_array_2d[i][j]
def recurse(array, indices=np.zeros(len(array), dtype=np.int32)): if isinstance(array, ak.layout.NumpyArray): def segment_sum_wrapper(arr, indices=np.zeros(len(arr), dtype=np.int32)): # print(The indices) # indices = np.zeros(len(arr), dtype = np.int32) arr = jax.ops.segment_sum(arr, indices) return arr value, func = jax.vjp(segment_sum_wrapper, np.asarray(array), indices) return value, func elif isinstance(array, ak._util.listtypes): indices = array.offsets segment_sum_indices = [] integer_tags = 0 for i in range(len(indices) - 1): start = indices[i] stop = indices[i + 1] segment_sum_indices = segment_sum_indices + [ integer_tags for _ in range(stop - start) ] integer_tags = integer_tags + 1 value, func = recurse(array.content, np.asarray(segment_sum_indices)) _, aux_data = ak._connect._jax.jax_utils.special_flatten( ak.Array(array)) children = [] children.append(ak.from_jax(func(value)[0])) return ak._connect._jax.jax_utils.special_unflatten( aux_data, children)
def test_from_jax_tolist(): jax_array_1d = jax.numpy.array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) ak_jax_array_1d = ak.from_jax(jax_array_1d) assert ak.to_list(ak_jax_array_1d.layout) == [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]