def get_out_dtb1ds(self, ev: torch.Tensor, return_unabs: bool): """ :param ev: [cond, fr, dim, (mean, var)] :param return_unabs: :return: p_dim_cond_td_ch, unabs_dim_td_cond_ev """ ev1 = ev.clone() if self.to_allow_irr_ixn: for dim in range(consts.N_DIM): odim = consts.get_odim(dim) kappa = self.dtb1ds[dim].kappa[:] okappa = self.dtb1ds[odim].kappa[:] ko = self.kappa_rel_odim[odim] / kappa * okappa kao = self.kappa_rel_abs_odim[odim] / kappa * okappa ev1[:, :, dim, 0] = (ev[:, :, dim, 0] + ev[:, :, odim, 0] * ko + ev[:, :, odim, 0].abs() * kao) outs = [ dtb(ev11, return_unabs=return_unabs) for ev11, dtb in zip(ev1.permute([2, 0, 1, 3]), self.dtb1ds) ] if return_unabs: p_dim_cond_td_ch = torch.stack([v[0] for v in outs]) unabs_dim_td_cond_ev = torch.stack([v[1] for v in outs]) else: # p_dim_cond__ch_td[dim, cond, ch, td] = P(ch_dim, td_dim | cond) p_dim_cond_td_ch = torch.stack(outs) unabs_dim_td_cond_ev = None return p_dim_cond_td_ch, unabs_dim_td_cond_ev
def get_coefs_irr_ixn_from_histogram( p_cond_dur_ch: np.ndarray, ev_cond_dim: np.ndarray ) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): """ :param p_cond_dur_ch: :param ev_cond_dim: [cond, dim] :param dif_irrs: return: (coef, se_coef, glmres, glmmodel) coef[(bias, slope), dim, dif, dur] """ n_dim = ev_cond_dim.shape[1] n_dur = p_cond_dur_ch.shape[1] siz = [n_dim, n_dur] n_coef = 6 # constant, rel, rel x abs(irr), rel x irr, abs(irr), irr, coef = np.zeros([n_coef] + siz) + np.nan se_coef = np.zeros([n_coef] + siz) + np.nan glmress = np.empty(siz, dtype=np.object) glmmodels = np.empty(siz, dtype=np.object) p_cond_dur_ch = p_cond_dur_ch.reshape([-1] + [n_dur] + [consts.N_CH] * 2) for dim_rel in range(n_dim): for idur in range(n_dur): dim_irr = consts.get_odim(dim_rel) reg = [ ev_cond_dim[:, dim_rel], ev_cond_dim[:, dim_rel] * np.abs(ev_cond_dim[:, dim_irr]), ev_cond_dim[:, dim_rel] * ev_cond_dim[:, dim_irr], np.abs(ev_cond_dim[:, dim_irr]), ev_cond_dim[:, dim_irr] ] reg = np.stack(reg, -1) reg = sm.add_constant(reg) n_coef1 = reg.shape[1] if dim_rel == 0: p_cond_ch = p_cond_dur_ch[:, idur, :, :].sum(-1) else: p_cond_ch = p_cond_dur_ch[:, idur, :, :].sum(-2) glmmodel = sm.GLM(np.flip(p_cond_ch, -1), reg, family=sm.families.Binomial()) glmres = glmmodel.fit() coef[:n_coef1, dim_rel, idur] = glmres.params se_coef[:n_coef1, dim_rel, idur] = glmres.bse glmress[dim_rel, idur] = glmres glmmodels[dim_rel, idur] = glmmodel return coef, se_coef, glmress, glmmodels
def upsample_ev(ev_cond_fr_dim_meanvar: torch.Tensor, dim_rel: int, steps=51) -> torch.Tensor: ev0 = ev_cond_fr_dim_meanvar ev_dim_cond = ev0[:, 0, :, 0].T evs_dim_cond = [v.unique() for v in ev_dim_cond] dim_irr = consts.get_odim(dim_rel) ev_rel = torch.linspace(evs_dim_cond[dim_rel].min(), evs_dim_cond[dim_rel].max(), steps=steps) ev_irr = evs_dim_cond[dim_irr] ev_rel, ev_irr = torch.meshgrid([ev_rel, ev_irr]) ev = torch.stack([v.flatten() for v in [ev_rel, ev_irr]], -1) if dim_rel == 1: ev = ev.flip(-1) ev = ev[:, None, :].expand([-1, ev0.shape[1], -1]) ev = torch.stack([ev, torch.zeros_like(ev)], -1) return ev
def plot_rt_vs_ev( ev_cond: Union[torch.Tensor, np.ndarray], n_cond__rt_ch: Union[torch.Tensor, np.ndarray], dim_rel=0, group_dcond_irr: Iterable[Iterable[int]] = None, correct_only=True, cmap: Union[str, Callable] = 'cool', kw_plot=(), **kwargs ) -> Tuple[Sequence[Sequence[plt.Line2D]], Sequence[Sequence[float]]]: """ @param ev_cond: [condition] @type ev_cond: torch.Tensor @param n_cond__rt_ch: [condition, frame, ch] @type n_cond__rt_ch: torch.Tensor @return: """ if ev_cond.ndim != 2: assert ev_cond.ndim == 4 ev_cond = npt.p2st(ev_cond.mean(1))[0] assert n_cond__rt_ch.ndim == 3 n_cond__rt_ch = n_cond__rt_ch.copy() ev_cond = npy(ev_cond) n_cond__rt_ch = npy(n_cond__rt_ch) nt = n_cond__rt_ch.shape[1] n_ch = n_cond__rt_ch.shape[2] n_cond_all = n_cond__rt_ch.shape[0] dim_irr = consts.get_odim(dim_rel) conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True) conds_irr0, dcond_irr0 = np.unique(ev_cond[:, dim_irr], return_inverse=True) conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]), return_inverse=True) # --- Exclude wrong choice on either dim n_cond__rt_ch0 = n_cond__rt_ch.copy() if correct_only: ch_signs = consts.ch_bool2sign(consts.CHS_ARRAY) for dim, ch_signs1 in enumerate(ch_signs): for i_ch, ch_sign1 in enumerate(ch_signs1): for i, cond1 in enumerate(ev_cond[:, dim]): if np.sign(cond1) == -ch_sign1: n_cond__rt_ch[i, :, i_ch] = 0. if group_dcond_irr is not None: conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr, group_dcond_irr) n_conds = [len(conds_rel), len(conds_irr)] if correct_only: # simply split at zero hs = [] rts = [] for dcond_irr1 in range(n_conds[1]): for ch1 in range(consts.N_CH): ch_sign = consts.ch_bool2sign(ch1) incl = dcond_irr == dcond_irr1 ev_cond1 = conds_rel n_cond__rt_ch1 = np.empty([n_conds[0], nt, n_ch]) for dcond_rel1 in range(n_conds[0]): incl1 = incl & (dcond_rel1 == dcond_rel) n_cond__rt_ch1[dcond_rel1] = n_cond__rt_ch[incl1].sum(0) if type(cmap) is str: color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1) else: color = cmap(n_conds[1])(dcond_irr1) chs = np.array(consts.CHS) n_cond__rt_ch11 = np.zeros(n_cond__rt_ch1.shape[:2] + (consts.N_CH, )) # # -- Pool across ch_irr for ch_rel in range(consts.N_CH): incl = chs[dim_rel] == ch_rel n_cond__rt_ch11[:, :, ch_rel] = n_cond__rt_ch1[:, :, incl].sum(-1) hs1, rts1 = sim1d.plot_rt_vs_ev( ev_cond1, n_cond__rt_ch11, color=color, correct_only=correct_only, kw_plot=kw_plot, **kwargs, ) hs.append(hs1) rts.append(rts1) else: raise NotImplementedError() return hs, rts
def plot_p_ch_vs_ev( ev_cond: Union[torch.Tensor, np.ndarray], n_ch: Union[torch.Tensor, np.ndarray], style='pred', ax: plt.Axes = None, dim_rel=0, group_dcond_irr: Iterable[Iterable[int]] = None, cmap: Union[str, Callable] = 'cool', kw_plot=(), ) -> Iterable[plt.Line2D]: """ @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)] @type ev_cond: torch.Tensor @param n_ch: [condition, ch] or [condition, rt_frame, ch] @type n_ch: torch.Tensor @return: hs[cond_irr][0] = Line2D, conds_irr """ if ax is None: ax = plt.gca() if ev_cond.ndim != 2: assert ev_cond.ndim == 4 ev_cond = npt.p2st(ev_cond.mean(1))[0] if n_ch.ndim != 2: assert n_ch.ndim == 3 n_ch = n_ch.sum(1) ev_cond = npy(ev_cond) n_ch = npy(n_ch) n_cond_all = n_ch.shape[0] ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0) n_ch = n_ch.reshape([-1]) ch_rel = ch_rel.reshape([-1]) dim_irr = consts.get_odim(dim_rel) conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True) conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]), return_inverse=True) if group_dcond_irr is not None: conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr, group_dcond_irr) n_conds = [len(conds_rel), len(conds_irr)] n_ch_rel = npg.aggregate( np.stack([ ch_rel, np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(), np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(), ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]]) p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0) hs = [] for dcond_irr1, p_ch1 in enumerate(p_ch_rel): if type(cmap) is str: color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1) else: color = cmap(n_conds[1])(dcond_irr1) kw1 = get_kw_plot(style, color=color, **dict(kw_plot)) h = ax.plot(conds_rel, p_ch1, **kw1) hs.append(h) plt2.box_off(ax=ax) x_lim = ax.get_xlim() plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax) plt2.detach_axis('y', amin=0, amax=1, ax=ax) ax.set_yticks([0, 0.5, 1]) ax.set_yticklabels(['0', '', '1']) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$") return hs, conds_irr
def get_coefs_mesh_from_histogram( p_cond_dur_ch: np.ndarray, ev_cond_dim: np.ndarray, dif_irrs=((2, ), (0, 1)) ) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): """ :param p_cond_dur_ch: :param ev_cond_dim: [cond, dim] :param dif_irrs: return: (coef, se_coef, glmres, glmmodel) coef[(bias, slope), dim, dif, dur] """ n_dim = ev_cond_dim.shape[1] n_dif = len(dif_irrs) n_dur = p_cond_dur_ch.shape[1] siz = [n_dim, n_dif, n_dur] n_coef = 4 coef = np.zeros([n_coef] + siz) + np.nan se_coef = np.zeros([n_coef] + siz) + np.nan glmress = np.empty(siz, dtype=np.object) glmmodels = np.empty(siz, dtype=np.object) p_cond_dur_ch = p_cond_dur_ch.reshape([-1] + [n_dur] + [consts.N_CH] * 2) for dim_rel in range(n_dim): for idif, dif_irr in enumerate(dif_irrs): for idur in range(n_dur): dim_irr = consts.get_odim(dim_rel) cond_irr = ev_cond_dim[:, dim_irr] adcond_irr = np.unique(np.abs(cond_irr), return_inverse=True)[1] incl = np.isin(adcond_irr, dif_irr) ev_cond_dim1 = ev_cond_dim[incl] reg = [ev_cond_dim1[:, dim_rel]] reg += [ ev_cond_dim1[:, dim_irr], ] if len(dif_irr) > 1: # otherwise np.abs(cond_irr) would be constant reg.append(np.abs(ev_cond_dim1[:, dim_irr])) reg = np.stack(reg, -1) reg = sm.add_constant(reg) n_coef1 = reg.shape[1] if dim_rel == 0: p_cond_ch = p_cond_dur_ch[incl, idur, :, :].sum(-1) else: p_cond_ch = p_cond_dur_ch[incl, idur, :, :].sum(-2) glmmodel = sm.GLM(np.flip(p_cond_ch, -1), reg, family=sm.families.Binomial()) glmres = glmmodel.fit() coef[:n_coef1, dim_rel, idif, idur] = glmres.params se_coef[:n_coef1, dim_rel, idif, idur] = glmres.bse glmress[dim_rel, idif, idur] = glmres glmmodels[dim_rel, idif, idur] = glmmodel return coef, se_coef, glmress, glmmodels
def plot_rt( parad, to_cumsum=False, corners_only=False, to_add_legend=True, use_easiest_only=0, plot_kind='rt_mean', model_names=('Ser', 'Par'), to_normalize_max=False, axs=None, ): """ :param parad: :param to_cumsum: :param corners_only: :param to_add_legend: :param use_easiest_only: :param plot_kind: 'mean'|'distrib' :return: """ subjs = consts.SUBJS[parad] n_models = len(model_names) p_preds = np.empty([n_models, len(subjs)], dtype=np.object) p_datas = np.empty([n_models, len(subjs)], dtype=np.object) condss = np.empty([n_models, len(subjs)], dtype=np.object) ts = np.empty([n_models, len(subjs)], dtype=np.object) p_pred_trains = np.empty([n_models, len(subjs)], dtype=np.object) p_data_trains = np.empty([n_models, len(subjs)], dtype=np.object) p_pred_tests = np.empty([n_models, len(subjs)], dtype=np.object) p_data_tests = np.empty([n_models, len(subjs)], dtype=np.object) colors = ('red', 'blue') normalize_ev = parad == 'RT' for i_model, model_name in enumerate(model_names): for i_subj, subj in enumerate(subjs): ( p_pred1, p_data1, conds1, t1, cond_ch_incl1, p_pred_train, p_data_train, p_pred_test, p_data_test, ) = load_fit(subj, parad, model_name, corners_only=corners_only, use_easiest_only=use_easiest_only) ix_abs_cond = np.unique(np.abs(conds1[1]), return_inverse=True)[1] sign_cond = np.sign(conds1[1]) if normalize_ev: conds1 = conds1 / np.amax(conds1, axis=1, keepdims=True) p_preds[i_model, i_subj], p_datas[i_model, i_subj], \ condss[i_model, i_subj], ts[i_model, i_subj] \ = p_pred1, p_data1, conds1, t1 ( p_pred_trains[i_model, i_subj], p_data_trains[i_model, i_subj], p_pred_tests[i_model, i_subj], p_data_tests[i_model, i_subj], ) = p_pred_train, p_data_train, p_pred_test, p_data_test siz0 = list(p_preds[0, 0].shape) def cell2array(p): return np.stack(p.flatten()).reshape( [len(model_names), len(subjs)] + siz0) p_preds = cell2array(p_preds) p_datas = cell2array(p_datas) p_pred_trains = cell2array(p_pred_trains) p_data_trains = cell2array(p_data_trains) p_pred_tests = cell2array(p_pred_tests) p_data_tests = cell2array(p_data_tests) def pool_subjs(p_datas, p_preds): n_data_subj = np.sum(p_datas, (2, 3, 4), keepdims=True) # P(ch, rt | cond, subj, model) p_preds = np2.nan2v(p_preds / p_preds.sum((3, 4), keepdims=True)) p_pred_avg = np.sum(p_preds * n_data_subj, 1) n_data_sum = np.sum(p_datas, 1) return p_pred_avg, n_data_sum p_pred_avg, n_data_sum = pool_subjs(p_datas, p_preds) p_pred_avg_train, n_data_sum_train = pool_subjs(p_data_trains, p_pred_trains) p_pred_avg_test, n_data_sum_test = pool_subjs(p_data_tests, p_pred_tests) ev_cond_dim = np.stack(condss[0, 0], -1) dt = 1 / 75 if plot_kind == 'rt_mean': if axs is None: axs = plt2.GridAxes(1, 2, top=0.4, left=0.6, right=0.1, bottom=0.65, wspace=0.35, heights=1.7, widths=2.2) for i_model, model_name in enumerate(model_names): hs = fit2d.plot_fit_combined( pAll_cond_rt_chFlat=n_data_sum[i_model], evAll_cond_dim=ev_cond_dim, pTrain_cond_rt_chFlat=n_data_sum_train[i_model], pTest_cond_rt_chFlat=n_data_sum_test[i_model], pModel_cond_rt_chFlat=p_pred_avg[i_model], dt=dt, to_plot_params=False, to_plot_internals=False, to_plot_choice=False, group_dcond_irr=None, kw_plot_pred={ # 'linestyle': ':', # 'alpha': 0.7, # 'linewidth': 2, # 'linestyle': '--' if model_name == 'Par' else '-', }, kw_plot_data={ 'markersize': 4, }, axs=axs, )[2] if to_add_legend: n_dim = 2 for dim in range(n_dim): plt.sca(axs[0, dim]) odim = consts.get_odim(dim) conds_irr = np.unique(np.abs(ev_cond_dim[:, odim])) hs1 = [v[0][0] for v in hs['rt'][dim]] hs1 = hs1[len(conds_irr):(len(conds_irr) * 2)] h = fit2d.legend_odim([np.round(v, 3) for v in conds_irr], hs1, '', loc='lower left', bbox_to_anchor=[1., 0.]) h.set_title(consts.DIM_NAMES_LONG[odim] + '\nstrength') plt.setp(h.get_title(), multialignment='center') ev_max = np.amax(condss[0, 0], -1) for col in [0, 1]: plt.sca(axs[-1, col]) txt = '%s strength' % consts.DIM_NAMES_LONG[col].lower() if normalize_ev: txt += '\n(a.u.)' plt.xlabel(txt) xticks = [-ev_max[col], 0, ev_max[col]] plt.xticks(xticks, ['%g' % v for v in xticks]) from matplotlib.ticker import MultipleLocator axs[0, 0].yaxis.set_major_locator(MultipleLocator(1)) print('--') elif plot_kind == 'rt_distrib': axs, hs = fit2d.plot_rt_distrib_pred_data( p_pred_avg, n_data_sum[0], ev_cond_dim, dt, to_normalize_max=to_normalize_max, to_cumsum=to_cumsum, to_skip_zero_trials=True, xlim=[0., 5.], kw_plot_data={ 'linewidth': 2, 'linestyle': ':', 'alpha': 0.75, }, labels=['serial', 'parallel', 'data'], colors=colors, ) if to_add_legend: plt.sca(axs[0, 0]) locs = { 'loc': 'center right', 'bbox_to_anchor': (1.05, 0., 0., 1.) } if to_cumsum else { 'loc': 'upper right', 'bbox_to_anchor': (1.05, 1.01) } plt.legend(**locs, handlelength=0.8, handletextpad=0.5, frameon=False, borderpad=0.) print('--') else: raise ValueError() return axs
def plot_ch_ev_by_dur(n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )), ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, dur_prct_groups=((0, 33), (33, 67), (67, 100)), style='data', kw_plot=(), jitter=0., axs=None, fig=None): """ Panels[dim, irr_dif_group], curves by dur_group :param n_cond_dur_ch: :param data: :param dif_irrs: :param ev_cond_dim: :param durs: :param dur_prct_groups: :param style: :param kw_plot: :param axs: [row, col] :return: """ if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = data.durs if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) n_conds_dur_chs = get_n_conds_dur_chs(n_cond_dur_ch) conds_dim = [np.unique(cond1) for cond1 in ev_cond_dim.T] n_dur = len(dur_prct_groups) n_dif = len(dif_irrs) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 4]) n_row = consts.N_DIM n_col = n_dur gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) for dim_rel in range(consts.N_DIM): for idif, dif_irr in enumerate(dif_irrs): for i_dur, dur_prct_group in enumerate(dur_prct_groups): ax = axs[dim_rel, i_dur] plt.sca(ax) conds_rel = conds_dim[dim_rel] dim_irr = consts.get_odim(dim_rel) _, cond_irr = np.unique(np.abs(conds_dim[dim_irr]), return_inverse=True) incl_irr = np.isin(cond_irr, dif_irr) cmap = consts.CMAP_DIM[dim_rel](n_dif) ix_dur = np.arange(len(durs)) dur_incl = ( (ix_dur >= np.percentile(ix_dur, dur_prct_group[0])) & (ix_dur <= np.percentile(ix_dur, dur_prct_group[1]))) if dim_rel == 0: n_cond_dur_ch1 = n_conds_dur_chs[:, incl_irr, :, :, :].sum( (1, -1)) else: n_cond_dur_ch1 = n_conds_dur_chs[incl_irr, :, :, :, :].sum( (0, -2)) n_cond_ch1 = n_cond_dur_ch1[:, dur_incl, :].sum(1) p_cond__ch1 = n_cond_ch1[:, 1] / n_cond_ch1.sum(1) x = conds_rel y = p_cond__ch1 dx = conds_rel[1] - conds_rel[0] jitter1 = 0 kw = consts.get_kw_plot(style, color=cmap(idif), **dict(kw_plot)) if style.startswith('data'): plt.plot(x + jitter1, y, **kw) else: plt.plot(x + jitter1, y, **kw) plt2.box_off(ax=ax) plt2.detach_yaxis(0, 1, ax=ax) ax.set_yticks([0, 0.5, 1.]) if i_dur == 0: ax.set_yticklabels(['0', '', '1']) else: ax.set_yticklabels([]) ax.set_xticklabels([]) return axs
def get_p_cond_dur_chFlat_vectorized( p_dim_cond_td_chDim: torch.Tensor, unabs_dim_td_cond_chDim: torch.Tensor, dur_buffer_fr: torch.Tensor, dur_stim_frs: torch.Tensor, p1st_dim0: torch.Tensor, ) -> torch.Tensor: if torch.is_floating_point(dur_buffer_fr): bufs = torch.cat([ dur_buffer_fr.floor().long().reshape([1]), dur_buffer_fr.floor().long().reshape([1]) + 1 ], 0) prop_buf = torch.tensor(1.) - torch.abs(dur_buffer_fr - bufs) ps = [] for buf in bufs: ps.append( Dtb2DVDBufSerial.get_p_cond_dur_chFlat( p_dim_cond_td_chDim, unabs_dim_td_cond_chDim, buf.long(), dur_stim_frs=dur_stim_frs, p1st_dim0=p1st_dim0)) ps = torch.stack(ps) p_cond_dur_chFlat = (ps * prop_buf[:, None, None, None]).sum(0) return p_cond_dur_chFlat # vectorized version p1st_dim = torch.stack([p1st_dim0, torch.tensor(1.) - p1st_dim0]) n_cond = p_dim_cond_td_chDim.shape[1] ndur = len(dur_stim_frs) p_cond_dur_chFlat = torch.zeros([n_cond, ndur, consts.N_CH_FLAT]) p_cond_dim_td_chDim = p_dim_cond_td_chDim.transpose(0, 1) unabs_dim_cond_td_chDim = unabs_dim_td_cond_chDim.transpose(1, 2) unabs_cond_dim_td_chDim = unabs_dim_cond_td_chDim.transpose(0, 1) ichs = torch.arange(consts.N_CH_FLAT) dim1sts = torch.arange(consts.N_DIM) for idur, dur_stim in enumerate(dur_stim_frs): p0 = torch.zeros([n_cond, consts.N_CH_FLAT]) for td1st in torch.arange(dur_stim): max_td2nd = dur_stim - max([td1st - dur_buffer_fr, 0]) td2nds = torch.arange(max_td2nd) dim1st, td2nd, ich = torch.meshgrid([dim1sts, td2nds, ichs]) dim2nd = consts.get_odim(dim1st) ch1st = consts.CHS_TENSOR[dim1st, ich] ch2nd = consts.CHS_TENSOR[dim2nd, ich] # When both dim1st and dim2nd are absorbed p0 = p0 + ( (p_cond_dim_td_chDim[:, dim1st, td1st.expand_as(td2nd), ch1st] * p_cond_dim_td_chDim[:, dim2nd, td2nd, ch2nd]).sum( -2) # sum across td2nd * p1st_dim[None, :, None]).sum(1) # sum across p1st # When only dim1st is absorbed, t2nd = max_td2nd dim1st, ich = torch.meshgrid([dim1sts, ichs]) dim2nd = consts.get_odim(dim1st) ch1st = consts.CHS_TENSOR[dim1st, ich] ch2nd = consts.CHS_TENSOR[dim2nd, ich] p0 = p0 + ((p_cond_dim_td_chDim[:, dim1st, td1st, ch1st] * unabs_cond_dim_td_chDim[:, dim2nd, t2nd, ch2nd]) * p1st_dim[None, :, None]).sum(1) # sum across dim1st dim1st, ich = torch.meshgrid([dim1sts, ichs]) dim2nd = consts.get_odim(dim1st) ch1st = consts.CHS_TENSOR[dim1st, ich] ch2nd = consts.CHS_TENSOR[dim2nd, ich] # When neither dim is absorbed, # then dim2nd is certainly not absorbed, # and stays at the state at t = min([dur_stim, dur_buffer_fr]) t2nd = min([dur_stim, dur_buffer_fr]) p0 = p0 + (unabs_cond_dim_td_chDim[:, dim1st, dur_stim, ch1st] * unabs_cond_dim_td_chDim[:, dim2nd, t2nd, ch2nd] * p1st_dim[None, :, None]).sum(1) p_cond_dur_chFlat[:, idur, :] = (p_cond_dur_chFlat[:, idur, :] + p0) return p_cond_dur_chFlat
def get_p_cond_dur_chFlat( p_dim_cond_td_chDim: torch.Tensor, unabs_dim_td_cond_chDim: torch.Tensor, dur_buffer_fr: torch.Tensor, dur_stim_frs: torch.Tensor, p1st_dim0: torch.Tensor, ) -> torch.Tensor: """ :param p_dim_cond_td_chDim: [dim, cond, td, chDim] :param unabs_dim_td_cond_chDim: [dim, td, cond, chDim] :param dur_buffer_fr: scalar :param dur_stim_frs: [idur] :return: p_cond_dur_chFlat[cond, dur, chFlat] """ if torch.is_floating_point(dur_buffer_fr): bufs = torch.cat([ dur_buffer_fr.floor().long().reshape([1]), dur_buffer_fr.floor().long().reshape([1]) + 1 ], 0) prop_buf = torch.tensor(1.) - torch.abs(dur_buffer_fr - bufs) ps = [] for buf in bufs: ps.append( Dtb2DVDBufSerial.get_p_cond_dur_chFlat( p_dim_cond_td_chDim, unabs_dim_td_cond_chDim, buf.long(), dur_stim_frs=dur_stim_frs, p1st_dim0=p1st_dim0)) ps = torch.stack(ps) p_cond_dur_chFlat = (ps * prop_buf[:, None, None, None]).sum(0) return p_cond_dur_chFlat p1st_dim = [p1st_dim0, torch.tensor(1.) - p1st_dim0] n_cond = p_dim_cond_td_chDim.shape[1] ndur = len(dur_stim_frs) p_cond_dur_chFlat = torch.zeros([n_cond, ndur, consts.N_CH_FLAT]) cumP_dim_cond_td_chDim = p_dim_cond_td_chDim.cumsum(-2) for dim1st in range(consts.N_DIM): dim2nd = consts.get_odim(dim1st) for idur, dur_stim in enumerate(dur_stim_frs): p0 = torch.zeros([n_cond, consts.N_CH_FLAT]) for ich, chs in enumerate(consts.CHS_TENSOR.T): ch1st = chs[dim1st] ch2nd = chs[dim2nd] for td1st in torch.arange(dur_stim): max_td2nd = dur_stim - max([td1st - dur_buffer_fr, 0]) # ==== When both dims are absorbed p0[:, ich] = p0[:, ich] + ( p_dim_cond_td_chDim[dim1st, :, td1st, ch1st] * cumP_dim_cond_td_chDim[dim2nd, :, max_td2nd, ch2nd]) # ==== When only dim1st is absorbed p0[:, ich] = p0[:, ich] + ( p_dim_cond_td_chDim[dim1st, :, td1st, ch1st] * unabs_dim_td_cond_chDim[dim2nd, max_td2nd, :, ch2nd]) # ==== When dim1st is not absorbed t1st = dur_stim t2nd = dur_stim - max([t1st - dur_buffer_fr, 0]) # ==== When only dim2nd is absorbed: this can happen when # dim2nd is absorbed within the buffer duration p0[:, ich] = p0[:, ich] + ( unabs_dim_td_cond_chDim[dim1st, t1st, :, ch1st] * cumP_dim_cond_td_chDim[dim2nd, :, t2nd, ch2nd]) # ==== When neither dim is absorbed p0[:, ich] = p0[:, ich] + ( unabs_dim_td_cond_chDim[dim1st, t1st, :, ch1st] * unabs_dim_td_cond_chDim[dim2nd, t2nd, :, ch2nd]) p0 = p0 / p0.sum(1, keepdim=True) p_cond_dur_chFlat[:, idur, :] = (p_cond_dur_chFlat[:, idur, :] + p1st_dim[dim1st] * p0) return p_cond_dur_chFlat