Ejemplo n.º 1
0
    def grad(self, inputs, cost_grad):
        """
        In defining the gradient, the Finite Fourier Transform is viewed as
        a complex-differentiable function of a complex variable
        """
        a = inputs[0]
        n = inputs[1]
        axis = inputs[2]
        grad = cost_grad[0]
        if not isinstance(axis, TensorConstant):
            raise NotImplementedError(
                f"{self.__class__.__name__}: gradient is currently implemented"
                " only for axis being an Aesara constant")
        axis = int(axis.data)
        # notice that the number of actual elements in wrto is independent of
        # possible padding or truncation:
        elem = arange(0, shape(a)[axis], 1)
        # accounts for padding:
        freq = arange(0, n, 1)
        outer_res = outer(freq, elem)
        pow_outer = exp(((-2 * math.pi * 1j) * outer_res) / (1.0 * n))
        res = tensordot(grad, pow_outer, (axis, 0))

        # This would be simpler but not implemented by aesara:
        # res = switch(lt(n, shape(a)[axis]),
        # set_subtensor(res[...,n::], 0, False, False), res)

        # Instead we resort to that to account for truncation:
        flip_shape = list(np.arange(0, a.ndim)[::-1])
        res = res.dimshuffle(flip_shape)
        res = switch(
            lt(n,
               shape(a)[axis]),
            set_subtensor(
                res[n::, ],
                0,
                False,
                False,
            ),
            res,
        )
        res = res.dimshuffle(flip_shape)

        # insures that gradient shape conforms to input shape:
        out_shape = (list(np.arange(0, axis)) + [a.ndim - 1] +
                     list(np.arange(axis, a.ndim - 1)))
        res = res.dimshuffle(*out_shape)
        return [res, None, None]
Ejemplo n.º 2
0
 def __get_expanded_dim(self, a, axis, i):
     index_shape = [1] * a.ndim
     index_shape[i] = a.shape[i]
     # it's a way to emulate
     # numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
     index_val = arange(a.shape[i]).reshape(index_shape)
     return index_val
Ejemplo n.º 3
0
    def L_op(self, inputs, outputs, out_grads):
        x, k = inputs
        k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable")

        if not (self.return_indices or self.return_values):
            x_grad = grad_undefined(
                self,
                0,
                x,
                "topk: cannot get gradient" " without both indices and values",
            )
        else:
            x_shp = shape(x)
            z_grad = out_grads[0]
            ndim = x.ndim
            axis = self.axis % ndim
            grad_indices = [
                arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1))
                if i != axis
                else outputs[-1]
                for i in range(ndim)
            ]
            x_grad = x.zeros_like(dtype=z_grad.dtype)
            x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad)

        return [x_grad, k_grad]
Ejemplo n.º 4
0
def test_jax_Subtensors():
    # Basic indices
    x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5))
    out_aet = x_aet[1, 2, 0]
    assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = x_aet[1:2, 1, :]
    assert isinstance(out_aet.owner.op, aet_subtensor.Subtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # Advanced indexing
    out_aet = x_aet[[1, 2]]
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = x_aet[[1, 2], [2, 3]]
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # Advanced and basic indexing
    out_aet = x_aet[[1, 2], :]
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = x_aet[[1, 2], :, [3, 4]]
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])
Ejemplo n.º 5
0
def test_jax_Subtensors_omni():
    x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5))

    # Boolean indices
    out_aet = x_aet[x_aet < 0]
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])
Ejemplo n.º 6
0
def test_arange_nonconcrete():

    a = scalar("a")
    a.tag.test_value = 10

    out = aet.arange(a)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Ejemplo n.º 7
0
def test_known_grads():
    # Tests that the grad method with no known_grads
    # matches what happens if you put its own known_grads
    # in for each variable

    full_range = aet.arange(10)
    x = scalar("x")
    t = iscalar("t")
    ft = full_range[t]
    ft.name = "ft"
    coeffs = vector("c")
    ct = coeffs[t]
    ct.name = "ct"
    p = x**ft
    p.name = "p"
    y = ct * p
    y.name = "y"
    cost = sqr(y)
    cost.name = "cost"

    layers = [[cost], [y], [ct, p], [ct, x, ft], [coeffs, t, full_range, x]]

    inputs = [coeffs, t, x]

    rng = np.random.default_rng([2012, 11, 15])
    values = [
        rng.standard_normal((10)),
        rng.integers(10),
        rng.standard_normal()
    ]
    values = [np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values)]

    true_grads = grad(cost, inputs, disconnected_inputs="ignore")
    true_grads = aesara.function(inputs, true_grads)
    true_grads = true_grads(*values)

    for layer in layers:
        first = grad(cost, layer, disconnected_inputs="ignore")
        known = OrderedDict(zip(layer, first))
        full = grad(cost=None,
                    known_grads=known,
                    wrt=inputs,
                    disconnected_inputs="ignore")
        full = aesara.function(inputs, full)
        full = full(*values)
        assert len(true_grads) == len(full)
        for a, b, var in zip(true_grads, full, inputs):
            assert np.allclose(a, b)
Ejemplo n.º 8
0
def to_one_hot(y, nb_class, dtype=None):
    """
    Return a matrix where each row correspond to the one hot
    encoding of each element in y.

    Parameters
    ----------
    y
        A vector of integer value between 0 and nb_class - 1.
    nb_class : int
        The number of class in y.
    dtype : data-type
        The dtype of the returned matrix. Default floatX.

    Returns
    -------
    object
        A matrix of shape (y.shape[0], nb_class), where each row ``i`` is
        the one hot encoding of the corresponding ``y[i]`` value.

    """
    ret = aet.zeros((y.shape[0], nb_class), dtype=dtype)
    ret = set_subtensor(ret[aet.arange(y.shape[0]), y], 1)
    return ret
Ejemplo n.º 9
0
def test_jax_IncSubtensor():
    rng = np.random.default_rng(213234)

    x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
    x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)

    # "Set" basic indices
    st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[1, 2, 3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[:2, 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Set" advanced indices
    st_aet = aet.as_tensor_variable(
        rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
    out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, :3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Set" boolean indices
    mask_aet = aet.as_tensor_variable(x_np) > 0
    out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 0.0)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" basic indices
    st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[1, 2, 3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[:2, 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" advanced indices
    st_aet = aet.as_tensor_variable(
        rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
    out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, :3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" boolean indices
    mask_aet = aet.as_tensor_variable(x_np) > 0
    out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 1.0)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])
Ejemplo n.º 10
0
    def grad(self, inp, grads):
        x, neib_shape, neib_step = inp
        (gz, ) = grads

        if self.mode in ("valid", "ignore_borders"):
            if (neib_shape is neib_step or neib_shape == neib_step or
                    # Aesara Constant == do not compare the data
                    # the equals function do that.
                (hasattr(neib_shape, "equals") and neib_shape.equals(neib_step)
                 )):
                return [
                    neibs2images(gz, neib_shape, x.shape, mode=self.mode),
                    grad_undefined(self, 1, neib_shape),
                    grad_undefined(self, 2, neib_step),
                ]

        if self.mode in ["valid"]:
            # Iterate over neighborhood positions, summing contributions.
            def pos2map(pidx, pgz, prior_result, neib_shape, neib_step):
                """
                Helper function that adds gradient contribution from a single
                neighborhood position i,j.
                pidx = Index of position within neighborhood.
                pgz  = Gradient of shape (batch_size*num_channels*neibs)
                prior_result  = Shape (batch_size, num_channnels, rows, cols)
                neib_shape = Number of rows, cols in a neighborhood.
                neib_step  = Step sizes from image2neibs.
                """
                nrows, ncols = neib_shape
                rstep, cstep = neib_step
                batch_size, num_channels, rows, cols = prior_result.shape
                i = pidx // ncols
                j = pidx - (i * ncols)
                # This position does not touch some img pixels in valid mode.
                result_indices = prior_result[:, :,
                                              i:(rows - nrows + i + 1):rstep,
                                              j:(cols - ncols + j + 1):cstep, ]
                newshape = ((batch_size, num_channels) +
                            ((rows - nrows) // rstep + 1, ) +
                            ((cols - ncols) // cstep + 1, ))
                return inc_subtensor(result_indices, pgz.reshape(newshape))

            indices = arange(neib_shape[0] * neib_shape[1])
            pgzs = gz.dimshuffle((1, 0))
            result, _ = aesara.scan(
                fn=pos2map,
                sequences=[indices, pgzs],
                outputs_info=zeros(x.shape),
                non_sequences=[neib_shape, neib_step],
            )
            grad_input = result[-1]
            return [
                grad_input,
                grad_undefined(self, 1, neib_shape),
                grad_undefined(self, 2, neib_step),
            ]

        return [
            grad_not_implemented(self, 0, x),
            grad_undefined(self, 1, neib_shape),
            grad_undefined(self, 2, neib_step),
        ]