def distinct_color_map(size): """ Generate a list of unique colours for matplotlib line/scatter plots """ if size > len(TABLEAU_COLORS): return XKCD_COLORS.keys() return TABLEAU_COLORS.keys()
def plot_embedding( graph: EnsmallenGraph, tsne_embedding: np.ndarray, k: int = 10, axes: Axes = None ): if axes is None: _, axes = plt.subplots(figsize=(5, 5)) if graph.node_types_mapping is None: node_types = np.zeros(graph.get_nodes_number(), dtype=np.uint8) common_node_types_names = ["No node type provided"] else: nodes, node_types = graph.get_top_k_nodes_by_node_type(k) tsne_embedding = tsne_embedding[nodes] common_node_types_names = list(np.array(graph.node_types_reverse_mapping)[np.unique(node_types)]) colors = list(TABLEAU_COLORS.keys())[:len(common_node_types_names)] scatter = axes.scatter( *tsne_embedding.T, s=0.25, c=node_types, cmap=ListedColormap(colors) ) axes.legend( handles=scatter.legend_elements()[0], labels=common_node_types_names ) return axes
def show_measurement_1d(measurements_or_func, figure=None, throttling=False, **kwargs): if figure is None: figure = plt.figure(fig_margin={ 'top': 0, 'bottom': 50, 'left': 50, 'right': 0 }) figure.layout.height = '250px' figure.layout.width = '300px' try: measurements = measurements_or_func() return_callback = True except TypeError: measurements = measurements_or_func return_callback = False lines = [] for measurement, color in zip(measurements, TABLEAU_COLORS.values()): calibration = measurement.calibrations[0] array = measurement.array x = np.linspace(calibration.offset, calibration.offset + len(array) * calibration.sampling, len(array)) line = plt.plot(x, array, colors=[color], **kwargs) figure.axes[0].label = format_label(measurement.calibrations[0]) lines.append(line) # figure.axes[1].label = format_label(measurement) if return_callback: @throttle(throttling) def callback(*args, **kwargs): for line, measurement in zip(lines, measurements_or_func()): x = np.linspace( calibration.offset, calibration.offset + len(array) * calibration.sampling, len(array)) line.x = x line.y = measurement.array return figure, callback else: return figure
def plot_data(data_table: pd.DataFrame, columns: List[str], x_column: str, title: str, colors: List = None, shareX: bool = True, shareY: bool = False): fig, axs = plt.subplots(nrows=len(columns), sharex=shareX, sharey=shareY) fig.suptitle(title, fontsize=16) if colors is None or len(colors) < len(columns): colors = [TABLEAU_COLORS[key] for key in TABLEAU_COLORS.keys()] for index, col in enumerate(columns): axs[index].plot(data_table[x_column], data_table[col], color=colors[index]) axs[index].set(ylabel=col, xlim=(0, data_table[x_column].max())) plt.show()
def plot_train_data_separate_by_param(self, param_name, filter_dict={}, show=True): """ This function takes the name of a parameter in the config, and plots each train energy corresponding to one value of the parameter in one particular color. """ colorlist = list(COLORS.keys()) param_values = [] legend_handles = [] for c_idx, config in self.config_dict.items(): if not self.filter_config(c_idx, filter_dict): continue if config[param_name] not in param_values: param_values.append(config[param_name]) color = colorlist[param_values.index(config[param_name])] print(f'color 2 {color}') path = op.join(self.config_dict[c_idx]['path'], 'train_data.hdf5') data_dict = utl.load_dict_h5py(path) plt.plot(data_dict['energy'], color=color) # legend for color, value in zip(colorlist, param_values): print(f"COLOR {color}") line = mlines.Line2D([], [], color=color, label=f'{param_name}: {value}') legend_handles.append(line) plt.legend(legend_handles, param_values) if show: plt.show()
m = {i: (0, 0, 0) for i in range(k)} for (x, y), i in cluster.items(): m[i] = (m[i][0]+x, m[i][1]+y, m[i][2]+1) return [(m[i][0]/m[i][2], m[i][1]/m[i][2]) if m[i][2] else (float('inf'), float('inf')) for i in m] def gen(n, s): from random import random seed(s) return [(random(), random()) for i in range(n)] if __name__ == '__main__': from matplotlib import pyplot as plt from matplotlib.colors import TABLEAU_COLORS n, k, s = 200, 5, 0 pts = gen(n, s) # clusters = k_means(pts, k, s=s) clusters = k_means(pts, k, clusters=k_means_pp_init_clusters(pts, k, s)) colors = list(TABLEAU_COLORS.values()) for i in range(k): pi = [p for p, j in clusters.items() if i == j] plt.scatter([x for x, _ in pi], [y for _, y in pi], c=colors[i]) plt.gca().set_aspect('equal') plt.tight_layout() plt.savefig('k_means_pp_ex1.png', bbox_inches='tight') plt.show()
def plot_relativeMass(PMF_profile, source="Biomass burning", isRelativeMass=True, totalVar="PM10", naxe=1, site_typologie=None): if not site_typologie: site_typologie = get_site_typology() # site_typologie = collections.OrderedDict() # site_typologie["Urban"] = ["Talence", "Lyon", "Poitiers", "Nice", "MRS-5av", # "PdB", "Aix-en-provence", "Nogent", # "Lens-2011-2012", "Lens-2013-2014"] # site_typologie["Valley"] = ["Chamonix", "GRE-fr"] # site_typologie["Traffic"] = ["Roubaix", "STG-cle"] # site_typologie["Rural"] = ["Revin"] carboneous = ["OC*", "EC"] ions = ["Cl-", "NO3-", "SO42-", "Na+", "NH4+", "K+", "Mg2+", "Ca2+"] organics = [ "MSA", "Polyols", "Levoglucosan", "Mannosan", ] metals = [ "Al", "As", "Ba", "Cd", "Co", "Cr", "Cs", "Cu", "Fe", "La", "Mn", "Mo", "Ni", "Pb", "Rb", "Sb", "Se", "Sn", "Sr", "Ti", "V", "Zn" ] keep_index = ["PM10"] + carboneous + ions + organics + metals if naxe == 2: keep_index1 = carboneous + [ "NO3-", "SO42-", "NH4+", "Na+", "K+", "Ca2+", "Mg2+", "Cl-", "Polyols", "Fe", "Al", "Levoglucosan", "Mannosan", "MSA" ] keep_index2 = [] #"MSA"] keep_index2tmp = list( set(keep_index) - set(keep_index1) - set(["PM10", "Mg2+"])) keep_index2tmp.sort() keep_index2 += keep_index2tmp dfperµg = pd.DataFrame(columns=keep_index) for station, df in PMF_profile.reset_index().groupby("station"): df = df.set_index("specie").drop("station", axis=1) if isRelativeMass: dfperµg.loc[station, :] = df.loc[:, source] else: dfperµg.loc[station, :] = to_relativeMass(df.loc[:, source], totalVar=totalVar).T dfperµg = dfperµg.convert_objects(convert_numeric=True) # FIGURE f, axes = plt.subplots(nrows=naxe, ncols=1, figsize=(12.67, 6.54)) if naxe == 1: axes = [axes, None] d = dfperµg.T.copy() d["specie"] = d.index # for i, keep_index in enumerate([keep_index1, keep_index2]): if naxe == 1: xtick_list = [keep_index] elif naxe == 2: xtick_list = [keep_index1, keep_index2] for i, keep_index in enumerate(xtick_list): dd = d.reindex(keep_index) # dd.rename(rename_station, inplace=True, axis="columns") dd = dd.melt(id_vars=["specie"]) # if not percent: dd.replace({0: np.nan}, inplace=True) axes[i].set_yscale("log") sns.boxplot(data=dd, x="specie", y="value", ax=axes[i], color="white", showcaps=False, showmeans=False, meanprops={"marker": "d"}) ntypo = len(site_typologie.keys()) colors = list(TABLEAU_COLORS.values()) # for sp, specie in enumerate(keep_index): for t, typo in enumerate(site_typologie.keys()): if typo == "Urban": marker = "*" elif (typo == "Valley") or typo == ("Urban+Alps"): marker = "o" elif typo == "Traffic": marker = "p" else: marker = "d" step = 0.1 j = 0 for site in site_typologie[typo]: if site not in PMF_profile.index.get_level_values( "station").unique(): continue axes[i].scatter(np.arange(0, len(keep_index)) - ntypo * step / 2 + step / 2 + t * step, d.loc[keep_index, site], marker=marker, color=colors[j], alpha=0.8) j += 1 # sns.swarmplot(data=dd, x="specie", y="value", color=".2", alpha=0.5, size=4, # ax=axes[i]) if naxe == 1: axes[0].set_ylim([1e-5, 2]) else: axes[0].set_ylim([1e-3, 1]) axes[1].set_ylim([1e-5, 1e-2]) for ax in axes: if ax: ax.set_ylabel("µg/µg of PM$_{10}$") ax.set_xlabel("") for tick in ax.get_xticklabels(): tick.set_rotation(90) # ax.legend(loc="center", # ncol=(len(dfperµg.columns))//2, # bbox_to_anchor=(0.5, -.55), # frameon=False) # create custom legend labels = [] artists = [] for typo in site_typologie.keys(): if typo == "Urban": marker = "*" elif (typo == "Valley") or (typo == "Urban+Alps"): marker = "o" elif typo == "Traffic": marker = "p" else: marker = "d" noSiteYet = True j = 0 for site in site_typologie[typo]: if site not in PMF_profile.index.get_level_values( "station").unique(): continue if noSiteYet: artist = typo label = "" artists.append(artist) labels.append(label) noSiteYet = False artist = mlines.Line2D([], [], ls='', marker=marker, color=colors[j], label=site) artists.append(artist) labels.append(site) j += 1 axes[0].legend(artists, labels, bbox_to_anchor=(1.1, 1), handler_map={str: LegendTitle()}, frameon=False) # ax.legend('', frameon=False) f.suptitle(source) # plt.subplots_adjust(left=0.07, right=0.83, bottom=0.15, top=0.900, hspace=0.5) plt.subplots_adjust(top=0.925, bottom=0.07, left=0.053, right=0.899, hspace=0.489, wspace=0.2)
def plot_msh(file, tipo, mostrar_nodos=False, mostrar_num_nodo=False, mostrar_num_elem=False): ''' Función para graficar la malla contenida en el archivo "file". Argumentos: - file: str. Debe ser un archivo de extensión .msh exportado por GMSH. - tipo: str. 'shell' o '2D' - mostrar_nodos: bool. Define si se muestran los nodos en el gráfico. - mostrar_num_nodo: bool. Define si se muestra el número de cada nodo en el gráfico. - mostrar_num_elem: bool. Define si se muestra el número de cada ele- mento finito en el gráfico. ''' LaG_mat = LaG_from_msh(file) mat = LaG_mat[:, 0] LaG = LaG_mat[:, 1:] nef = LaG.shape[0] # Se determina el tipo de elemento finito if LaG.shape[1] == 3: elem = 'T3' elif LaG.shape[1] == 4: elem = 'Q4' elif LaG.shape[1] == 6: elem = 'T6' elif LaG.shape[1] == 8: elem = 'Q8' elif LaG.shape[1] == 9: elem = 'Q9' elif LaG.shape[1] == 10: elem = 'T10' if tipo == 'shell': xnod = xnod_from_msh(file, 3) nno = xnod.shape[0] cg = np.empty((nef, 3)) # Centro de gravedad del EF colores = list(BASE_COLORS.values()) + list(TABLEAU_COLORS.values()) fig = plt.figure(figsize=(12, 12)) ax = plt.gca(projection='3d') for e in range(nef): if elem == 'T3' or elem == 'Q4': nodos = np.r_[LaG[e], LaG[e, 0]] elif elem == 'T6': nodos = LaG[e, [0, 3, 1, 4, 2, 5, 0]] elif elem == 'Q8' or elem == 'Q9': nodos = LaG[e, [0, 4, 1, 5, 2, 6, 3, 7, 0]] elif elem == 'T10': nodos = LaG[e, [0, 3, 4, 1, 5, 6, 2, 7, 8, 0]] color = colores[mat[e]] X = xnod[nodos, 0] Y = xnod[nodos, 1] Z = xnod[nodos, 2] ax.plot3D(X, Y, Z, '-', lw=0.8, c=color) if mostrar_nodos: ax.plot3D(xnod[LaG[e], 0], xnod[LaG[e], 1], xnod[LaG[e], 2], 'ko', ms=5, mfc='r') if mostrar_num_elem: # se calcula la posición del centro de gravedad del EF cg[e] = np.mean(xnod[LaG[e]], axis=0) # y se reporta el número del elemento actual ax.text(cg[e, 0], cg[e, 1], cg[e, 2], f'{e+1}', horizontalalignment='center', verticalalignment='center', color=color) if mostrar_num_nodo: for i in range(nno): ax.text(xnod[i, 0], xnod[i, 1], xnod[i, 2], f'{i+1}', color='r') fig.suptitle(f'Malla de EF Shell ({elem})', fontsize='x-large') ax.set_xlabel('x') ax.set_ylabel('y') ax.set_zlabel('z') ax.view_init(45, 45) elif tipo == '2D': xnod = xnod_from_msh(file, 2) nno = xnod.shape[0] cg = np.empty((nef, 2)) colores = list(BASE_COLORS.values()) + list(TABLEAU_COLORS.values()) fig = plt.figure(figsize=(12, 12)) ax = plt.gca() for e in range(nef): if elem == 'T3' or elem == 'Q4': nodos = np.r_[LaG[e], LaG[e, 0]] elif elem == 'T6': nodos = LaG[e, [0, 3, 1, 4, 2, 5, 0]] elif elem == 'Q8' or elem == 'Q9': nodos = LaG[e, [0, 4, 1, 5, 2, 6, 3, 7, 0]] elif elem == 'T10': nodos = LaG[e, [0, 3, 4, 1, 5, 6, 2, 7, 8, 0]] color = colores[mat[e]] X = xnod[nodos, 0] Y = xnod[nodos, 1] ax.plot(X, Y, '-', lw=0.8, c=color) if mostrar_nodos: ax.plot(xnod[LaG[e], 0], xnod[LaG[e], 1], 'ko', ms=5, mfc='r') if mostrar_num_elem: # se calcula la posición del centro de gravedad del EF cg[e] = np.mean(xnod[LaG[e]], axis=0) # y se reporta el número del elemento actual ax.text(cg[e, 0], cg[e, 1], f'{e+1}', horizontalalignment='center', verticalalignment='center', color=color) if mostrar_num_nodo: for i in range(nno): ax.text(xnod[i, 0], xnod[i, 1], f'{i+1}', color='r') fig.suptitle(f'Malla de EF 2D ({elem})', fontsize='x-large') ax.set_xlabel('x') ax.set_ylabel('y') ax.set_aspect('equal', adjustable='box') else: raise ValueError('El argumento "tipo" introducido no es válido')
def plot_coeff_all_boxplot(sto, list_OPtype): """ plot a boxplot + a swarmplot plot of the coefficients sto: dict of Station object. list_OPtype: list of OP to plot """ list_station = list(sto.keys()) f, axes = plt.subplots(nrows=len(list_OPtype), ncols=1, sharex=True, figsize=(17, len(list_OPtype) * 5)) sources = [] for s in sto.values(): sources = sources + [a for a in s.OPi.index if a not in sources] sources.sort() site_typologie, marker_typologie = get_typologie(sto.keys()) coeff = dict() for i, OPtype in enumerate(list_OPtype): coeff[OPtype] = pd.DataFrame(columns=list_station, index=sources) for s in sto.values(): if OPtype in s.OPi.columns: coeff[OPtype][s.name] = s.OPi[OPtype] else: coeff[OPtype][s.name] = np.nan coeff[OPtype] = coeff[OPtype].dropna(axis=1, how="all") df = coeff[OPtype].copy( ) #.unstack().reset_index() # needed only with swarmplot #df.columns=["site", "source", "OPi"] df[df.isnull()] = -999 # little hack to have all the legend # get color palette # palette = sns.color_palette("Paired", len(list_station)) # ... then plot it if isinstance(axes, np.ndarray): ax = axes[i] else: ax = axes # sns.boxplot(x="source", y="OPi", data=df, color="white", ax=ax) sns.boxplot(data=coeff[OPtype].T, color="white", ax=ax) step = 0.20 ntypo = len(site_typologie.keys()) colors = list(TABLEAU_COLORS.values()) for t, typo in enumerate(site_typologie.keys()): for c, site in enumerate(site_typologie[typo]): if site not in df.columns: continue marker = marker_typologie[typo] for j, s in enumerate(sources): ax.scatter(j - ntypo * step / 2 + step / 2 + t * step, df.loc[s, site], marker=marker, color=colors[c], s=100, alpha=0.7, zorder=10, label=site if j == 0 else '') ax.legend("") # There is a little hack for the legend here...* ymax = df.max().max() ax.set_ylim(bottom=-0.1 * ymax, top=ymax * 1.1) #sns.swarmplot(x="source", y="OPi", hue="site", # data=df, # palette=palette, # size=8, edgecolor="black", # ax=ax) #ax.legend("") # ax.set_title(OPtype) ax.set_xlabel("") ax.set_ylabel( "Intrinsic {OPtype}\nnmol/min/µg".format(OPtype=OPtype[:-1])) # remove "_" xl = ax.get_xticklabels() l = list() for xli in xl: l.append(xli.get_text()) l = [l.replace("_", " ") for l in l] ax.set_xticklabels(l, rotation=0) ax.set_xticklabels(ax.get_xticklabels(), rotation=90) plt.subplots_adjust(top=0.95, bottom=0.16, left=0.12, right=0.85) if isinstance(axes, np.ndarray): f.legend(*ax.get_legend_handles_labels(), loc="center right") else: ax.legend(loc="center right", bbox_to_anchor=(1. + len(list_OPtype) * 0.1, 0.5))
def barplot(df: pd.DataFrame, bar_width: float = 0.3, space_width: float = 0.3, height: float = None, dpi: int = 200, min_std: float = 0, min_value: float = None, max_value: float = None, show_legend: bool = True, show_title: str = True, legend_position: str = "best", data_label: str = None, title: str = None, path: str = None, colors: Dict[str, str] = None, alphas: Dict[str, float] = None, facecolors: Dict[str, str] = None, orientation: str = "vertical", subplots: bool = False, plots_per_row: Union[int, str] = "auto", minor_rotation: float = 0, major_rotation: float = 0, unique_minor_labels: bool = False, unique_major_labels: bool = True, unique_data_label: bool = True, auto_normalize_metrics: bool = True, placeholder: bool = False, scale: str = "linear", custom_defaults: Dict[str, List[str]] = None, sort_subplots: Callable[[List], List] = None, sort_bars: Callable[[pd.DataFrame], pd.DataFrame] = None, letter: str = None) -> Tuple[Figure, Axes]: """Plot barplot corresponding to given dataframe, containing y value and optionally std. Parameters ---------- df: pd.DataFrame, Dataframe from which to extrat data for plotting barplot. bar_width: float = 0.3, Width of the bar of the barplot. height: float = None, Height of the barplot. By default golden ratio of the width. dpi: int = 200, DPI for plotting the barplots. min_std: float = 0.001, Minimum standard deviation for showing error bars. min_value: float = None, Minimum value for the barplot. max_value: float = 0, Maximum value for the barplot. show_legend: bool = True, Whetever to show or not the legend. If legend is hidden, the bar ticks are shown alternatively. show_title: str = True, Whetever to show or not the barplot title. legend_position: str = "best", Legend position, by default "best". data_label: str = None, Barplot's data_label. Use None for not showing any data_label (default). title: str = None, Barplot's title. Use None for not showing any title (default). path: str = None, Path where to save the barplot. Use None for not saving it (default). colors: Dict[str, str] = None, Dict of colors to be used for innermost index of dataframe. By default None, using the default color tableau from matplotlib. alphas: Dict[str, float] = None, Dict of alphas to be used for innermost index of dataframe. By default None, using the default alpha. orientation: str = "vertical", Orientation of the bars. Can either be "vertical" of "horizontal". subplots: bool = False, Whetever to slit the top indexing layer to multiple subplots. plots_per_row: Union[int, str] = "auto", If subplots is True, specifies the number of plots for row. If "auto" is used, for vertical the default is 2 plots per row, while for horizontal the default is 4 plots per row. minor_rotation: float = 0, Rotation for the minor ticks of the bars. major_rotation: float = 0, Rotation for the major ticks of the bars. unique_minor_labels: bool = False, Avoid replicating minor labels on the same axis in multiple subplots settings. unique_major_labels: bool = True, Avoid replicating major labels on the same axis in multiple subplots settings. unique_data_label: bool = True, Avoid replication of data axis label when using subplots. auto_normalize_metrics: bool = True, Whetever to apply or not automatic normalization to the metrics that are recognized to be between zero and one. For example AUROC, AUPRC or accuracy. placeholder: bool = False, Whetever to add a text on top of the barplots to show the word "placeholder". Useful when generating placeholder data. scale: str = "linear", Scale to use for the barplots. Can either be "linear" or "log". custom_defaults: Dict[str, List[str]], Dictionary to normalize labels. letter: str = None, Letter to show on the top left of the figure. This is sometimes necessary on papers. By default it is None, that is no letter to be shown. Raises ------ ValueError: If the given orientation is nor "vertical" nor "horizontal". ValueError: If the given plots_per_row is nor "auto" or a positive integer. ValueError: If subplots is True and less than a single index level is provided. Returns ------- Tuple containing Figure and Axes of created barplot. """ if orientation not in ("vertical", "horizontal"): raise ValueError( "Given orientation \"{orientation}\" is not supported.".format( orientation=orientation)) if not isinstance(plots_per_row, int) and plots_per_row != "auto" or isinstance( plots_per_row, int) and plots_per_row < 1: raise ValueError( "Given plots_per_row \"{plots_per_row}\" is not 'auto' or a positive integer." .format(plots_per_row=plots_per_row)) vertical = orientation == "vertical" levels = get_levels(df) expected_levels = len(levels) - int(show_legend) - int(subplots) if len(levels) <= 1 and subplots: raise ValueError( "Unable to split plots with only a single index level.") if plots_per_row == "auto": if subplots: plots_per_row = min(len(levels[0]), 2 if vertical else 4) else: plots_per_row = min(plots_per_row, len(levels[0])) if colors is None: colors = dict( zip(levels[-1], list(TABLEAU_COLORS.keys()) + list(CSS4_COLORS.keys()))) if alphas is None: alphas = dict(zip(levels[-1], (0.9, ) * len(levels[-1]))) if facecolors is None: facecolors = dict(zip(levels[0], ("white", ) * len(levels[0]))) sorted_level = levels[0] if sort_subplots is not None: sorted_level = sort_subplots(sorted_level) if subplots: titles = sorted_level else: titles = ("", ) figure, axes = get_axes(df, bar_width, space_width, height, dpi, title, data_label, vertical, subplots, titles, plots_per_row, custom_defaults, expected_levels, scale, facecolors, show_title) for i, (index, ax) in enumerate(zip(titles, axes)): if subplots: sub_df = df.loc[index] else: sub_df = df if sort_bars is not None: sub_df = sort_bars(sub_df) plot_bars(ax, sub_df, bar_width, space_width, alphas, colors, index, vertical=vertical, min_std=min_std) is_not_first_ax = subplots and ( (not vertical and i % plots_per_row) or (vertical and i < len(axes) - plots_per_row)) is_not_first_vertical_ax = subplots and ( (vertical and i % plots_per_row) or (not vertical and i < len(axes) - plots_per_row)) plot_bar_labels(ax, figure, sub_df, vertical, expected_levels, bar_width, space_width, minor_rotation, major_rotation, unique_minor_labels and is_not_first_ax, unique_major_labels and is_not_first_ax, unique_data_label and is_not_first_vertical_ax, custom_defaults) if show_legend: remove_duplicated_legend_labels(ax, legend_position, custom_defaults) max_lenght, min_lenght = get_max_bar_lenght(sub_df, bar_width, space_width) max_lenght *= 1.01 min_lenght *= 1.01 min_lenght = min(min_lenght, 0) if min_value is not None: min_lenght = min_value if auto_normalize_metrics and (is_normalized_metric(df.columns[0]) or is_normalized_metric(title)): max_lenght = max(max_lenght, 1.01) if max_value is not None: max_lenght = max_value if placeholder: ax.text(0.5, 0.5, "PLACEHOLDER", fontsize=30, alpha=0.75, color="red", rotation=8, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) if vertical: ax.set_ylim(min_lenght, max_lenght) else: ax.set_xlim(min_lenght, max_lenght) if letter: figure.text(0.01, 0.9, letter, horizontalalignment='center', verticalalignment='center', weight='bold', fontsize=15) figure.tight_layout() if path is not None: save_picture(path, figure) return figure, axes
def create_plots(self, axes, wheel_axes=None, trial_events=None, color_map=None, linestyle=None): """ Plots the data for bnc1 (sound) and bnc2 (frame2ttl) :param axes: An axes handle on which to plot the TTL events :param wheel_axes: An axes handle on which to plot the wheel trace :param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'], if None, valve, sound and stimulus events are plotted :param color_map: A color map to use for the events, default is the tableau color map linestyle: A line style map to use for the events, default is random. :return: None """ color_map = color_map or TABLEAU_COLORS.keys() if trial_events is None: # Default trial events to plot as vertical lines trial_events = [ 'goCue_times', 'goCueTrigger_times', 'feedback_times', 'stimFreeze_times', 'stimOff_times', 'stimOn_times' ] plot_args = { 'ymin': 0, 'ymax': 4, 'linewidth': 2, 'ax': axes } bnc1 = self.extractor.frame_ttls bnc2 = self.extractor.audio_ttls trial_data = self.extractor.data plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k') plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k') linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events)) if self.extractor.bpod_ttls is not None: bpttls = self.extractor.bpod_ttls plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k') plot_args['ymax'] = 4 ylabels = ['', 'frame2ttl', 'sound', 'bpod', ''] else: plot_args['ymax'] = 3 ylabels = ['', 'frame2ttl', 'sound', ''] for event, c, l in zip(trial_events, cycle(color_map), linestyle): plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args) axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5)) axes.set_yticklabels(ylabels) axes.set_yticks(list(range(plot_args['ymax'] + 1))) axes.set_ylim([0, plot_args['ymax']]) if wheel_axes: wheel_plot_args = { 'ax': wheel_axes, 'ymin': self.wheel_data['re_pos'].min(), 'ymax': self.wheel_data['re_pos'].max()} plot_args = {**plot_args, **wheel_plot_args} wheel_axes.plot(self.wheel_data['re_ts'], self.wheel_data['re_pos'], 'k-x') for event, c, ln in zip(trial_events, cycle(color_map), linestyle): plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=ln, **plot_args)