Ejemplo n.º 1
0
def irfft(inp, norm=None, is_odd=False):
    r"""
    Performs the inverse fast Fourier Transform with real-valued output.

    The input is a variable of dimensions (m, ..., n//2+1, 2)
    representing the non-trivial elements of m real-valued Fourier transforms
    of initial size (..., n). The real and imaginary parts are stored as a
    pair of float arrays.

    The output is a real-valued variable of dimensions (m, ..., n)
    giving the m inverse FFTs.

    Parameters
    ----------
    inp
        Array of size (m, ..., n//2+1, 2), containing m inputs
        with n//2+1 non-trivial elements on the last dimension and real
        and imaginary parts stored as separate real arrays.
    norm : {None, 'ortho', 'no_norm'}
        Normalization of transform. Following numpy, default *None* normalizes
        only the inverse transform by n, 'ortho' yields the unitary transform
        (:math:`1/\sqrt n` forward and inverse). In addition, 'no_norm' leaves
        the transform unnormalized.
    is_odd : {True, False}
        Set to True to get a real inverse transform output with an odd last dimension
        of length (N-1)*2 + 1 for an input last dimension of length N.

    """

    if is_odd not in (True, False):
        raise ValueError(
            f"Invalid value {is_odd} for id_odd, must be True or False")

    s = inp.shape[1:-1]
    if is_odd:
        s = set_subtensor(s[-1], (s[-1] - 1) * 2 + 1)
    else:
        s = set_subtensor(s[-1], (s[-1] - 1) * 2)

    cond_norm = _unitary(norm)
    scaling = 1
    # Numpy's default normalization is 1/N on the inverse transform.
    if cond_norm is None:
        scaling = s.prod().astype(inp.dtype)
    elif cond_norm == "ortho":
        scaling = sqrt(s.prod().astype(inp.dtype))

    return irfft_op(inp, s) / scaling
Ejemplo n.º 2
0
    def make_node(self, inp, s=None):
        # A shape parameter is expected as an input. For now this is used to
        # manage odd transform sizes.
        # Later this could be extended to handle padding and trunkation,
        # following numpy's interface. However, cuFFT expects array that match
        # the shape given to the plan, so padding will have to be done in the op.
        # The effect of padding on gradients has yet to be investigated.

        if not skcuda_available:
            raise RuntimeError("skcuda is needed for CuIFFTOp")

        if not pygpu_available:
            raise RuntimeError("pygpu is needed for CuIFFTOp")

        if not pycuda_available:
            raise RuntimeError("pycuda is needed for CuIFFTOp")

        inp = gpu_contiguous(as_gpuarray_variable(inp, infer_context_name(inp)))

        # If no shape is provided as input, calculate shape assuming even real transform.
        if s is None:
            s = inp.shape[1:-1]
            s = set_subtensor(s[-1], (s[-1] - 1) * 2)
        s = as_tensor_variable(s)

        assert inp.dtype == "float32"
        assert s.ndim == 1

        return Apply(self, [inp, s], [self.output_type(inp)()])
Ejemplo n.º 3
0
    def L_op(self, inputs, outputs, out_grads):
        x, k = inputs
        k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable")

        if not (self.return_indices or self.return_values):
            x_grad = grad_undefined(
                self,
                0,
                x,
                "topk: cannot get gradient" " without both indices and values",
            )
        else:
            x_shp = shape(x)
            z_grad = out_grads[0]
            ndim = x.ndim
            axis = self.axis % ndim
            grad_indices = [
                arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1))
                if i != axis
                else outputs[-1]
                for i in range(ndim)
            ]
            x_grad = x.zeros_like(dtype=z_grad.dtype)
            x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad)

        return [x_grad, k_grad]
Ejemplo n.º 4
0
 def grad(self, inputs, output_grads):
     (gout, ) = output_grads
     s = inputs[1]
     # Divide the last dimension of the output gradients by 2, they are
     # double-counted by the real-IFFT due to symmetry, except the first
     # and last elements (for even transforms) which are unique.
     idx = ([slice(None)] * (gout.ndim - 2) +
            [slice(1, (s[-1] // 2) + (s[-1] % 2))] + [slice(None)])
     gout = set_subtensor(gout[idx], gout[idx] * 0.5)
     return [irfft_op(gout, s), DisconnectedType()()]
Ejemplo n.º 5
0
 def grad(self, inputs, output_grads):
     (gout, ) = output_grads
     s = inputs[1]
     gf = rfft_op(gout, s)
     # Multiply the last dimension of the gradient by 2, they represent
     # both positive and negative frequencies, except the first
     # and last elements (for even transforms) which are unique.
     idx = ([slice(None)] * (gf.ndim - 2) +
            [slice(1, (s[-1] // 2) + (s[-1] % 2))] + [slice(None)])
     gf = set_subtensor(gf[idx], gf[idx] * 2)
     return [gf, DisconnectedType()()]
Ejemplo n.º 6
0
    def grad(self, inputs, cost_grad):
        """
        In defining the gradient, the Finite Fourier Transform is viewed as
        a complex-differentiable function of a complex variable
        """
        a = inputs[0]
        n = inputs[1]
        axis = inputs[2]
        grad = cost_grad[0]
        if not isinstance(axis, TensorConstant):
            raise NotImplementedError(
                f"{self.__class__.__name__}: gradient is currently implemented"
                " only for axis being an Aesara constant")
        axis = int(axis.data)
        # notice that the number of actual elements in wrto is independent of
        # possible padding or truncation:
        elem = arange(0, shape(a)[axis], 1)
        # accounts for padding:
        freq = arange(0, n, 1)
        outer_res = outer(freq, elem)
        pow_outer = exp(((-2 * math.pi * 1j) * outer_res) / (1.0 * n))
        res = tensordot(grad, pow_outer, (axis, 0))

        # This would be simpler but not implemented by aesara:
        # res = switch(lt(n, shape(a)[axis]),
        # set_subtensor(res[...,n::], 0, False, False), res)

        # Instead we resort to that to account for truncation:
        flip_shape = list(np.arange(0, a.ndim)[::-1])
        res = res.dimshuffle(flip_shape)
        res = switch(
            lt(n,
               shape(a)[axis]),
            set_subtensor(
                res[n::, ],
                0,
                False,
                False,
            ),
            res,
        )
        res = res.dimshuffle(flip_shape)

        # insures that gradient shape conforms to input shape:
        out_shape = (list(np.arange(0, axis)) + [a.ndim - 1] +
                     list(np.arange(axis, a.ndim - 1)))
        res = res.dimshuffle(*out_shape)
        return [res, None, None]
Ejemplo n.º 7
0
def expand_empty(tensor_var, size):
    """
    Transforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
    by adding uninitialized memory at the end of the tensor.

    """

    if size == 0:
        return tensor_var
    shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)]
    new_shape = [size + shapes[0]] + shapes[1:]
    empty = AllocEmpty(tensor_var.dtype)(*new_shape)

    ret = set_subtensor(empty[:shapes[0]], tensor_var)
    ret.tag.nan_guard_mode_check = False
    return ret
Ejemplo n.º 8
0
    def make_node(self, a, s=None):
        a = as_tensor_variable(a)
        if a.ndim < 3:
            raise TypeError(
                f"{self.__class__.__name__}: input must have dimension >= 3,  with "
                + "first dimension batches and last real/imag parts")

        if s is None:
            s = a.shape[1:-1]
            s = set_subtensor(s[-1], (s[-1] - 1) * 2)
            s = as_tensor_variable(s)
        else:
            s = as_tensor_variable(s)
            if s.dtype not in integer_dtypes:
                raise TypeError("%s: length of the transformed axis must be"
                                " of type integer" % self.__class__.__name__)
        return Apply(self, [a, s], [self.output_type(a)()])
Ejemplo n.º 9
0
def to_one_hot(y, nb_class, dtype=None):
    """
    Return a matrix where each row correspond to the one hot
    encoding of each element in y.

    Parameters
    ----------
    y
        A vector of integer value between 0 and nb_class - 1.
    nb_class : int
        The number of class in y.
    dtype : data-type
        The dtype of the returned matrix. Default floatX.

    Returns
    -------
    object
        A matrix of shape (y.shape[0], nb_class), where each row ``i`` is
        the one hot encoding of the corresponding ``y[i]`` value.

    """
    ret = aet.zeros((y.shape[0], nb_class), dtype=dtype)
    ret = set_subtensor(ret[aet.arange(y.shape[0]), y], 1)
    return ret
Ejemplo n.º 10
0
def scan_checkpoints(
    fn,
    sequences=None,
    outputs_info=None,
    non_sequences=None,
    name="checkpointscan_fn",
    n_steps=None,
    save_every_N=10,
    padding=True,
):
    """Scan function that uses less memory, but is more restrictive.

    In :func:`~aesara.scan`, if you compute the gradient of the output
    with respect to the input, you will have to store the intermediate
    results at each time step, which can be prohibitively huge. This
    function allows to do ``save_every_N`` steps of forward computations
    without storing the intermediate results, and to recompute them during
    the gradient computation.

    Notes
    -----
    Current assumptions:

    * Every sequence has the same length.
    * If ``n_steps`` is specified, it has the same value as the length of
      any sequence.
    * The value of ``save_every_N`` divides the number of steps the scan
      will run without remainder.
    * Only singly-recurrent and non-recurrent outputs are used.
      No multiple recurrences.
    * Only the last timestep of any output will ever be used.

    Parameters
    ----------
    fn
        ``fn`` is a function that describes the operations involved in one
        step of ``scan``. See the documentation of :func:`~aesara.scan`
        for more information.

    sequences
        ``sequences`` is the list of Aesara variables or dictionaries
        describing the sequences ``scan`` has to iterate over. All
        sequences must be the same length in this version of ``scan``.

    outputs_info
        ``outputs_info`` is the list of Aesara variables or dictionaries
        describing the initial state of the outputs computed
        recurrently.

    non_sequences
        ``non_sequences`` is the list of arguments that are passed to
        ``fn`` at each steps. One can opt to exclude variable
        used in ``fn`` from this list as long as they are part of the
        computational graph, though for clarity we encourage not to do so.

    n_steps
        ``n_steps`` is the number of steps to iterate given as an int
        or Aesara scalar (> 0). If any of the input sequences do not have
        enough elements, scan will raise an error. If n_steps is not provided,
        ``scan`` will figure out the amount of steps it should run given its
        input sequences.

    save_every_N
        ``save_every_N`` is the number of steps to go without storing
        the computations of ``scan`` (ie they will have to be recomputed
        during the gradient computation).

    padding
        If the length of the sequences is not a multiple of ``save_every_N``,
        the sequences will be zero padded to make this version of ``scan``
        work properly, but will also result in a memory copy. It can be
        avoided by setting ``padding`` to False, but you need to make
        sure the length of the sequences is a multiple of ``save_every_N``.

    Returns
    -------
    tuple
        Tuple of the form ``(outputs, updates)`` as in :func:`~aesara.scan`, but
        with a small change: It only contain the output at each
        ``save_every_N`` step. The time steps that are not returned by
        this function will be recomputed during the gradient computation
        (if any).

    See Also
    --------
    :func:`~aesara.scan`: Looping in Aesara.

    """
    # Standardize the format of input arguments
    if sequences is None:
        sequences = []
    elif not isinstance(sequences, list):
        sequences = [sequences]

    if not isinstance(outputs_info, list):
        outputs_info = [outputs_info]

    if non_sequences is None:
        non_sequences = []
    elif not isinstance(non_sequences, list):
        non_sequences = [non_sequences]

    # Check that outputs_info has no taps:
    for element in outputs_info:
        if isinstance(element, dict) and "taps" in element:
            raise RuntimeError("scan_checkpoints doesn't work with taps.")

    # Determine how many steps the original scan would run
    if n_steps is None:
        n_steps = sequences[0].shape[0]

    # Compute the number of steps of the outer scan
    o_n_steps = at.cast(ceil(n_steps / save_every_N), "int64")

    # Compute the number of steps of the inner scan
    i_n_steps = save_every_N * at.ones((o_n_steps, ), "int64")
    mod = n_steps % save_every_N
    last_n_steps = at.switch(eq(mod, 0), save_every_N, mod)
    i_n_steps = set_subtensor(i_n_steps[-1], last_n_steps)

    # Pad the sequences if needed
    if padding:
        # Since padding could be an empty tensor, Join returns a view of s.
        join = Join(view=0)
        for i, s in enumerate(sequences):
            n = s.shape[0] % save_every_N
            z = at.zeros((n, s.shape[1:]), dtype=s.dtype)
            sequences[i] = join(0, [s, z])

    # Establish the input variables of the outer scan
    o_sequences = [
        s.reshape(
            [s.shape[0] / save_every_N, save_every_N] +
            [s.shape[i] for i in range(1, s.ndim)],
            s.ndim + 1,
        ) for s in sequences
    ]
    o_sequences.append(i_n_steps)
    new_nitsots = [i for i in outputs_info if i is None]
    o_nonsequences = non_sequences

    def outer_step(*args):
        # Separate the received arguments into their respective (seq, outputs
        # from previous iterations, nonseqs) categories
        i_sequences = list(args[:len(o_sequences)])
        i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)])
        i_non_sequences = list(args[-len(o_nonsequences):])
        i_outputs_infos = i_prev_outputs + [
            None,
        ] * len(new_nitsots)

        # Call the user-provided function with the proper arguments
        results, updates = scan(
            fn=fn,
            sequences=i_sequences[:-1],
            outputs_info=i_outputs_infos,
            non_sequences=i_non_sequences,
            name=name + "_inner",
            n_steps=i_sequences[-1],
        )
        if not isinstance(results, list):
            results = [results]

        # Keep only the last timestep of every output but keep all the updates
        if not isinstance(results, list):
            return results[-1], updates
        else:
            return [r[-1] for r in results], updates

    results, updates = scan(
        fn=outer_step,
        sequences=o_sequences,
        outputs_info=outputs_info,
        non_sequences=o_nonsequences,
        name=name + "_outer",
        n_steps=o_n_steps,
        allow_gc=True,
    )

    return results, updates
Ejemplo n.º 11
0
def test_jax_IncSubtensor():
    rng = np.random.default_rng(213234)

    x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
    x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)

    # "Set" basic indices
    st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[1, 2, 3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[:2, 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Set" advanced indices
    st_aet = aet.as_tensor_variable(
        rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
    out_aet = aet_subtensor.set_subtensor(x_aet[[0, 2], 0, :3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Set" boolean indices
    mask_aet = aet.as_tensor_variable(x_np) > 0
    out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 0.0)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" basic indices
    st_aet = aet.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[1, 2, 3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[:2, 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    out_aet = aet_subtensor.set_subtensor(x_aet[0, 1:3, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.IncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" advanced indices
    st_aet = aet.as_tensor_variable(
        rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
    out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, 0], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    st_aet = aet.as_tensor_variable(x_np[[0, 2], 0, :3])
    out_aet = aet_subtensor.inc_subtensor(x_aet[[0, 2], 0, :3], st_aet)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])

    # "Increment" boolean indices
    mask_aet = aet.as_tensor_variable(x_np) > 0
    out_aet = aet_subtensor.set_subtensor(x_aet[mask_aet], 1.0)
    assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
    out_fg = FunctionGraph([], [out_aet])
    compare_jax_and_py(out_fg, [])
Ejemplo n.º 12
0
def test_jax_basic():
    rng = np.random.default_rng(28494)

    x = matrix("x")
    y = matrix("y")
    b = vector("b")

    # `ScalarOp`
    z = cosh(x**2 + y / 3.0)

    # `[Inc]Subtensor`
    out = aet_subtensor.set_subtensor(z[0], -10.0)
    out = aet_subtensor.inc_subtensor(out[0, 1], 2.0)
    out = out[:5, :3]

    out_fg = FunctionGraph([x, y], [out])

    test_input_vals = [
        np.tile(np.arange(10), (10, 1)).astype(config.floatX),
        np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
    ]
    (jax_res, ) = compare_jax_and_py(out_fg, test_input_vals)

    # Confirm that the `Subtensor` slice operations are correct
    assert jax_res.shape == (5, 3)

    # Confirm that the `IncSubtensor` operations are correct
    assert jax_res[0, 0] == -10.0
    assert jax_res[0, 1] == -8.0

    out = clip(x, y, 5)
    out_fg = FunctionGraph([x, y], [out])
    compare_jax_and_py(out_fg, test_input_vals)

    out = aet.diagonal(x, 0)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_slinalg.cholesky(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    # not sure why this isn't working yet with lower=False
    out = aet_slinalg.Cholesky(lower=False)(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )

    out = aet_slinalg.solve(x, b)
    out_fg = FunctionGraph([x, b], [out])
    compare_jax_and_py(
        out_fg,
        [
            np.eye(10).astype(config.floatX),
            np.arange(10).astype(config.floatX),
        ],
    )

    out = aet.diag(b)
    out_fg = FunctionGraph([b], [out])
    compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])

    out = aet_nlinalg.det(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)])

    out = aet_nlinalg.matrix_inverse(x)
    out_fg = FunctionGraph([x], [out])
    compare_jax_and_py(
        out_fg,
        [(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
            config.floatX)],
    )
Ejemplo n.º 13
0
def neibs2images(neibs, neib_shape, original_shape, mode="valid"):
    """
    Function :func:`neibs2images <aesara.sandbox.neighbours.neibs2images>`
    performs the inverse operation of
    :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`. It inputs
    the output of :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`
    and reconstructs its input.

    Parameters
    ----------
    neibs : 2d tensor
        Like the one obtained by
        :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`.
    neib_shape
        `neib_shape` that was used in
        :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`.
    original_shape
        Original shape of the 4d tensor given to
        :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`

    Returns
    -------
    object
        Reconstructs the input of
        :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`,
        a 4d tensor of shape `original_shape`.

    Notes
    -----
    Currently, the function doesn't support tensors created with
    `neib_step` different from default value. This means that it may be
    impossible to compute the gradient of a variable gained by
    :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>` w.r.t.
    its inputs in this case, because it uses
    :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>` for
    gradient computation.

    Examples
    --------
    Example, which uses a tensor gained in example for
    :func:`images2neibs <aesara.sandbox.neighbours.neibs2images>`:

    .. code-block:: python

        im_new = neibs2images(neibs, (5, 5), im_val.shape)
        # Aesara function definition
        inv_window = aesara.function([neibs], im_new)
        # Function application
        im_new_val = inv_window(neibs_val)

    .. note:: The code will output the initial image array.

    """
    neibs = as_tensor_variable(neibs)
    neib_shape = as_tensor_variable(neib_shape)
    original_shape = as_tensor_variable(original_shape)

    new_neib_shape = stack(
        [original_shape[-1] // neib_shape[1], neib_shape[1]])
    output_2d = images2neibs(neibs.dimshuffle("x", "x", 0, 1),
                             new_neib_shape,
                             mode=mode)

    if mode == "ignore_borders":
        # We use set_subtensor to accept original_shape we can't infer
        # the shape and still raise error when it don't have the right
        # shape.
        valid_shape = original_shape
        valid_shape = set_subtensor(
            valid_shape[2], (valid_shape[2] // neib_shape[0]) * neib_shape[0])
        valid_shape = set_subtensor(
            valid_shape[3], (valid_shape[3] // neib_shape[1]) * neib_shape[1])
        output_4d = output_2d.reshape(valid_shape, ndim=4)
        # padding the borders with zeros
        for d in (2, 3):
            pad_shape = list(output_4d.shape)
            pad_shape[d] = original_shape[d] - valid_shape[d]
            output_4d = concatenate([output_4d, zeros(pad_shape)], axis=d)
    elif mode == "valid":
        # TODO: we do not implement all mode with this code.
        # Add a check for the good cases.
        output_4d = output_2d.reshape(original_shape, ndim=4)
    else:
        raise NotImplementedError(f"neibs2images do not support mode={mode}")

    return output_4d