Ejemplo n.º 1
0
    def predict(self, X, y=None, p=None):
        """

        Parameters
        ==========

        X : array_like, shape (n_samples, n_features)
            Stimulus design matrix.

        y : None or array_like, shape (n_samples, )
            Recorded response. Needed when post-spike filter is fitted.

        p : None or dict
            Model parameters. Only needed if model performance is monitored
            during training.

        """

        extra = {'X': X, 'y': y}
        if self.h_mle is not None:

            if y is None:
                raise ValueError(
                    '`y` is needed for calculating response history.')

            yh = jnp.array(
                build_design_matrix(extra['y'][:, jnp.newaxis],
                                    self.yh.shape[1],
                                    shift=self.shift_h))
            extra.update({'yh': yh})

        params = self.p_opt if p is None else p
        y_pred = self.forwardpass(params, extra=extra)

        return y_pred
Ejemplo n.º 2
0
    def initialize_history_filter(self, dims, shift=1):
        """
        Parameters
        ==========

        dims : list or array_like, shape (ndims, )
            Dimensions or shape of the response-history filter. It should be 1D [nt, ]

        shift : int
            Should be 1 or larger. 

        """
        y = self.y
        yh = jnp.array(
            build_design_matrix(y[:, jnp.newaxis], dims, shift=shift))
        self.shift_h = shift
        self.yh = jnp.array(yh)
        self.h_mle = jnp.linalg.solve(yh.T @ yh, yh.T @ y)
Ejemplo n.º 3
0
def noise2d(n_samples,
            dims,
            shift=0,
            beta=None,
            noise='gaussian',
            design_matrix=False,
            random_seed=2046):
    """
    2D noise. Gaussian white noise or checkerboard binary noise.
    """

    if len(dims) == 2:
        nt = None
    elif len(dims) == 3:
        nt = dims[0]
        dims = dims[1:]
    else:
        raise NotImplementedError(len(dims))

    if noise == 'gaussian':
        X = np.random.randn(n_samples, *dims)
    elif noise == 'binary':
        X = np.random.choice([-1, 1], size=[n_samples, *dims])
    else:
        raise NotImplementedError(noise)

    if beta is not None:
        if noise != 'gaussian':
            raise ValueError('1/f noise only applis to Gaussian noise.')

        X = colornoise2d(n_samples=n_samples,
                         dims=dims,
                         beta=beta,
                         phi=X,
                         random_seed=random_seed)
        X = (X - X.mean()) / X.std()

    if design_matrix:
        if nt is None:
            pass
        else:
            X = build_design_matrix(X, nt, shift)

    return X
Ejemplo n.º 4
0
    def predict(self, X, y=None, p=None):
        """

        Parameters
        ==========

        X : array_like, shape (n_samples, n_features)
            Stimulus design matrix.

        y : None or array_like, shape (n_samples, )
            Recorded response. Needed when post-spike filter is fitted.

        p : None or dict
            Model parameters. Only needed if model performance is monitored
            during training.

        """

        if self.n_c > 1:
            XS = jnp.dstack([X[:, :, i] @ self.S for i in range(self.n_c)
                             ]).reshape(X.shape[0], -1)
        else:
            XS = X @ self.S

        extra = {'X': X, 'XS': XS, 'y': y}

        if self.h_spl is not None:

            if y is None:
                raise ValueError(
                    '`y` is needed for calculating response history.')

            yh = jnp.array(
                build_design_matrix(extra['y'][:, jnp.newaxis],
                                    self.Sh.shape[0],
                                    shift=self.shift_h))
            yS = yh @ self.Sh
            extra.update({'yS': yS})

        params = self.p_opt if p is None else p
        y_pred = self.forwardpass(params, extra=extra)

        return y_pred
Ejemplo n.º 5
0
    def initialize_history_filter(self, dims, df, smooth='cr', shift=1):
        """

        Parameters
        ==========

        dims : list or array_like, shape (ndims, )
            Dimensions or shape of the response-history filter. It should be 1D [nt, ]

        df : list or array_list
            Number of basis.

        smooth : str
            Type of basis.

        shift : int
            Should be 1 or larger. 

        """

        y = self.y
        Sh = jnp.array(build_spline_matrix([
            dims,
        ], [
            df,
        ], smooth))  # for h
        yh = jnp.array(
            build_design_matrix(self.y[:, jnp.newaxis],
                                Sh.shape[0],
                                shift=shift))
        yS = yh @ Sh

        self.shift_h = shift
        self.yh = jnp.array(yh)
        self.Sh = Sh  # spline basis for spike-history
        self.yS = yS
        self.bh_spl = jnp.linalg.solve(yS.T @ yS, yS.T @ y)
        self.h_spl = Sh @ self.bh_spl
Ejemplo n.º 6
0
def flickerbar(n_samples,
               dims,
               shift=0,
               beta=None,
               noise='gaussian',
               design_matrix=False,
               random_seed=2046):
    """
    Flicker bar.
    """

    nt, nx = dims

    np.random.seed(random_seed)

    if noise == 'gaussian':
        X = np.random.randn(n_samples, nx)
    elif noise == 'binary':
        X = np.random.choice([-1, 1], size=[n_samples, nx])
    else:
        raise NotImplementedError(noise)

    if beta is not None:
        if noise != 'gaussian':
            raise ValueError('1/f noise only applis to Gaussian noise.')
        X = colornoise1d(n_samples=n_samples,
                         dims=nx,
                         beta=beta,
                         phi=X,
                         random_seed=random_seed)
        X = (X - X.mean()) / X.std()

    if design_matrix:
        X = build_design_matrix(X, nt, shift)

    return X
Ejemplo n.º 7
0
def flickerfield(n_samples,
                 dims=None,
                 shift=0,
                 beta=None,
                 noise='gaussian',
                 design_matrix=False,
                 random_seed=2046):
    """
    Full field flicker.
    """

    np.random.seed(random_seed)

    if noise == 'gaussian':
        X = np.random.randn(n_samples)[:, np.newaxis]
    elif noise == 'binary':
        X = np.random.choice([-1, 1], size=n_samples)[:, np.newaxis]
    else:
        raise NotImplementedError(noise)

    if beta is not None:
        X = colornoise1d(1,
                         dims=n_samples,
                         beta=beta,
                         phi=X.flatten(),
                         random_seed=2046)[:, np.newaxis]
        X = (X - X.mean()) / X.std()

    if design_matrix:

        if dims is None:
            raise ValueError(
                '`dims` is needed for building stimulus design matrix.')
        X = build_design_matrix(X, dims, shift)

    return X
Ejemplo n.º 8
0
    def fit(self,
            p0=None,
            extra=None,
            num_subunits=2,
            num_epochs=1,
            num_iters=3000,
            initialize='random',
            metric=None,
            alpha=1,
            beta=0.05,
            fit_linear_filter=True,
            fit_intercept=True,
            fit_R=True,
            fit_history_filter=False,
            fit_nonlinearity=False,
            step_size=1e-2,
            tolerance=10,
            verbose=100,
            random_seed=2046,
            return_model=None):

        self.metric = metric

        self.alpha = alpha  # elastic net parameter (1=L1, 0=L2)
        self.beta = beta  # elastic net parameter - global penalty weight

        self.n_s = num_subunits
        self.num_iters = num_iters

        self.fit_linear_filter = fit_linear_filter
        self.fit_history_filter = fit_history_filter
        self.fit_nonlinearity = fit_nonlinearity
        self.fit_intercept = fit_intercept
        self.fit_R = fit_R

        # initialize parameters
        if p0 is None:
            p0 = {}

        dict_keys = p0.keys()
        if 'b' not in dict_keys:
            if initialize == 'random':  # not necessary, but for consistency with others.
                key = random.PRNGKey(random_seed)
                b0 = 0.01 * random.normal(
                    key, shape=(self.n_b * self.n_c * self.n_s, )).flatten()
                p0.update({'b': b0})

        if 'intercept' not in dict_keys:
            p0.update({'intercept': jnp.zeros(1)})

        if 'R' not in dict_keys:
            p0.update({'R': jnp.array([1.])})

        if 'bh' not in dict_keys:
            try:
                p0.update({'bh': self.bh_spl})
            except:
                p0.update({'bh': None})

        if 'nl_params' not in dict_keys:
            if self.nl_params is not None:
                p0.update({
                    'nl_params': [self.nl_params for _ in range(self.n_s + 1)]
                })
            else:
                p0.update({'nl_params': [None for _ in range(self.n_s + 1)]})

        if extra is not None:

            if self.n_c > 1:
                XS_ext = jnp.dstack([
                    extra['X'][:, :, i] @ self.S for i in range(self.n_c)
                ]).reshape(extra['X'].shape[0], -1)
                extra.update({'XS': XS_ext})
            else:
                extra.update({'XS': extra['X'] @ self.S})

            if self.h_spl is not None:
                yh = jnp.array(
                    build_design_matrix(extra['y'][:, jnp.newaxis],
                                        self.Sh.shape[0],
                                        shift=1))
                yS = yh @ self.Sh
                extra.update({'yS': yS})

            extra = {key: jnp.array(extra[key]) for key in extra.keys()}

        self.p0 = p0
        self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters,
                                          metric, step_size, tolerance,
                                          verbose, return_model)

        self.R = self.p_opt['R'] if fit_R else jnp.array([1.])

        if fit_linear_filter:
            self.b_opt = self.p_opt['b']

            if self.n_c > 1:
                self.w_opt = jnp.stack([(self.S @ self.b_opt.reshape(
                    self.n_b, self.n_c, self.n_s)[:, :, i])
                                        for i in range(self.n_s)],
                                       axis=-1)
            else:
                self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_s)

        if fit_history_filter:
            self.bh_opt = self.p_opt['bh']
            self.h_opt = self.Sh @ self.bh_opt

        if fit_intercept:
            self.intercept = self.p_opt['intercept']

        if fit_nonlinearity:
            self.nl_params_opt = self.p_opt['nl_params']
Ejemplo n.º 9
0
    def fit(self,
            p0=None,
            extra=None,
            initialize='random',
            num_epochs=1,
            num_iters=5,
            metric=None,
            alpha=1,
            beta=0.05,
            fit_linear_filter=True,
            fit_intercept=True,
            fit_R=True,
            fit_history_filter=False,
            fit_nonlinearity=False,
            step_size=1e-2,
            tolerance=10,
            verbose=1,
            random_seed=2046,
            return_model=None):
        """

        Parameters
        ==========

        p0 : dict
            * 'b': Initial spline coefficients.
            * 'bh': Initial response history filter coefficients

        extra : None or dict {'X': X_dev, 'y': y_dev}
            Development set.

        initialize : None or str
            Parametric initialization.
            * if `initialize=None`, `w` will be initialized by STA.
            * if `initialize='random'`, `w` will be randomly initialized.

        num_iters : int
            Max number of optimization iterations.

        metric : None or str
            Extra cross-validation metric. Default is `None`. Or
            * 'mse': mean squared error
            * 'r2': R2 score
            * 'corrcoef': Correlation coefficient

        alpha : float, from 0 to 1.
            Elastic net parameter, balance between L1 and L2 regularization.
            * 0.0 -> only L2
            * 1.0 -> only L1

        beta : float
            Elastic net parameter, overall weight of regularization.

        step_size : float
            Initial step size for JAX optimizer (ADAM).

        tolerance : int
            Set early stop tolerance. Optimization stops when cost (dev) monotonically
            increases or cost (train) stop increases for tolerance=n steps. 
            If `tolerance=0`, then early stop is not used.

        verbose: int
            When `verbose=0`, progress is not printed. When `verbose=n`,
            progress will be printed in every n steps.

        """

        self.metric = metric  # metric for cross-validation and prediction

        self.alpha = alpha
        self.beta = beta  # elastic net parameter - global penalty weight for linear filter
        self.num_iters = num_iters

        self.fit_R = fit_R
        self.fit_linear_filter = fit_linear_filter
        self.fit_history_filter = fit_history_filter
        self.fit_nonlinearity = fit_nonlinearity
        self.fit_intercept = fit_intercept

        # initialize parameters
        if p0 is None:
            p0 = {}

        dict_keys = p0.keys()
        if 'w' not in dict_keys:
            if initialize is None:
                p0.update({'w': self.w_sta})
            elif initialize == 'random':
                key = random.PRNGKey(random_seed)
                w0 = 0.01 * random.normal(
                    key, shape=(self.w_sta.shape[0], )).flatten()
                p0.update({'w': w0})

        if 'intercept' not in dict_keys:
            p0.update({'intercept': jnp.array([0.])})

        if 'R' not in dict_keys and self.fit_R:
            p0.update({'R': jnp.array([1.])})

        if 'h' not in dict_keys:
            if initialize is None and self.h_mle is not None:
                p0.update({'h': self.h_mle})

            elif initialize == 'random' and self.h_mle is not None:
                key = random.PRNGKey(random_seed)
                h0 = 0.01 * random.normal(
                    key, shape=(self.h_mle.shape[0], )).flatten()
                p0.update({'h': h0})
            else:
                p0.update({'h': None})

        if 'nl_params' not in dict_keys:
            if self.nl_params is not None:
                p0.update({'nl_params': self.nl_params})
            else:
                p0.update({'nl_params': None})

        if extra is not None:

            if self.h_mle is not None:
                yh = jnp.array(
                    build_design_matrix(extra['y'][:, jnp.newaxis],
                                        self.yh.shape[1],
                                        shift=1))
                extra.update({'yh': yh})

            extra = {key: jnp.array(extra[key]) for key in extra.keys()}

        # store optimized parameters
        self.p0 = p0
        self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters,
                                          metric, step_size, tolerance,
                                          verbose, return_model)
        self.R = self.p_opt['R'] if fit_R else jnp.array([1.])

        if fit_linear_filter:
            self.w_opt = self.p_opt['w']

        if fit_history_filter:
            self.h_opt = self.p_opt['h']

        if fit_nonlinearity:
            self.nl_params_opt = self.p_opt['nl_params']

        if fit_intercept:
            self.intercept = self.p_opt['intercept']
Ejemplo n.º 10
0
    def fit(self,
            p0=None,
            extra=None,
            initialize='random',
            num_epochs=1,
            num_iters=3000,
            metric=None,
            alpha=1,
            beta=0.05,
            fit_linear_filter=True,
            fit_intercept=True,
            fit_R=True,
            fit_history_filter=False,
            fit_nonlinearity=False,
            step_size=1e-2,
            tolerance=10,
            verbose=100,
            random_seed=2046,
            return_model=None):
        """

        Parameters
        ==========

        p0 : dict
            * 'b': Initial spline coefficients.
            * 'bh': Initial response history filter coefficients

        initialize : None or str
            Parametric initialization.
            * if `initialize=None`, `b` will be initialized by b_spl.
            * if `initialize='random'`, `b` will be randomly initialized.

        num_iters : int
            Max number of optimization iterations.

        metric : None or str
            Extra cross-validation metric. Default is `None`. Or
            * 'mse': mean squared error
            * 'r2': R2 score
            * 'corrcoef': Correlation coefficient

        alpha : float, from 0 to 1.
            Elastic net parameter, balance between L1 and L2 regularization.
            * 0.0 -> only L2
            * 1.0 -> only L1

        beta : float
            Elastic net parameter, overall weight of regularization for receptive field.

        step_size : float
            Initial step size for JAX optimizer.

        tolerance : int
            Set early stop tolerance. Optimization stops when cost monotonically
            increases or stop increases for tolerance=n steps.

        verbose: int
            When `verbose=0`, progress is not printed. When `verbose=n`,
            progress will be printed in every n steps.

        """

        self.metric = metric

        self.alpha = alpha
        self.beta = beta  # elastic net parameter - global penalty weight for linear filter
        self.num_iters = num_iters

        self.fit_R = fit_R
        self.fit_linear_filter = fit_linear_filter
        self.fit_history_filter = fit_history_filter
        self.fit_nonlinearity = fit_nonlinearity
        self.fit_intercept = fit_intercept

        # initial parameters

        if p0 is None:
            p0 = {}

        dict_keys = p0.keys()
        if 'b' not in dict_keys:
            if initialize is None:
                p0.update({'b': self.b_spl})
            else:
                if initialize == 'random':
                    key = random.PRNGKey(random_seed)
                    b0 = 0.01 * random.normal(
                        key, shape=(self.n_b * self.n_c, )).flatten()
                    p0.update({'b': b0})

        if 'intercept' not in dict_keys:
            p0.update({'intercept': jnp.array([0.])})

        if 'R' not in dict_keys:
            p0.update({'R': jnp.array([1.])})

        if 'bh' not in dict_keys:
            if initialize is None and self.bh_spl is not None:
                p0.update({'bh': self.bh_spl})
            elif initialize == 'random' and self.bh_spl is not None:
                key = random.PRNGKey(random_seed)
                bh0 = 0.01 * random.normal(key, shape=(len(
                    self.bh_spl), )).flatten()
                p0.update({'bh': bh0})
            else:
                p0.update({'bh': None})

        if 'nl_params' not in dict_keys:
            if self.nl_params is not None:
                p0.update({'nl_params': self.nl_params})
            else:
                p0.update({'nl_params': None})

        if extra is not None:

            if self.n_c > 1:
                XS_ext = jnp.dstack([
                    extra['X'][:, :, i] @ self.S for i in range(self.n_c)
                ]).reshape(extra['X'].shape[0], -1)
                extra.update({'XS': XS_ext})
            else:
                extra.update({'XS': extra['X'] @ self.S})

            if self.h_spl is not None:
                yh_ext = jnp.array(
                    build_design_matrix(extra['y'][:, jnp.newaxis],
                                        self.Sh.shape[0],
                                        shift=1))
                yS_ext = yh_ext @ self.Sh
                extra.update({'yS': yS_ext})

            extra = {key: jnp.array(extra[key]) for key in extra.keys()}

            self.extra = extra  # store for cross-validation

        # store optimized parameters
        self.p0 = p0
        self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters,
                                          metric, step_size, tolerance,
                                          verbose, return_model)
        self.R = self.p_opt['R'] if fit_R else jnp.array([1.])

        if fit_linear_filter:
            self.b_opt = self.p_opt['b']  # optimized RF basis coefficients
            if self.n_c > 1:
                self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_c)
            else:
                self.w_opt = self.S @ self.b_opt  # optimized RF

        if fit_history_filter:
            self.bh_opt = self.p_opt['bh']
            self.h_opt = self.Sh @ self.bh_opt

        if fit_nonlinearity:
            self.nl_params_opt = self.p_opt['nl_params']

        if fit_intercept:
            self.intercept = self.p_opt['intercept']
Ejemplo n.º 11
0
df = [i - 2 for i in dims]
stim = f(sample['stim'][:,0,crop[2]:crop[3], crop[0]:crop[1]]).detach().cpu().clone().numpy()
robs = sample['robs'].detach().cpu().numpy()
del gd


# dims = [25, 20, 15]

#%%
from rfest import ALD
#%%
robs.shape
#%%
# stim.shape

X = build_design_matrix(stim, dims[0])
X.shape
#%%
import copy
dt = 1/240

df = [5, 15, 15]
sigma0 = [1.]
rho0 = [1.]
params_t0 = [3., 20., 1., 1.] # taus, nus, tauf, nuf
params_y0 = [3., 20., 1., 1.]
params_x0 = [3., 20., 1., 1.]
p0 = sigma0 + rho0 + params_t0 + params_y0 + params_x0


lsgs = []
Ejemplo n.º 12
0
    def fit(self,
            p0=None,
            extra=None,
            num_subunits=2,
            num_epochs=1,
            num_iters=3000,
            metric=None,
            alpha=1,
            beta=0.05,
            fit_linear_filter=True,
            fit_intercept=True,
            fit_R=True,
            fit_history_filter=False,
            fit_nonlinearity=False,
            step_size=1e-2,
            tolerance=10,
            verbose=100,
            random_seed=2046,
            return_model=None):

        self.metric = metric

        self.alpha = alpha  # elastic net parameter (1=L1, 0=L2)
        self.beta = beta  # elastic net parameter - global penalty weight for linear filter

        self.n_s = num_subunits
        self.num_iters = num_iters

        self.fit_R = fit_R
        self.fit_linear_filter = fit_linear_filter
        self.fit_history_filter = fit_history_filter
        self.fit_nonlinearity = fit_nonlinearity
        self.fit_intercept = fit_intercept

        if extra is not None:

            if self.h_mle is not None:
                yh = jnp.array(
                    build_design_matrix(extra['y'][:, jnp.newaxis],
                                        self.yh.shape[1],
                                        shift=1))
                extra.update({'yh': yh})

            extra = {key: jnp.array(extra[key]) for key in extra.keys()}

        # initialize parameters
        if p0 is None:
            p0 = {}

        dict_keys = p0.keys()

        if 'w' not in dict_keys:
            key = random.PRNGKey(random_seed)
            w0 = 0.01 * random.normal(
                key,
                shape=(self.n_features * self.n_c * self.n_s, )).flatten()
            p0.update({'w': w0})

        if 'intercept' not in dict_keys:
            p0.update({'intercept': jnp.array([0.])})

        if 'R' not in dict_keys and self.fit_R:
            p0.update({'R': jnp.array([1.])})

        if 'h' not in dict_keys:
            try:
                p0.update({'h': self.h_mle})
            except:
                p0.update({'h': None})

        if 'nl_params' not in dict_keys:
            if self.nl_params is not None:
                p0.update({
                    'nl_params': [self.nl_params for _ in range(self.n_s + 1)]
                })
            else:
                p0.update({'nl_params': [None for _ in range(self.n_s + 1)]})

        self.p0 = p0
        self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters,
                                          metric, step_size, tolerance,
                                          verbose, return_model)
        self.R = self.p_opt['R']

        if fit_linear_filter:
            if self.n_c > 1:
                self.w_opt = self.p_opt['w'].reshape(self.n_features, self.n_c,
                                                     self.n_s)
            else:
                self.w_opt = self.p_opt['w'].reshape(self.n_features, self.n_s)

        if fit_history_filter:
            self.h_opt = self.p_opt['h']

        if fit_nonlinearity:
            self.nl_params_opt = self.p_opt['nl_params']

        if fit_intercept:
            self.intercept = self.p_opt['intercept']
Ejemplo n.º 13
0
    def add_design_matrix(self,
                          X,
                          dims=None,
                          df=None,
                          smooth=None,
                          lag=True,
                          filter_nonlinearity='none',
                          kind='train',
                          name='stimulus',
                          shift=0,
                          burn_in=None):
        """
        Add input design matrix to the model.

        Parameters
        ----------

        X: jnp.array, shape=(n_samples, ) or (n_samples, n_pixels)
            Original input.

        dims: int, or list / jnp.array, shape=dim_t, or (dim_t, dim_x, dim_y)
            Filter shape.

        df: None, int, or list / jnp.array
            Number of spline bases. Should be the same shape as dims.

        smooth: None, or str
            Type of spline bases. If None, no basis is used.

        lag: bool
            If True, the design matrix will be build based on the dims[0].
            If False, a instantaneous RF will be fitted.

        filter_nonlinearity: str
            Nonlinearity for the stimulus filter.

        kind: str
            Dataset type, should be one of `train` (training set),
            `dev` (validation set) or `test` (testing set).

        name: str
            Name of the corresponding filter.
            A receptive field (stimulus) filter should have `stimulus` in the name.
            A response-history filter should have `history` in the name.

        shift: int
            Time offset for the design matrix, positive number will shift the design
            matrix to the past, negative number will shift it to the future.

        burn_in: int or None
            Number of samples / frames to be ignored for prediction.
            (Because the first few frames in the design matrix are full of zero, which
            tend to predict poorly.)

        """

        # check X shape
        if len(X.shape) == 1:
            X = X[:, jnp.newaxis].astype(self.dtype)
        else:
            X = X.astype(self.dtype)

        if kind not in self.X:
            self.X.update({kind: {}})

        if kind == 'train':
            self.filter_nonlinearity[name] = filter_nonlinearity
            self.filter_names.append(name)

            dims = dims if type(dims) is not int else [
                dims,
            ]
            self.dims[name] = dims
            self.shift[name] = shift
        else:
            dims = self.dims[name]
            shift = self.shift[name]

        if self.burn_in is None:  # if exists, ignore
            self.burn_in = dims[
                0] - 1 if burn_in is None else burn_in  # number of first few frames to ignore

        if lag:
            self.X[kind][name] = build_design_matrix(
                X, dims[0], shift=shift, dtype=self.dtype)[self.burn_in:]
        else:
            self.burn_in = 0
            self.X[kind][
                name] = X  # if not time lag, shouldn't it also be no burn in?
            # TODO: might need different handlings for instantaneous RF.
            # conflict: history filter burned-in but the stimulus filter didn't

        if smooth is None:
            # if train set exists and used spline as basis
            # automatically apply the same basis for dev/test set
            if name in self.S:
                if kind not in self.XS:
                    self.XS.update({kind: {}})
                S = self.S[name]
                self.XS[kind][name] = self.X[kind][name] @ S

            elif kind == 'test':
                if kind not in self.XS:
                    self.XS.update({kind: {}})
                if self.num_subunits is not None and self.num_subunits > 1:
                    S = self.S['stimulus_s0']
                else:
                    S = self.S[name]
                self.XS[kind][name] = self.X[kind][name] @ S

            else:
                if kind == 'train':
                    self.n_features[name] = self.X['train'][name].shape[1]

        else:  # use spline
            if kind not in self.XS:
                self.XS.update({kind: {}})

            self.df[name] = df if type(df) is not int else [
                df,
            ]
            S = build_spline_matrix(dims=dims,
                                    df=self.df[name],
                                    smooth=smooth,
                                    dtype=self.dtype)
            self.S[name] = S

            XS = self.X[kind][name] @ S
            self.XS[kind][name] = XS

            if kind == 'train':
                self.n_features[name] = self.XS['train'][name].shape[1]