예제 #1
0
    def add_genotype_annotations_to_plot(
            self, ax: Axes, points: Dict[str, Tuple[float, float]],
            annotations: Dict[str,
                              List[str]], color_palette: Dict[str,
                                                              str]) -> Axes:
        locations = list()
        for genotype_label, point in points.items():
            if genotype_label == self.root_genotype_name:
                # There's no point in adding an annotation for the root genotype.
                continue

            genotype_color: str = color_palette[genotype_label]
            genotype_annotations: List[str] = annotations.get(
                genotype_label, [])
            if not genotype_annotations:
                # No annotations for this genome. Don't draw anything.
                continue

            background_properties = self._get_annotation_label_background_properties(
                genotype_color)
            label_properties = self._get_annotation_label_font_properties(
                genotype_color)

            if locations:
                x_loc, y_loc = relocate_point(point, locations)
            else:
                x_loc, y_loc = point

            locations.append((x_loc, y_loc))
            ax.text(x_loc,
                    y_loc,
                    "\n".join(genotype_annotations),
                    bbox=background_properties,
                    fontdict=label_properties)
        return ax
예제 #2
0
def plot_sticks(atoms: np.ndarray,
                ax: Axes,
                bt=(0.2, 2),
                nodup=True,
                threshold=0.1,
                **kwargs) -> None:
    """Plot bonds on given Axes.

    Parameters
    ----------
    atoms: array
        Array of atom positions (2D or 3D).
    ax: matplotlib.axes.Axes
        Axes to plot on.
    bs: tuple(int, int)
        Bond threshold, 
        e.g. (0.2, 2) means distance of atom between 0.2 and 2 (Å).
    nodup: bool
        No duplicate: if several bonds are overlayed, only plot one of them.
    threshold: float
        Threshold of two bonds to regard as duplicate.
    **kwargs: dict
        Matplotlib plot kwargs.
        Default style: {'color': 'tan', 'lw': 2, 'zorder': -1}.

    """
    style = {'color': 'tan', 'lw': 2, 'zorder': -1}
    style.update(kwargs)
    paired = bond_pair(atoms, bt)
    if nodup:
        paired = rm_dup_bond(paired)
    for pair in paired:
        ax.plot(pair[:, 0], pair[:, 1], **style)
예제 #3
0
def plot_balls(atoms: np.ndarray,
               ax: Axes,
               bt=(0.2, 2),
               nodup=True,
               threshold=0.1,
               **kwargs) -> None:
    """Plot bonds on given Axes.

    Parameters
    ----------
    atoms: array
        Array of atom positions (2D or 3D).
    ax: matplotlib.axes.Axes
        Axes to plot on.
    bs: tuple(int, int)
        Bond threshold, 
        e.g. (0.2, 2) means distance of atom between 0.2 and 2 (Å).
    nodup: bool
        No duplicate: if several bonds are overlayed, only plot one of them.
    threshold: float
        Threshold of two bonds to regard as duplicate.
    **kwargs: dict
        Matplotlib plot kwargs.
        Default style: {'s': 32, 'lw': 2, 'c': 'w', 'edgecolors': 'tan'}.

    """
    style = {'s': 32, 'lw': 2, 'c': 'w', 'edgecolors': 'tan'}
    style.update(kwargs)
    if nodup:
        atoms = rm_dup_atom(atoms)
    ax.scatter(atoms[:, 0], atoms[:, 1], **style)
예제 #4
0
    def _setup_axes_for_russian_regions_stat(self,
                                             ax: _figure.Axes,
                                             title: str = None,
                                             grid: bool = True,
                                             legend: bool = True,
                                             draw_key_dates: bool = True):
        ax.xaxis_date()

        if draw_key_dates:
            self.key_russian_dates(ax)

        if legend:
            ax.legend(loc='upper left')

        ax.set_ylim(bottom=0)

        if grid:
            ax.grid(axis='y', color='black', linestyle='dashed', alpha=0.4)

        if title:
            ax.set_title(title)
예제 #5
0
    def _draw_daily_stats(self, ax: _figure.Axes, df: _pd.DataFrame,
                          draw_key_dates: bool):
        index = df.index
        confirmed_daily = df.Confirmed_Change
        recovered_daily = df.Recovered_Change
        deaths_daily = df.Deaths_Change

        self.bar_with_sma_line(ax, confirmed_daily, label="Заболевшие")
        self.bar_with_sma_line(ax, recovered_daily, label="Выздоровевшие")

        ax.bar(index,
               deaths_daily,
               label='Смерти',
               alpha=0.3,
               bottom=recovered_daily)
        ax.plot(index,
                (recovered_daily + deaths_daily).rolling(window=7).mean(),
                label='Смерти-SMA7')

        self._setup_axes_for_russian_regions_stat(
            ax, "Статистика день ко дню", draw_key_dates=draw_key_dates)
예제 #6
0
def get_plot_formatting(ax: Axes, series: MediaResource, index_attribute: str,
                        by: str, scheme: str) -> Axes:
    """ Formats the plot aesthetics
		Parameters
		----------
		ax: matplotlib.axes._subplots.AxesSubplot
			The plot to format
		Returns
		----------
		ax : matplotlib.axes._subplots.AxesSubplot
	"""
    color_scheme_label = scheme if scheme in colorschemes else 'graphtv'
    color_scheme = colorschemes.get(color_scheme_label)
    background_color = color_scheme.background
    tick_color = color_scheme.ticks
    font_color = color_scheme.font

    ax.patch.set_facecolor(background_color)
    ax.spines['bottom'].set_color(background_color)
    ax.spines['top'].set_color(background_color)
    ax.spines['left'].set_color(background_color)
    ax.spines['right'].set_color(background_color)

    # change the label colors
    [i.set_color(tick_color) for i in plt.gca().get_yticklabels()]
    [i.set_color(tick_color) for i in plt.gca().get_xticklabels()]

    # Change tick size
    ax.tick_params(axis='y', which='major', labelsize=22)
    ax.tick_params(axis='y', which='minor', labelsize=22)
    ax.yaxis.grid(True)
    ax.xaxis.grid(False)

    # Add series parameters
    episode_indicies = [
        float(episode[index_attribute]) for season in series.seasons
        for episode in season
    ]

    # Set plot bounds

    x_min = min(episode_indicies)
    x_max = max(episode_indicies)

    if by == 'index':
        x_max += 1
    else:
        x_min -= 1 / 12
        x_max += 1 / 12

    ax.set_xlim((x_min, x_max))
    ax.set_ylim(ymax=10)

    plt.xlabel(index_attribute, fontsize=16, color=font_color)
    plt.ylabel('imdbRating', fontsize=16, color=font_color)
    plt.title(series.title, fontsize=24, color=font_color)

    return ax
예제 #7
0
    def _draw_stats_bar(self, ax: _figure.Axes, df: _pd.DataFrame):
        index = df.index
        confirmed = df.Confirmed_Change
        recovered = df.Recovered_Change
        deaths = df.Deaths_Change
        one_day = self.__dates.to_Timedelta(1)

        ax.bar(index, confirmed, label='Заболевшие', width=2)
        ax.bar(index + one_day, recovered, label='Выздоровевшие', width=2)
        ax.bar(index + one_day * 2, deaths, label='Смерти', width=2)
예제 #8
0
    def bar_with_sma_line(self,
                          ax: _figure.Axes,
                          values: _pd.Series,
                          sma_window: int = 7,
                          label: str = None,
                          bar_alpha: float = 0.3,
                          color: str = None):
        if label:
            ax.plot(values.rolling(window=sma_window).mean(),
                    label=label + "-SMA" + str(sma_window),
                    color=color)
        else:
            ax.plot(values.rolling(window=sma_window).mean(), color=color)

        ax.bar(values.index, values, alpha=bar_alpha, color=color)
예제 #9
0
    def key_russian_dates(self, ax: _figure.Axes):
        pass
        ''' Draws key dates from Russia on plot as vertical lines and spans. '''

        for (date, length, name) in self.__dates.get_key_russian_dates():
            x = None
            if length == self.__dates.to_Timedelta(1):
                ax.axvline(date, color='Gray', alpha=0.6)
                x = date
            else:
                ax.axvspan(date, date + length, color='Gray', alpha=0.6)
                x = date + length / 2

            ax.text(s=name,
                    x=x,
                    y=ax.get_ylim()[1] * .95,
                    rotation=90,
                    ha='center',
                    va='top',
                    fontsize='small',
                    bbox=dict(boxstyle="square,pad=0.6",
                              fc="white",
                              ec=(0.7, 0.7, 0.7),
                              lw=2))
예제 #10
0
 def write(self, x: Axes) -> DPTmpFile:
     f: Figure = x.get_figure()
     return super()._write_figure(f)
예제 #11
0
def _draw_subplot(subplot: Subplot, ax: Axes, start: datetime.datetime,
                  end: datetime.datetime):

    data_series = [line.load(start, end) for line in subplot.lines]
    for line, data in zip(subplot.lines, data_series):
        _draw_series(data, ax, line)

    if subplot.ylim:
        if np.isfinite(subplot.ylim[0]):
            ax.set_ylim(bottom=subplot.ylim[0])
        if np.isfinite(subplot.ylim[1]):
            ax.set_ylim(top=subplot.ylim[1])

    ax.set_ylabel(subplot.ylabel, fontsize=subplot.plot.fontsize(1.2))

    # Show log book entries for the logsite of this subplot
    # Draw only logs if logsite is a site of the subplot's lines
    if subplot.logsite in [l.siteid for l in subplot.lines]:
        # Traverse logs and draw them
        for logtime, logtype, logtext in subplot.get_logs():
            x = np.datetime64(logtime)
            ax.axvline(x, linestyle='-', color='r', alpha=0.5, linewidth=3)
            ax.text(x,
                    ax.get_ylim()[0],
                    logtype,
                    ha='left',
                    va='bottom',
                    fontsize=subplot.plot.fontsize(0.9))

    ax.set_xlim(subplot.plot.start, subplot.plot.end)

    for xtl in ax.get_xticklabels():
        xtl.set_rotation(15)
    ax.yaxis.set_major_locator(MaxNLocator(prune='upper'))
    ax.tick_params(axis='both',
                   which='major',
                   labelsize=subplot.plot.fontsize(1.1))

    ax.grid()
    ax.legend(loc=0, prop=dict(size=subplot.plot.fontsize(1)))