def irfft(inp, norm=None, is_odd=False): r""" Performs the inverse fast Fourier Transform with real-valued output. The input is a variable of dimensions (m, ..., n//2+1, 2) representing the non-trivial elements of m real-valued Fourier transforms of initial size (..., n). The real and imaginary parts are stored as a pair of float arrays. The output is a real-valued variable of dimensions (m, ..., n) giving the m inverse FFTs. Parameters ---------- inp Array of size (m, ..., n//2+1, 2), containing m inputs with n//2+1 non-trivial elements on the last dimension and real and imaginary parts stored as separate real arrays. norm : {None, 'ortho', 'no_norm'} Normalization of transform. Following numpy, default *None* normalizes only the inverse transform by n, 'ortho' yields the unitary transform (:math:`1/\sqrt n` forward and inverse). In addition, 'no_norm' leaves the transform unnormalized. is_odd : {True, False} Set to True to get a real inverse transform output with an odd last dimension of length (N-1)*2 + 1 for an input last dimension of length N. """ if is_odd not in (True, False): raise ValueError( f"Invalid value {is_odd} for id_odd, must be True or False") s = inp.shape[1:-1] if is_odd: s = set_subtensor(s[-1], (s[-1] - 1) * 2 + 1) else: s = set_subtensor(s[-1], (s[-1] - 1) * 2) cond_norm = _unitary(norm) scaling = 1 # Numpy's default normalization is 1/N on the inverse transform. if cond_norm is None: scaling = s.prod().astype(inp.dtype) elif cond_norm == "ortho": scaling = sqrt(s.prod().astype(inp.dtype)) return irfft_op(inp, s) / scaling
def make_node(self, inp, s=None): # A shape parameter is expected as an input. For now this is used to # manage odd transform sizes. # Later this could be extended to handle padding and trunkation, # following numpy's interface. However, cuFFT expects array that match # the shape given to the plan, so padding will have to be done in the op. # The effect of padding on gradients has yet to be investigated. if not skcuda_available: raise RuntimeError("skcuda is needed for CuIFFTOp") if not pygpu_available: raise RuntimeError("pygpu is needed for CuIFFTOp") if not pycuda_available: raise RuntimeError("pycuda is needed for CuIFFTOp") inp = gpu_contiguous(as_gpuarray_variable(inp, infer_context_name(inp))) # If no shape is provided as input, calculate shape assuming even real transform. if s is None: s = inp.shape[1:-1] s = set_subtensor(s[-1], (s[-1] - 1) * 2) s = as_tensor_variable(s) assert inp.dtype == "float32" assert s.ndim == 1 return Apply(self, [inp, s], [self.output_type(inp)()])
def L_op(self, inputs, outputs, out_grads): x, k = inputs k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable") if not (self.return_indices or self.return_values): x_grad = grad_undefined( self, 0, x, "topk: cannot get gradient" " without both indices and values", ) else: x_shp = shape(x) z_grad = out_grads[0] ndim = x.ndim axis = self.axis % ndim grad_indices = [ arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1)) if i != axis else outputs[-1] for i in range(ndim) ] x_grad = x.zeros_like(dtype=z_grad.dtype) x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad) return [x_grad, k_grad]
def grad(self, inputs, output_grads): (gout, ) = output_grads s = inputs[1] # Divide the last dimension of the output gradients by 2, they are # double-counted by the real-IFFT due to symmetry, except the first # and last elements (for even transforms) which are unique. idx = ([slice(None)] * (gout.ndim - 2) + [slice(1, (s[-1] // 2) + (s[-1] % 2))] + [slice(None)]) gout = set_subtensor(gout[idx], gout[idx] * 0.5) return [irfft_op(gout, s), DisconnectedType()()]
def grad(self, inputs, output_grads): (gout, ) = output_grads s = inputs[1] gf = rfft_op(gout, s) # Multiply the last dimension of the gradient by 2, they represent # both positive and negative frequencies, except the first # and last elements (for even transforms) which are unique. idx = ([slice(None)] * (gf.ndim - 2) + [slice(1, (s[-1] // 2) + (s[-1] % 2))] + [slice(None)]) gf = set_subtensor(gf[idx], gf[idx] * 2) return [gf, DisconnectedType()()]
def grad(self, inputs, cost_grad): """ In defining the gradient, the Finite Fourier Transform is viewed as a complex-differentiable function of a complex variable """ a = inputs[0] n = inputs[1] axis = inputs[2] grad = cost_grad[0] if not isinstance(axis, TensorConstant): raise NotImplementedError( f"{self.__class__.__name__}: gradient is currently implemented" " only for axis being an Aesara constant") axis = int(axis.data) # notice that the number of actual elements in wrto is independent of # possible padding or truncation: elem = arange(0, shape(a)[axis], 1) # accounts for padding: freq = arange(0, n, 1) outer_res = outer(freq, elem) pow_outer = exp(((-2 * math.pi * 1j) * outer_res) / (1.0 * n)) res = tensordot(grad, pow_outer, (axis, 0)) # This would be simpler but not implemented by aesara: # res = switch(lt(n, shape(a)[axis]), # set_subtensor(res[...,n::], 0, False, False), res) # Instead we resort to that to account for truncation: flip_shape = list(np.arange(0, a.ndim)[::-1]) res = res.dimshuffle(flip_shape) res = switch( lt(n, shape(a)[axis]), set_subtensor( res[n::, ], 0, False, False, ), res, ) res = res.dimshuffle(flip_shape) # insures that gradient shape conforms to input shape: out_shape = (list(np.arange(0, axis)) + [a.ndim - 1] + list(np.arange(axis, a.ndim - 1))) res = res.dimshuffle(*out_shape) return [res, None, None]
def expand_empty(tensor_var, size): """ Transforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..) by adding uninitialized memory at the end of the tensor. """ if size == 0: return tensor_var shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)] new_shape = [size + shapes[0]] + shapes[1:] empty = AllocEmpty(tensor_var.dtype)(*new_shape) ret = set_subtensor(empty[:shapes[0]], tensor_var) ret.tag.nan_guard_mode_check = False return ret
def make_node(self, a, s=None): a = as_tensor_variable(a) if a.ndim < 3: raise TypeError( f"{self.__class__.__name__}: input must have dimension >= 3, with " + "first dimension batches and last real/imag parts") if s is None: s = a.shape[1:-1] s = set_subtensor(s[-1], (s[-1] - 1) * 2) s = as_tensor_variable(s) else: s = as_tensor_variable(s) if s.dtype not in integer_dtypes: raise TypeError("%s: length of the transformed axis must be" " of type integer" % self.__class__.__name__) return Apply(self, [a, s], [self.output_type(a)()])
def to_one_hot(y, nb_class, dtype=None): """ Return a matrix where each row correspond to the one hot encoding of each element in y. Parameters ---------- y A vector of integer value between 0 and nb_class - 1. nb_class : int The number of class in y. dtype : data-type The dtype of the returned matrix. Default floatX. Returns ------- object A matrix of shape (y.shape[0], nb_class), where each row ``i`` is the one hot encoding of the corresponding ``y[i]`` value. """ ret = aet.zeros((y.shape[0], nb_class), dtype=dtype) ret = set_subtensor(ret[aet.arange(y.shape[0]), y], 1) return ret
def scan_checkpoints( fn, sequences=None, outputs_info=None, non_sequences=None, name="checkpointscan_fn", n_steps=None, save_every_N=10, padding=True, ): """Scan function that uses less memory, but is more restrictive. In :func:`~aesara.scan`, if you compute the gradient of the output with respect to the input, you will have to store the intermediate results at each time step, which can be prohibitively huge. This function allows to do ``save_every_N`` steps of forward computations without storing the intermediate results, and to recompute them during the gradient computation. Notes ----- Current assumptions: * Every sequence has the same length. * If ``n_steps`` is specified, it has the same value as the length of any sequence. * The value of ``save_every_N`` divides the number of steps the scan will run without remainder. * Only singly-recurrent and non-recurrent outputs are used. No multiple recurrences. * Only the last timestep of any output will ever be used. Parameters ---------- fn ``fn`` is a function that describes the operations involved in one step of ``scan``. See the documentation of :func:`~aesara.scan` for more information. sequences ``sequences`` is the list of Aesara variables or dictionaries describing the sequences ``scan`` has to iterate over. All sequences must be the same length in this version of ``scan``. outputs_info ``outputs_info`` is the list of Aesara variables or dictionaries describing the initial state of the outputs computed recurrently. non_sequences ``non_sequences`` is the list of arguments that are passed to ``fn`` at each steps. One can opt to exclude variable used in ``fn`` from this list as long as they are part of the computational graph, though for clarity we encourage not to do so. n_steps ``n_steps`` is the number of steps to iterate given as an int or Aesara scalar (> 0). If any of the input sequences do not have enough elements, scan will raise an error. If n_steps is not provided, ``scan`` will figure out the amount of steps it should run given its input sequences. save_every_N ``save_every_N`` is the number of steps to go without storing the computations of ``scan`` (ie they will have to be recomputed during the gradient computation). padding If the length of the sequences is not a multiple of ``save_every_N``, the sequences will be zero padded to make this version of ``scan`` work properly, but will also result in a memory copy. It can be avoided by setting ``padding`` to False, but you need to make sure the length of the sequences is a multiple of ``save_every_N``. Returns ------- tuple Tuple of the form ``(outputs, updates)`` as in :func:`~aesara.scan`, but with a small change: It only contain the output at each ``save_every_N`` step. The time steps that are not returned by this function will be recomputed during the gradient computation (if any). See Also -------- :func:`~aesara.scan`: Looping in Aesara. """ # Standardize the format of input arguments if sequences is None: sequences = [] elif not isinstance(sequences, list): sequences = [sequences] if not isinstance(outputs_info, list): outputs_info = [outputs_info] if non_sequences is None: non_sequences = [] elif not isinstance(non_sequences, list): non_sequences = [non_sequences] # Check that outputs_info has no taps: for element in outputs_info: if isinstance(element, dict) and "taps" in element: raise RuntimeError("scan_checkpoints doesn't work with taps.") # Determine how many steps the original scan would run if n_steps is None: n_steps = sequences[0].shape[0] # Compute the number of steps of the outer scan o_n_steps = at.cast(ceil(n_steps / save_every_N), "int64") # Compute the number of steps of the inner scan i_n_steps = save_every_N * at.ones((o_n_steps, ), "int64") mod = n_steps % save_every_N last_n_steps = at.switch(eq(mod, 0), save_every_N, mod) i_n_steps = set_subtensor(i_n_steps[-1], last_n_steps) # Pad the sequences if needed if padding: # Since padding could be an empty tensor, Join returns a view of s. join = Join(view=0) for i, s in enumerate(sequences): n = s.shape[0] % save_every_N z = at.zeros((n, s.shape[1:]), dtype=s.dtype) sequences[i] = join(0, [s, z]) # Establish the input variables of the outer scan o_sequences = [ s.reshape( [s.shape[0] / save_every_N, save_every_N] + [s.shape[i] for i in range(1, s.ndim)], s.ndim + 1, ) for s in sequences ] o_sequences.append(i_n_steps) new_nitsots = [i for i in outputs_info if i is None] o_nonsequences = non_sequences def outer_step(*args): # Separate the received arguments into their respective (seq, outputs # from previous iterations, nonseqs) categories i_sequences = list(args[:len(o_sequences)]) i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)]) i_non_sequences = list(args[-len(o_nonsequences):]) i_outputs_infos = i_prev_outputs + [ None, ] * len(new_nitsots) # Call the user-provided function with the proper arguments results, updates = scan( fn=fn, sequences=i_sequences[:-1], outputs_info=i_outputs_infos, non_sequences=i_non_sequences, name=name + "_inner", n_steps=i_sequences[-1], ) if not isinstance(results, list): results = [results] # Keep only the last timestep of every output but keep all the updates if not isinstance(results, list): return results[-1], updates else: return [r[-1] for r in results], updates results, updates = scan( fn=outer_step, sequences=o_sequences, outputs_info=outputs_info, non_sequences=o_nonsequences, name=name + "_outer", n_steps=o_n_steps, allow_gc=True, ) return results, updates
def test_jax_IncSubtensor(): rng = np.random.default_rng(213234) x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) # "Set" basic indices st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[1, 2, 3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[:2, 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Set" advanced indices st_aet = aet.as_tensor_variable( rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3]) out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, :3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Set" boolean indices mask_aet = aet.as_tensor_variable(x_np) > 0 out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 0.0) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" basic indices st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[1, 2, 3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[:2, 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" advanced indices st_aet = aet.as_tensor_variable( rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, 0], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3]) out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, :3], st_aet) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, []) # "Increment" boolean indices mask_aet = aet.as_tensor_variable(x_np) > 0 out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 1.0) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_aet]) compare_jax_and_py(out_fg, [])
def test_jax_basic(): rng = np.random.default_rng(28494) x = matrix("x") y = matrix("y") b = vector("b") # `ScalarOp` z = cosh(x**2 + y / 3.0) # `[Inc]Subtensor` out = aet_subtensor.set_subtensor(z[0], -10.0) out = aet_subtensor.inc_subtensor(out[0, 1], 2.0) out = out[:5, :3] out_fg = FunctionGraph([x, y], [out]) test_input_vals = [ np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), ] (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals) # Confirm that the `Subtensor` slice operations are correct assert jax_res.shape == (5, 3) # Confirm that the `IncSubtensor` operations are correct assert jax_res[0, 0] == -10.0 assert jax_res[0, 1] == -8.0 out = clip(x, y, 5) out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py(out_fg, test_input_vals) out = aet.diagonal(x, 0) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) # not sure why this isn't working yet with lower=False out = aet_slinalg.Cholesky(lower=False)(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], ) out = aet_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_jax_and_py( out_fg, [ np.eye(10).astype(config.floatX), np.arange(10).astype(config.floatX), ], ) out = aet.diag(b) out_fg = FunctionGraph([b], [out]) compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) out = aet_nlinalg.det(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]) out = aet_nlinalg.matrix_inverse(x) out_fg = FunctionGraph([x], [out]) compare_jax_and_py( out_fg, [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( config.floatX)], )
def neibs2images(neibs, neib_shape, original_shape, mode="valid"): """ Function :func:`neibs2images <aesara.sandbox.neighbours.neibs2images>` performs the inverse operation of :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`. It inputs the output of :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>` and reconstructs its input. Parameters ---------- neibs : 2d tensor Like the one obtained by :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`. neib_shape `neib_shape` that was used in :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`. original_shape Original shape of the 4d tensor given to :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>` Returns ------- object Reconstructs the input of :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`, a 4d tensor of shape `original_shape`. Notes ----- Currently, the function doesn't support tensors created with `neib_step` different from default value. This means that it may be impossible to compute the gradient of a variable gained by :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>` w.r.t. its inputs in this case, because it uses :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>` for gradient computation. Examples -------- Example, which uses a tensor gained in example for :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`: .. code-block:: python im_new = neibs2images(neibs, (5, 5), im_val.shape) # Aesara function definition inv_window = aesara.function([neibs], im_new) # Function application im_new_val = inv_window(neibs_val) .. note:: The code will output the initial image array. """ neibs = as_tensor_variable(neibs) neib_shape = as_tensor_variable(neib_shape) original_shape = as_tensor_variable(original_shape) new_neib_shape = stack( [original_shape[-1] // neib_shape[1], neib_shape[1]]) output_2d = images2neibs(neibs.dimshuffle("x", "x", 0, 1), new_neib_shape, mode=mode) if mode == "ignore_borders": # We use set_subtensor to accept original_shape we can't infer # the shape and still raise error when it don't have the right # shape. valid_shape = original_shape valid_shape = set_subtensor( valid_shape[2], (valid_shape[2] // neib_shape[0]) * neib_shape[0]) valid_shape = set_subtensor( valid_shape[3], (valid_shape[3] // neib_shape[1]) * neib_shape[1]) output_4d = output_2d.reshape(valid_shape, ndim=4) # padding the borders with zeros for d in (2, 3): pad_shape = list(output_4d.shape) pad_shape[d] = original_shape[d] - valid_shape[d] output_4d = concatenate([output_4d, zeros(pad_shape)], axis=d) elif mode == "valid": # TODO: we do not implement all mode with this code. # Add a check for the good cases. output_4d = output_2d.reshape(original_shape, ndim=4) else: raise NotImplementedError(f"neibs2images do not support mode={mode}") return output_4d