Esempio n. 1
0
    def grad(self, inputs, cost_grad):
        """
        In defining the gradient, the Finite Fourier Transform is viewed as
        a complex-differentiable function of a complex variable
        """
        a = inputs[0]
        n = inputs[1]
        axis = inputs[2]
        grad = cost_grad[0]
        if not isinstance(axis, tensor.TensorConstant):
            raise NotImplementedError(
                "%s: gradient is currently implemented"
                " only for axis being a Aesara constant" %
                self.__class__.__name__)
        axis = int(axis.data)
        # notice that the number of actual elements in wrto is independent of
        # possible padding or truncation:
        elem = tensor.arange(0, tensor.shape(a)[axis], 1)
        # accounts for padding:
        freq = tensor.arange(0, n, 1)
        outer = tensor.outer(freq, elem)
        pow_outer = tensor.exp(((-2 * math.pi * 1j) * outer) / (1.0 * n))
        res = tensor.tensordot(grad, pow_outer, (axis, 0))

        # This would be simpler but not implemented by aesara:
        # res = tensor.switch(tensor.lt(n, tensor.shape(a)[axis]),
        # tensor.set_subtensor(res[...,n::], 0, False, False), res)

        # Instead we resort to that to account for truncation:
        flip_shape = list(np.arange(0, a.ndim)[::-1])
        res = res.dimshuffle(flip_shape)
        res = tensor.switch(
            tensor.lt(n,
                      tensor.shape(a)[axis]),
            tensor.set_subtensor(
                res[n::, ],
                0,
                False,
                False,
            ),
            res,
        )
        res = res.dimshuffle(flip_shape)

        # insures that gradient shape conforms to input shape:
        out_shape = (list(np.arange(0, axis)) + [a.ndim - 1] +
                     list(np.arange(axis, a.ndim - 1)))
        res = res.dimshuffle(*out_shape)
        return [res, None, None]
Esempio n. 2
0
    def shape(self):
        if not any(s is None for s in self.type.shape):
            return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64)

        return at.shape(self)
Esempio n. 3
0
 def shape(self):
     return aet.shape(self)