def grad(self, inputs, cost_grad): """ In defining the gradient, the Finite Fourier Transform is viewed as a complex-differentiable function of a complex variable """ a = inputs[0] n = inputs[1] axis = inputs[2] grad = cost_grad[0] if not isinstance(axis, tensor.TensorConstant): raise NotImplementedError( "%s: gradient is currently implemented" " only for axis being a Aesara constant" % self.__class__.__name__) axis = int(axis.data) # notice that the number of actual elements in wrto is independent of # possible padding or truncation: elem = tensor.arange(0, tensor.shape(a)[axis], 1) # accounts for padding: freq = tensor.arange(0, n, 1) outer = tensor.outer(freq, elem) pow_outer = tensor.exp(((-2 * math.pi * 1j) * outer) / (1.0 * n)) res = tensor.tensordot(grad, pow_outer, (axis, 0)) # This would be simpler but not implemented by aesara: # res = tensor.switch(tensor.lt(n, tensor.shape(a)[axis]), # tensor.set_subtensor(res[...,n::], 0, False, False), res) # Instead we resort to that to account for truncation: flip_shape = list(np.arange(0, a.ndim)[::-1]) res = res.dimshuffle(flip_shape) res = tensor.switch( tensor.lt(n, tensor.shape(a)[axis]), tensor.set_subtensor( res[n::, ], 0, False, False, ), res, ) res = res.dimshuffle(flip_shape) # insures that gradient shape conforms to input shape: out_shape = (list(np.arange(0, axis)) + [a.ndim - 1] + list(np.arange(axis, a.ndim - 1))) res = res.dimshuffle(*out_shape) return [res, None, None]
def shape(self): if not any(s is None for s in self.type.shape): return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64) return at.shape(self)
def shape(self): return aet.shape(self)