def grad(self, inputs, cost_grad): """ In defining the gradient, the Finite Fourier Transform is viewed as a complex-differentiable function of a complex variable """ a = inputs[0] n = inputs[1] axis = inputs[2] grad = cost_grad[0] if not isinstance(axis, TensorConstant): raise NotImplementedError( f"{self.__class__.__name__}: gradient is currently implemented" " only for axis being an Aesara constant") axis = int(axis.data) # notice that the number of actual elements in wrto is independent of # possible padding or truncation: elem = arange(0, shape(a)[axis], 1) # accounts for padding: freq = arange(0, n, 1) outer_res = outer(freq, elem) pow_outer = exp(((-2 * math.pi * 1j) * outer_res) / (1.0 * n)) res = tensordot(grad, pow_outer, (axis, 0)) # This would be simpler but not implemented by aesara: # res = switch(lt(n, shape(a)[axis]), # set_subtensor(res[...,n::], 0, False, False), res) # Instead we resort to that to account for truncation: flip_shape = list(np.arange(0, a.ndim)[::-1]) res = res.dimshuffle(flip_shape) res = switch( lt(n, shape(a)[axis]), set_subtensor( res[n::, ], 0, False, False, ), res, ) res = res.dimshuffle(flip_shape) # insures that gradient shape conforms to input shape: out_shape = (list(np.arange(0, axis)) + [a.ndim - 1] + list(np.arange(axis, a.ndim - 1))) res = res.dimshuffle(*out_shape) return [res, None, None]
def __get_expanded_dim(self, a, axis, i): index_shape = [1] * a.ndim index_shape[i] = a.shape[i] # it's a way to emulate # numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]] index_val = arange(a.shape[i]).reshape(index_shape) return index_val
def L_op(self, inputs, outputs, out_grads): x, k = inputs k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable") if not (self.return_indices or self.return_values): x_grad = grad_undefined( self, 0, x, "topk: cannot get gradient" " without both indices and values", ) else: x_shp = shape(x) z_grad = out_grads[0] ndim = x.ndim axis = self.axis % ndim grad_indices = [ arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1)) if i != axis else outputs[-1] for i in range(ndim) ] x_grad = x.zeros_like(dtype=z_grad.dtype) x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad) return [x_grad, k_grad]
def test_jax_Subtensors(): # Basic indices x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)) out_aet = x_aet[1, 2, 0] assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = x_aet[1:2, 1, :] assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # Advanced indexing out_aet = x_aet[[1, 2]] assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = x_aet[[1, 2], [2, 3]] assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # Advanced and basic indexing out_aet = x_aet[[1, 2], :] assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = x_aet[[1, 2], :, [3, 4]] assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, [])
def test_jax_Subtensors_omni(): x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)) # Boolean indices out_aet = x_aet[x_aet < 0] assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, [])
def test_arange_nonconcrete(): a = scalar("a") a.tag.test_value = 10 out = aet.arange(a) fgraph = FunctionGraph([a], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_known_grads(): # Tests that the grad method with no known_grads # matches what happens if you put its own known_grads # in for each variable full_range = aet.arange(10) x = scalar("x") t = iscalar("t") ft = full_range[t] ft.name = "ft" coeffs = vector("c") ct = coeffs[t] ct.name = "ct" p = x**ft p.name = "p" y = ct * p y.name = "y" cost = sqr(y) cost.name = "cost" layers = [[cost], [y], [ct, p], [ct, x, ft], [coeffs, t, full_range, x]] inputs = [coeffs, t, x] rng = np.random.default_rng([2012, 11, 15]) values = [ rng.standard_normal((10)), rng.integers(10), rng.standard_normal() ] values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)] true_grads = grad(cost, inputs, disconnected_inputs="ignore") true_grads = aesara.function(inputs, true_grads) true_grads = true_grads(*values) for layer in layers: first = grad(cost, layer, disconnected_inputs="ignore") known = OrderedDict(zip(layer, first)) full = grad(cost=None, known_grads=known, wrt=inputs, disconnected_inputs="ignore") full = aesara.function(inputs, full) full = full(*values) assert len(true_grads) == len(full) for a, b, var in zip(true_grads, full, inputs): assert np.allclose(a, b)
def to_one_hot(y, nb_class, dtype=None): """ Return a matrix where each row correspond to the one hot encoding of each element in y. Parameters ---------- y A vector of integer value between 0 and nb_class - 1. nb_class : int The number of class in y. dtype : data-type The dtype of the returned matrix. Default floatX. Returns ------- object A matrix of shape (y.shape[0], nb_class), where each row ``i`` is the one hot encoding of the corresponding ``y[i]`` value. """ ret = aet.zeros((y.shape[0], nb_class), dtype=dtype) ret = set_subtensor(ret[aet.arange(y.shape[0]), y], 1) return ret
def test_jax_IncSubtensor(): rng = np.random.default_rng(213234) x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) # "Set" basic indices st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[1, 2, 3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[:2, 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Set" advanced indices st_aet = aet.as_tensor_variable( rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3]) out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, :3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Set" boolean indices mask_aet = aet.as_tensor_variable(x_np) > 0 out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 0.0) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" basic indices st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[1, 2, 3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[:2, 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" advanced indices st_aet = aet.as_tensor_variable( rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3]) out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, :3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" boolean indices mask_aet = aet.as_tensor_variable(x_np) > 0 out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 1.0) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, [])
def grad(self, inp, grads): x, neib_shape, neib_step = inp (gz, ) = grads if self.mode in ("valid", "ignore_borders"): if (neib_shape is neib_step or neib_shape == neib_step or # Aesara Constant == do not compare the data # the equals function do that. (hasattr(neib_shape, "equals") and neib_shape.equals(neib_step) )): return [ neibs2images(gz, neib_shape, x.shape, mode=self.mode), grad_undefined(self, 1, neib_shape), grad_undefined(self, 2, neib_step), ] if self.mode in ["valid"]: # Iterate over neighborhood positions, summing contributions. def pos2map(pidx, pgz, prior_result, neib_shape, neib_step): """ Helper function that adds gradient contribution from a single neighborhood position i,j. pidx = Index of position within neighborhood. pgz = Gradient of shape (batch_size*num_channels*neibs) prior_result = Shape (batch_size, num_channnels, rows, cols) neib_shape = Number of rows, cols in a neighborhood. neib_step = Step sizes from image2neibs. """ nrows, ncols = neib_shape rstep, cstep = neib_step batch_size, num_channels, rows, cols = prior_result.shape i = pidx // ncols j = pidx - (i * ncols) # This position does not touch some img pixels in valid mode. result_indices = prior_result[:, :, i:(rows - nrows + i + 1):rstep, j:(cols - ncols + j + 1):cstep, ] newshape = ((batch_size, num_channels) + ((rows - nrows) // rstep + 1, ) + ((cols - ncols) // cstep + 1, )) return inc_subtensor(result_indices, pgz.reshape(newshape)) indices = arange(neib_shape[0] * neib_shape[1]) pgzs = gz.dimshuffle((1, 0)) result, _ = aesara.scan( fn=pos2map, sequences=[indices, pgzs], outputs_info=zeros(x.shape), non_sequences=[neib_shape, neib_step], ) grad_input = result[-1] return [ grad_input, grad_undefined(self, 1, neib_shape), grad_undefined(self, 2, neib_step), ] return [ grad_not_implemented(self, 0, x), grad_undefined(self, 1, neib_shape), grad_undefined(self, 2, neib_step), ]