示例#1
0
 def pos2map(pidx, pgz, prior_result, neib_shape, neib_step):
     """
     Helper function that adds gradient contribution from a single
     neighborhood position i,j.
     pidx = Index of position within neighborhood.
     pgz  = Gradient of shape (batch_size*num_channels*neibs)
     prior_result  = Shape (batch_size, num_channnels, rows, cols)
     neib_shape = Number of rows, cols in a neighborhood.
     neib_step  = Step sizes from image2neibs.
     """
     nrows, ncols = neib_shape
     rstep, cstep = neib_step
     batch_size, num_channels, rows, cols = prior_result.shape
     i = pidx // ncols
     j = pidx - (i * ncols)
     # This position does not touch some img pixels in valid mode.
     result_indices = prior_result[
         :,
         :,
         i : (rows - nrows + i + 1) : rstep,
         j : (cols - ncols + j + 1) : cstep,
     ]
     newshape = (
         (batch_size, num_channels)
         + ((rows - nrows) // rstep + 1,)
         + ((cols - ncols) // cstep + 1,)
     )
     return inc_subtensor(result_indices, pgz.reshape(newshape))
示例#2
0
def test_incsub_offset():
    # Test for https://github.com/Theano/Theano/issues/5670

    # Build a GPU variable which value will have an offset (x1)
    x = gpuarray_shared_constructor(np.zeros(5, dtype=aesara.config.floatX))
    x1 = x[1:]
    # Use inc_subtensor on it
    y = vector()
    z = inc_subtensor(x1[2:], y)
    # Use updates so that inc_subtensor can happen inplace
    f = aesara.function([y], z, updates={x: z}, mode=mode_with_gpu)
    utt.assert_allclose(f([1, 2]), np.array([0, 0, 1, 2], dtype=aesara.config.floatX))
示例#3
0
def test_incsub_f16():
    shp = (3, 3)
    shared = gpuarray_shared_constructor
    xval = np.arange(np.prod(shp), dtype="float16").reshape(shp) + 1
    yval = np.empty((2,) + shp[1:], dtype="float16")
    yval[:] = 2
    x = shared(xval, name="x")
    y = tensor(dtype="float16", broadcastable=(False,) * len(shp), name="y")
    expr = advanced_inc_subtensor1(x, y, [0, 2])
    f = aesara.function([y], expr, mode=mode_with_gpu)
    assert (
        sum(
            [
                isinstance(node.op, GpuAdvancedIncSubtensor1)
                for node in f.maker.fgraph.toposort()
            ]
        )
        == 1
    )
    rval = f(yval)
    rep = xval.copy()
    np.add.at(rep, [[0, 2]], yval)
    assert np.allclose(rval, rep)

    expr = inc_subtensor(x[1:], y)
    f = aesara.function([y], expr, mode=mode_with_gpu)
    assert (
        sum(
            [isinstance(node.op, GpuIncSubtensor) for node in f.maker.fgraph.toposort()]
        )
        == 1
    )
    rval = f(yval)
    rep = xval.copy()
    rep[1:] += yval
    assert np.allclose(rval, rep)
示例#4
0
def test_jax_IncSubtensor():
    rng = np.random.default_rng(213234)

    x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
    x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)

    # "Set" basic indices
    st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[1, 2, 3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[:2, 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Set" advanced indices
    st_aet = aet.as_tensor_variable(
        rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
    out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, :3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Set" boolean indices
    mask_aet = aet.as_tensor_variable(x_np) > 0
    out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 0.0)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" basic indices
    st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[1, 2, 3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[:2, 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" advanced indices
    st_aet = aet.as_tensor_variable(
        rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
    out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, :3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" boolean indices
    mask_aet = aet.as_tensor_variable(x_np) > 0
    out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 1.0)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])
示例#5
0
def test_jax_basic():
    rng = np.random.default_rng(28494)

    x = matrix("x")
    y = matrix("y")
    b = vector("b")

    # `ScalarOp`
    z = cosh(x**2 + y / 3.0)

    # `[Inc]Subtensor`
    out = aet_subtensor.set_subtensor(z[0], -10.0)
    out = aet_subtensor.inc_subtensor(out[0, 1], 2.0)
    out = out[:5, :3]

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        np.tile(np.arange(10), (10, 1)).astype(config.floatX),
        np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
    ]
    (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals)

    # Confirm that the `Subtensor` slice operations are correct
    assert jax_res.shape == (5, 3)

    # Confirm that the `IncSubtensor` operations are correct
    assert jax_res[0, 0] == -10.0
    assert jax_res[0, 1] == -8.0

    out = clip(x, y, 5)
    out_fg = FunctionGraph([x, y], [out])
    compare_jax_and_py(out_fg, test_input_vals)

    out = aet.diagonal(x, 0)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_slinalg.cholesky(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    # not sure why this isn't working yet with lower=False
    out = aet_slinalg.Cholesky(lower=False)(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    out = aet_slinalg.solve(x, b)
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )

    out = aet.diag(b)
    out_fg = FunctionGraph([b], [out])
    compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])

    out = aet_nlinalg.det(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_nlinalg.matrix_inverse(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )