def predict(self, X, y=None, p=None): """ Parameters ========== X : array_like, shape (n_samples, n_features) Stimulus design matrix. y : None or array_like, shape (n_samples, ) Recorded response. Needed when post-spike filter is fitted. p : None or dict Model parameters. Only needed if model performance is monitored during training. """ extra = {'X': X, 'y': y} if self.h_mle is not None: if y is None: raise ValueError( '`y` is needed for calculating response history.') yh = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.yh.shape[1], shift=self.shift_h)) extra.update({'yh': yh}) params = self.p_opt if p is None else p y_pred = self.forwardpass(params, extra=extra) return y_pred
def initialize_history_filter(self, dims, shift=1): """ Parameters ========== dims : list or array_like, shape (ndims, ) Dimensions or shape of the response-history filter. It should be 1D [nt, ] shift : int Should be 1 or larger. """ y = self.y yh = jnp.array( build_design_matrix(y[:, jnp.newaxis], dims, shift=shift)) self.shift_h = shift self.yh = jnp.array(yh) self.h_mle = jnp.linalg.solve(yh.T @ yh, yh.T @ y)
def noise2d(n_samples, dims, shift=0, beta=None, noise='gaussian', design_matrix=False, random_seed=2046): """ 2D noise. Gaussian white noise or checkerboard binary noise. """ if len(dims) == 2: nt = None elif len(dims) == 3: nt = dims[0] dims = dims[1:] else: raise NotImplementedError(len(dims)) if noise == 'gaussian': X = np.random.randn(n_samples, *dims) elif noise == 'binary': X = np.random.choice([-1, 1], size=[n_samples, *dims]) else: raise NotImplementedError(noise) if beta is not None: if noise != 'gaussian': raise ValueError('1/f noise only applis to Gaussian noise.') X = colornoise2d(n_samples=n_samples, dims=dims, beta=beta, phi=X, random_seed=random_seed) X = (X - X.mean()) / X.std() if design_matrix: if nt is None: pass else: X = build_design_matrix(X, nt, shift) return X
def predict(self, X, y=None, p=None): """ Parameters ========== X : array_like, shape (n_samples, n_features) Stimulus design matrix. y : None or array_like, shape (n_samples, ) Recorded response. Needed when post-spike filter is fitted. p : None or dict Model parameters. Only needed if model performance is monitored during training. """ if self.n_c > 1: XS = jnp.dstack([X[:, :, i] @ self.S for i in range(self.n_c) ]).reshape(X.shape[0], -1) else: XS = X @ self.S extra = {'X': X, 'XS': XS, 'y': y} if self.h_spl is not None: if y is None: raise ValueError( '`y` is needed for calculating response history.') yh = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.Sh.shape[0], shift=self.shift_h)) yS = yh @ self.Sh extra.update({'yS': yS}) params = self.p_opt if p is None else p y_pred = self.forwardpass(params, extra=extra) return y_pred
def initialize_history_filter(self, dims, df, smooth='cr', shift=1): """ Parameters ========== dims : list or array_like, shape (ndims, ) Dimensions or shape of the response-history filter. It should be 1D [nt, ] df : list or array_list Number of basis. smooth : str Type of basis. shift : int Should be 1 or larger. """ y = self.y Sh = jnp.array(build_spline_matrix([ dims, ], [ df, ], smooth)) # for h yh = jnp.array( build_design_matrix(self.y[:, jnp.newaxis], Sh.shape[0], shift=shift)) yS = yh @ Sh self.shift_h = shift self.yh = jnp.array(yh) self.Sh = Sh # spline basis for spike-history self.yS = yS self.bh_spl = jnp.linalg.solve(yS.T @ yS, yS.T @ y) self.h_spl = Sh @ self.bh_spl
def flickerbar(n_samples, dims, shift=0, beta=None, noise='gaussian', design_matrix=False, random_seed=2046): """ Flicker bar. """ nt, nx = dims np.random.seed(random_seed) if noise == 'gaussian': X = np.random.randn(n_samples, nx) elif noise == 'binary': X = np.random.choice([-1, 1], size=[n_samples, nx]) else: raise NotImplementedError(noise) if beta is not None: if noise != 'gaussian': raise ValueError('1/f noise only applis to Gaussian noise.') X = colornoise1d(n_samples=n_samples, dims=nx, beta=beta, phi=X, random_seed=random_seed) X = (X - X.mean()) / X.std() if design_matrix: X = build_design_matrix(X, nt, shift) return X
def flickerfield(n_samples, dims=None, shift=0, beta=None, noise='gaussian', design_matrix=False, random_seed=2046): """ Full field flicker. """ np.random.seed(random_seed) if noise == 'gaussian': X = np.random.randn(n_samples)[:, np.newaxis] elif noise == 'binary': X = np.random.choice([-1, 1], size=n_samples)[:, np.newaxis] else: raise NotImplementedError(noise) if beta is not None: X = colornoise1d(1, dims=n_samples, beta=beta, phi=X.flatten(), random_seed=2046)[:, np.newaxis] X = (X - X.mean()) / X.std() if design_matrix: if dims is None: raise ValueError( '`dims` is needed for building stimulus design matrix.') X = build_design_matrix(X, dims, shift) return X
def fit(self, p0=None, extra=None, num_subunits=2, num_epochs=1, num_iters=3000, initialize='random', metric=None, alpha=1, beta=0.05, fit_linear_filter=True, fit_intercept=True, fit_R=True, fit_history_filter=False, fit_nonlinearity=False, step_size=1e-2, tolerance=10, verbose=100, random_seed=2046, return_model=None): self.metric = metric self.alpha = alpha # elastic net parameter (1=L1, 0=L2) self.beta = beta # elastic net parameter - global penalty weight self.n_s = num_subunits self.num_iters = num_iters self.fit_linear_filter = fit_linear_filter self.fit_history_filter = fit_history_filter self.fit_nonlinearity = fit_nonlinearity self.fit_intercept = fit_intercept self.fit_R = fit_R # initialize parameters if p0 is None: p0 = {} dict_keys = p0.keys() if 'b' not in dict_keys: if initialize == 'random': # not necessary, but for consistency with others. key = random.PRNGKey(random_seed) b0 = 0.01 * random.normal( key, shape=(self.n_b * self.n_c * self.n_s, )).flatten() p0.update({'b': b0}) if 'intercept' not in dict_keys: p0.update({'intercept': jnp.zeros(1)}) if 'R' not in dict_keys: p0.update({'R': jnp.array([1.])}) if 'bh' not in dict_keys: try: p0.update({'bh': self.bh_spl}) except: p0.update({'bh': None}) if 'nl_params' not in dict_keys: if self.nl_params is not None: p0.update({ 'nl_params': [self.nl_params for _ in range(self.n_s + 1)] }) else: p0.update({'nl_params': [None for _ in range(self.n_s + 1)]}) if extra is not None: if self.n_c > 1: XS_ext = jnp.dstack([ extra['X'][:, :, i] @ self.S for i in range(self.n_c) ]).reshape(extra['X'].shape[0], -1) extra.update({'XS': XS_ext}) else: extra.update({'XS': extra['X'] @ self.S}) if self.h_spl is not None: yh = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.Sh.shape[0], shift=1)) yS = yh @ self.Sh extra.update({'yS': yS}) extra = {key: jnp.array(extra[key]) for key in extra.keys()} self.p0 = p0 self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters, metric, step_size, tolerance, verbose, return_model) self.R = self.p_opt['R'] if fit_R else jnp.array([1.]) if fit_linear_filter: self.b_opt = self.p_opt['b'] if self.n_c > 1: self.w_opt = jnp.stack([(self.S @ self.b_opt.reshape( self.n_b, self.n_c, self.n_s)[:, :, i]) for i in range(self.n_s)], axis=-1) else: self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_s) if fit_history_filter: self.bh_opt = self.p_opt['bh'] self.h_opt = self.Sh @ self.bh_opt if fit_intercept: self.intercept = self.p_opt['intercept'] if fit_nonlinearity: self.nl_params_opt = self.p_opt['nl_params']
def fit(self, p0=None, extra=None, initialize='random', num_epochs=1, num_iters=5, metric=None, alpha=1, beta=0.05, fit_linear_filter=True, fit_intercept=True, fit_R=True, fit_history_filter=False, fit_nonlinearity=False, step_size=1e-2, tolerance=10, verbose=1, random_seed=2046, return_model=None): """ Parameters ========== p0 : dict * 'b': Initial spline coefficients. * 'bh': Initial response history filter coefficients extra : None or dict {'X': X_dev, 'y': y_dev} Development set. initialize : None or str Parametric initialization. * if `initialize=None`, `w` will be initialized by STA. * if `initialize='random'`, `w` will be randomly initialized. num_iters : int Max number of optimization iterations. metric : None or str Extra cross-validation metric. Default is `None`. Or * 'mse': mean squared error * 'r2': R2 score * 'corrcoef': Correlation coefficient alpha : float, from 0 to 1. Elastic net parameter, balance between L1 and L2 regularization. * 0.0 -> only L2 * 1.0 -> only L1 beta : float Elastic net parameter, overall weight of regularization. step_size : float Initial step size for JAX optimizer (ADAM). tolerance : int Set early stop tolerance. Optimization stops when cost (dev) monotonically increases or cost (train) stop increases for tolerance=n steps. If `tolerance=0`, then early stop is not used. verbose: int When `verbose=0`, progress is not printed. When `verbose=n`, progress will be printed in every n steps. """ self.metric = metric # metric for cross-validation and prediction self.alpha = alpha self.beta = beta # elastic net parameter - global penalty weight for linear filter self.num_iters = num_iters self.fit_R = fit_R self.fit_linear_filter = fit_linear_filter self.fit_history_filter = fit_history_filter self.fit_nonlinearity = fit_nonlinearity self.fit_intercept = fit_intercept # initialize parameters if p0 is None: p0 = {} dict_keys = p0.keys() if 'w' not in dict_keys: if initialize is None: p0.update({'w': self.w_sta}) elif initialize == 'random': key = random.PRNGKey(random_seed) w0 = 0.01 * random.normal( key, shape=(self.w_sta.shape[0], )).flatten() p0.update({'w': w0}) if 'intercept' not in dict_keys: p0.update({'intercept': jnp.array([0.])}) if 'R' not in dict_keys and self.fit_R: p0.update({'R': jnp.array([1.])}) if 'h' not in dict_keys: if initialize is None and self.h_mle is not None: p0.update({'h': self.h_mle}) elif initialize == 'random' and self.h_mle is not None: key = random.PRNGKey(random_seed) h0 = 0.01 * random.normal( key, shape=(self.h_mle.shape[0], )).flatten() p0.update({'h': h0}) else: p0.update({'h': None}) if 'nl_params' not in dict_keys: if self.nl_params is not None: p0.update({'nl_params': self.nl_params}) else: p0.update({'nl_params': None}) if extra is not None: if self.h_mle is not None: yh = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.yh.shape[1], shift=1)) extra.update({'yh': yh}) extra = {key: jnp.array(extra[key]) for key in extra.keys()} # store optimized parameters self.p0 = p0 self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters, metric, step_size, tolerance, verbose, return_model) self.R = self.p_opt['R'] if fit_R else jnp.array([1.]) if fit_linear_filter: self.w_opt = self.p_opt['w'] if fit_history_filter: self.h_opt = self.p_opt['h'] if fit_nonlinearity: self.nl_params_opt = self.p_opt['nl_params'] if fit_intercept: self.intercept = self.p_opt['intercept']
def fit(self, p0=None, extra=None, initialize='random', num_epochs=1, num_iters=3000, metric=None, alpha=1, beta=0.05, fit_linear_filter=True, fit_intercept=True, fit_R=True, fit_history_filter=False, fit_nonlinearity=False, step_size=1e-2, tolerance=10, verbose=100, random_seed=2046, return_model=None): """ Parameters ========== p0 : dict * 'b': Initial spline coefficients. * 'bh': Initial response history filter coefficients initialize : None or str Parametric initialization. * if `initialize=None`, `b` will be initialized by b_spl. * if `initialize='random'`, `b` will be randomly initialized. num_iters : int Max number of optimization iterations. metric : None or str Extra cross-validation metric. Default is `None`. Or * 'mse': mean squared error * 'r2': R2 score * 'corrcoef': Correlation coefficient alpha : float, from 0 to 1. Elastic net parameter, balance between L1 and L2 regularization. * 0.0 -> only L2 * 1.0 -> only L1 beta : float Elastic net parameter, overall weight of regularization for receptive field. step_size : float Initial step size for JAX optimizer. tolerance : int Set early stop tolerance. Optimization stops when cost monotonically increases or stop increases for tolerance=n steps. verbose: int When `verbose=0`, progress is not printed. When `verbose=n`, progress will be printed in every n steps. """ self.metric = metric self.alpha = alpha self.beta = beta # elastic net parameter - global penalty weight for linear filter self.num_iters = num_iters self.fit_R = fit_R self.fit_linear_filter = fit_linear_filter self.fit_history_filter = fit_history_filter self.fit_nonlinearity = fit_nonlinearity self.fit_intercept = fit_intercept # initial parameters if p0 is None: p0 = {} dict_keys = p0.keys() if 'b' not in dict_keys: if initialize is None: p0.update({'b': self.b_spl}) else: if initialize == 'random': key = random.PRNGKey(random_seed) b0 = 0.01 * random.normal( key, shape=(self.n_b * self.n_c, )).flatten() p0.update({'b': b0}) if 'intercept' not in dict_keys: p0.update({'intercept': jnp.array([0.])}) if 'R' not in dict_keys: p0.update({'R': jnp.array([1.])}) if 'bh' not in dict_keys: if initialize is None and self.bh_spl is not None: p0.update({'bh': self.bh_spl}) elif initialize == 'random' and self.bh_spl is not None: key = random.PRNGKey(random_seed) bh0 = 0.01 * random.normal(key, shape=(len( self.bh_spl), )).flatten() p0.update({'bh': bh0}) else: p0.update({'bh': None}) if 'nl_params' not in dict_keys: if self.nl_params is not None: p0.update({'nl_params': self.nl_params}) else: p0.update({'nl_params': None}) if extra is not None: if self.n_c > 1: XS_ext = jnp.dstack([ extra['X'][:, :, i] @ self.S for i in range(self.n_c) ]).reshape(extra['X'].shape[0], -1) extra.update({'XS': XS_ext}) else: extra.update({'XS': extra['X'] @ self.S}) if self.h_spl is not None: yh_ext = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.Sh.shape[0], shift=1)) yS_ext = yh_ext @ self.Sh extra.update({'yS': yS_ext}) extra = {key: jnp.array(extra[key]) for key in extra.keys()} self.extra = extra # store for cross-validation # store optimized parameters self.p0 = p0 self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters, metric, step_size, tolerance, verbose, return_model) self.R = self.p_opt['R'] if fit_R else jnp.array([1.]) if fit_linear_filter: self.b_opt = self.p_opt['b'] # optimized RF basis coefficients if self.n_c > 1: self.w_opt = self.S @ self.b_opt.reshape(self.n_b, self.n_c) else: self.w_opt = self.S @ self.b_opt # optimized RF if fit_history_filter: self.bh_opt = self.p_opt['bh'] self.h_opt = self.Sh @ self.bh_opt if fit_nonlinearity: self.nl_params_opt = self.p_opt['nl_params'] if fit_intercept: self.intercept = self.p_opt['intercept']
df = [i - 2 for i in dims] stim = f(sample['stim'][:,0,crop[2]:crop[3], crop[0]:crop[1]]).detach().cpu().clone().numpy() robs = sample['robs'].detach().cpu().numpy() del gd # dims = [25, 20, 15] #%% from rfest import ALD #%% robs.shape #%% # stim.shape X = build_design_matrix(stim, dims[0]) X.shape #%% import copy dt = 1/240 df = [5, 15, 15] sigma0 = [1.] rho0 = [1.] params_t0 = [3., 20., 1., 1.] # taus, nus, tauf, nuf params_y0 = [3., 20., 1., 1.] params_x0 = [3., 20., 1., 1.] p0 = sigma0 + rho0 + params_t0 + params_y0 + params_x0 lsgs = []
def fit(self, p0=None, extra=None, num_subunits=2, num_epochs=1, num_iters=3000, metric=None, alpha=1, beta=0.05, fit_linear_filter=True, fit_intercept=True, fit_R=True, fit_history_filter=False, fit_nonlinearity=False, step_size=1e-2, tolerance=10, verbose=100, random_seed=2046, return_model=None): self.metric = metric self.alpha = alpha # elastic net parameter (1=L1, 0=L2) self.beta = beta # elastic net parameter - global penalty weight for linear filter self.n_s = num_subunits self.num_iters = num_iters self.fit_R = fit_R self.fit_linear_filter = fit_linear_filter self.fit_history_filter = fit_history_filter self.fit_nonlinearity = fit_nonlinearity self.fit_intercept = fit_intercept if extra is not None: if self.h_mle is not None: yh = jnp.array( build_design_matrix(extra['y'][:, jnp.newaxis], self.yh.shape[1], shift=1)) extra.update({'yh': yh}) extra = {key: jnp.array(extra[key]) for key in extra.keys()} # initialize parameters if p0 is None: p0 = {} dict_keys = p0.keys() if 'w' not in dict_keys: key = random.PRNGKey(random_seed) w0 = 0.01 * random.normal( key, shape=(self.n_features * self.n_c * self.n_s, )).flatten() p0.update({'w': w0}) if 'intercept' not in dict_keys: p0.update({'intercept': jnp.array([0.])}) if 'R' not in dict_keys and self.fit_R: p0.update({'R': jnp.array([1.])}) if 'h' not in dict_keys: try: p0.update({'h': self.h_mle}) except: p0.update({'h': None}) if 'nl_params' not in dict_keys: if self.nl_params is not None: p0.update({ 'nl_params': [self.nl_params for _ in range(self.n_s + 1)] }) else: p0.update({'nl_params': [None for _ in range(self.n_s + 1)]}) self.p0 = p0 self.p_opt = self.optimize_params(p0, extra, num_epochs, num_iters, metric, step_size, tolerance, verbose, return_model) self.R = self.p_opt['R'] if fit_linear_filter: if self.n_c > 1: self.w_opt = self.p_opt['w'].reshape(self.n_features, self.n_c, self.n_s) else: self.w_opt = self.p_opt['w'].reshape(self.n_features, self.n_s) if fit_history_filter: self.h_opt = self.p_opt['h'] if fit_nonlinearity: self.nl_params_opt = self.p_opt['nl_params'] if fit_intercept: self.intercept = self.p_opt['intercept']
def add_design_matrix(self, X, dims=None, df=None, smooth=None, lag=True, filter_nonlinearity='none', kind='train', name='stimulus', shift=0, burn_in=None): """ Add input design matrix to the model. Parameters ---------- X: jnp.array, shape=(n_samples, ) or (n_samples, n_pixels) Original input. dims: int, or list / jnp.array, shape=dim_t, or (dim_t, dim_x, dim_y) Filter shape. df: None, int, or list / jnp.array Number of spline bases. Should be the same shape as dims. smooth: None, or str Type of spline bases. If None, no basis is used. lag: bool If True, the design matrix will be build based on the dims[0]. If False, a instantaneous RF will be fitted. filter_nonlinearity: str Nonlinearity for the stimulus filter. kind: str Dataset type, should be one of `train` (training set), `dev` (validation set) or `test` (testing set). name: str Name of the corresponding filter. A receptive field (stimulus) filter should have `stimulus` in the name. A response-history filter should have `history` in the name. shift: int Time offset for the design matrix, positive number will shift the design matrix to the past, negative number will shift it to the future. burn_in: int or None Number of samples / frames to be ignored for prediction. (Because the first few frames in the design matrix are full of zero, which tend to predict poorly.) """ # check X shape if len(X.shape) == 1: X = X[:, jnp.newaxis].astype(self.dtype) else: X = X.astype(self.dtype) if kind not in self.X: self.X.update({kind: {}}) if kind == 'train': self.filter_nonlinearity[name] = filter_nonlinearity self.filter_names.append(name) dims = dims if type(dims) is not int else [ dims, ] self.dims[name] = dims self.shift[name] = shift else: dims = self.dims[name] shift = self.shift[name] if self.burn_in is None: # if exists, ignore self.burn_in = dims[ 0] - 1 if burn_in is None else burn_in # number of first few frames to ignore if lag: self.X[kind][name] = build_design_matrix( X, dims[0], shift=shift, dtype=self.dtype)[self.burn_in:] else: self.burn_in = 0 self.X[kind][ name] = X # if not time lag, shouldn't it also be no burn in? # TODO: might need different handlings for instantaneous RF. # conflict: history filter burned-in but the stimulus filter didn't if smooth is None: # if train set exists and used spline as basis # automatically apply the same basis for dev/test set if name in self.S: if kind not in self.XS: self.XS.update({kind: {}}) S = self.S[name] self.XS[kind][name] = self.X[kind][name] @ S elif kind == 'test': if kind not in self.XS: self.XS.update({kind: {}}) if self.num_subunits is not None and self.num_subunits > 1: S = self.S['stimulus_s0'] else: S = self.S[name] self.XS[kind][name] = self.X[kind][name] @ S else: if kind == 'train': self.n_features[name] = self.X['train'][name].shape[1] else: # use spline if kind not in self.XS: self.XS.update({kind: {}}) self.df[name] = df if type(df) is not int else [ df, ] S = build_spline_matrix(dims=dims, df=self.df[name], smooth=smooth, dtype=self.dtype) self.S[name] = S XS = self.X[kind][name] @ S self.XS[kind][name] = XS if kind == 'train': self.n_features[name] = self.XS['train'][name].shape[1]