def do(action, args, config): """ Manage tasks. """ print("ACTION*: " + str(action)) print("ARGS*: " + str(args)) #===================================================================================== if 'trials' in action: try: trials_per_condition = int(args[0]) except: trials_per_condition = 500 model = config['model'] pg = model.get_pg(config['savefile'], config['seed'], config['dt']) # Conditions spec = model.spec mods = spec.mods freqs = spec.freqs n_conditions = spec.n_conditions n_trials = n_conditions * trials_per_condition print("{} trials".format(n_trials)) task = model.Task() trials = [] for n in xrange(n_trials): k = tasktools.unravel_index(n, (len(mods), len(freqs))) context = {'mod': mods[k.pop(0)], 'freq': freqs[k.pop(0)]} trials.append(task.get_condition(pg.rng, pg.dt, context)) runtools.run(action, trials, pg, config['trialspath']) #===================================================================================== elif action == 'psychometric': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() psychometric(trialsfile, plot) plot.vline(config['model'].spec.boundary) fig.save(path=config['figspath'], name='psychometric') fig.close() #===================================================================================== elif action == 'sort': if 'value' in args: network = 'v' else: network = 'p' trialsfile = runtools.activityfile(config['trialspath']) sort(trialsfile, (config['figspath'], 'sorted'), network=network)
def sort(trialsfile, plots, units=None, network='p', **kwargs): """ Sort trials. """ # Load trials data = utils.load(trialsfile) trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data # Which network? if network == 'p': r = r_p else: r = r_v # Number of units N = r.shape[-1] # Same for every trial time = trials[0]['time'] Ntime = len(time) # Aligned time time_a = np.concatenate((-time[1:][::-1], time)) Ntime_a = len(time_a) #===================================================================================== # Aligned to stimulus onset #===================================================================================== r_by_cond_stimulus = {} n_r_by_cond_stimulus = {} for n, trial in enumerate(trials): if not perf.decisions[n]: continue if trial['mod'] == 'va': continue assert trial['mod'] == 'v' or trial['mod'] == 'a' if not perf.corrects[n]: continue # Condition mod = trial['mod'] choice = perf.choices[n] cond = (mod, choice) # Storage r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N))) n_r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N))) # Firing rates Mn = np.tile(M[:, n], (N, 1)).T Rn = r[:, n] * Mn # Align point t0 = trial['epochs']['stimulus'][0] - 1 # Before n_b = Rn[:t0].shape[0] r_by_cond_stimulus[cond][Ntime - 1 - n_b:Ntime - 1] += Rn[:t0] n_r_by_cond_stimulus[cond][Ntime - 1 - n_b:Ntime - 1] += Mn[:t0] # After n_a = Rn[t0:].shape[0] r_by_cond_stimulus[cond][Ntime - 1:Ntime - 1 + n_a] += Rn[t0:] n_r_by_cond_stimulus[cond][Ntime - 1:Ntime - 1 + n_a] += Mn[t0:] for cond in r_by_cond_stimulus: r_by_cond_stimulus[cond] = utils.div(r_by_cond_stimulus[cond], n_r_by_cond_stimulus[cond]) #------------------------------------------------------------------------------------- # Plot #------------------------------------------------------------------------------------- lw = kwargs.get('lw', 1.5) dashes = kwargs.get('dashes', [3, 2]) vline_props = {'lw': kwargs.get('lw_vline', 0.5)} if 'dashes_vline' in kwargs: vline_props['linestyle'] = '--' vline_props['dashes'] = dashes colors_by_mod = {'v': Figure.colors('blue'), 'a': Figure.colors('green')} linestyle_by_choice = {'L': '-', 'H': '--'} lineprops = dict(lw=lw) def plot_sorted(plot, unit, w, r_sorted): t = time_a[w] yall = [[1]] for cond in [('v', 'H'), ('v', 'L'), ('a', 'H'), ('a', 'L')]: mod, choice = cond if mod == 'v': label = 'Vis, ' elif mod == 'a': label = 'Aud, ' else: raise ValueError(mod) if choice == 'H': label += 'high' elif choice == 'L': label += 'low' else: raise ValueError(choice) linestyle = linestyle_by_choice[choice] if linestyle == '-': lineprops = dict(linestyle=linestyle, lw=lw) else: lineprops = dict(linestyle=linestyle, lw=lw, dashes=dashes) plot.plot(t, r_sorted[cond][w, unit], color=colors_by_mod[mod], label=label, **lineprops) yall.append(r_sorted[cond][w, unit]) return t, yall def on_stimulus(plot, unit): w, = np.where((time_a >= -300) & (time_a <= 1000)) t, yall = plot_sorted(plot, unit, w, r_by_cond_stimulus) plot.xlim(t[0], t[-1]) return yall if units is not None: for plot, unit in zip(plots, units): on_stimulus(plot, unit) else: figspath, name = plots for unit in xrange(N): fig = Figure() plot = fig.add() #----------------------------------------------------------------------------- yall = [] yall += on_stimulus(plot, unit) plot.lim('y', yall, lower=0) plot.vline(0) plot.xlabel('Time (ms)') plot.ylabel('Firing rate (a.u.)') #----------------------------------------------------------------------------- fig.save(path=figspath, name=name + '_{}{:03d}'.format(network, unit)) fig.close()
def do(action, args, config): print("ACTION*: " + str(action)) print("ARGS*: " + str(args)) #===================================================================================== if 'trials' in action: try: trials_per_condition = int(args[0]) except: trials_per_condition = 100 model = config['model'] pg = model.get_pg(config['savefile'], config['seed'], config['dt']) spec = model.spec wagers = spec.wagers left_rights = spec.left_rights cohs = spec.cohs n_conditions = spec.n_conditions n_trials = trials_per_condition * n_conditions print("{} trials".format(n_trials)) task = model.Task() trials = [] for n in xrange(n_trials): k = tasktools.unravel_index(n, (len(wagers), len(left_rights), len(cohs))) context = { 'wager': wagers[k.pop(0)], 'left_right': left_rights[k.pop(0)], 'coh': cohs[k.pop(0)] } trials.append(task.get_condition(pg.rng, pg.dt, context)) runtools.run(action, trials, pg, config['trialspath']) #===================================================================================== elif action == 'sure_stimulus_duration': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() sure_stimulus_duration(trialsfile, plot) plot.xlabel('Stimulus duration (ms)') plot.ylabel('Probability sure target') fig.save(path=config['figspath'], name=action) #===================================================================================== elif action == 'correct_stimulus_duration': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() correct_stimulus_duration(trialsfile, plot) plot.xlabel('Stimulus duration (ms)') plot.ylabel('Probability correct') fig.save(path=config['figspath'], name=action) #===================================================================================== elif action == 'value_stimulus_duration': trialsfile = runtools.activityfile(config['trialspath']) fig = Figure() plot = fig.add() value_stimulus_duration(trialsfile, plot) plot.xlabel('Stimulus duration (ms)') plot.ylabel('Expected reward') fig.save(path=config['figspath'], name=action) #===================================================================================== elif action == 'sort': if 'value' in args: network = 'v' else: network = 'p' trialsfile = runtools.activityfile(config['trialspath']) sort(trialsfile, os.path.join(config['figspath'], 'sorted'), network=network)
def sort(trialsfile, plots, unit=None, network='p', **kwargs): # Load trials data = utils.load(trialsfile) if len(data) == 9: trials, U, Z, A, P, M, perf, r_p, r_v = data else: trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data if network == 'p': print("Sorting policy network activity.") r = r_p else: print("Sorting value network activity.") r = r_v # Number of units N = r.shape[-1] # Time time = trials[0]['time'] Ntime = len(time) # Aligned time time_a = np.concatenate((-time[1:][::-1], time)) Ntime_a = len(time_a) #===================================================================================== # Preferred targets #===================================================================================== preferred_targets = get_preferred_targets(trials, perf, r) #===================================================================================== # No-wager trials #===================================================================================== def get_no_wager(func_t0): trials_by_cond = {} for n, trial in enumerate(trials): if trial['wager']: continue if trial['coh'] == 0: continue if perf.choices[n] is None: continue cond = trial['left_right'] m_n = np.tile(M[:,n], (N, 1)).T r_n = r[:,n]*m_n t0 = func_t0(trial['epochs'], perf.t_choices[n]) # Storage trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)), 'n': np.zeros((Ntime_a, N))}) # Before n_b = r_n[:t0].shape[0] trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0] trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0] # After n_a = r_n[t0:].shape[0] trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:] trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:] # Average for cond in trials_by_cond: trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'], trials_by_cond[cond]['n']) return trials_by_cond noTs_stimulus = get_no_wager(lambda epochs, t_choice: epochs['stimulus'][0] - 1) noTs_choice = get_no_wager(lambda epochs, t_choice: t_choice) #===================================================================================== # Wager trials, aligned to stimulus onset #===================================================================================== def get_wager(func_t0): trials_by_cond = {} trials_by_cond_sure = {} for n, trial in enumerate(trials): if not trial['wager']: continue if perf.choices[n] is None: continue if trial['coh'] == 0: continue cond = trial['left_right'] m_n = np.tile(M[:,n], (N, 1)).T r_n = r[:,n]*m_n t0 = func_t0(trial['epochs'], perf.t_choices[n]) if perf.choices[n] == 'S': # Storage trials_by_cond_sure.setdefault(cond, {'r': np.zeros((Ntime_a, N)), 'n': np.zeros((Ntime_a, N))}) # Before n_b = r_n[:t0].shape[0] trials_by_cond_sure[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0] trials_by_cond_sure[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0] # After n_a = r_n[t0:].shape[0] trials_by_cond_sure[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:] trials_by_cond_sure[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:] else: # Storage trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)), 'n': np.zeros((Ntime_a, N))}) # Before n_b = r_n[:t0].shape[0] trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0] trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0] # After n_a = r_n[t0:].shape[0] trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:] trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:] # Average for cond in trials_by_cond: trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'], trials_by_cond[cond]['n']) # Average for cond in trials_by_cond_sure: trials_by_cond_sure[cond] = utils.div(trials_by_cond_sure[cond]['r'], trials_by_cond_sure[cond]['n']) return trials_by_cond, trials_by_cond_sure Ts_stimulus, Ts_stimulus_sure = get_wager(lambda epochs, t_choice: epochs['stimulus'][0] - 1) Ts_sure, Ts_sure_sure = get_wager(lambda epochs, t_choice: epochs['sure'][0] - 1) Ts_choice, Ts_choice_sure = get_wager(lambda epochs, t_choice: t_choice) #===================================================================================== # Plot #===================================================================================== lw = kwargs.get('lw', 1.25) dashes = kwargs.get('dashes', [3, 1.5]) in_opp_colors = {-1: '0.6', +1: 'k'} def plot_noTs(noTs, plot, unit, tmin, tmax): w, = np.where((tmin <= time_a) & (time_a <= tmax)) t = time_a[w] yall = [[1]] for lr in noTs: color = in_opp_colors[lr*preferred_targets[unit]] y = noTs[lr][w,unit] plot.plot(t, y, color=color, lw=lw) yall.append(y) plot.xlim(tmin, tmax) plot.xticks([0, tmax]) plot.lim('y', yall, lower=0) return yall def plot_Ts(Ts, Ts_sure, plot, unit, tmin, tmax): w, = np.where((tmin <= time_a) & (time_a <= tmax)) t = time_a[w] yall = [[1]] for lr in Ts: color = in_opp_colors[lr*preferred_targets[unit]] y = Ts[lr][w,unit] plot.plot(t, y, color=color, lw=lw) yall.append(y) for lr in Ts_sure: color = in_opp_colors[lr*preferred_targets[unit]] y = Ts_sure[lr][w,unit] plot.plot(t, y, color=color, lw=lw, linestyle='--', dashes=dashes) yall.append(y) plot.xlim(tmin, tmax) plot.xticks([0, tmax]) plot.lim('y', yall, lower=0) return yall if unit is not None: y = [] tmin = kwargs.get('noTs-stimulus-tmin', -100) tmax = kwargs.get('noTs-stimulus-tmax', 700) y += plot_noTs(noTs_stimulus, plots['noTs-stimulus'], unit, tmin, tmax) tmin = kwargs.get('noTs-choice-tmin', -500) tmax = kwargs.get('noTs-choice-tmax', 0) y += plot_noTs(noTs_choice, plots['noTs-choice'], unit, tmin, tmax) tmin = kwargs.get('Ts-stimulus-tmin', -100) tmax = kwargs.get('Ts-stimulus-tmax', 700) y += plot_Ts(Ts_stimulus, Ts_stimulus_sure, plots['Ts-stimulus'], unit, tmin, tmax) tmin = kwargs.get('Ts-sure-tmin', -200) tmax = kwargs.get('Ts-sure-tmax', 700) y += plot_Ts(Ts_sure, Ts_sure_sure, plots['Ts-sure'], unit, tmin, tmax) tmin = kwargs.get('Ts-choice-tmin', -500) tmax = kwargs.get('Ts-choice-tmax', 0) y += plot_Ts(Ts_choice, Ts_choice_sure, plots['Ts-choice'], unit, tmin, tmax) return y else: name = plots for unit in xrange(N): w = utils.mm_to_inch(174) r = 0.35 fig = Figure(w=w, r=r) x0 = 0.09 y0 = 0.15 w = 0.13 h = 0.75 dx = 0.05 DX = 0.08 fig.add('noTs-stimulus', [x0, y0, w, h]) fig.add('noTs-choice', [fig[-1].right+dx, y0, w, h]) fig.add('Ts-stimulus', [fig[-1].right+DX, y0, w, h]) fig.add('Ts-sure', [fig[-1].right+dx, y0, w, h]) fig.add('Ts-choice', [fig[-1].right+dx, y0, w, h]) #----------------------------------------------------------------------------- y = [] plot = fig['noTs-stimulus'] y += plot_noTs(noTs_stimulus, plot, unit, -100, 700) plot.vline(0) plot = fig['noTs-choice'] y += plot_noTs(noTs_choice, plot, unit, -500, 200) plot.vline(0) plot = fig['Ts-stimulus'] y += plot_Ts(Ts_stimulus, Ts_stimulus_sure, plot, unit, -100, 700) plot.vline(0) plot = fig['Ts-sure'] y += plot_Ts(Ts_sure, Ts_sure_sure, plot, unit, -200, 700) plot.vline(0) plot = fig['Ts-choice'] y += plot_Ts(Ts_choice, Ts_choice_sure, plot, unit, -500, 200) plot.vline(0) for plot in fig.plots.values(): plot.lim('y', y, lower=0) #----------------------------------------------------------------------------- fig.save(name+'_{}{:03d}'.format(network, unit)) fig.close()
def sort(trialsfile, plots, units=None, network='p', **kwargs): """ Sort trials. """ # Load trials data = utils.load(trialsfile) if len(data) == 9: trials, U, Z, A, P, M, perf, r_p, r_v = data else: trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data # Which network? if network == 'p': r = r_p else: r = r_v # Data shape Ntime = r.shape[0] N = r.shape[-1] # Same for every trial time = trials[0]['time'] # Aligned time time_a = np.concatenate((-time[1:][::-1], time)) Ntime_a = len(time_a) #===================================================================================== # Sort trials #===================================================================================== # Sort trials_by_cond = {} for n, trial in enumerate(trials): if perf.choices[n] is None or not perf.corrects[n]: continue # Condition gt_lt = trial['gt_lt'] fpair = trial['fpair'] if gt_lt == '>': f1, f2 = fpair else: f2, f1 = fpair cond = (f1, f2) # Firing rates Mn = np.tile(M[:, n], (N, 1)).T Rn = r[:, n] * Mn # Align point t0 = trial['epochs']['f1'][0] - 1 # Storage trials_by_cond.setdefault(cond, { 'r': np.zeros((Ntime_a, N)), 'n': np.zeros((Ntime_a, N)) }) # Before n_b = Rn[:t0].shape[0] trials_by_cond[cond]['r'][Ntime - 1 - n_b:Ntime - 1] += Rn[:t0] trials_by_cond[cond]['n'][Ntime - 1 - n_b:Ntime - 1] += Mn[:t0] # After n_a = Rn[t0:].shape[0] trials_by_cond[cond]['r'][Ntime - 1:Ntime - 1 + n_a] += Rn[t0:] trials_by_cond[cond]['n'][Ntime - 1:Ntime - 1 + n_a] += Mn[t0:] # Average for cond in trials_by_cond: trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'], trials_by_cond[cond]['n']) #===================================================================================== # Plot #===================================================================================== lw = kwargs.get('lw', 1.5) w, = np.where((time_a >= -500) & (time_a <= 4000)) def plot_sorted(plot, unit): t = 1e-3 * time_a[w] yall = [[1]] for (f1, f2), r in trials_by_cond.items(): plot.plot(t, r[w, unit], color=smap.to_rgba(f1), lw=lw) yall.append(r[w, unit]) return t, yall if units is not None: for plot, unit in zip(plots, units): plot_sorted(plot, unit) else: figspath, name = plots for unit in xrange(N): fig = Figure() plot = fig.add() #----------------------------------------------------------------------------- t, yall = plot_sorted(plot, unit) plot.xlim(t[0], t[-1]) plot.lim('y', yall, lower=0) plot.highlight(0, 0.5) plot.highlight(3.5, 4) #----------------------------------------------------------------------------- fig.save(path=figspath, name=name + '_{}{:03d}'.format(network, unit)) fig.close()
kwargs = {'on-stimulus-tmin': -200, 'on-stimulus-tmax': 600, 'on-choice-tmin': -400, 'on-choice-tmax': 0, 'colors': 'kiani', 'dashes': [3.5, 2]} rdm_analysis.sort_return(rdm_fixed_activity, fig.plots, **kwargs) plot = fig['on-stimulus'] plot.xlim(-200, 600) plot.xticks([-200, 0, 200, 400, 600]) plot.ylim(0.5, 1.2) #plot.yticks([0.5, 1]) plot.xlabel('Time from stimulus (ms)') plot.ylabel('Expected reward') # Legend props = {'prop': {'size': 8}, 'handlelength': 1.2, 'handletextpad': 1.1, 'labelspacing': 0.7} plot.legend(bbox_to_anchor=(0.33, 1), **props) plot = fig['on-choice'] plot.xlim(-400, 0) plot.xticks([-400, -200, 0]) plot.ylim(0.5, 1.2) #plot.yticks([0.5, 1]) plot.xlabel('Time from decision (ms)') #========================================================================================= fig.save()
def sort_epoch(behaviorfile, activityfile, epoch, offers, plots, units=None, network='p', separate_by_choice=False, **kwargs): """ Sort trials. """ # Load trials data = utils.load(activityfile) trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data if network == 'p': print("POLICY NETWORK") r = r_p else: print("VALUE NETWORK") r = r_v # Number of units N = r.shape[-1] # Same for every trial time = trials[0]['time'] Ntime = len(time) # Aligned time time_a = np.concatenate((-time[1:][::-1], time)) Ntime_a = len(time_a) #===================================================================================== # Sort trials #===================================================================================== # Epochs events = ['offer', 'choice'] # Sort events_by_cond = {e: {} for e in events} n_by_cond = {} n_nondecision = 0 for n, trial in enumerate(trials): if perf.choices[n] is None: n_nondecision += 1 continue # Condition offer = trial['offer'] choice = perf.choices[n] if separate_by_choice: cond = (offer, choice) else: cond = offer n_by_cond.setdefault(cond, 0) n_by_cond[cond] += 1 # Storage for e in events_by_cond: events_by_cond[e].setdefault(cond, { 'r': np.zeros((Ntime_a, N)), 'n': np.zeros((Ntime_a, N)) }) # Firing rates m_n = np.tile(M[:, n], (N, 1)).T r_n = r[:, n] * m_n for e in events_by_cond: # Align point if e == 'offer': t0 = trial['epochs']['offer-on'][0] elif e == 'choice': t0 = perf.t_choices[n] else: raise ValueError(e) # Before n_b = r_n[:t0].shape[0] events_by_cond[e][cond]['r'][Ntime - 1 - n_b:Ntime - 1] += r_n[:t0] events_by_cond[e][cond]['n'][Ntime - 1 - n_b:Ntime - 1] += m_n[:t0] # After n_a = r_n[t0:].shape[0] events_by_cond[e][cond]['r'][Ntime - 1:Ntime - 1 + n_a] += r_n[t0:] events_by_cond[e][cond]['n'][Ntime - 1:Ntime - 1 + n_a] += m_n[t0:] print("Non-decision trials: {}/{}".format(n_nondecision, len(trials))) # Average trials for e in events_by_cond: for cond in events_by_cond[e]: events_by_cond[e][cond] = utils.div(events_by_cond[e][cond]['r'], events_by_cond[e][cond]['n']) # Epochs epochs = ['preoffer', 'postoffer', 'latedelay', 'prechoice'] # Average epochs epochs_by_cond = {e: {} for e in epochs} for e in epochs_by_cond: if e == 'preoffer': ev = 'offer' w, = np.where((-500 <= time_a) & (time_a < 0)) elif e == 'postoffer': ev = 'offer' w, = np.where((0 <= time_a) & (time_a < 500)) elif e == 'latedelay': ev = 'offer' w, = np.where((500 <= time_a) & (time_a < 1000)) elif e == 'prechoice': ev = 'choice' w, = np.where((-500 <= time_a) & (time_a < 0)) else: raise ValueError(e) for cond in events_by_cond[ev]: epochs_by_cond[e][cond] = np.mean(events_by_cond[ev][cond][w], axis=0) #===================================================================================== # Classify units #===================================================================================== idpt = indifference_point(behaviorfile, offers) unit_types = classify_units(trials, perf, r, idpt) #unit_types = {} numbers = {} for v in unit_types.values(): numbers[v] = 0 for k, v in unit_types.items(): numbers[v] += 1 n_tot = np.sum(numbers.values()) for k, v in numbers.items(): print("{}: {}/{} = {}%".format(k, v, n_tot, 100 * v / n_tot)) #===================================================================================== # Plot #===================================================================================== lw = kwargs.get('lw', 1.5) ms = kwargs.get('ms', 6) mew = kwargs.get('mew', 0.5) rotation = kwargs.get('rotation', 60) #min_trials = kwargs.get('min_trials', 100) def plot_activity(plot, unit): yall = [1] min_trials = 20 # Pre-offer epoch_by_cond = epochs_by_cond['preoffer'] color = '0.7' if separate_by_choice: for choice, marker in zip(['A', 'B'], ['d', 'o']): x = [] y = [] for i, offer in enumerate(offers): cond = (offer, choice) if cond in n_by_cond and n_by_cond[cond] >= min_trials: y_i = epoch_by_cond[cond][unit] plot.plot(i, y_i, marker, mfc=color, mec=color, ms=0.8 * ms, mew=0.8 * mew, zorder=10) yall.append(y_i) if i != 0 and i != len(offers) - 1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=0.8 * lw, zorder=5) else: x = [] y = [] for i, offer in enumerate(offers): y_i = epoch_by_cond[offer][unit] plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=0.8 * ms, mew=0.8 * mew, zorder=10) yall.append(y_i) if i != 0 and i != len(offers) - 1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=0.8 * lw, zorder=5) # Epoch epoch_by_cond = epochs_by_cond[epoch] if epoch == 'postoffer': color = Figure.colors('darkblue') elif epoch == 'latedelay': color = Figure.colors('darkblue') elif epoch == 'prechoice': color = Figure.colors('darkblue') else: raise ValueError(epoch) if separate_by_choice: for choice, marker, color in zip( ['A', 'B'], ['d', 'o'], [Figure.colors('red'), Figure.colors('blue')]): x = [] y = [] for i, offer in enumerate(offers): cond = (offer, choice) if cond in n_by_cond and n_by_cond[cond] >= min_trials: y_i = epoch_by_cond[cond][unit] yall.append(y_i) plot.plot(i, y_i, marker, mfc=color, mec=color, ms=ms, mew=mew, zorder=10) if i != 0 and i != len(offers) - 1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=lw, zorder=5) else: x = [] y = [] for i, offer in enumerate(offers): y_i = epoch_by_cond[offer][unit] plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=ms, mew=mew, zorder=10) yall.append(y_i) if i != 0 and i != len(offers) - 1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=lw, zorder=5) plot.xticks(range(len(offers))) plot.xticklabels(['{}B:{}A'.format(*offer) for offer in offers], rotation=rotation) plot.xlim(0, len(offers) - 1) plot.lim('y', yall, lower=0) return yall #------------------------------------------------------------------------------------- if units is not None: for plot, unit in zip(plots, units): plot_activity(plot, unit) else: name = plots for unit in xrange(N): fig = Figure() plot = fig.add() plot_activity(plot, unit) if separate_by_choice: suffix = '_sbc' else: suffix = '' if unit in unit_types: plot.text_upper_right(unit_types[unit], fontsize=9) fig.save(name + '_{}{}_{}{:03d}'.format(epoch, suffix, network, unit)) fig.close()
rho = matrixtools.spectral_radius(W) plot_weights(plot, W) plot.text_upper_left(r'$W_\text{rec}$', fontsize=fontsize, dy=dy) plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy) plot.xlabel('$W$') #for j in xrange(W.shape[1]): # print(sum(1*(W[:,j] != 0))) Wrec_gates = params['Wrec_gates'] if 'Wrec_gates' in masks: Wrec_gates *= masks['Wrec_gates'] plot = fig['Wrec_lambda'] W = Wrec_gates[:,:N] rho = matrixtools.spectral_radius(W) plot_weights(plot, W) plot.text_upper_left(r'$W_\text{rec}^\lambda$', fontsize=fontsize, dy=dy) plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy) plot = fig['Wrec_gamma'] W = Wrec_gates[:,N:] rho = matrixtools.spectral_radius(W) plot_weights(plot, W) plot.text_upper_left(r'$W_\text{rec}^\gamma$', fontsize=fontsize, dy=dy) plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy) #''' #========================================================================================= fig.save(path=figspath, name='fig_weights_'+modelname)
plot.ylabel('Percent correct\n(decision trials)') target_color = Figure.colors('red') plot = fig['correct'] if modelname.startswith('rdm_fixed'): target = 80 plot.hline(target, color=target_color, zorder=1) elif modelname == 'rdm_rt': target = 80 plot.hline(target, color=target_color, zorder=1) elif modelname == 'mante': target = 85 plot.hline(target, color=target_color, zorder=1) elif modelname == 'multisensory': target = 82 plot.hline(target, color=target_color, zorder=1) elif modelname == 'romo': target = 97 plot.hline(target, color=target_color, zorder=1) elif modelname == 'postdecisionwager': target = 79 plot.hline(target, color=target_color, zorder=1) elif modelname == 'padoaschioppa2006': target = 95 plot.hline(target, color=target_color, zorder=1) #========================================================================================= fig.save(path=figspath, name='fig_learning_' + modelname)
plot.ylabel('Percent correct\n(decision trials)') target_color = Figure.colors('red') plot = fig['correct'] if modelname.startswith('rdm_fixed'): target = 80 plot.hline(target, color=target_color, zorder=1) elif modelname == 'rdm_rt': target = 80 plot.hline(target, color=target_color, zorder=1) elif modelname == 'mante': target = 85 plot.hline(target, color=target_color, zorder=1) elif modelname == 'multisensory': target = 82 plot.hline(target, color=target_color, zorder=1) elif modelname == 'romo': target = 97 plot.hline(target, color=target_color, zorder=1) elif modelname == 'postdecisionwager': target = 79 plot.hline(target, color=target_color, zorder=1) elif modelname == 'padoaschioppa2006': target = 95 plot.hline(target, color=target_color, zorder=1) #========================================================================================= fig.save(path=figspath, name='fig_learning_'+modelname)
def do(action, args, config): """ Manage tasks. """ print("ACTION*: " + str(action)) print("ARGS*: " + str(args)) #===================================================================================== if 'trials' in action: try: trials_per_condition = int(args[0]) except IndexError: trials_per_condition = 100 model = config['model'] pg = model.get_pg(config['savefile'], config['seed'], config['dt']) spec = model.spec juices = spec.juices offers = spec.offers n_conditions = spec.n_conditions n_trials = trials_per_condition * n_conditions print("{} trials".format(n_trials)) task = model.Task() trials = [] for n in xrange(n_trials): k = tasktools.unravel_index(n, (len(juices), len(offers))) context = { 'juice': juices[k.pop(0)], 'offer': offers[k.pop(0)] } trials.append(task.get_condition(pg.rng, pg.dt, context)) runtools.run(action, trials, pg, config['trialspath']) #===================================================================================== elif action == 'choice_pattern': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() spec = config['model'].spec choice_pattern(trialsfile, spec.offers, plot) plot.xlabel('Offer (\#B : \#A)') plot.ylabel('Percent choice B') plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10) fig.save(path=config['figspath'], name=action) fig.close() elif action == 'indifference_point': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() spec = config['model'].spec indifference_point(trialsfile, spec.offers, plot) plot.xlabel('$(n_B - n_A)/(n_B + n_A)$') plot.ylabel('Percent choice B') #plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10) fig.save(path=config['figspath'], name=action) fig.close() #===================================================================================== elif action == 'sort_epoch': behaviorfile = runtools.behaviorfile(config['trialspath']) activityfile = runtools.activityfile(config['trialspath']) epoch = args[0] if 'value' in args: network = 'v' else: network = 'p' separate_by_choice = ('separate-by-choice' in args) sort_epoch(behaviorfile, activityfile, epoch, config['model'].spec.offers, os.path.join(config['figspath'], 'sorted'), network=network, separate_by_choice=separate_by_choice)
def sort_epoch(behaviorfile, activityfile, epoch, offers, plots, units=None, network='p', separate_by_choice=False, **kwargs): """ Sort trials. """ # Load trials data = utils.load(activityfile) trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data if network == 'p': print("POLICY NETWORK") r = r_p else: print("VALUE NETWORK") r = r_v # Number of units N = r.shape[-1] # Same for every trial time = trials[0]['time'] Ntime = len(time) # Aligned time time_a = np.concatenate((-time[1:][::-1], time)) Ntime_a = len(time_a) #===================================================================================== # Sort trials #===================================================================================== # Epochs events = ['offer', 'choice'] # Sort events_by_cond = {e: {} for e in events} n_by_cond = {} n_nondecision = 0 for n, trial in enumerate(trials): if perf.choices[n] is None: n_nondecision += 1 continue # Condition offer = trial['offer'] choice = perf.choices[n] if separate_by_choice: cond = (offer, choice) else: cond = offer n_by_cond.setdefault(cond, 0) n_by_cond[cond] += 1 # Storage for e in events_by_cond: events_by_cond[e].setdefault(cond, {'r': np.zeros((Ntime_a, N)), 'n': np.zeros((Ntime_a, N))}) # Firing rates m_n = np.tile(M[:,n], (N,1)).T r_n = r[:,n]*m_n for e in events_by_cond: # Align point if e == 'offer': t0 = trial['epochs']['offer-on'][0] elif e == 'choice': t0 = perf.t_choices[n] else: raise ValueError(e) # Before n_b = r_n[:t0].shape[0] events_by_cond[e][cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0] events_by_cond[e][cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0] # After n_a = r_n[t0:].shape[0] events_by_cond[e][cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:] events_by_cond[e][cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:] print("Non-decision trials: {}/{}".format(n_nondecision, len(trials))) # Average trials for e in events_by_cond: for cond in events_by_cond[e]: events_by_cond[e][cond] = utils.div(events_by_cond[e][cond]['r'], events_by_cond[e][cond]['n']) # Epochs epochs = ['preoffer', 'postoffer', 'latedelay', 'prechoice'] # Average epochs epochs_by_cond = {e: {} for e in epochs} for e in epochs_by_cond: if e == 'preoffer': ev = 'offer' w, = np.where((-500 <= time_a) & (time_a < 0)) elif e == 'postoffer': ev = 'offer' w, = np.where((0 <= time_a) & (time_a < 500)) elif e == 'latedelay': ev = 'offer' w, = np.where((500 <= time_a) & (time_a < 1000)) elif e == 'prechoice': ev = 'choice' w, = np.where((-500 <= time_a) & (time_a < 0)) else: raise ValueError(e) for cond in events_by_cond[ev]: epochs_by_cond[e][cond] = np.mean(events_by_cond[ev][cond][w], axis=0) #===================================================================================== # Classify units #===================================================================================== idpt = indifference_point(behaviorfile, offers) unit_types = classify_units(trials, perf, r, idpt) #unit_types = {} numbers = {} for v in unit_types.values(): numbers[v] = 0 for k, v in unit_types.items(): numbers[v] += 1 n_tot = np.sum(numbers.values()) for k, v in numbers.items(): print("{}: {}/{} = {}%".format(k, v, n_tot, 100*v/n_tot)) #===================================================================================== # Plot #===================================================================================== lw = kwargs.get('lw', 1.5) ms = kwargs.get('ms', 6) mew = kwargs.get('mew', 0.5) rotation = kwargs.get('rotation', 60) #min_trials = kwargs.get('min_trials', 100) def plot_activity(plot, unit): yall = [1] min_trials = 20 # Pre-offer epoch_by_cond = epochs_by_cond['preoffer'] color = '0.7' if separate_by_choice: for choice, marker in zip(['A', 'B'], ['d', 'o']): x = [] y = [] for i, offer in enumerate(offers): cond = (offer, choice) if cond in n_by_cond and n_by_cond[cond] >= min_trials: y_i = epoch_by_cond[cond][unit] plot.plot(i, y_i, marker, mfc=color, mec=color, ms=0.8*ms, mew=0.8*mew, zorder=10) yall.append(y_i) if i != 0 and i != len(offers)-1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=0.8*lw, zorder=5) else: x = [] y = [] for i, offer in enumerate(offers): y_i = epoch_by_cond[offer][unit] plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=0.8*ms, mew=0.8*mew, zorder=10) yall.append(y_i) if i != 0 and i != len(offers)-1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=0.8*lw, zorder=5) # Epoch epoch_by_cond = epochs_by_cond[epoch] if epoch == 'postoffer': color = Figure.colors('darkblue') elif epoch == 'latedelay': color = Figure.colors('darkblue') elif epoch == 'prechoice': color = Figure.colors('darkblue') else: raise ValueError(epoch) if separate_by_choice: for choice, marker, color in zip(['A', 'B'], ['d', 'o'], [Figure.colors('red'), Figure.colors('blue')]): x = [] y = [] for i, offer in enumerate(offers): cond = (offer, choice) if cond in n_by_cond and n_by_cond[cond] >= min_trials: y_i = epoch_by_cond[cond][unit] yall.append(y_i) plot.plot(i, y_i, marker, mfc=color, mec=color, ms=ms, mew=mew, zorder=10) if i != 0 and i != len(offers)-1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=lw, zorder=5) else: x = [] y = [] for i, offer in enumerate(offers): y_i = epoch_by_cond[offer][unit] plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=ms, mew=mew, zorder=10) yall.append(y_i) if i != 0 and i != len(offers)-1: x.append(i) y.append(y_i) plot.plot(x, y, '-', color=color, lw=lw, zorder=5) plot.xticks(range(len(offers))) plot.xticklabels(['{}B:{}A'.format(*offer) for offer in offers], rotation=rotation) plot.xlim(0, len(offers)-1) plot.lim('y', yall, lower=0) return yall #------------------------------------------------------------------------------------- if units is not None: for plot, unit in zip(plots, units): plot_activity(plot, unit) else: name = plots for unit in xrange(N): fig = Figure() plot = fig.add() plot_activity(plot, unit) if separate_by_choice: suffix = '_sbc' else: suffix = '' if unit in unit_types: plot.text_upper_right(unit_types[unit], fontsize=9) fig.save(name+'_{}{}_{}{:03d}'.format(epoch, suffix, network, unit)) fig.close()
def do(action, args, config): """ Manage tasks. """ print("ACTION*: " + str(action)) print("ARGS*: " + str(args)) #===================================================================================== if 'trials' in action: try: trials_per_condition = int(args[0]) except IndexError: trials_per_condition = 100 model = config['model'] pg = model.get_pg(config['savefile'], config['seed'], config['dt']) spec = model.spec juices = spec.juices offers = spec.offers n_conditions = spec.n_conditions n_trials = trials_per_condition * n_conditions print("{} trials".format(n_trials)) task = model.Task() trials = [] for n in xrange(n_trials): k = tasktools.unravel_index(n, (len(juices), len(offers))) context = {'juice': juices[k.pop(0)], 'offer': offers[k.pop(0)]} trials.append(task.get_condition(pg.rng, pg.dt, context)) runtools.run(action, trials, pg, config['trialspath']) #===================================================================================== elif action == 'choice_pattern': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() spec = config['model'].spec choice_pattern(trialsfile, spec.offers, plot) plot.xlabel('Offer (\#B : \#A)') plot.ylabel('Percent choice B') plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10) fig.save(path=config['figspath'], name=action) fig.close() elif action == 'indifference_point': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() spec = config['model'].spec indifference_point(trialsfile, spec.offers, plot) plot.xlabel('$(n_B - n_A)/(n_B + n_A)$') plot.ylabel('Percent choice B') #plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10) fig.save(path=config['figspath'], name=action) fig.close() #===================================================================================== elif action == 'sort_epoch': behaviorfile = runtools.behaviorfile(config['trialspath']) activityfile = runtools.activityfile(config['trialspath']) epoch = args[0] if 'value' in args: network = 'v' else: network = 'p' separate_by_choice = ('separate-by-choice' in args) sort_epoch(behaviorfile, activityfile, epoch, config['model'].spec.offers, os.path.join(config['figspath'], 'sorted'), network=network, separate_by_choice=separate_by_choice)
rho = matrixtools.spectral_radius(W) plot_weights(plot, W) plot.text_upper_left(r'$W_\text{rec}$', fontsize=fontsize, dy=dy) plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy) plot.xlabel('$W$') #for j in xrange(W.shape[1]): # print(sum(1*(W[:,j] != 0))) Wrec_gates = params['Wrec_gates'] if 'Wrec_gates' in masks: Wrec_gates *= masks['Wrec_gates'] plot = fig['Wrec_lambda'] W = Wrec_gates[:, :N] rho = matrixtools.spectral_radius(W) plot_weights(plot, W) plot.text_upper_left(r'$W_\text{rec}^\lambda$', fontsize=fontsize, dy=dy) plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy) plot = fig['Wrec_gamma'] W = Wrec_gates[:, N:] rho = matrixtools.spectral_radius(W) plot_weights(plot, W) plot.text_upper_left(r'$W_\text{rec}^\gamma$', fontsize=fontsize, dy=dy) plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy) #''' #========================================================================================= fig.save(path=figspath, name='fig_weights_' + modelname)
def do(action, args, config): """ Manage tasks. """ print("ACTION*: " + str(action)) print("ARGS*: " + str(args)) #===================================================================================== if action == 'plot_trial': try: trials_per_condition = int(args[0]) except: trials_per_condition = 1000 model = config['model'] pg = model.get_pg(config['savefile'], config['seed'], config['dt']) spec = model.spec juices = spec.juices offers = spec.offers n_conditions = spec.n_conditions n_trials = trials_per_condition * n_conditions print("{} trials".format(n_trials)) task = model.Task() fig = Figure(axislabelsize=10, ticklabelsize=9) plot = fig.add() plot_trial() performance(config['savefile'], plot) fig.save(path=config['figspath'], name='performance') fig.close() #===================================================================================== elif 'trials' in action: try: trials_per_condition = int(args[0]) except IndexError: trials_per_condition = 100 model = config['model'] pg = model.get_pg(config['savefile'], config['seed'], config['dt']) spec = model.spec juices = spec.juices offers = spec.offers n_conditions = spec.n_conditions n_trials = trials_per_condition * n_conditions print("{} trials".format(n_trials)) task = model.Task() trials = [] for n in xrange(n_trials): k = tasktools.unravel_index(n, (len(juices), len(offers))) context = {'juice': juices[k.pop(0)], 'offer': offers[k.pop(0)]} trials.append(task.get_condition(pg.rng, pg.dt, context)) runtools.run(action, trials, pg, config['trialspath']) #===================================================================================== elif action == 'choice_pattern': trialsfile = runtools.behaviorfile(config['trialspath']) # print trialsfile # fig = Figure() # plot = fig.add() savefile = config['figspath'] spec = config['model'].spec #print spec.offers choice_pattern(trialsfile, spec.offers, savefile, action) #plot.xlabel('Offer (\#B : \#A)') #plot.ylabel('Percent choice B') #plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10) #===================================================================================== elif action == 'sort': if 'value' in args: network = 'v' else: network = 'p' trialsfile = runtools.activityfile(config['trialspath']) sort(trialsfile, (config['figspath'], 'sorted'), network=network) #===================================================================================== elif action == 'statespace': trialsfile = runtools.activityfile(config['trialspath']) statespace(trialsfile, (config['figspath'], 'statespace'))
plot.xlim(-200, 600) plot.xticks([-200, 0, 200, 400, 600]) plot.ylim(0.5, 1.2) #plot.yticks([0.5, 1]) plot.xlabel('Time from stimulus (ms)') plot.ylabel('Expected reward') # Legend props = { 'prop': { 'size': 8 }, 'handlelength': 1.2, 'handletextpad': 1.1, 'labelspacing': 0.7 } plot.legend(bbox_to_anchor=(0.33, 1), **props) plot = fig['on-choice'] plot.xlim(-400, 0) plot.xticks([-400, -200, 0]) plot.ylim(0.5, 1.2) #plot.yticks([0.5, 1]) plot.xlabel('Time from decision (ms)') #========================================================================================= fig.save()
def do(action, args, config): """ Manage tasks. """ print("ACTION*: " + str(action)) print("ARGS*: " + str(args)) if 'trials' in action: try: trials_per_condition = int(args[0]) except: trials_per_condition = 100 model = config['model'] pg = model.get_pg(config['savefile'], config['seed'], config['dt']) spec = model.spec gt_lts = spec.gt_lts fpairs = spec.fpairs n_conditions = spec.n_conditions n_trials = trials_per_condition * n_conditions print("{} trials".format(n_trials)) task = model.Task() trials = [] for n in xrange(n_trials): k = tasktools.unravel_index(n, (len(gt_lts), len(fpairs))) context = { 'delay': 3000, 'gt_lt': gt_lts[k.pop(0)], 'fpair': fpairs[k.pop(0)] } trials.append(task.get_condition(pg.rng, pg.dt, context)) runtools.run(action, trials, pg, config['trialspath']) #===================================================================================== elif action == 'performance': trialsfile = runtools.behaviorfile(config['trialspath']) fig = Figure() plot = fig.add() performance(trialsfile, plot) plot.xlabel('$f_1$ (Hz)') plot.ylabel('$f_2$ (Hz)') fig.save(os.path.join(config['figspath'], action)) #===================================================================================== elif action == 'sort': if 'value' in args: network = 'v' else: network = 'p' trialsfile = runtools.activityfile(config['trialspath']) sort(trialsfile, (config['figspath'], 'sorted'), network=network)
def sort(trialsfile, plots, units=None, network='p', **kwargs): """ Sort trials. """ # Load trials data = utils.load(trialsfile) trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data # Which network? if network == 'p': r = r_p else: r = r_v # Number of units N = r.shape[-1] # Same for every trial time = trials[0]['time'] Ntime = len(time) # Aligned time time_a = np.concatenate((-time[1:][::-1], time)) Ntime_a = len(time_a) #===================================================================================== # Aligned to stimulus onset #===================================================================================== r_by_cond_stimulus = {} n_r_by_cond_stimulus = {} for n, trial in enumerate(trials): if not perf.decisions[n]: continue if trial['mod'] == 'va': continue assert trial['mod'] == 'v' or trial['mod'] == 'a' if not perf.corrects[n]: continue # Condition mod = trial['mod'] choice = perf.choices[n] cond = (mod, choice) # Storage r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N))) n_r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N))) # Firing rates Mn = np.tile(M[:,n], (N,1)).T Rn = r[:,n]*Mn # Align point t0 = trial['epochs']['stimulus'][0] - 1 # Before n_b = Rn[:t0].shape[0] r_by_cond_stimulus[cond][Ntime-1-n_b:Ntime-1] += Rn[:t0] n_r_by_cond_stimulus[cond][Ntime-1-n_b:Ntime-1] += Mn[:t0] # After n_a = Rn[t0:].shape[0] r_by_cond_stimulus[cond][Ntime-1:Ntime-1+n_a] += Rn[t0:] n_r_by_cond_stimulus[cond][Ntime-1:Ntime-1+n_a] += Mn[t0:] for cond in r_by_cond_stimulus: r_by_cond_stimulus[cond] = utils.div(r_by_cond_stimulus[cond], n_r_by_cond_stimulus[cond]) #------------------------------------------------------------------------------------- # Plot #------------------------------------------------------------------------------------- lw = kwargs.get('lw', 1.5) dashes = kwargs.get('dashes', [3, 2]) vline_props = {'lw': kwargs.get('lw_vline', 0.5)} if 'dashes_vline' in kwargs: vline_props['linestyle'] = '--' vline_props['dashes'] = dashes colors_by_mod = { 'v': Figure.colors('blue'), 'a': Figure.colors('green') } linestyle_by_choice = { 'L': '-', 'H': '--' } lineprops = dict(lw=lw) def plot_sorted(plot, unit, w, r_sorted): t = time_a[w] yall = [[1]] for cond in [('v', 'H'), ('v', 'L'), ('a', 'H'), ('a', 'L')]: mod, choice = cond if mod == 'v': label = 'Vis, ' elif mod == 'a': label = 'Aud, ' else: raise ValueError(mod) if choice == 'H': label += 'high' elif choice == 'L': label += 'low' else: raise ValueError(choice) linestyle = linestyle_by_choice[choice] if linestyle == '-': lineprops = dict(linestyle=linestyle, lw=lw) else: lineprops = dict(linestyle=linestyle, lw=lw, dashes=dashes) plot.plot(t, r_sorted[cond][w,unit], color=colors_by_mod[mod], label=label, **lineprops) yall.append(r_sorted[cond][w,unit]) return t, yall def on_stimulus(plot, unit): w, = np.where((time_a >= -300) & (time_a <= 1000)) t, yall = plot_sorted(plot, unit, w, r_by_cond_stimulus) plot.xlim(t[0], t[-1]) return yall if units is not None: for plot, unit in zip(plots, units): on_stimulus(plot, unit) else: figspath, name = plots for unit in xrange(N): fig = Figure() plot = fig.add() #----------------------------------------------------------------------------- yall = [] yall += on_stimulus(plot, unit) plot.lim('y', yall, lower=0) plot.vline(0) plot.xlabel('Time (ms)') plot.ylabel('Firing rate (a.u.)') #----------------------------------------------------------------------------- fig.save(path=figspath, name=name+'_{}{:03d}'.format(network, unit)) fig.close()