Esempio n. 1
0
def distinct_color_map(size):
    """
    Generate a list of unique colours for matplotlib line/scatter plots
    """
    if size > len(TABLEAU_COLORS):
        return XKCD_COLORS.keys()
    return TABLEAU_COLORS.keys()
def plot_embedding(
    graph: EnsmallenGraph,
    tsne_embedding: np.ndarray,
    k: int = 10,
    axes: Axes = None
):
    if axes is None:
        _, axes = plt.subplots(figsize=(5, 5))
    if graph.node_types_mapping is None:
        node_types = np.zeros(graph.get_nodes_number(), dtype=np.uint8)
        common_node_types_names = ["No node type provided"]
    else:
        nodes, node_types = graph.get_top_k_nodes_by_node_type(k)
        tsne_embedding = tsne_embedding[nodes]
        common_node_types_names = list(np.array(graph.node_types_reverse_mapping)[np.unique(node_types)])
    colors = list(TABLEAU_COLORS.keys())[:len(common_node_types_names)]
    scatter = axes.scatter(
        *tsne_embedding.T,
        s=0.25,
        c=node_types,
        cmap=ListedColormap(colors)
    )
    axes.legend(
        handles=scatter.legend_elements()[0],
        labels=common_node_types_names
    )
    return axes
Esempio n. 3
0
def show_measurement_1d(measurements_or_func,
                        figure=None,
                        throttling=False,
                        **kwargs):
    if figure is None:
        figure = plt.figure(fig_margin={
            'top': 0,
            'bottom': 50,
            'left': 50,
            'right': 0
        })

        figure.layout.height = '250px'
        figure.layout.width = '300px'

    try:
        measurements = measurements_or_func()
        return_callback = True

    except TypeError:
        measurements = measurements_or_func
        return_callback = False

    lines = []
    for measurement, color in zip(measurements, TABLEAU_COLORS.values()):
        calibration = measurement.calibrations[0]
        array = measurement.array
        x = np.linspace(calibration.offset,
                        calibration.offset + len(array) * calibration.sampling,
                        len(array))

        line = plt.plot(x, array, colors=[color], **kwargs)
        figure.axes[0].label = format_label(measurement.calibrations[0])
        lines.append(line)

    # figure.axes[1].label = format_label(measurement)

    if return_callback:

        @throttle(throttling)
        def callback(*args, **kwargs):
            for line, measurement in zip(lines, measurements_or_func()):
                x = np.linspace(
                    calibration.offset,
                    calibration.offset + len(array) * calibration.sampling,
                    len(array))
                line.x = x
                line.y = measurement.array

        return figure, callback
    else:
        return figure
Esempio n. 4
0
def plot_data(data_table: pd.DataFrame,
              columns: List[str],
              x_column: str,
              title: str,
              colors: List = None,
              shareX: bool = True,
              shareY: bool = False):
    fig, axs = plt.subplots(nrows=len(columns), sharex=shareX, sharey=shareY)
    fig.suptitle(title, fontsize=16)
    if colors is None or len(colors) < len(columns):
        colors = [TABLEAU_COLORS[key] for key in TABLEAU_COLORS.keys()]
    for index, col in enumerate(columns):
        axs[index].plot(data_table[x_column],
                        data_table[col],
                        color=colors[index])
        axs[index].set(ylabel=col, xlim=(0, data_table[x_column].max()))
    plt.show()
Esempio n. 5
0
    def plot_train_data_separate_by_param(self,
                                          param_name,
                                          filter_dict={},
                                          show=True):
        """
        This function takes the name of a parameter in the config, and plots each train energy corresponding
        to one value of the parameter in one particular color.
        """
        colorlist = list(COLORS.keys())
        param_values = []
        legend_handles = []
        for c_idx, config in self.config_dict.items():

            if not self.filter_config(c_idx, filter_dict):
                continue

            if config[param_name] not in param_values:
                param_values.append(config[param_name])

            color = colorlist[param_values.index(config[param_name])]
            print(f'color 2 {color}')
            path = op.join(self.config_dict[c_idx]['path'], 'train_data.hdf5')
            data_dict = utl.load_dict_h5py(path)
            plt.plot(data_dict['energy'], color=color)

        # legend
        for color, value in zip(colorlist, param_values):
            print(f"COLOR {color}")
            line = mlines.Line2D([], [],
                                 color=color,
                                 label=f'{param_name}: {value}')
            legend_handles.append(line)

        plt.legend(legend_handles, param_values)

        if show:
            plt.show()
Esempio n. 6
0
    m = {i: (0, 0, 0) for i in range(k)}
    for (x, y), i in cluster.items():
        m[i] = (m[i][0]+x, m[i][1]+y, m[i][2]+1)
    return [(m[i][0]/m[i][2], m[i][1]/m[i][2])
            if m[i][2] else (float('inf'), float('inf')) for i in m]


def gen(n, s):
    from random import random
    seed(s)
    return [(random(), random()) for i in range(n)]


if __name__ == '__main__':
    from matplotlib import pyplot as plt
    from matplotlib.colors import TABLEAU_COLORS
    n, k, s = 200, 5, 0
    pts = gen(n, s)
    # clusters = k_means(pts, k, s=s)
    clusters = k_means(pts, k, clusters=k_means_pp_init_clusters(pts, k, s))

    colors = list(TABLEAU_COLORS.values())
    for i in range(k):
        pi = [p for p, j in clusters.items() if i == j]
        plt.scatter([x for x, _ in pi], [y for _, y in pi], c=colors[i])

    plt.gca().set_aspect('equal')
    plt.tight_layout()
    plt.savefig('k_means_pp_ex1.png', bbox_inches='tight')
    plt.show()
Esempio n. 7
0
def plot_relativeMass(PMF_profile,
                      source="Biomass burning",
                      isRelativeMass=True,
                      totalVar="PM10",
                      naxe=1,
                      site_typologie=None):
    if not site_typologie:
        site_typologie = get_site_typology()
        # site_typologie = collections.OrderedDict()
        # site_typologie["Urban"] = ["Talence", "Lyon", "Poitiers", "Nice", "MRS-5av",
        #                            "PdB", "Aix-en-provence", "Nogent",
        #                            "Lens-2011-2012", "Lens-2013-2014"]
        # site_typologie["Valley"] = ["Chamonix", "GRE-fr"]
        # site_typologie["Traffic"] = ["Roubaix", "STG-cle"]
        # site_typologie["Rural"] = ["Revin"]

    carboneous = ["OC*", "EC"]
    ions = ["Cl-", "NO3-", "SO42-", "Na+", "NH4+", "K+", "Mg2+", "Ca2+"]
    organics = [
        "MSA",
        "Polyols",
        "Levoglucosan",
        "Mannosan",
    ]
    metals = [
        "Al", "As", "Ba", "Cd", "Co", "Cr", "Cs", "Cu", "Fe", "La", "Mn", "Mo",
        "Ni", "Pb", "Rb", "Sb", "Se", "Sn", "Sr", "Ti", "V", "Zn"
    ]
    keep_index = ["PM10"] + carboneous + ions + organics + metals
    if naxe == 2:
        keep_index1 = carboneous + [
            "NO3-", "SO42-", "NH4+", "Na+", "K+", "Ca2+", "Mg2+", "Cl-",
            "Polyols", "Fe", "Al", "Levoglucosan", "Mannosan", "MSA"
        ]
        keep_index2 = []  #"MSA"]
        keep_index2tmp = list(
            set(keep_index) - set(keep_index1) - set(["PM10", "Mg2+"]))
        keep_index2tmp.sort()
        keep_index2 += keep_index2tmp
    dfperµg = pd.DataFrame(columns=keep_index)
    for station, df in PMF_profile.reset_index().groupby("station"):
        df = df.set_index("specie").drop("station", axis=1)
        if isRelativeMass:
            dfperµg.loc[station, :] = df.loc[:, source]
        else:
            dfperµg.loc[station, :] = to_relativeMass(df.loc[:, source],
                                                      totalVar=totalVar).T
    dfperµg = dfperµg.convert_objects(convert_numeric=True)

    # FIGURE
    f, axes = plt.subplots(nrows=naxe, ncols=1, figsize=(12.67, 6.54))
    if naxe == 1:
        axes = [axes, None]

    d = dfperµg.T.copy()
    d["specie"] = d.index
    # for i, keep_index in enumerate([keep_index1, keep_index2]):
    if naxe == 1:
        xtick_list = [keep_index]
    elif naxe == 2:
        xtick_list = [keep_index1, keep_index2]
    for i, keep_index in enumerate(xtick_list):
        dd = d.reindex(keep_index)
        # dd.rename(rename_station, inplace=True, axis="columns")
        dd = dd.melt(id_vars=["specie"])
        # if not percent:
        dd.replace({0: np.nan}, inplace=True)
        axes[i].set_yscale("log")
        sns.boxplot(data=dd,
                    x="specie",
                    y="value",
                    ax=axes[i],
                    color="white",
                    showcaps=False,
                    showmeans=False,
                    meanprops={"marker": "d"})
        ntypo = len(site_typologie.keys())
        colors = list(TABLEAU_COLORS.values())
        # for sp, specie in enumerate(keep_index):
        for t, typo in enumerate(site_typologie.keys()):
            if typo == "Urban":
                marker = "*"
            elif (typo == "Valley") or typo == ("Urban+Alps"):
                marker = "o"
            elif typo == "Traffic":
                marker = "p"
            else:
                marker = "d"
            step = 0.1
            j = 0
            for site in site_typologie[typo]:
                if site not in PMF_profile.index.get_level_values(
                        "station").unique():
                    continue
                axes[i].scatter(np.arange(0, len(keep_index)) -
                                ntypo * step / 2 + step / 2 + t * step,
                                d.loc[keep_index, site],
                                marker=marker,
                                color=colors[j],
                                alpha=0.8)
                j += 1

        # sns.swarmplot(data=dd, x="specie", y="value", color=".2", alpha=0.5, size=4,
        #              ax=axes[i])
    if naxe == 1:
        axes[0].set_ylim([1e-5, 2])
    else:
        axes[0].set_ylim([1e-3, 1])
        axes[1].set_ylim([1e-5, 1e-2])
    for ax in axes:
        if ax:
            ax.set_ylabel("µg/µg of PM$_{10}$")
            ax.set_xlabel("")
            for tick in ax.get_xticklabels():
                tick.set_rotation(90)
            # ax.legend(loc="center",
            #           ncol=(len(dfperµg.columns))//2,
            #           bbox_to_anchor=(0.5, -.55),
            #           frameon=False)
    # create custom legend
    labels = []
    artists = []
    for typo in site_typologie.keys():
        if typo == "Urban":
            marker = "*"
        elif (typo == "Valley") or (typo == "Urban+Alps"):
            marker = "o"
        elif typo == "Traffic":
            marker = "p"
        else:
            marker = "d"
        noSiteYet = True
        j = 0
        for site in site_typologie[typo]:
            if site not in PMF_profile.index.get_level_values(
                    "station").unique():
                continue
            if noSiteYet:
                artist = typo
                label = ""
                artists.append(artist)
                labels.append(label)
                noSiteYet = False
            artist = mlines.Line2D([], [],
                                   ls='',
                                   marker=marker,
                                   color=colors[j],
                                   label=site)
            artists.append(artist)
            labels.append(site)
            j += 1

    axes[0].legend(artists,
                   labels,
                   bbox_to_anchor=(1.1, 1),
                   handler_map={str: LegendTitle()},
                   frameon=False)
    # ax.legend('', frameon=False)

    f.suptitle(source)
    # plt.subplots_adjust(left=0.07, right=0.83, bottom=0.15, top=0.900, hspace=0.5)
    plt.subplots_adjust(top=0.925,
                        bottom=0.07,
                        left=0.053,
                        right=0.899,
                        hspace=0.489,
                        wspace=0.2)
Esempio n. 8
0
def plot_msh(file,
             tipo,
             mostrar_nodos=False,
             mostrar_num_nodo=False,
             mostrar_num_elem=False):
    ''' Función para graficar la malla contenida en el archivo "file".
        Argumentos:
        - file: str. Debe ser un archivo de extensión .msh exportado por GMSH.
        - tipo: str. 'shell' o '2D'
        - mostrar_nodos: bool. Define si se muestran los nodos en el gráfico.
        - mostrar_num_nodo: bool. Define si se muestra el número de cada nodo
                            en el gráfico.
        - mostrar_num_elem: bool. Define si se muestra el número de cada ele-
                            mento finito en el gráfico.                            
    '''

    LaG_mat = LaG_from_msh(file)
    mat = LaG_mat[:, 0]
    LaG = LaG_mat[:, 1:]
    nef = LaG.shape[0]

    # Se determina el tipo de elemento finito
    if LaG.shape[1] == 3:
        elem = 'T3'
    elif LaG.shape[1] == 4:
        elem = 'Q4'
    elif LaG.shape[1] == 6:
        elem = 'T6'
    elif LaG.shape[1] == 8:
        elem = 'Q8'
    elif LaG.shape[1] == 9:
        elem = 'Q9'
    elif LaG.shape[1] == 10:
        elem = 'T10'

    if tipo == 'shell':
        xnod = xnod_from_msh(file, 3)
        nno = xnod.shape[0]

        cg = np.empty((nef, 3))  # Centro de gravedad del EF
        colores = list(BASE_COLORS.values()) + list(TABLEAU_COLORS.values())

        fig = plt.figure(figsize=(12, 12))
        ax = plt.gca(projection='3d')
        for e in range(nef):
            if elem == 'T3' or elem == 'Q4':
                nodos = np.r_[LaG[e], LaG[e, 0]]
            elif elem == 'T6':
                nodos = LaG[e, [0, 3, 1, 4, 2, 5, 0]]
            elif elem == 'Q8' or elem == 'Q9':
                nodos = LaG[e, [0, 4, 1, 5, 2, 6, 3, 7, 0]]
            elif elem == 'T10':
                nodos = LaG[e, [0, 3, 4, 1, 5, 6, 2, 7, 8, 0]]
            color = colores[mat[e]]
            X = xnod[nodos, 0]
            Y = xnod[nodos, 1]
            Z = xnod[nodos, 2]

            ax.plot3D(X, Y, Z, '-', lw=0.8, c=color)
            if mostrar_nodos:
                ax.plot3D(xnod[LaG[e], 0],
                          xnod[LaG[e], 1],
                          xnod[LaG[e], 2],
                          'ko',
                          ms=5,
                          mfc='r')
            if mostrar_num_elem:
                # se calcula la posición del centro de gravedad del EF
                cg[e] = np.mean(xnod[LaG[e]], axis=0)

                # y se reporta el número del elemento actual
                ax.text(cg[e, 0],
                        cg[e, 1],
                        cg[e, 2],
                        f'{e+1}',
                        horizontalalignment='center',
                        verticalalignment='center',
                        color=color)
        if mostrar_num_nodo:
            for i in range(nno):
                ax.text(xnod[i, 0],
                        xnod[i, 1],
                        xnod[i, 2],
                        f'{i+1}',
                        color='r')
        fig.suptitle(f'Malla de EF Shell ({elem})', fontsize='x-large')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')
        ax.view_init(45, 45)

    elif tipo == '2D':
        xnod = xnod_from_msh(file, 2)
        nno = xnod.shape[0]

        cg = np.empty((nef, 2))
        colores = list(BASE_COLORS.values()) + list(TABLEAU_COLORS.values())
        fig = plt.figure(figsize=(12, 12))
        ax = plt.gca()
        for e in range(nef):
            if elem == 'T3' or elem == 'Q4':
                nodos = np.r_[LaG[e], LaG[e, 0]]
            elif elem == 'T6':
                nodos = LaG[e, [0, 3, 1, 4, 2, 5, 0]]
            elif elem == 'Q8' or elem == 'Q9':
                nodos = LaG[e, [0, 4, 1, 5, 2, 6, 3, 7, 0]]
            elif elem == 'T10':
                nodos = LaG[e, [0, 3, 4, 1, 5, 6, 2, 7, 8, 0]]
            color = colores[mat[e]]
            X = xnod[nodos, 0]
            Y = xnod[nodos, 1]
            ax.plot(X, Y, '-', lw=0.8, c=color)
            if mostrar_nodos:
                ax.plot(xnod[LaG[e], 0], xnod[LaG[e], 1], 'ko', ms=5, mfc='r')
            if mostrar_num_elem:
                # se calcula la posición del centro de gravedad del EF
                cg[e] = np.mean(xnod[LaG[e]], axis=0)

                # y se reporta el número del elemento actual
                ax.text(cg[e, 0],
                        cg[e, 1],
                        f'{e+1}',
                        horizontalalignment='center',
                        verticalalignment='center',
                        color=color)
        if mostrar_num_nodo:
            for i in range(nno):
                ax.text(xnod[i, 0], xnod[i, 1], f'{i+1}', color='r')
        fig.suptitle(f'Malla de EF 2D ({elem})', fontsize='x-large')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_aspect('equal', adjustable='box')
    else:
        raise ValueError('El argumento "tipo" introducido no es válido')
Esempio n. 9
0
def plot_coeff_all_boxplot(sto, list_OPtype):
    """
    plot a boxplot + a swarmplot plot of the coefficients

    sto: dict of Station object.
    list_OPtype: list of OP to plot

    """
    list_station = list(sto.keys())
    f, axes = plt.subplots(nrows=len(list_OPtype),
                           ncols=1,
                           sharex=True,
                           figsize=(17, len(list_OPtype) * 5))

    sources = []
    for s in sto.values():
        sources = sources + [a for a in s.OPi.index if a not in sources]
    sources.sort()
    site_typologie, marker_typologie = get_typologie(sto.keys())

    coeff = dict()
    for i, OPtype in enumerate(list_OPtype):
        coeff[OPtype] = pd.DataFrame(columns=list_station, index=sources)
        for s in sto.values():
            if OPtype in s.OPi.columns:
                coeff[OPtype][s.name] = s.OPi[OPtype]
            else:
                coeff[OPtype][s.name] = np.nan

        coeff[OPtype] = coeff[OPtype].dropna(axis=1, how="all")
        df = coeff[OPtype].copy(
        )  #.unstack().reset_index() # needed only with swarmplot
        #df.columns=["site", "source", "OPi"]
        df[df.isnull()] = -999  # little hack to have all the legend

        # get color palette
        # palette = sns.color_palette("Paired", len(list_station))

        # ... then plot it
        if isinstance(axes, np.ndarray):
            ax = axes[i]
        else:
            ax = axes
        # sns.boxplot(x="source", y="OPi", data=df, color="white", ax=ax)
        sns.boxplot(data=coeff[OPtype].T, color="white", ax=ax)

        step = 0.20
        ntypo = len(site_typologie.keys())
        colors = list(TABLEAU_COLORS.values())

        for t, typo in enumerate(site_typologie.keys()):
            for c, site in enumerate(site_typologie[typo]):
                if site not in df.columns:
                    continue
                marker = marker_typologie[typo]
                for j, s in enumerate(sources):
                    ax.scatter(j - ntypo * step / 2 + step / 2 + t * step,
                               df.loc[s, site],
                               marker=marker,
                               color=colors[c],
                               s=100,
                               alpha=0.7,
                               zorder=10,
                               label=site if j == 0 else '')
                    ax.legend("")
                    # There is a little hack for the legend here...*
        ymax = df.max().max()
        ax.set_ylim(bottom=-0.1 * ymax, top=ymax * 1.1)

        #sns.swarmplot(x="source", y="OPi", hue="site",
        #              data=df,
        #              palette=palette,
        #              size=8, edgecolor="black",
        #              ax=ax)
        #ax.legend("")

        # ax.set_title(OPtype)
        ax.set_xlabel("")
        ax.set_ylabel(
            "Intrinsic {OPtype}\nnmol/min/µg".format(OPtype=OPtype[:-1]))
    # remove "_"
    xl = ax.get_xticklabels()
    l = list()
    for xli in xl:
        l.append(xli.get_text())
        l = [l.replace("_", " ") for l in l]
    ax.set_xticklabels(l, rotation=0)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    plt.subplots_adjust(top=0.95, bottom=0.16, left=0.12, right=0.85)
    if isinstance(axes, np.ndarray):
        f.legend(*ax.get_legend_handles_labels(), loc="center right")
    else:
        ax.legend(loc="center right",
                  bbox_to_anchor=(1. + len(list_OPtype) * 0.1, 0.5))
Esempio n. 10
0
def barplot(df: pd.DataFrame,
            bar_width: float = 0.3,
            space_width: float = 0.3,
            height: float = None,
            dpi: int = 200,
            min_std: float = 0,
            min_value: float = None,
            max_value: float = None,
            show_legend: bool = True,
            show_title: str = True,
            legend_position: str = "best",
            data_label: str = None,
            title: str = None,
            path: str = None,
            colors: Dict[str, str] = None,
            alphas: Dict[str, float] = None,
            facecolors: Dict[str, str] = None,
            orientation: str = "vertical",
            subplots: bool = False,
            plots_per_row: Union[int, str] = "auto",
            minor_rotation: float = 0,
            major_rotation: float = 0,
            unique_minor_labels: bool = False,
            unique_major_labels: bool = True,
            unique_data_label: bool = True,
            auto_normalize_metrics: bool = True,
            placeholder: bool = False,
            scale: str = "linear",
            custom_defaults: Dict[str, List[str]] = None,
            sort_subplots: Callable[[List], List] = None,
            sort_bars: Callable[[pd.DataFrame], pd.DataFrame] = None,
            letter: str = None) -> Tuple[Figure, Axes]:
    """Plot barplot corresponding to given dataframe, containing y value and optionally std.

    Parameters
    ----------
    df: pd.DataFrame,
        Dataframe from which to extrat data for plotting barplot.
    bar_width: float = 0.3,
        Width of the bar of the barplot.
    height: float = None,
        Height of the barplot. By default golden ratio of the width.
    dpi: int = 200,
        DPI for plotting the barplots.
    min_std: float = 0.001,
        Minimum standard deviation for showing error bars.
    min_value: float = None,
        Minimum value for the barplot.
    max_value: float = 0,
        Maximum value for the barplot.
    show_legend: bool = True,
        Whetever to show or not the legend.
        If legend is hidden, the bar ticks are shown alternatively.
    show_title: str = True,
        Whetever to show or not the barplot title.
    legend_position: str = "best",
        Legend position, by default "best".
    data_label: str = None,
        Barplot's data_label.
        Use None for not showing any data_label (default).
    title: str = None,
        Barplot's title.
        Use None for not showing any title (default).
    path: str = None,
        Path where to save the barplot.
        Use None for not saving it (default).
    colors: Dict[str, str] = None,
        Dict of colors to be used for innermost index of dataframe.
        By default None, using the default color tableau from matplotlib.
    alphas: Dict[str, float] = None,
        Dict of alphas to be used for innermost index of dataframe.
        By default None, using the default alpha.
    orientation: str = "vertical",
        Orientation of the bars.
        Can either be "vertical" of "horizontal".
    subplots: bool = False,
        Whetever to slit the top indexing layer to multiple subplots.
    plots_per_row: Union[int, str] = "auto",
        If subplots is True, specifies the number of plots for row.
        If "auto" is used, for vertical the default is 2 plots per row,
        while for horizontal the default is 4 plots per row.
    minor_rotation: float = 0,
        Rotation for the minor ticks of the bars.
    major_rotation: float = 0,
        Rotation for the major ticks of the bars.
    unique_minor_labels: bool = False,
        Avoid replicating minor labels on the same axis in multiple subplots settings.
    unique_major_labels: bool = True,
        Avoid replicating major labels on the same axis in multiple subplots settings.
    unique_data_label: bool = True,
        Avoid replication of data axis label when using subplots.
    auto_normalize_metrics: bool = True,
        Whetever to apply or not automatic normalization
        to the metrics that are recognized to be between
        zero and one. For example AUROC, AUPRC or accuracy.
    placeholder: bool = False,
        Whetever to add a text on top of the barplots to show
        the word "placeholder". Useful when generating placeholder data.
    scale: str = "linear",
        Scale to use for the barplots.
        Can either be "linear" or "log".
    custom_defaults: Dict[str, List[str]],
        Dictionary to normalize labels.
    letter: str = None,
        Letter to show on the top left of the figure.
        This is sometimes necessary on papers.
        By default it is None, that is no letter to be shown.

    Raises
    ------
    ValueError:
        If the given orientation is nor "vertical" nor "horizontal".
    ValueError:
        If the given plots_per_row is nor "auto" or a positive integer.
    ValueError:
        If subplots is True and less than a single index level is provided.

    Returns
    -------
    Tuple containing Figure and Axes of created barplot.
    """

    if orientation not in ("vertical", "horizontal"):
        raise ValueError(
            "Given orientation \"{orientation}\" is not supported.".format(
                orientation=orientation))

    if not isinstance(plots_per_row,
                      int) and plots_per_row != "auto" or isinstance(
                          plots_per_row, int) and plots_per_row < 1:
        raise ValueError(
            "Given plots_per_row \"{plots_per_row}\" is not 'auto' or a positive integer."
            .format(plots_per_row=plots_per_row))

    vertical = orientation == "vertical"

    levels = get_levels(df)
    expected_levels = len(levels) - int(show_legend) - int(subplots)

    if len(levels) <= 1 and subplots:
        raise ValueError(
            "Unable to split plots with only a single index level.")

    if plots_per_row == "auto":
        if subplots:
            plots_per_row = min(len(levels[0]), 2 if vertical else 4)
    else:
        plots_per_row = min(plots_per_row, len(levels[0]))

    if colors is None:
        colors = dict(
            zip(levels[-1],
                list(TABLEAU_COLORS.keys()) + list(CSS4_COLORS.keys())))

    if alphas is None:
        alphas = dict(zip(levels[-1], (0.9, ) * len(levels[-1])))

    if facecolors is None:
        facecolors = dict(zip(levels[0], ("white", ) * len(levels[0])))

    sorted_level = levels[0]

    if sort_subplots is not None:
        sorted_level = sort_subplots(sorted_level)

    if subplots:
        titles = sorted_level
    else:
        titles = ("", )

    figure, axes = get_axes(df, bar_width, space_width, height, dpi, title,
                            data_label, vertical, subplots, titles,
                            plots_per_row, custom_defaults, expected_levels,
                            scale, facecolors, show_title)

    for i, (index, ax) in enumerate(zip(titles, axes)):
        if subplots:
            sub_df = df.loc[index]
        else:
            sub_df = df

        if sort_bars is not None:
            sub_df = sort_bars(sub_df)

        plot_bars(ax,
                  sub_df,
                  bar_width,
                  space_width,
                  alphas,
                  colors,
                  index,
                  vertical=vertical,
                  min_std=min_std)

        is_not_first_ax = subplots and (
            (not vertical and i % plots_per_row) or
            (vertical and i < len(axes) - plots_per_row))

        is_not_first_vertical_ax = subplots and (
            (vertical and i % plots_per_row) or
            (not vertical and i < len(axes) - plots_per_row))

        plot_bar_labels(ax, figure, sub_df, vertical, expected_levels,
                        bar_width, space_width, minor_rotation, major_rotation,
                        unique_minor_labels and is_not_first_ax,
                        unique_major_labels and is_not_first_ax,
                        unique_data_label and is_not_first_vertical_ax,
                        custom_defaults)

        if show_legend:
            remove_duplicated_legend_labels(ax, legend_position,
                                            custom_defaults)

        max_lenght, min_lenght = get_max_bar_lenght(sub_df, bar_width,
                                                    space_width)
        max_lenght *= 1.01
        min_lenght *= 1.01
        min_lenght = min(min_lenght, 0)

        if min_value is not None:
            min_lenght = min_value

        if auto_normalize_metrics and (is_normalized_metric(df.columns[0])
                                       or is_normalized_metric(title)):
            max_lenght = max(max_lenght, 1.01)

        if max_value is not None:
            max_lenght = max_value

        if placeholder:
            ax.text(0.5,
                    0.5,
                    "PLACEHOLDER",
                    fontsize=30,
                    alpha=0.75,
                    color="red",
                    rotation=8,
                    horizontalalignment='center',
                    verticalalignment='center',
                    transform=ax.transAxes)

        if vertical:
            ax.set_ylim(min_lenght, max_lenght)
        else:
            ax.set_xlim(min_lenght, max_lenght)

    if letter:
        figure.text(0.01,
                    0.9,
                    letter,
                    horizontalalignment='center',
                    verticalalignment='center',
                    weight='bold',
                    fontsize=15)

    figure.tight_layout()

    if path is not None:
        save_picture(path, figure)

    return figure, axes
Esempio n. 11
0
    def create_plots(self, axes,
                     wheel_axes=None, trial_events=None, color_map=None, linestyle=None):
        """
        Plots the data for bnc1 (sound) and bnc2 (frame2ttl)
        :param axes: An axes handle on which to plot the TTL events
        :param wheel_axes: An axes handle on which to plot the wheel trace
        :param trial_events: A list of Bpod trial events to plot, e.g. ['stimFreeze_times'],
        if None, valve, sound and stimulus events are plotted
        :param color_map: A color map to use for the events, default is the tableau color map
        linestyle: A line style map to use for the events, default is random.
        :return: None
        """
        color_map = color_map or TABLEAU_COLORS.keys()
        if trial_events is None:
            # Default trial events to plot as vertical lines
            trial_events = [
                'goCue_times',
                'goCueTrigger_times',
                'feedback_times',
                'stimFreeze_times',
                'stimOff_times',
                'stimOn_times'
            ]

        plot_args = {
            'ymin': 0,
            'ymax': 4,
            'linewidth': 2,
            'ax': axes
        }

        bnc1 = self.extractor.frame_ttls
        bnc2 = self.extractor.audio_ttls
        trial_data = self.extractor.data

        plots.squares(bnc1['times'], bnc1['polarities'] * 0.4 + 1, ax=axes, color='k')
        plots.squares(bnc2['times'], bnc2['polarities'] * 0.4 + 2, ax=axes, color='k')
        linestyle = linestyle or random.choices(('-', '--', '-.', ':'), k=len(trial_events))

        if self.extractor.bpod_ttls is not None:
            bpttls = self.extractor.bpod_ttls
            plots.squares(bpttls['times'], bpttls['polarities'] * 0.4 + 3, ax=axes, color='k')
            plot_args['ymax'] = 4
            ylabels = ['', 'frame2ttl', 'sound', 'bpod', '']
        else:
            plot_args['ymax'] = 3
            ylabels = ['', 'frame2ttl', 'sound', '']

        for event, c, l in zip(trial_events, cycle(color_map), linestyle):
            plots.vertical_lines(trial_data[event], label=event, color=c, linestyle=l, **plot_args)

        axes.legend(loc='upper left', fontsize='xx-small', bbox_to_anchor=(1, 0.5))
        axes.set_yticklabels(ylabels)
        axes.set_yticks(list(range(plot_args['ymax'] + 1)))
        axes.set_ylim([0, plot_args['ymax']])

        if wheel_axes:
            wheel_plot_args = {
                'ax': wheel_axes,
                'ymin': self.wheel_data['re_pos'].min(),
                'ymax': self.wheel_data['re_pos'].max()}
            plot_args = {**plot_args, **wheel_plot_args}

            wheel_axes.plot(self.wheel_data['re_ts'], self.wheel_data['re_pos'], 'k-x')
            for event, c, ln in zip(trial_events, cycle(color_map), linestyle):
                plots.vertical_lines(trial_data[event],
                                     label=event, color=c, linestyle=ln, **plot_args)