Ejemplo n.º 1
0
def test_from_jax():
    jax_array_1d = jax.numpy.arange(10)
    jax_array_2d = jax.numpy.array([[1.1, 2.2], [3.3, 4.4], [5.5, 6.6],
                                    [7.7, 8.8]])

    ak_jax_array_1d = ak.from_jax(jax_array_1d)
    ak_jax_array_2d = ak.from_jax(jax_array_2d)

    for i in range(10):
        assert ak_jax_array_1d[i] == jax_array_1d[i]

    for i in range(4):
        for j in range(2):
            assert ak_jax_array_2d[i][j] == jax_array_2d[i][j]
Ejemplo n.º 2
0
    def recurse(array, indices=np.zeros(len(array), dtype=np.int32)):
        if isinstance(array, ak.layout.NumpyArray):

            def segment_sum_wrapper(arr,
                                    indices=np.zeros(len(arr),
                                                     dtype=np.int32)):
                # print(The indices)
                # indices = np.zeros(len(arr), dtype = np.int32)
                arr = jax.ops.segment_sum(arr, indices)
                return arr

            value, func = jax.vjp(segment_sum_wrapper, np.asarray(array),
                                  indices)
            return value, func

        elif isinstance(array, ak._util.listtypes):
            indices = array.offsets
            segment_sum_indices = []
            integer_tags = 0
            for i in range(len(indices) - 1):
                start = indices[i]
                stop = indices[i + 1]
                segment_sum_indices = segment_sum_indices + [
                    integer_tags for _ in range(stop - start)
                ]
                integer_tags = integer_tags + 1

            value, func = recurse(array.content,
                                  np.asarray(segment_sum_indices))
            _, aux_data = ak._connect._jax.jax_utils.special_flatten(
                ak.Array(array))
            children = []
            children.append(ak.from_jax(func(value)[0]))
            return ak._connect._jax.jax_utils.special_unflatten(
                aux_data, children)
Ejemplo n.º 3
0
def test_from_jax_tolist():
    jax_array_1d = jax.numpy.array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])

    ak_jax_array_1d = ak.from_jax(jax_array_1d)

    assert ak.to_list(ak_jax_array_1d.layout) == [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]