예제 #1
class Sparse(Transform):
    """A sparse matrix transformation between an input and output signal.

    .. versionadded:: 3.0.0

    shape : tuple of int
        The full shape of the sparse matrix: ``(size_out, size_in)``.
    indices : array_like of int
        An Nx2 array of integers indicating the (row,col) coordinates for the
        N non-zero elements in the matrix.
    init : `.Distribution` or array_like, optional
        A Distribution used to initialize the transform matrix, or a concrete
        instantiation for the matrix. If the matrix is square we also allow a
        scalar (equivalent to ``np.eye(n) * init``) or a vector (equivalent to
        ``np.diag(init)``) to represent the matrix more compactly.

    shape = ShapeParam("shape", length=2, low=1)
    init = SparseInitParam("init")

    def __init__(self, shape, indices=None, init=1.0):

        self.shape = shape

        if scipy_sparse and isinstance(init, scipy_sparse.spmatrix):
            assert indices is None
            assert init.shape == self.shape
            self.init = init
        elif indices is not None:
            self.init = SparseMatrix(indices, init, shape)
            raise ValidationError(
                "Either `init` must be a `scipy.sparse.spmatrix`, "
                "or `indices` must be specified.",

    def _argreprs(self):
        return ["shape=%r" % (self.shape, )]

    def sample(self, rng=np.random):
        if scipy_sparse and isinstance(self.init, scipy_sparse.spmatrix):
            return self.init
            return self.init.sample(rng=rng)

    def size_in(self):
        return self.shape[1]

    def size_out(self):
        return self.shape[0]
예제 #2
class Dense(Transform):
    """A dense matrix transformation between an input and output signal.

    .. versionadded:: 3.0.0

    shape : tuple of int
        The shape of the dense matrix: ``(size_out, size_in)``.
    init : `.Distribution` or array_like, optional
        A Distribution used to initialize the transform matrix, or a concrete
        instantiation for the matrix. If the matrix is square we also allow a
        scalar (equivalent to ``np.eye(n) * init``) or a vector (equivalent to
        ``np.diag(init)``) to represent the matrix more compactly.

    shape = ShapeParam("shape", length=2, low=1)
    init = DistOrArrayParam("init")

    def __init__(self, shape, init=1.0):

        self.shape = shape

        if is_array_like(init):
            init = np.asarray(init, dtype=rc.float_dtype)

            # check that the shape of init is compatible with the given shape
            # for this transform
            expected_shape = None
            if shape[0] != shape[1]:
                # init must be 2D if transform is not square
                expected_shape = shape
            elif init.ndim == 1:
                expected_shape = (shape[0], )
            elif init.ndim >= 2:
                expected_shape = shape

            if expected_shape is not None and init.shape != expected_shape:
                raise ValidationError(
                    "Shape of initial value %s does not match expected "
                    "shape %s" % (init.shape, expected_shape),

        self.init = init

    def _argreprs(self):
        return ["shape=%r" % (self.shape, )]

    def sample(self, rng=np.random):
        if isinstance(self.init, Distribution):
            return self.init.sample(*self.shape, rng=rng)

        return self.init

    def init_shape(self):
        """The shape of the initial value."""
        return self.shape if isinstance(self.init,
                                        Distribution) else self.init.shape

    def size_in(self):
        return self.shape[1]

    def size_out(self):
        return self.shape[0]
예제 #3
class Convolution(Transform):
    """An N-dimensional convolutional transform.

    The dimensionality of the convolution is determined by the input shape.

    .. versionadded:: 3.0.0

    n_filters : int
        The number of convolutional filters to apply
    input_shape : tuple of int or `.ChannelShape`
        Shape of the input signal to the convolution; e.g.,
        ``(height, width, channels)`` for a 2D convolution with
    kernel_size : tuple of int, optional
        Size of the convolutional kernels (1 element for a 1D convolution,
        2 for a 2D convolution, etc.).
    strides : tuple of int, optional
        Stride of the convolution (1 element for a 1D convolution, 2 for
        a 2D convolution, etc.).
    padding : ``"same"`` or ``"valid"``, optional
        Padding method for input signal. "Valid" means no padding, and
        convolution will only be applied to the fully-overlapping areas of the
        input signal (meaning the output will be smaller). "Same" means that
        the input signal is zero-padded so that the output is the same shape
        as the input.
    channels_last : bool, optional
        If ``True`` (default), the channels are the last dimension in the input
        signal (e.g., a 28x28 image with 3 channels would have shape
        ``(28, 28, 3)``).  ``False`` means that channels are the first
        dimension (e.g., ``(3, 28, 28)``).
    init : `.Distribution` or `~numpy:numpy.ndarray`, optional
        A predefined kernel with shape
        ``kernel_size + (input_channels, n_filters)``, or a ``Distribution``
        that will be used to initialize the kernel.

    As is typical in neural networks, this is technically correlation rather
    than convolution (because the kernel is not flipped).

    n_filters = IntParam("n_filters", low=1)
    input_shape = ChannelShapeParam("input_shape", low=1)
    kernel_size = ShapeParam("kernel_size", low=1)
    strides = ShapeParam("strides", low=1)
    padding = EnumParam("padding", values=("same", "valid"))
    channels_last = BoolParam("channels_last")
    init = DistOrArrayParam("init")

    _param_init_order = ["channels_last", "input_shape"]

    def __init__(
            kernel_size=(3, 3),
            strides=(1, 1),
            init=Uniform(-1, 1),

        self.n_filters = n_filters
        self.channels_last = channels_last  # must be set before input_shape
        self.input_shape = input_shape
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.init = init

        if len(kernel_size) != self.dimensions:
            raise ValidationError(
                "Kernel dimensions (%d) do not match input dimensions (%d)" %
                (len(kernel_size), self.dimensions),
        if len(strides) != self.dimensions:
            raise ValidationError(
                "Stride dimensions (%d) do not match input dimensions (%d)" %
                (len(strides), self.dimensions),
        if not isinstance(init, Distribution):
            if init.shape != self.kernel_shape:
                raise ValidationError(
                    "Kernel shape %s does not match expected shape %s" %
                    (init.shape, self.kernel_shape),

    def _argreprs(self):
        argreprs = [
            "n_filters=%r" % (self.n_filters, ),
            "input_shape=%s" % (self.input_shape.shape, ),
        if self.kernel_size != (3, 3):
            argreprs.append("kernel_size=%r" % (self.kernel_size, ))
        if self.strides != (1, 1):
            argreprs.append("strides=%r" % (self.strides, ))
        if self.padding != "valid":
            argreprs.append("padding=%r" % (self.padding, ))
        if self.channels_last is not True:
            argreprs.append("channels_last=%r" % (self.channels_last, ))
        return argreprs

    def sample(self, rng=np.random):
        if isinstance(self.init, Distribution):
            # we sample this way so that any variancescaling distribution based
            # on n/d is scaled appropriately
            kernel = [
                for _ in range(np.prod(self.kernel_size))
            kernel = np.reshape(kernel, self.kernel_shape)
            kernel = np.array(self.init, dtype=rc.float_dtype)
        return kernel

    def kernel_shape(self):
        """Full shape of kernel."""
        return self.kernel_size + (self.input_shape.n_channels, self.n_filters)

    def size_in(self):
        return self.input_shape.size

    def size_out(self):
        return self.output_shape.size

    def dimensions(self):
        """Dimensionality of convolution."""
        return self.input_shape.dimensions

    def output_shape(self):
        """Output shape after applying convolution to input."""
        output_shape = np.array(self.input_shape.spatial_shape,
        if self.padding == "valid":
            output_shape -= self.kernel_size
            output_shape += 1
        output_shape /= self.strides
        output_shape = tuple(np.ceil(output_shape).astype(rc.int_dtype))
        output_shape = (output_shape +
                        (self.n_filters, ) if self.channels_last else
                        (self.n_filters, ) + output_shape)

        return ChannelShape(output_shape, channels_last=self.channels_last)
예제 #4
class SparseMatrix(FrozenObject):
    """Represents a sparse matrix.

    .. versionadded:: 3.0.0

    indices : array_like of int
        An Nx2 array of integers indicating the (row,col) coordinates for the
        N non-zero elements in the matrix.
    data : array_like or `.Distribution`
        An Nx1 array defining the value of the nonzero elements in the matrix
        (corresponding to ``indices``), or a `.Distribution` that will be
        used to initialize the nonzero elements.
    shape : tuple of int
        Shape of the full matrix.

    indices = NdarrayParam("indices", shape=("*", 2), dtype=np.int64)
    data = DistOrArrayParam("data", sample_shape=("*", ))
    shape = ShapeParam("shape", length=2)

    def __init__(self, indices, data, shape):

        self.indices = indices
        self.shape = shape

        # if data is not a distribution
        if is_array_like(data):
            data = np.asarray(data)

            # convert scalars to vectors
            if data.size == 1:
                data = data.item() * np.ones(self.indices.shape[0],

            if data.ndim != 1 or data.shape[0] != self.indices.shape[0]:
                raise ValidationError(
                    "Must be a vector of the same length as `indices`",

        self.data = data
        self._allocated = None
        self._dense = None

    def dtype(self):
        return self.data.dtype

    def ndim(self):
        return len(self.shape)

    def size(self):
        return self.indices.shape[0]

    def allocate(self):
        """Return a `scipy.sparse.csr_matrix` or dense matrix equivalent.

        We mark this data as readonly to be consistent with how other
        data associated with signals are allocated. If this allocated
        data is to be modified, it should be copied first.

        if self._allocated is not None:
            return self._allocated

        if scipy_sparse is None:
            warnings.warn("Sparse operations require Scipy, which is not "
                          "installed. Using dense matrices instead.")
            self._allocated = self.toarray().view()
            self._allocated = scipy_sparse.csr_matrix(
                (self.data, self.indices.T), shape=self.shape)

        return self._allocated

    def sample(self, rng=np.random):
        """Convert `.Distribution` data to fixed array.

        rng : `.numpy.random.mtrand.RandomState`
            Random number generator that will be used when
            sampling distribution.

        matrix : `.SparseMatrix`
            A new `.SparseMatrix` instance with `.Distribution` converted to
            array if ``self.data`` is a `.Distribution`, otherwise simply
            returns ``self``.
        if isinstance(self.data, Distribution):
            return SparseMatrix(
                self.data.sample(self.indices.shape[0], rng=rng),
            return self

    def toarray(self):
        """Return the dense matrix equivalent of this matrix."""

        if self._dense is not None:
            return self._dense

        self._dense = np.zeros(self.shape, dtype=self.dtype)
        self._dense[self.indices[:, 0], self.indices[:, 1]] = self.data
        # Mark as readonly, if the user wants to modify they should copy first
        return self._dense
예제 #5
class TensorNode(Node):
    Inserts TensorFlow code into a Nengo model.

    tensor_func : callable
        A function that maps node inputs to outputs
    shape_in : tuple of int
        Shape of TensorNode input signal (not including batch dimension).
    shape_out : tuple of int
        Shape of TensorNode output signal (not including batch dimension).
        If None, value will be inferred by calling ``tensor_func``.
    pass_time : bool
        If True, pass current simulation time to TensorNode function (in addition
        to the standard input).
    label : str (Default: None)
        A name for the node, used for debugging and visualization

    tensor_func = TensorFuncParam("tensor_func")
    shape_in = ShapeParam("shape_in", default=None, low=1, optional=True)
    shape_out = ShapeParam("shape_out", default=None, low=1, optional=True)
    pass_time = BoolParam("pass_time", default=True)

    def __init__(
        # pylint: disable=non-parent-init-called,super-init-not-called
        # note: we bypass the Node constructor, because we don't want to
        # perform validation on `output`
        NengoObject.__init__(self, label=label, seed=None)

        self.shape_in = shape_in
        self.shape_out = shape_out
        self.pass_time = pass_time

        if not (self.shape_in or self.pass_time):
            raise ValidationError(
                "Must specify either shape_in or pass_time", "TensorNode"

        self.tensor_func = tensor_func

    def output(self):
        Ensures that nothing tries to evaluate the `output` attribute
        (indicating that something is trying to simulate this as a regular
        `nengo.Node` rather than a TensorNode).

        def output_func(*_):
            raise SimulationError(
                "Cannot call TensorNode output function (this probably means "
                "you are trying to use a TensorNode inside a Simulator other "
                "than NengoDL)"

        return output_func

    def size_in(self):
        """Number of input elements (flattened)."""

        return 0 if self.shape_in is None else np.prod(self.shape_in)

    def size_out(self):
        """Number of output elements (flattened)."""

        return 0 if self.shape_out is None else np.prod(self.shape_out)
예제 #6
class _ConvolutionBase(Transform):
    """Abstract base class for Convolution and ConvolutionTranspose."""

    n_filters = IntParam("n_filters", low=1)
    input_shape = ChannelShapeParam("input_shape", low=1)
    kernel_size = ShapeParam("kernel_size", low=1)
    strides = ShapeParam("strides", low=1)
    padding = EnumParam("padding", values=("same", "valid"))
    channels_last = BoolParam("channels_last")
    init = DistOrArrayParam("init")
    groups = IntParam("groups", low=1)

    _param_init_order = ["channels_last", "input_shape"]

    def __init__(
        kernel_size=(3, 3),
        strides=(1, 1),
        init=Uniform(-1, 1),

        self.n_filters = n_filters
        self.channels_last = channels_last  # must be set before input_shape
        self.input_shape = input_shape
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.init = init
        self.groups = groups

        if len(kernel_size) != self.dimensions:
            raise ValidationError(
                f"Kernel dimensions ({len(kernel_size)}) does not match "
                f"input dimensions ({self.dimensions})",
        if len(strides) != self.dimensions:
            raise ValidationError(
                f"Stride dimensions ({len(strides)}) does not match "
                f"input dimensions ({self.dimensions})",
        if not isinstance(init, Distribution):
            if init.shape != self.kernel_shape:
                raise ValidationError(
                    f"Kernel shape {init.shape} does not match "
                    f"expected shape {self.kernel_shape}",

        in_channels = self.input_shape.n_channels
        if groups > in_channels:
            raise ValidationError(
                f"Groups ({groups}) cannot be greater than "
                f"the number of input channels ({in_channels})",
        if in_channels % groups != 0 or self.n_filters % groups != 0:
            raise ValidationError(
                f"Both the number of input channels ({in_channels}) and filters "
                f"({self.n_filters}) must be evenly divisible by ``groups`` ({groups})",

    def _argreprs(self):
        argreprs = [
        if self.kernel_size != (3, 3):
        if self.strides != (1, 1):
        if self.padding != "valid":
        if self.channels_last is not True:
        if self.groups != 1:
        return argreprs

    def sample(self, rng=np.random):
        if isinstance(self.init, Distribution):
            # we sample this way so that any variancescaling distribution based
            # on n/d is scaled appropriately
            kernel = [
                    self.input_shape.n_channels // self.groups, self.n_filters, rng=rng
                for _ in range(np.prod(self.kernel_size))
            kernel = np.reshape(kernel, self.kernel_shape)
            kernel = np.array(self.init, dtype=rc.float_dtype)
        return kernel

    def kernel_shape(self):
        """Full shape of kernel."""
        return self.kernel_size + (
            self.input_shape.n_channels // self.groups,

    def size_in(self):
        return self.input_shape.size

    def size_out(self):
        return self.output_shape.size

    def dimensions(self):
        """Dimensionality of convolution."""
        return self.input_shape.dimensions

    def _forward_shape(self, input_spatial_shape, n_filters):
        output_shape = np.array(input_spatial_shape, dtype=rc.float_dtype)
        if self.padding == "valid":
            output_shape -= self.kernel_size
            output_shape += 1
        output_shape /= self.strides
        output_shape = tuple(np.ceil(output_shape).astype(rc.int_dtype))

        return ChannelShape.from_space_and_channels(
            output_shape, n_filters, channels_last=self.channels_last
예제 #7
class PresentJitteredImages(Process):
    images = NdarrayParam('images', shape=('...', ))
    image_shape = ShapeParam('image_shape', length=3, low=1)
    output_shape = ShapeParam('output_shape', length=2, low=1)
    presentation_time = NumberParam('presentation_time', low=0, low_open=True)
    jitter_std = NumberParam('jitter_std', low=0, low_open=True, optional=True)
    jitter_tau = NumberParam('jitter_tau', low=0, low_open=True)

    def __init__(self,
        import scipy.ndimage.interpolation
        # ^ required for simulation, so check it here

        self.images = images
        self.presentation_time = presentation_time
        self.image_shape = images.shape[1:]
        self.output_shape = output_shape
        self.jitter_std = jitter_std
        self.jitter_tau = (presentation_time
                           if jitter_tau is None else jitter_tau)

        nc = self.image_shape[0]
        nyi, nyj = self.output_shape
                             default_size_out=nc * nyi * nyj,

    def make_step(self, shape_in, shape_out, dt, rng):
        import scipy.ndimage.interpolation

        nc, nxi, nxj = self.image_shape
        nyi, nyj = self.output_shape
        ni, nj = nxi - nyi, nxj - nyj
        nij = np.array([ni, nj])
        assert shape_in == (0, )
        assert shape_out == (nc * nyi * nyj, )

        if self.jitter_std is None:
            si, sj = ni / 4., nj / 4.
            si = sj = self.jitter_std

        tau = self.jitter_tau

        n = len(self.images)
        images = self.images.reshape((n, nc, nxi, nxj))
        presentation_time = float(self.presentation_time)

        cij = (nij - 1) / 2.
        dt7tau = dt / tau
        sigma2 = np.sqrt(2. * dt / tau) * np.array([si, sj])
        ij = cij.copy()

        def step_presentjitteredimages(t):
            # update jitter position
            ij0 = dt7tau * (cij - ij) + sigma2 * rng.normal(size=2)
            ij[:] = (ij + ij0).clip((0, 0), (ni, nj))

            # select image
            k = int((t - dt) / presentation_time + 1e-7)
            image = images[k % n]

            # interpolate jittered sub-image
            i, j = ij
            image = scipy.ndimage.interpolation.shift(
                image, (0, ni - i, nj - j))[:, -nyi:, -nyj:]

            return image.ravel()

        return step_presentjitteredimages
예제 #8
class Pool2d(Process):
    """Perform 2-D (image) pooling on an input.

    shape_in : 3-tuple (channels, height, width)
        Shape of the input image.
    pool_size : 2-tuple (vertical, horizontal) or int
        Shape of the pooling region. If an integer is provided, the shape will
        be square with the given side length.
    strides : 2-tuple (vertical, horizontal) or int
        Spacing between pooling placements. If ``None`` (default), will be
        equal to ``pool_size`` resulting in non-overlapping pooling.
    kind : "avg" or "max"
        Type of pooling to perform: average pooling or max pooling.
    mode : "full" or "valid"
        If the input image does not divide into an integer number of pooling
        regions, whether to add partial pooling regions for the extra
        pixels ("full"), or discard extra input pixels ("valid").

    shape_out : 3-tuple (channels, height, width)
        Shape of the output image.
    shape_in = ShapeParam('shape_in', length=3, low=1)
    shape_out = ShapeParam('shape_out', length=3, low=1)
    pool_size = ShapeParam('pool_size', length=2, low=1)
    strides = ShapeParam('strides', length=2, low=1)
    kind = EnumParam('kind', values=('avg', 'max'))
    mode = EnumParam('mode', values=('full', 'valid'))

    def __init__(self,
        self.shape_in = shape_in
        self.pool_size = (pool_size
                          if is_iterable(pool_size) else [pool_size] * 2)
        self.strides = (strides if is_iterable(strides) else [strides] *
                        2 if strides is not None else self.pool_size)
        self.kind = kind
        self.mode = mode
        if not all(st <= p for st, p in zip(self.strides, self.pool_size)):
            raise ValueError("Strides %s must be <= pool_size %s" %
                             (self.strides, self.pool_size))

        nc, nxi, nxj = self.shape_in
        nyi_float = float(nxi - self.pool_size[0]) / self.strides[0]
        nyj_float = float(nxj - self.pool_size[1]) / self.strides[1]
        if self.mode == 'full':
            nyi = 1 + int(np.ceil(nyi_float))
            nyj = 1 + int(np.ceil(nyj_float))
        elif self.mode == 'valid':
            nyi = 1 + int(np.floor(nyi_float))
            nyj = 1 + int(np.floor(nyj_float))
        self.shape_out = (nc, nyi, nyj)

        super(Pool2d, self).__init__(default_size_in=np.prod(self.shape_in),

    def make_step(self, shape_in, shape_out, dt, rng):
        assert np.prod(shape_in) == np.prod(self.shape_in)
        assert np.prod(shape_out) == np.prod(self.shape_out)
        nc, nxi, nxj = self.shape_in
        nc, nyi, nyj = self.shape_out
        si, sj = self.pool_size
        sti, stj = self.strides
        kind = self.kind
        nxi2, nxj2 = nyi * sti, nyj * stj

        def step_pool2d(t, x):
            x = x.reshape(-1, nc, nxi, nxj)
            y = np.zeros((x.shape[0], nc, nyi, nyj), dtype=x.dtype)
            n = np.zeros((nyi, nyj))

            for i in range(si):
                for j in range(sj):
                    xij = x[:, :, i:min(nxi2 + i, nxi):sti,
                            j:min(nxj2 + j, nxj):stj]
                    ni, nj = xij.shape[-2:]
                    if kind == 'max':
                        y[:, :, :ni, :nj] = np.maximum(y[:, :, :ni, :nj], xij)
                    elif kind == 'avg':
                        y[:, :, :ni, :nj] += xij
                        n[:ni, :nj] += 1
                        raise NotImplementedError(kind)

            if kind == 'avg':
                y /= n

            return y.ravel()

        return step_pool2d
예제 #9
class Conv2d(Process):
    """Perform 2-D (image) convolution on an input.

    shape_in : 3-tuple (n_channels, height, width)
        Shape of the input images: channels, height, width.
    filters : array_like (n_filters, n_channels, f_height, f_width)
        Static filters to convolve with the input. Shape is number of filters,
        number of input channels, filter height, and filter width. Shape can
        also be (n_filters, height, width, n_channels, f_height, f_width)
        to apply different filters at each point in the image, where 'height'
        and 'width' are the input image height and width.
    biases : array_like (1,) or (n_filters,) or (n_filters, height, width)
        Biases to add to outputs. Can have one bias across the entire output
        space, one bias per filter, or a unique bias for each output pixel.
    strides : 2-tuple (vertical, horizontal) or int
        Spacing between filter placements. If an integer
        is provided, the same spacing is used in both dimensions.
    padding : 2-tuple (vertical, horizontal) or int
        Amount of zero-padding around the outside of the input image. Padding
        is applied to both sides, e.g. ``padding=(1, 0)`` will add one pixel
        of padding to the top and bottom, and none to the left and right.

    shape_in = ShapeParam('shape_in', length=3, low=1)
    shape_out = ShapeParam('shape_out', length=3, low=1)
    strides = ShapeParam('strides', length=2, low=1)
    padding = ShapeParam('padding', length=2)
    filters = NdarrayParam('filters', shape=('...', ))
    biases = NdarrayParam('biases', shape=('...', ), optional=True)
    border = EnumParam('border', values=('floor', 'ceil'))

    def __init__(self,
                 border='ceil'):  # noqa: C901
        self.shape_in = shape_in
        self.filters = filters
        if self.filters.ndim not in [4, 6]:
            raise ValueError(
                "`filters` must have four or six dimensions "
                "(filters, [height, width,] channels, f_height, f_width)")
        if self.filters.shape[-3] != self.shape_in[0]:
            raise ValueError(
                "Filter channels (%d) and input channels (%d) must match" %
                (self.filters.shape[-3], self.shape_in[0]))
        if not all(s % 2 == 1 for s in self.filters.shape[-2:]):
            raise ValueError("Filter shapes must be odd (got %r)" %
                             (self.filters.shape[-2:], ))

        self.strides = strides if is_iterable(strides) else [strides] * 2
        self.padding = padding if is_iterable(padding) else [padding] * 2
        self.border = border

        nf = self.filters.shape[0]
        nxi, nxj = self.shape_in[1:]
        si, sj = self.filters.shape[-2:]
        pi, pj = self.padding
        sti, stj = self.strides
        rounder = np.ceil if self.border == 'ceil' else np.floor
        nyi = 1 + max(int(rounder(float(2 * pi + nxi - si) / sti)), 0)
        nyj = 1 + max(int(rounder(float(2 * pj + nxj - sj) / stj)), 0)
        self.shape_out = (nf, nyi, nyj)
        if self.filters.ndim == 6 and self.filters.shape[1:3] != (nyi, nyj):
            raise ValueError("Number of local filters %r must match out shape "
                             "%r" % (self.filters.shape[1:3], (nyi, nyj)))

        self.biases = biases if biases is not None else None
        if self.biases is not None:
            if self.biases.size == 1:
                self.biases.shape = (1, 1, 1)
            elif self.biases.size == np.prod(self.shape_out):
                self.biases.shape = self.shape_out
            elif self.biases.size == self.shape_out[0]:
                self.biases.shape = (self.shape_out[0], 1, 1)
            elif self.biases.size == np.prod(self.shape_out[1:]):
                self.biases.shape = (1, ) + self.shape_out[1:]
                raise ValueError(
                    "Biases size (%d) does not match output shape %s" %
                    (self.biases.size, self.shape_out))

        super(Conv2d, self).__init__(default_size_in=np.prod(self.shape_in),

    def make_step(self, shape_in, shape_out, dt, rng):
        assert np.prod(shape_in) == np.prod(self.shape_in)
        assert np.prod(shape_out) == np.prod(self.shape_out)
        shape_in, shape_out = self.shape_in, self.shape_out

        filters = self.filters
        local_filters = filters.ndim == 6
        biases = self.biases

        nc, nxi, nxj = shape_in
        nf, nyi, nyj = shape_out
        si, sj = filters.shape[-2:]
        pi, pj = self.padding
        sti, stj = self.strides

        def step_conv2d(t, x):
            x = x.reshape(-1, nc, nxi, nxj)
            n = x.shape[0]
            y = np.zeros((n, nf, nyi, nyj), dtype=x.dtype)

            for i in range(nyi):
                for j in range(nyj):
                    i0 = i * sti - pi
                    j0 = j * stj - pj
                    i1, j1 = i0 + si, j0 + sj
                    sli = slice(max(-i0, 0), min(nxi + si - i1, si))
                    slj = slice(max(-j0, 0), min(nxj + sj - j1, sj))
                    w = (filters[:, i, j, :, sli,
                                 slj] if local_filters else filters[:, :, sli,
                    xij = x[:, :,
                            max(i0, 0):min(i1, nxi),
                            max(j0, 0):min(j1, nxj)]
                    y[:, :, i, j] = np.dot(xij.reshape(n, -1),
                                           w.reshape(nf, -1).T)

            if biases is not None:
                y += biases

            return y.ravel()

        return step_conv2d