def pos2map(pidx, pgz, prior_result, neib_shape, neib_step): """ Helper function that adds gradient contribution from a single neighborhood position i,j. pidx = Index of position within neighborhood. pgz = Gradient of shape (batch_size*num_channels*neibs) prior_result = Shape (batch_size, num_channnels, rows, cols) neib_shape = Number of rows, cols in a neighborhood. neib_step = Step sizes from image2neibs. """ nrows, ncols = neib_shape rstep, cstep = neib_step batch_size, num_channels, rows, cols = prior_result.shape i = pidx // ncols j = pidx - (i * ncols) # This position does not touch some img pixels in valid mode. result_indices = prior_result[ :, :, i : (rows - nrows + i + 1) : rstep, j : (cols - ncols + j + 1) : cstep, ] newshape = ( (batch_size, num_channels) + ((rows - nrows) // rstep + 1,) + ((cols - ncols) // cstep + 1,) ) return inc_subtensor(result_indices, pgz.reshape(newshape))
def test_incsub_offset(): # Test for https://github.com/Theano/Theano/issues/5670 # Build a GPU variable which value will have an offset (x1) x = gpuarray_shared_constructor(np.zeros(5, dtype=aesara.config.floatX)) x1 = x[1:] # Use inc_subtensor on it y = vector() z = inc_subtensor(x1[2:], y) # Use updates so that inc_subtensor can happen inplace f = aesara.function([y], z, updates={x: z}, mode=mode_with_gpu) utt.assert_allclose(f([1, 2]), np.array([0, 0, 1, 2], dtype=aesara.config.floatX))
def test_incsub_f16(): shp = (3, 3) shared = gpuarray_shared_constructor xval = np.arange(np.prod(shp), dtype="float16").reshape(shp) + 1 yval = np.empty((2,) + shp[1:], dtype="float16") yval[:] = 2 x = shared(xval, name="x") y = tensor(dtype="float16", broadcastable=(False,) * len(shp), name="y") expr = advanced_inc_subtensor1(x, y, [0, 2]) f = aesara.function([y], expr, mode=mode_with_gpu) assert ( sum( [ isinstance(node.op, GpuAdvancedIncSubtensor1) for node in f.maker.fgraph.toposort() ] ) == 1 ) rval = f(yval) rep = xval.copy() np.add.at(rep, [[0, 2]], yval) assert np.allclose(rval, rep) expr = inc_subtensor(x[1:], y) f = aesara.function([y], expr, mode=mode_with_gpu) assert ( sum( [isinstance(node.op, GpuIncSubtensor) for node in f.maker.fgraph.toposort()] ) == 1 ) rval = f(yval) rep = xval.copy() rep[1:] += yval assert np.allclose(rval, rep)
def test_jax_IncSubtensor(): rng = np.random.default_rng(213234) x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) # "Set" basic indices st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[1, 2, 3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[:2, 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Set" advanced indices st_aet = aet.as_tensor_variable( rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3]) out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, :3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Set" boolean indices mask_aet = aet.as_tensor_variable(x_np) > 0 out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 0.0) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" basic indices st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[1, 2, 3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[:2, 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" advanced indices st_aet = aet.as_tensor_variable( rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3]) out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, :3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" boolean indices mask_aet = aet.as_tensor_variable(x_np) > 0 out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 1.0) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, [])
def test_jax_basic(): rng = np.random.default_rng(28494) x = matrix("x") y = matrix("y") b = vector("b") # `ScalarOp` z = cosh(x**2 + y / 3.0) # `[Inc]Subtensor` out = aet_subtensor.set_subtensor(z[0], -10.0) out = aet_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) # Confirm that the `IncSubtensor` operations are correct assert jax_res[0, 0] == -10.0 assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py(out_fg, test_input_vals) out = aet.diagonal(x, 0) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) # not sure why this isn't working yet with lower=False out = aet_slinalg.Cholesky(lower=False)(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) out = aet_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], ) out = aet.diag(b) out_fg = FunctionGraph([b], [out]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) out = aet_nlinalg.det(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_nlinalg.matrix_inverse(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], )