Example #1
0
class Convolution(Transform):
    """An N-dimensional convolutional transform.

    The dimensionality of the convolution is determined by the input shape.

    .. versionadded:: 3.0.0

    Parameters
    ----------
    n_filters : int
        The number of convolutional filters to apply
    input_shape : tuple of int or `.ChannelShape`
        Shape of the input signal to the convolution; e.g.,
        ``(height, width, channels)`` for a 2D convolution with
        ``channels_last=True``.
    kernel_size : tuple of int, optional
        Size of the convolutional kernels (1 element for a 1D convolution,
        2 for a 2D convolution, etc.).
    strides : tuple of int, optional
        Stride of the convolution (1 element for a 1D convolution, 2 for
        a 2D convolution, etc.).
    padding : ``"same"`` or ``"valid"``, optional
        Padding method for input signal. "Valid" means no padding, and
        convolution will only be applied to the fully-overlapping areas of the
        input signal (meaning the output will be smaller). "Same" means that
        the input signal is zero-padded so that the output is the same shape
        as the input.
    channels_last : bool, optional
        If ``True`` (default), the channels are the last dimension in the input
        signal (e.g., a 28x28 image with 3 channels would have shape
        ``(28, 28, 3)``).  ``False`` means that channels are the first
        dimension (e.g., ``(3, 28, 28)``).
    init : `.Distribution` or `~numpy:numpy.ndarray`, optional
        A predefined kernel with shape
        ``kernel_size + (input_channels, n_filters)``, or a ``Distribution``
        that will be used to initialize the kernel.

    Notes
    -----
    As is typical in neural networks, this is technically correlation rather
    than convolution (because the kernel is not flipped).
    """

    n_filters = IntParam("n_filters", low=1)
    input_shape = ChannelShapeParam("input_shape", low=1)
    kernel_size = ShapeParam("kernel_size", low=1)
    strides = ShapeParam("strides", low=1)
    padding = EnumParam("padding", values=("same", "valid"))
    channels_last = BoolParam("channels_last")
    init = DistOrArrayParam("init")

    _param_init_order = ["channels_last", "input_shape"]

    def __init__(
            self,
            n_filters,
            input_shape,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding="valid",
            channels_last=True,
            init=Uniform(-1, 1),
    ):
        super().__init__()

        self.n_filters = n_filters
        self.channels_last = channels_last  # must be set before input_shape
        self.input_shape = input_shape
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.init = init

        if len(kernel_size) != self.dimensions:
            raise ValidationError(
                "Kernel dimensions (%d) do not match input dimensions (%d)" %
                (len(kernel_size), self.dimensions),
                attr="kernel_size",
            )
        if len(strides) != self.dimensions:
            raise ValidationError(
                "Stride dimensions (%d) do not match input dimensions (%d)" %
                (len(strides), self.dimensions),
                attr="strides",
            )
        if not isinstance(init, Distribution):
            if init.shape != self.kernel_shape:
                raise ValidationError(
                    "Kernel shape %s does not match expected shape %s" %
                    (init.shape, self.kernel_shape),
                    attr="init",
                )

    @property
    def _argreprs(self):
        argreprs = [
            "n_filters=%r" % (self.n_filters, ),
            "input_shape=%s" % (self.input_shape.shape, ),
        ]
        if self.kernel_size != (3, 3):
            argreprs.append("kernel_size=%r" % (self.kernel_size, ))
        if self.strides != (1, 1):
            argreprs.append("strides=%r" % (self.strides, ))
        if self.padding != "valid":
            argreprs.append("padding=%r" % (self.padding, ))
        if self.channels_last is not True:
            argreprs.append("channels_last=%r" % (self.channels_last, ))
        return argreprs

    def sample(self, rng=np.random):
        if isinstance(self.init, Distribution):
            # we sample this way so that any variancescaling distribution based
            # on n/d is scaled appropriately
            kernel = [
                self.init.sample(self.input_shape.n_channels,
                                 self.n_filters,
                                 rng=rng)
                for _ in range(np.prod(self.kernel_size))
            ]
            kernel = np.reshape(kernel, self.kernel_shape)
        else:
            kernel = np.array(self.init, dtype=rc.float_dtype)
        return kernel

    @property
    def kernel_shape(self):
        """Full shape of kernel."""
        return self.kernel_size + (self.input_shape.n_channels, self.n_filters)

    @property
    def size_in(self):
        return self.input_shape.size

    @property
    def size_out(self):
        return self.output_shape.size

    @property
    def dimensions(self):
        """Dimensionality of convolution."""
        return self.input_shape.dimensions

    @property
    def output_shape(self):
        """Output shape after applying convolution to input."""
        output_shape = np.array(self.input_shape.spatial_shape,
                                dtype=rc.float_dtype)
        if self.padding == "valid":
            output_shape -= self.kernel_size
            output_shape += 1
        output_shape /= self.strides
        output_shape = tuple(np.ceil(output_shape).astype(rc.int_dtype))
        output_shape = (output_shape +
                        (self.n_filters, ) if self.channels_last else
                        (self.n_filters, ) + output_shape)

        return ChannelShape(output_shape, channels_last=self.channels_last)
Example #2
0
class Piecewise(Process):
    """A piecewise function with different options for interpolation.

    Given an input dictionary of ``{0: 0, 0.5: -1, 0.75: 0.5, 1: 0}``,
    this process  will emit the numerical values (0, -1, 0.5, 0)
    starting at the corresponding time points (0, 0.5, 0.75, 1).

    The keys in the input dictionary must be times (float or int).
    The values in the dictionary can be floats, lists of floats,
    or numpy arrays. All lists or numpy arrays must be of the same length,
    as the output shape of the process will be determined by the shape
    of the values.

    Interpolation on the data points using `scipy.interpolate` is also
    supported. The default interpolation is 'zero', which creates a
    piecewise function whose values change at the specified time points.
    So the above example would be shortcut for::

        def function(t):
            if t < 0.5:
                return 0
            elif t < 0.75
                return -1
            elif t < 1:
                return 0.5
            else:
                return 0

    For times before the first specified time, an array of zeros (of
    the correct length) will be emitted.
    This means that the above can be simplified to::

        Piecewise({0.5: -1, 0.75: 0.5, 1: 0})

    Parameters
    ----------
    data : dict
        A dictionary mapping times to the values that should be emitted
        at those times. Times must be numbers (ints or floats), while values
        can be numbers, lists of numbers, numpy arrays of numbers,
        or callables that return any of those options.
    interpolation : str, optional (Default: 'zero')
        One of 'linear', 'nearest', 'slinear', 'quadratic', 'cubic', or 'zero'.
        Specifies how to interpolate between times with specified value.
        'zero' creates a plain piecewise function whose values begin at
        corresponding time points, while all other options interpolate
        as described in `scipy.interpolate`.

    Attributes
    ----------
    data : dict
        A dictionary mapping times to the values that should be emitted
        at those times. Times are numbers (ints or floats), while values
        can be numbers, lists of numbers, numpy arrays of numbers,
        or callables that return any of those options.
    interpolation : str
        One of 'linear', 'nearest', 'slinear', 'quadratic', 'cubic', or 'zero'.
        Specifies how to interpolate between times with specified value.
        'zero' creates a plain piecewise function whose values change at
        corresponding time points, while all other options interpolate
        as described in `scipy.interpolate`.

    Examples
    --------

    >>> from nengo.processes import Piecewise
    >>> process = Piecewise({0.5: 1, 0.75: -1, 1: 0})
    >>> with nengo.Network() as model:
    ...     u = nengo.Node(process, size_out=process.default_size_out)
    ...     up = nengo.Probe(u)
    >>> with nengo.Simulator(model) as sim:
    ...     sim.run(1.5)
    >>> f = sim.data[up]
    >>> t = sim.trange()
    >>> f[t == 0.2]
    array([[ 0.]])
    >>> f[t == 0.58]
    array([[ 1.]])
    """

    data = PiecewiseDataParam('data', readonly=True)
    interpolation = EnumParam('interpolation',
                              values=('zero', 'linear', 'nearest', 'slinear',
                                      'quadratic', 'cubic'))

    def __init__(self, data, interpolation='zero', **kwargs):
        self.data = data

        needs_scipy = ('linear', 'nearest', 'slinear', 'quadratic', 'cubic')
        if interpolation in needs_scipy:
            self.sp_interpolate = None
            if any(callable(val) for val in itervalues(self.data)):
                warnings.warn("%r interpolation cannot be applied because "
                              "a callable was supplied for some piece of the "
                              "function. Using 'zero' interpolation instead." %
                              (interpolation, ))
                interpolation = 'zero'
            else:
                try:
                    import scipy.interpolate
                    self.sp_interpolate = scipy.interpolate
                except ImportError:
                    warnings.warn("%r interpolation cannot be applied because "
                                  "scipy is not installed. Using 'zero' "
                                  "interpolation instead." % (interpolation, ))
                    interpolation = 'zero'
        self.interpolation = interpolation

        super(Piecewise, self).__init__(default_size_in=0,
                                        default_size_out=self.size_out,
                                        **kwargs)

    @property
    def size_out(self):
        time, value = next(iteritems(self.data))
        value = np.ravel(value(time)) if callable(value) else value
        return value.size

    def make_step(self, shape_in, shape_out, dt, rng):
        tp, yp = zip(*sorted(iteritems(self.data)))
        assert shape_in == (0, )
        assert shape_out == (self.size_out, )

        if self.interpolation == 'zero':

            def step_piecewise(t):
                ti = (np.searchsorted(tp, t + 0.5 * dt) - 1).clip(
                    -1,
                    len(yp) - 1)
                if ti == -1:
                    return np.zeros(shape_out)
                else:
                    return (np.ravel(yp[ti](t))
                            if callable(yp[ti]) else yp[ti])
        else:
            assert self.sp_interpolate is not None

            if self.interpolation == "cubic" and 0 not in tp:
                warnings.warn("'cubic' interpolation may fail if data not "
                              "specified for t=0.0")

            f = self.sp_interpolate.interp1d(tp,
                                             yp,
                                             axis=0,
                                             kind=self.interpolation,
                                             bounds_error=False,
                                             fill_value=0.)

            def step_piecewise(t):
                return np.ravel(f(t))

        return step_piecewise
Example #3
0
class ScatteredHypersphere(Distribution):
    r"""Quasirandom distribution over the hypersphere or hyperball.

    Applies a spherical transform to the given quasirandom sequence
    (by default `.QuasirandomSequence`) to obtain uniformly scattered samples.

    This distribution has the nice mathematical property that the discrepancy
    between the empirical distribution and :math:`n` samples is
    :math:`\widetilde{\mathcal{O}} (1 / n)` as opposed to
    :math:`\mathcal{O} (1 / \sqrt{n})` for the Monte Carlo method [1]_.
    This means that the number of samples is effectively squared, making this
    useful as a means for sampling ``eval_points`` and ``encoders``.

    Parameters
    ----------
    surface : bool, optional
        Whether sample points should be distributed uniformly
        over the surface of the hyperphere (True),
        or within the hypersphere (False).
    min_magnitude : Number, optional
        Lower bound on the returned vector magnitudes (such that they are in
        the range ``[min_magnitude, 1]``). Must be in the range [0, 1).
        Ignored if ``surface`` is ``True``.
    base : `.Distribution`, optional
        The base distribution from which to sample quasirandom numbers.
    method : {"sct-approx", "sct", "tfww"}
        Method to use for mapping points to the hypersphere.

        * "sct-approx": Same as "sct", but uses lookup table to approximate the
          beta distribution, making it faster with almost exactly the same result.
        * "sct": Use the exact Spherical Coordinate Transform
          (section 1.5.2 of [1]_).
        * "tfww": Use the Tashiro-Fang-Wang-Wong method (section 4.3 of [1]_).
          Faster than "sct" and "sct-approx", with the same level of uniformity
          for larger numbers of samples (``n >= 4000``, approximately).

    See Also
    --------
    UniformHypersphere
    QuasirandomSequence

    Notes
    -----
    The `.QuasirandomSequence` distribution is mostly deterministic.
    Nondeterminism comes from a random ``d``-dimensional rotation.

    References
    ----------
    .. [1] K.-T. Fang and Y. Wang, Number-Theoretic Methods in Statistics.
       Chapman & Hall, 1994.

    Examples
    --------
    Plot points sampled from the surface of the sphere in 3 dimensions:

    .. testcode::

       from mpl_toolkits.mplot3d import Axes3D

       points = nengo.dists.ScatteredHypersphere(surface=True).sample(1000, d=3)

       ax = plt.subplot(111, projection="3d")
       ax.scatter(*points.T, s=5)

    Plot points sampled from the volume of the sphere in 2 dimensions (i.e. circle):

    .. testcode::

       points = nengo.dists.ScatteredHypersphere(surface=False).sample(1000, d=2)
       plt.scatter(*points.T, s=5)
    """

    surface = BoolParam("surface")
    min_magnitude = NumberParam("min_magnitude", low=0, high=1, high_open=True)
    base = DistributionParam("base")
    method = EnumParam("method", values=("sct-approx", "sct", "tfww"))

    def __init__(
            self,
            surface=False,
            min_magnitude=0,
            base=QuasirandomSequence(),
            method="sct-approx",
    ):
        super().__init__()
        if surface and min_magnitude > 0:
            warnings.warn("min_magnitude ignored because surface is True")
        self.surface = surface
        self.min_magnitude = min_magnitude
        self.base = base
        self.method = method

        if self.method == "sct":
            import scipy.special  # pylint: disable=import-outside-toplevel

            assert scipy.special

    @classmethod
    def spherical_coords_ppf(cls, dims, y, approx=False):
        if not approx:
            import scipy.special  # pylint: disable=import-outside-toplevel

        y_reflect = np.where(y < 0.5, y, 1 - y)
        if approx:
            z_sq = _betaincinv22.lookup(dims, 2 * y_reflect)
        else:
            z_sq = scipy.special.betaincinv(dims / 2.0, 0.5, 2 * y_reflect)
        x = np.arcsin(np.sqrt(z_sq)) / np.pi
        return np.where(y < 0.5, x, 1 - x)

    @classmethod
    def spherical_transform_sct(cls, samples, approx=False):
        """Map samples from the ``[0, 1]``-cube onto the hypersphere.

        Uses the SCT method described in section 1.5.3 of Fang and Wang (1994).
        """
        samples = np.asarray(samples)
        samples = samples[:, np.newaxis] if samples.ndim == 1 else samples
        n, d = samples.shape

        # inverse transform method (section 1.5.2)
        coords = np.empty_like(samples)
        for j in range(d):
            coords[:, j] = cls.spherical_coords_ppf(d - j,
                                                    samples[:, j],
                                                    approx=approx)

        # spherical coordinate transform
        mapped = np.ones((n, d + 1))
        i = np.ones(d)
        i[-1] = 2.0
        s = np.sin(i[np.newaxis, :] * np.pi * coords)
        c = np.cos(i[np.newaxis, :] * np.pi * coords)
        mapped[:, 1:] = np.cumprod(s, axis=1)
        mapped[:, :-1] *= c
        return mapped

    @staticmethod
    def spherical_transform_tfww(c_samples):
        """Map samples from the ``[0, 1]``-cube onto the hypersphere surface.

        Uses the TFWW method described in section 4.3 of Fang and Wang (1994).
        """
        c_samples = np.asarray(c_samples)
        c_samples = c_samples[:,
                              np.newaxis] if c_samples.ndim == 1 else c_samples
        n, s1 = c_samples.shape
        s = s1 + 1

        x_samples = np.zeros((n, s))

        if s == 2:
            phi = 2 * np.pi * c_samples[:, 0]
            x_samples[:, 0] = np.cos(phi)
            x_samples[:, 1] = np.sin(phi)
            return x_samples

        even = s % 2 == 0
        m = s // 2 if even else (s - 1) // 2

        g = np.zeros((n, m + 1))
        g[:, -1] = 1
        for j in range(m - 1, 0, -1):
            g[:,
              j] = g[:, j + 1] * c_samples[:, j - 1]**((1.0 / j) if even else
                                                       (2.0 / (2 * j + 1)))

        d = np.sqrt(np.diff(g, axis=1))

        phi = c_samples[:, m - 1:]
        if even:
            phi *= 2 * np.pi
            x_samples[:, 0::2] = d * np.cos(phi)
            x_samples[:, 1::2] = d * np.sin(phi)
        else:
            # there is a mistake in eq. 4.3.7 here, see eq. 1.5.28 for correct version
            phi[:, 1:] *= 2 * np.pi
            f = 2 * d[:, 0] * np.sqrt(phi[:, 0] * (1 - phi[:, 0]))
            x_samples[:, 0] = d[:, 0] * (1 - 2 * phi[:, 0])
            x_samples[:, 1] = f * np.cos(phi[:, 1])
            x_samples[:, 2] = f * np.sin(phi[:, 1])
            if s > 3:
                x_samples[:, 3::2] = d[:, 1:] * np.cos(phi[:, 2:])
                x_samples[:, 4::2] = d[:, 1:] * np.sin(phi[:, 2:])

        return x_samples

    @staticmethod
    def random_orthogonal(d, rng=np.random):
        """Returns a random orthogonal matrix."""
        m = rng.standard_normal((d, d))
        u, _, v = np.linalg.svd(m)
        return np.dot(u, v)

    def sample(self, n, d=1, rng=np.random):
        if d == 1 and self.surface:
            return np.sign(self.base.sample(n, d, rng) - 0.5)

        if d == 1:
            pos_samples = self.base.sample(int(n / 2), d, rng)
            neg_samples = self.base.sample(n - pos_samples.size, d, rng)
            if self.min_magnitude > 0:
                for samples in [pos_samples, neg_samples]:
                    samples *= 1.0 - self.min_magnitude
                    samples += self.min_magnitude
            samples = np.vstack([pos_samples, -1 * neg_samples])
            rng.shuffle(samples)
            return samples

        radius = None
        if self.surface:
            samples = self.base.sample(n, d - 1, rng)
        else:
            samples = self.base.sample(n, d, rng)
            samples, radius = samples[:, :-1], samples[:, -1:]
            if self.min_magnitude != 0:
                min_d = self.min_magnitude**d
                radius *= 1 - min_d
                radius += min_d
            radius **= 1.0 / d

        if self.method == "sct":
            mapped = self.spherical_transform_sct(samples, approx=False)
        elif self.method == "sct-approx":
            mapped = self.spherical_transform_sct(samples, approx=True)
        else:
            assert self.method == "tfww"
            mapped = self.spherical_transform_tfww(samples)

        # radius adjustment for ball
        if radius is not None:
            mapped *= radius

        # random rotation
        rotation = self.random_orthogonal(d, rng=rng)
        return np.dot(mapped, rotation)
Example #4
0
class Pool2d(Process):
    """Perform 2-D (image) pooling on an input.

    Parameters
    ----------
    shape_in : 3-tuple (channels, height, width)
        Shape of the input image.
    pool_size : 2-tuple (vertical, horizontal) or int
        Shape of the pooling region. If an integer is provided, the shape will
        be square with the given side length.
    strides : 2-tuple (vertical, horizontal) or int
        Spacing between pooling placements. If ``None`` (default), will be
        equal to ``pool_size`` resulting in non-overlapping pooling.
    kind : "avg" or "max"
        Type of pooling to perform: average pooling or max pooling.
    mode : "full" or "valid"
        If the input image does not divide into an integer number of pooling
        regions, whether to add partial pooling regions for the extra
        pixels ("full"), or discard extra input pixels ("valid").

    Attributes
    ----------
    shape_out : 3-tuple (channels, height, width)
        Shape of the output image.
    """
    shape_in = ShapeParam('shape_in', length=3, low=1)
    shape_out = ShapeParam('shape_out', length=3, low=1)
    pool_size = ShapeParam('pool_size', length=2, low=1)
    strides = ShapeParam('strides', length=2, low=1)
    kind = EnumParam('kind', values=('avg', 'max'))
    mode = EnumParam('mode', values=('full', 'valid'))

    def __init__(self,
                 shape_in,
                 pool_size,
                 strides=None,
                 kind='avg',
                 mode='full'):
        self.shape_in = shape_in
        self.pool_size = (pool_size
                          if is_iterable(pool_size) else [pool_size] * 2)
        self.strides = (strides if is_iterable(strides) else [strides] *
                        2 if strides is not None else self.pool_size)
        self.kind = kind
        self.mode = mode
        if not all(st <= p for st, p in zip(self.strides, self.pool_size)):
            raise ValueError("Strides %s must be <= pool_size %s" %
                             (self.strides, self.pool_size))

        nc, nxi, nxj = self.shape_in
        nyi_float = float(nxi - self.pool_size[0]) / self.strides[0]
        nyj_float = float(nxj - self.pool_size[1]) / self.strides[1]
        if self.mode == 'full':
            nyi = 1 + int(np.ceil(nyi_float))
            nyj = 1 + int(np.ceil(nyj_float))
        elif self.mode == 'valid':
            nyi = 1 + int(np.floor(nyi_float))
            nyj = 1 + int(np.floor(nyj_float))
        self.shape_out = (nc, nyi, nyj)

        super(Pool2d, self).__init__(default_size_in=np.prod(self.shape_in),
                                     default_size_out=np.prod(self.shape_out))

    def make_step(self, shape_in, shape_out, dt, rng):
        assert np.prod(shape_in) == np.prod(self.shape_in)
        assert np.prod(shape_out) == np.prod(self.shape_out)
        nc, nxi, nxj = self.shape_in
        nc, nyi, nyj = self.shape_out
        si, sj = self.pool_size
        sti, stj = self.strides
        kind = self.kind
        nxi2, nxj2 = nyi * sti, nyj * stj

        def step_pool2d(t, x):
            x = x.reshape(nc, nxi, nxj)
            y = np.zeros((nc, nyi, nyj), dtype=x.dtype)
            n = np.zeros((nyi, nyj))

            for i in range(si):
                for j in range(sj):
                    xij = x[:, i:min(nxi2 + i, nxi):sti,
                            j:min(nxj2 + j, nxj):stj]
                    ni, nj = xij.shape[-2:]
                    if kind == 'max':
                        y[:, :ni, :nj] = np.maximum(y[:, :ni, :nj], xij)
                    elif kind == 'avg':
                        y[:, :ni, :nj] += xij
                        n[:ni, :nj] += 1
                    else:
                        raise NotImplementedError(kind)

            if kind == 'avg':
                y /= n

            return y.ravel()

        return step_pool2d
Example #5
0
class LinearFilter(Synapse):
    """General linear time-invariant (LTI) system synapse.

    This class can be used to implement any linear filter, given the
    filter's transfer function. [1]_

    Parameters
    ----------
    num : array_like
        Numerator coefficients of transfer function.
    den : array_like
        Denominator coefficients of transfer function.
    analog : boolean, optional
        Whether the synapse coefficients are analog (i.e. continuous-time),
        or discrete. Analog coefficients will be converted to discrete for
        simulation using the simulator ``dt``.
    method : string
        The method to use for discretization (if ``analog`` is True). See
        `scipy.signal.cont2discrete` for information about the options.

        .. versionadded:: 3.0.0

    Attributes
    ----------
    analog : boolean
        Whether the synapse coefficients are analog (i.e. continuous-time),
        or discrete. Analog coefficients will be converted to discrete for
        simulation using the simulator ``dt``.
    den : ndarray
        Denominator coefficients of transfer function.
    num : ndarray
        Numerator coefficients of transfer function.
    method : string
        The method to use for discretization (if ``analog`` is True). See
        `scipy.signal.cont2discrete` for information about the options.

    References
    ----------
    .. [1] https://en.wikipedia.org/wiki/Filter_%28signal_processing%29
    """

    num = NdarrayParam("num", shape="*")
    den = NdarrayParam("den", shape="*")
    analog = BoolParam("analog")
    method = EnumParam(
        "method", values=("gbt", "bilinear", "euler", "backward_diff", "zoh")
    )

    def __init__(self, num, den, analog=True, method="zoh", **kwargs):
        super().__init__(**kwargs)
        self.num = num
        self.den = den
        self.analog = analog
        self.method = method

    def combine(self, obj):
        """Combine in series with another LinearFilter."""
        if not isinstance(obj, LinearFilter):
            raise ValidationError(
                "Can only combine with other LinearFilters", attr="obj"
            )
        if self.analog != obj.analog:
            raise ValidationError(
                "Cannot combine analog and digital filters", attr="obj"
            )
        num = np.polymul(self.num, obj.num)
        den = np.polymul(self.den, obj.den)
        return LinearFilter(
            num,
            den,
            analog=self.analog,
            default_size_in=self.default_size_in,
            default_size_out=self.default_size_out,
            default_dt=self.default_dt,
            seed=self.seed,
        )

    def evaluate(self, frequencies):
        """Evaluate the transfer function at the given frequencies.

        Examples
        --------
        Using the ``evaluate`` function to make a Bode plot:

        .. testcode::

           import matplotlib.pyplot as plt

           synapse = nengo.synapses.LinearFilter([1], [0.02, 1])
           f = np.logspace(-1, 3, 100)
           y = synapse.evaluate(f)
           plt.subplot(211); plt.semilogx(f, 20*np.log10(np.abs(y)))
           plt.xlabel('frequency [Hz]'); plt.ylabel('magnitude [dB]')
           plt.subplot(212); plt.semilogx(f, np.angle(y))
           plt.xlabel('frequency [Hz]'); plt.ylabel('phase [radians]')
        """
        frequencies = 2.0j * np.pi * frequencies
        w = frequencies if self.analog else np.exp(frequencies)
        y = np.polyval(self.num, w) / np.polyval(self.den, w)
        return y

    def _get_ss(self, dt):
        A, B, C, D = tf2ss(self.num, self.den)

        # discretize (if len(A) == 0, filter is stateless and already discrete)
        if self.analog and len(A) > 0:
            A, B, C, D, _ = cont2discrete((A, B, C, D), dt, method=self.method)

        return A, B, C, D

    def make_state(self, shape_in, shape_out, dt, dtype=None, y0=0):
        assert shape_in == shape_out

        dtype = rc.float_dtype if dtype is None else np.dtype(dtype)
        if dtype.kind != "f":
            raise ValidationError(
                f"Only float data types are supported (got {dtype}). Please cast "
                "your data to a float type.",
                attr="dtype",
                obj=self,
            )

        A, B, C, D = self._get_ss(dt)

        # create state memory variable X
        X = np.zeros((A.shape[0],) + shape_out, dtype=dtype)

        # initialize X using y0 as steady-state output
        y0 = np.array(y0, copy=False, ndmin=2)
        if (y0 == 0).all():
            # just leave X as zeros in this case, so that this value works
            # for unstable systems
            pass
        elif LinearFilter.OneX.check(A, B, C, D, X):
            # OneX combines B and C into one scaling value `b`
            b = B.item() * C.item()
            X[:] = (b / (1 - A.item())) * y0
        else:
            # Solve for u0 (input) given y0 (output), then X given u0
            assert B.ndim == 1 or B.ndim == 2 and B.shape[1] == 1
            y0 = np.array(y0, copy=False, ndmin=2)
            IAB = np.linalg.solve(np.eye(len(A)) - A, B)
            Q = C.dot(IAB) + D  # multiplier from input to output (DC gain)
            assert Q.size == 1
            if np.abs(Q.item()) > 1e-8:
                u0 = y0 / Q.item()
                X[:] = IAB.dot(u0)
            else:
                raise ValidationError(
                    "Cannot solve for state if DC gain is zero. Please set `y0=0`.",
                    "y0",
                    obj=self,
                )

        return {"X": X}

    def make_step(self, shape_in, shape_out, dt, rng, state):
        """Returns a `.Step` instance that implements the linear filter."""
        assert shape_in == shape_out
        assert state is not None

        A, B, C, D = self._get_ss(dt)
        X = state["X"]

        if LinearFilter.NoX.check(A, B, C, D, X):
            return LinearFilter.NoX(A, B, C, D, X)
        if LinearFilter.OneXScalar.check(A, B, C, D, X):
            return LinearFilter.OneXScalar(A, B, C, D, X)
        elif LinearFilter.OneX.check(A, B, C, D, X):
            return LinearFilter.OneX(A, B, C, D, X)
        elif LinearFilter.NoD.check(A, B, C, D, X):
            return LinearFilter.NoD(A, B, C, D, X)
        else:
            assert LinearFilter.General.check(A, B, C, D, X)
            return LinearFilter.General(A, B, C, D, X)

    class Step:
        """Abstract base class for LTI filtering step functions."""

        def __init__(self, A, B, C, D, X):
            if not self.check(A, B, C, D, X):
                raise ValidationError(
                    "Matrices do not meet the requirements for this Step",
                    attr="A,B,C,D,X",
                    obj=self,
                )
            self.A = A
            self.B = B
            self.C = C
            self.D = D
            self.X = X

        def __call__(self, t, signal):
            raise NotImplementedError("Step object must implement __call__")

        @classmethod
        def check(cls, A, B, C, D, X):
            if A.size == 0:
                return X.size == B.size == C.size == 0 and D.size == 1
            else:
                return (
                    A.shape[0] == A.shape[1] == B.shape[0] == C.shape[1]
                    and A.shape[0] == X.shape[0]
                    and C.shape[0] == B.shape[1] == 1
                    and D.size == 1
                )

    class NoX(Step):
        """Step for system with no state, only passthrough matrix (D)."""

        def __init__(self, A, B, C, D, X):
            super().__init__(A, B, C, D, X)
            self.d = D.item()

        def __call__(self, t, signal):
            return self.d * signal

        @classmethod
        def check(cls, A, B, C, D, X):
            return super().check(A, B, C, D, X) and A.size == 0

    class OneX(Step):
        """Step for systems with one state element and no passthrough (D)."""

        def __init__(self, A, B, C, D, X):
            super().__init__(A, B, C, D, X)
            self.a = A.item()
            self.b = C.item() * B.item()

        def __call__(self, t, signal):
            self.X *= self.a
            self.X += self.b * signal
            return self.X[0]

        @classmethod
        def check(cls, A, B, C, D, X):
            return super().check(A, B, C, D, X) and (len(A) == 1 and (D == 0).all())

    class OneXScalar(OneX):
        """Step for systems with one state element, no passthrough, and a size-1 input.

        Using the builtin float math improves performance.
        """

        def __call__(self, t, signal):
            self.X[:] = self.a * self.X.item() + self.b * signal.item()
            return self.X[0]

        @classmethod
        def check(cls, A, B, C, D, X):
            return super().check(A, B, C, D, X) and X.size == 1

    class NoD(Step):
        """Step for systems with no passthrough matrix (D).

        Implements::

            x[t] = A x[t-1] + B u[t]
            y[t] = C x[t]

        Note how the input has been advanced one step as compared with the
        General system below, to remove the unnecessary delay.
        """

        def __call__(self, t, signal):
            self.X[:] = np.dot(self.A, self.X) + self.B * signal
            return np.dot(self.C, self.X)[0]

        @classmethod
        def check(cls, A, B, C, D, X):
            return super().check(A, B, C, D, X) and (len(A) >= 1 and (D == 0).all())

    class General(Step):
        """Step for any LTI system with at least one state element (X).

        Implements::

            x[t+1] = A x[t] + B u[t]
            y[t] = C x[t] + D u[t]

        Use ``NoX`` for systems with no state elements.
        """

        def __call__(self, t, signal):
            Y = np.dot(self.C, self.X)[0] + self.D * signal
            self.X[:] = np.dot(self.A, self.X) + self.B * signal
            return Y

        @classmethod
        def check(cls, A, B, C, D, X):
            return super().check(A, B, C, D, X) and len(A) >= 1
Example #6
0
class DeltaRule(LearningRuleType):
    r"""Implementation of the Delta rule.

    By default, this implementation pretends the neurons are linear, and thus
    does not require the derivative of the postsynaptic neuron activation
    function. The derivative function, or a surrogate function, for the
    postsynaptic neurons can be provided in ``post_fn``.

    The update is given by:

        \delta W_ij = \eta a_j e_i f(u_i)

    where ``e_i`` is the input error in the postsynaptic neuron space,
    ``a_j`` is the output activity for presynaptic neuron j,
    ``u_i`` is the input for postsynaptic neuron i,
    and ``f`` is a provided function.

    Parameters
    ----------
    learning_rate : float
        A scalar indicating the rate at which weights will be adjusted.
    pre_tau : float
        Filter constant on the presynaptic output ``a_j``.
    post_fn : callable
        Function ``f`` to apply to the postsynaptic inputs ``u_i``. The
        default of ``None`` means the ``f(u_i)`` term is omitted.
    post_tau : float
        Filter constant on the postsynaptic input ``u_i``. This defaults to
        ``None`` because these should typically be filtered by the connection.
    """
    modifies = 'weights'
    probeable = ('delta', 'in', 'error', 'correction', 'pre', 'post')

    pre_tau = NumberParam('pre_tau', low=0, low_open=True)
    post_tau = NumberParam('post_tau', low=0, low_open=True, optional=True)
    post_fn = DeltaRuleFunctionParam('post_fn', optional=True)
    post_target = EnumParam('post_target', values=('in', 'out'))

    def __init__(self, learning_rate=1e-4, pre_tau=0.005,
                 post_fn=None, post_tau=None, post_target='in'):
        if learning_rate >= 1.0:
            warnings.warn("This learning rate is very high, and can result "
                          "in floating point errors from too much current.")
        self.pre_tau = pre_tau
        self.post_tau = post_tau
        self.post_fn = post_fn
        self.post_target = post_target
        super(DeltaRule, self).__init__(learning_rate, size_in='post')

    @property
    def _argreprs(self):
        args = []
        if self.learning_rate != 1e-4:
            args.append("learning_rate=%g" % self.learning_rate)
        if self.pre_tau != 0.005:
            args.append("pre_tau=%f" % self.pre_tau)
        if self.post_fn is not None:
            args.append("post_fn=%s" % self.post_fn.function)
        if self.post_tau is not None:
            args.append("post_tau=%f" % self.post_tau)
        if self.post_target != 'in':
            args.append("post_target=%s" % self.post_target)

        return args
Example #7
0
class _ConvolutionBase(Transform):
    """Abstract base class for Convolution and ConvolutionTranspose."""

    n_filters = IntParam("n_filters", low=1)
    input_shape = ChannelShapeParam("input_shape", low=1)
    kernel_size = ShapeParam("kernel_size", low=1)
    strides = ShapeParam("strides", low=1)
    padding = EnumParam("padding", values=("same", "valid"))
    channels_last = BoolParam("channels_last")
    init = DistOrArrayParam("init")
    groups = IntParam("groups", low=1)

    _param_init_order = ["channels_last", "input_shape"]

    def __init__(
        self,
        n_filters,
        input_shape,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding="valid",
        channels_last=True,
        init=Uniform(-1, 1),
        groups=1,
    ):
        super().__init__()

        self.n_filters = n_filters
        self.channels_last = channels_last  # must be set before input_shape
        self.input_shape = input_shape
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.init = init
        self.groups = groups

        if len(kernel_size) != self.dimensions:
            raise ValidationError(
                f"Kernel dimensions ({len(kernel_size)}) does not match "
                f"input dimensions ({self.dimensions})",
                attr="kernel_size",
            )
        if len(strides) != self.dimensions:
            raise ValidationError(
                f"Stride dimensions ({len(strides)}) does not match "
                f"input dimensions ({self.dimensions})",
                attr="strides",
            )
        if not isinstance(init, Distribution):
            if init.shape != self.kernel_shape:
                raise ValidationError(
                    f"Kernel shape {init.shape} does not match "
                    f"expected shape {self.kernel_shape}",
                    attr="init",
                )

        in_channels = self.input_shape.n_channels
        if groups > in_channels:
            raise ValidationError(
                f"Groups ({groups}) cannot be greater than "
                f"the number of input channels ({in_channels})",
                attr="groups",
            )
        if in_channels % groups != 0 or self.n_filters % groups != 0:
            raise ValidationError(
                f"Both the number of input channels ({in_channels}) and filters "
                f"({self.n_filters}) must be evenly divisible by ``groups`` ({groups})",
                attr="groups",
            )

    @property
    def _argreprs(self):
        argreprs = [
            f"n_filters={self.n_filters!r}",
            f"input_shape={self.input_shape.shape}",
        ]
        if self.kernel_size != (3, 3):
            argreprs.append(f"kernel_size={self.kernel_size!r}")
        if self.strides != (1, 1):
            argreprs.append(f"strides={self.strides!r}")
        if self.padding != "valid":
            argreprs.append(f"padding={self.padding!r}")
        if self.channels_last is not True:
            argreprs.append(f"channels_last={self.channels_last!r}")
        if self.groups != 1:
            argreprs.append(f"groups={self.groups!r}")
        return argreprs

    def sample(self, rng=np.random):
        if isinstance(self.init, Distribution):
            # we sample this way so that any variancescaling distribution based
            # on n/d is scaled appropriately
            kernel = [
                self.init.sample(
                    self.input_shape.n_channels // self.groups, self.n_filters, rng=rng
                )
                for _ in range(np.prod(self.kernel_size))
            ]
            kernel = np.reshape(kernel, self.kernel_shape)
        else:
            kernel = np.array(self.init, dtype=rc.float_dtype)
        return kernel

    @property
    def kernel_shape(self):
        """Full shape of kernel."""
        return self.kernel_size + (
            self.input_shape.n_channels // self.groups,
            self.n_filters,
        )

    @property
    def size_in(self):
        return self.input_shape.size

    @property
    def size_out(self):
        return self.output_shape.size

    @property
    def dimensions(self):
        """Dimensionality of convolution."""
        return self.input_shape.dimensions

    def _forward_shape(self, input_spatial_shape, n_filters):
        output_shape = np.array(input_spatial_shape, dtype=rc.float_dtype)
        if self.padding == "valid":
            output_shape -= self.kernel_size
            output_shape += 1
        output_shape /= self.strides
        output_shape = tuple(np.ceil(output_shape).astype(rc.int_dtype))

        return ChannelShape.from_space_and_channels(
            output_shape, n_filters, channels_last=self.channels_last
        )
Example #8
0
class VarianceScaling(Distribution):
    """Variance scaling distribution for weight initialization (analogous to
    TensorFlow ``init_ops.VarianceScaling`).

    Parameters
    ----------
    scale : float, optional
        overall scale on values
    mode : "fan_in" or "fan_out" or "fan_avg", optional
        whether to scale based on input or output dimensionality, or average of
        the two
    distribution: "uniform" or "normal", optional
        whether to use a uniform or normal distribution for weights
    """

    scale = NumberParam("scale", low=0)
    mode = EnumParam("mode", values=["fan_in", "fan_out", "fan_avg"])
    distribution = EnumParam("distribution", values=["uniform", "normal"])

    def __init__(self, scale=1, mode="fan_avg", distribution="uniform"):
        self.scale = scale
        self.mode = mode
        self.distribution = distribution

    def sample(self, n, d=None, rng=np.random):
        """Samples the distribution.

        Parameters
        ----------
        n : int
            Number samples to take.
        d : int or None, optional
            The number of dimensions to return. If this is an int, the return
            value will be of shape ``(n, d)``. If None, the return
            value will be of shape ``(n,)``.
        rng : `numpy.random.RandomState`, optional
            Random number generator state.

        Returns
        -------
        samples : (n,) or (n, d) array_like
            Samples as a 1d or 2d array depending on ``d``. The second
            dimension enumerates the dimensions of the process.
        """

        fan_in = n
        fan_out = 1 if d is None else d
        scale = self.scale
        if self.mode == "fan_in":
            scale /= fan_in
        elif self.mode == "fan_out":
            scale /= fan_out
        elif self.mode == "fan_avg":
            scale /= (fan_in + fan_out) / 2

        shape = (n, ) if d is None else (n, d)
        if self.distribution == "uniform":
            limit = np.sqrt(3.0 * scale)
            return rng.uniform(-limit, limit, size=shape)
        elif self.distribution == "normal":
            stddev = np.sqrt(scale)
            # TODO: use truncated normal distribution
            return rng.normal(scale=stddev, size=shape)
Example #9
0
class Conv2d(Process):
    """Perform 2-D (image) convolution on an input.

    Parameters
    ----------
    shape_in : 3-tuple (n_channels, height, width)
        Shape of the input images: channels, height, width.
    filters : array_like (n_filters, n_channels, f_height, f_width)
        Static filters to convolve with the input. Shape is number of filters,
        number of input channels, filter height, and filter width. Shape can
        also be (n_filters, height, width, n_channels, f_height, f_width)
        to apply different filters at each point in the image, where 'height'
        and 'width' are the input image height and width.
    biases : array_like (1,) or (n_filters,) or (n_filters, height, width)
        Biases to add to outputs. Can have one bias across the entire output
        space, one bias per filter, or a unique bias for each output pixel.
    strides : 2-tuple (vertical, horizontal) or int
        Spacing between filter placements. If an integer
        is provided, the same spacing is used in both dimensions.
    padding : 2-tuple (vertical, horizontal) or int
        Amount of zero-padding around the outside of the input image. Padding
        is applied to both sides, e.g. ``padding=(1, 0)`` will add one pixel
        of padding to the top and bottom, and none to the left and right.
    """

    shape_in = ShapeParam('shape_in', length=3, low=1)
    shape_out = ShapeParam('shape_out', length=3, low=1)
    strides = ShapeParam('strides', length=2, low=1)
    padding = ShapeParam('padding', length=2)
    filters = NdarrayParam('filters', shape=('...', ))
    biases = NdarrayParam('biases', shape=('...', ), optional=True)
    border = EnumParam('border', values=('floor', 'ceil'))

    def __init__(self,
                 shape_in,
                 filters,
                 biases=None,
                 strides=1,
                 padding=0,
                 border='ceil'):  # noqa: C901
        self.shape_in = shape_in
        self.filters = filters
        if self.filters.ndim not in [4, 6]:
            raise ValueError(
                "`filters` must have four or six dimensions "
                "(filters, [height, width,] channels, f_height, f_width)")
        if self.filters.shape[-3] != self.shape_in[0]:
            raise ValueError(
                "Filter channels (%d) and input channels (%d) must match" %
                (self.filters.shape[-3], self.shape_in[0]))
        if not all(s % 2 == 1 for s in self.filters.shape[-2:]):
            raise ValueError("Filter shapes must be odd (got %r)" %
                             (self.filters.shape[-2:], ))

        self.strides = strides if is_iterable(strides) else [strides] * 2
        self.padding = padding if is_iterable(padding) else [padding] * 2
        self.border = border

        nf = self.filters.shape[0]
        nxi, nxj = self.shape_in[1:]
        si, sj = self.filters.shape[-2:]
        pi, pj = self.padding
        sti, stj = self.strides
        rounder = np.ceil if self.border == 'ceil' else np.floor
        nyi = 1 + max(int(rounder(float(2 * pi + nxi - si) / sti)), 0)
        nyj = 1 + max(int(rounder(float(2 * pj + nxj - sj) / stj)), 0)
        self.shape_out = (nf, nyi, nyj)
        if self.filters.ndim == 6 and self.filters.shape[1:3] != (nyi, nyj):
            raise ValueError("Number of local filters %r must match out shape "
                             "%r" % (self.filters.shape[1:3], (nyi, nyj)))

        self.biases = biases if biases is not None else None
        if self.biases is not None:
            if self.biases.size == 1:
                self.biases.shape = (1, 1, 1)
            elif self.biases.size == np.prod(self.shape_out):
                self.biases.shape = self.shape_out
            elif self.biases.size == self.shape_out[0]:
                self.biases.shape = (self.shape_out[0], 1, 1)
            elif self.biases.size == np.prod(self.shape_out[1:]):
                self.biases.shape = (1, ) + self.shape_out[1:]
            else:
                raise ValueError(
                    "Biases size (%d) does not match output shape %s" %
                    (self.biases.size, self.shape_out))

        super(Conv2d, self).__init__(default_size_in=np.prod(self.shape_in),
                                     default_size_out=np.prod(self.shape_out))

    def make_step(self, shape_in, shape_out, dt, rng):
        assert np.prod(shape_in) == np.prod(self.shape_in)
        assert np.prod(shape_out) == np.prod(self.shape_out)
        shape_in, shape_out = self.shape_in, self.shape_out

        filters = self.filters
        local_filters = filters.ndim == 6
        biases = self.biases

        nc, nxi, nxj = shape_in
        nf, nyi, nyj = shape_out
        si, sj = filters.shape[-2:]
        pi, pj = self.padding
        sti, stj = self.strides

        def step_conv2d(t, x):
            x = x.reshape(-1, nc, nxi, nxj)
            n = x.shape[0]
            y = np.zeros((n, nf, nyi, nyj), dtype=x.dtype)

            for i in range(nyi):
                for j in range(nyj):
                    i0 = i * sti - pi
                    j0 = j * stj - pj
                    i1, j1 = i0 + si, j0 + sj
                    sli = slice(max(-i0, 0), min(nxi + si - i1, si))
                    slj = slice(max(-j0, 0), min(nxj + sj - j1, sj))
                    w = (filters[:, i, j, :, sli,
                                 slj] if local_filters else filters[:, :, sli,
                                                                    slj])
                    xij = x[:, :,
                            max(i0, 0):min(i1, nxi),
                            max(j0, 0):min(j1, nxj)]
                    y[:, :, i, j] = np.dot(xij.reshape(n, -1),
                                           w.reshape(nf, -1).T)

            if biases is not None:
                y += biases

            return y.ravel()

        return step_conv2d
Example #10
0
class Pool3(Process):
    """Perform 3-D (frames) pooling on an input.

    Currently only supports average pooling.
    """
    shape_in = TupleParam(length=4)
    shape_out = TupleParam(length=4)
    size = IntParam(low=1)
    depth_size = IntParam(low=1)
    stride = IntParam(low=1)
    temporal_stride = IntParam(low=1)
    kind = EnumParam(values=('avg', 'max'))

    def __init__(self, shape_in, size,depth_size, stride=None, kind='avg',temporal_stride=0):
        self.shape_in = shape_in
        self.size = size
        self.depth_size=depth_size
        self.temporal_stride=temporal_stride
        self.stride = stride if stride is not None else size
        self.kind = kind
        if self.stride > self.size:
            raise ValueError("Stride (%d) must be <= size (%d)" %
                             (self.stride, self.size))

        c, nxd, nxi, nxj = self.shape_in
        nyd = (nxd - 1) / self.temporal_stride + 1
        nyi = (nxi - 1) / self.stride + 1
        nyj = (nxj - 1) / self.stride + 1
        self.shape_out = (c, nyd,nyi, nyj)

        super(Pool3, self).__init__(
            default_size_in=np.prod(self.shape_in),
            default_size_out=np.prod(self.shape_out))

    def make_step(self, size_in, size_out, dt, rng):
        assert size_in == np.prod(self.shape_in)
        assert size_out == np.prod(self.shape_out)
        c, nxd, nxi, nxj = self.shape_in
        c, nyd, nyi, nyj = self.shape_out
        s = self.size
        depth_s=self.depth_size
        st = self.stride
        temp_st = self.temporal_stride
        kind = self.kind

        def step_pool3(t, x):
            x = x.reshape(c, nxd,nxi, nxj)
            y = np.zeros_like(x[:,::temp_st, ::st, ::st])
            n = np.zeros((nyd, nyi, nyj))
            assert y.shape[-3:] == (nyd, nyi, nyj)
            for k in range(depth_s):
                for i in range(s):
                    for j in range(s):
                        xkij = x[:,k::temp_st, i::st, j::st]
                        nk, ni, nj = xkij.shape[-3:]
                        if kind == 'max':
                            y[:,:nk, :ni, :nj] = np.maximum(y[:,:nk, :ni, :nj], xkij)
                        elif kind == 'avg':
                            y[:,:nk, :ni, :nj] += xkij
                            n[:nk, :ni, :nj] += 1
                        else:
                            raise NotImplementedError(kind)

            if kind == 'avg':
                y /= n

            return y.ravel()

        return step_pool3
Example #11
0
class VarianceScaling(Distribution):
    """Variance scaling distribution for weight initialization (analogous to
    ``tf.initializers.VarianceScaling``).

    Parameters
    ----------
    scale : float
        Overall scale on values.
    mode : "fan_in" or "fan_out" or "fan_avg"
        Whether to scale based on input or output dimensionality, or average of
        the two.
    distribution: "uniform" or "normal"
        Whether to use a uniform or truncated normal distribution for weights.
    """

    scale = NumberParam("scale", low=0)
    mode = EnumParam("mode", values=["fan_in", "fan_out", "fan_avg"])
    distribution = EnumParam("distribution", values=["uniform", "normal"])

    def __init__(self, scale=1, mode="fan_avg", distribution="uniform"):
        super().__init__()

        self.scale = scale
        self.mode = mode
        self.distribution = distribution

    def sample(self, n, d=None, rng=None):
        """Samples the distribution.

        Parameters
        ----------
        n : int
            Number samples to take.
        d : int or None
            The number of dimensions to return. If this is an int, the return
            value will be of shape ``(n, d)``. If None, the return
            value will be of shape ``(n,)``.
        rng : `~numpy.random.mtrand.RandomState`
            Random number generator state (if None, will use the default
            numpy random number generator).

        Returns
        -------
        samples : (n,) or (n, d) array_like
            Samples as a 1d or 2d array depending on ``d``. The second
            dimension enumerates the dimensions of the process.
        """

        if rng is None:
            rng = np.random
        fan_out = n
        fan_in = 1 if d is None else d
        scale = self.scale
        if self.mode == "fan_in":
            scale /= fan_in
        elif self.mode == "fan_out":
            scale /= fan_out
        elif self.mode == "fan_avg":
            scale /= (fan_in + fan_out) / 2

        shape = (n, ) if d is None else (n, d)
        if self.distribution == "uniform":
            limit = np.sqrt(3.0 * scale)
            return rng.uniform(-limit, limit, size=shape)
        elif self.distribution == "normal":
            stddev = np.sqrt(scale)
            return TruncatedNormal(stddev=stddev).sample(n, d, rng=rng)
        else:
            # note: this should be caught by the enumparam check
            raise NotImplementedError