Exemple #1
0
def format_axes(
    ax: plt.Axes,
    ticklabels,
    patch_size,
    total_size,
    block_size,
):
    robust_models = [1, 2]
    ax_mm_ticks = np.arange(0, total_size, block_size)
    ax.set_xticks(ax_mm_ticks - 0.5)
    ax.set_yticks(ax_mm_ticks - 0.5)
    ax.set_xticklabels(ax_mm_ticks)
    ax.set_yticklabels(ax_mm_ticks)
    ax.grid(which='major', axis='both', lw=1, color='k', alpha=0.5, ls='-')
    ax.set_xticks(ax_mm_ticks + block_size / 2, minor=True)
    ax.set_yticks(ax_mm_ticks + block_size / 2, minor=True)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_xticklabels(ticklabels, minor=True, fontsize=12, weight='light')
    ax.set_yticklabels(ticklabels, minor=True, fontsize=12, weight='light')
    plt.setp(ax.get_yticklabels(minor=True),
             rotation=90,
             ha="center",
             va="bottom",
             rotation_mode="anchor")
    # ax.set_title(f'Patch Size {patch_size}', fontsize=1, pad=10)
    ax.set_xlabel('Models', fontsize=14)
    ax.set_ylabel('Models', fontsize=14)
    ax.xaxis.labelpad = 7
    ax.yaxis.labelpad = 7
Exemple #2
0
def full_extent(fig: plt.Figure, ax: plt.Axes, *extras, pad=0.01):
    """ Get the full extent of an axes, including axes labels, tick labels, and
    titles.
    
    """
    # For text objects, we need to draw the figure first, otherwise the extents
    # are undefined.
    ax.figure.canvas.draw()
    items = [ax, ax.title, ax.xaxis.label, ax.yaxis.label, *extras]
    items += [*ax.get_xticklabels(), *ax.get_yticklabels()]
    items += [e.xaxis.label for e in extras if hasattr(e, 'xaxis')]
    items += [e.yaxis.label for e in extras if hasattr(e, 'yaxis')]
    items += sum((e.get_xticklabels()
                  for e in extras if hasattr(e, 'get_xticklabels')), [])
    items += sum((e.get_yticklabels()
                  for e in extras if hasattr(e, 'get_yticklabels')), [])
    bbox = Bbox.union([
        item.get_window_extent() for item in items
        if hasattr(item, 'get_window_extent')
    ])

    bbox = bbox.expanded(1.0 + pad, 1.0 + pad)
    bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
    return bbox
Exemple #3
0
def matrix(values: np.ndarray,
           row_labels: Sequence[str] = None,
           col_labels: Sequence[str] = None,
           row_seps: Union[int, Collection[int]] = None,
           col_seps: Union[int, Collection[int]] = None,
           cmap="RdBu",
           fontcolor_thresh=0.5,
           norm: plt.Normalize = None,
           text_len=4,
           omit_leading_zero=False,
           trailing_zeros=False,
           grid=True,
           angle_left=False,
           cbar=True,
           cbar_label: str = None,
           ax: plt.Axes = None,
           figsize: Tuple[int, int] = None,
           cellsize=0.65,
           title: str = None):
    cmap = get_cmap(cmap)

    # Create figure if necessary.
    if ax is None:
        if figsize is None:
            # Note the extra width factor for the colorbar.
            figsize = (cellsize * values.shape[1] * (1.2 if cbar else 1),
                       cellsize * values.shape[0])
        ax = plt.figure(figsize=figsize).gca()

    # Set title if applicable.
    if title is not None:
        ax.set_title(title)

    if row_seps is not None:
        values = np.insert(values, row_seps, np.nan, axis=0)
        if row_labels is not None:
            row_labels = np.insert(row_labels, row_seps, "")
    if col_seps is not None:
        values = np.insert(values, col_seps, np.nan, axis=1)
        if col_labels is not None:
            col_labels = np.insert(col_labels, col_seps, "")

    # Plot the heatmap.
    im = ax.matshow(values, cmap=cmap, norm=norm)

    # Plot the text annotations showing each cell's value.
    norm_values = im.norm(values)
    for row, col in product(range(values.shape[0]), range(values.shape[1])):
        val = values[row, col]
        if not np.isnan(val):
            # Find text color.
            bg_color = cmap(norm_values[row, col])[:3]
            luma = 0.299 * bg_color[0] + 0.587 * bg_color[
                1] + 0.114 * bg_color[2]
            color = "white" if luma < fontcolor_thresh else "black"

            # Plot cell text.
            annotation = _format_value(val, text_len, omit_leading_zero,
                                       trailing_zeros)
            ax.text(col,
                    row,
                    annotation,
                    ha="center",
                    va="center",
                    color=color)

    # Add ticks and labels.
    if col_labels is None:
        ax.set_xticks([])
    else:
        col_labels = np.asarray(col_labels)
        labeled_cols = np.where(col_labels)[0]
        ax.set_xticks(labeled_cols)
        ax.set_xticklabels(col_labels[labeled_cols])
    if row_labels is None:
        ax.set_yticks([])
    else:
        row_labels = np.asarray(row_labels)
        labeled_rows = np.where(row_labels)[0]
        ax.set_yticks(labeled_rows)
        ax.set_yticklabels(row_labels[labeled_rows])

    ax.tick_params(which="major", bottom=False)

    plt.setp(ax.get_xticklabels(),
             rotation=40,
             ha="left",
             rotation_mode="anchor")

    # Turn off spines.
    for edge, spine in ax.spines.items():
        spine.set_visible(False)

    # Rotate the left labels if applicable.
    if angle_left:
        plt.setp(ax.get_yticklabels(),
                 rotation=40,
                 ha="right",
                 rotation_mode="anchor")

    # Create the white grid if applicable.
    if grid:
        # Extra ticks required to avoid glitch.
        xticks = np.concatenate([[-0.56],
                                 np.arange(values.shape[1] + 1) - 0.5,
                                 [values.shape[1] - 0.44]])
        yticks = np.concatenate([[-0.56],
                                 np.arange(values.shape[0] + 1) - 0.5,
                                 [values.shape[0] - 0.44]])
        ax.set_xticks(xticks, minor=True)
        ax.set_yticks(yticks, minor=True)
        ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
        ax.tick_params(which="minor", bottom=False, top=False, left=False)

    # Create the colorbar if applicable.
    if cbar:
        bar = ax.figure.colorbar(im, ax=ax)
        bar.ax.set_ylabel(cbar_label)
        fmt = bar.ax.yaxis.get_major_formatter()
        if isinstance(fmt, FixedFormatter):
            fmt.seq = [
                _format_value(
                    eval(
                        re.sub(r"[a-z$\\{}]", "",
                               label.replace("times", "*").replace(
                                   "^", "**"))), text_len, omit_leading_zero,
                    trailing_zeros) if label else "" for label in fmt.seq
            ]
Exemple #4
0
    def plot(self,
             plot_these: List[np.ndarray],
             ax: plt.Axes = None,
             fout: str = None,
             poly=True,
             root: str = "",
             title: str = None,
             aspect: int = 3,
             ticks: list = [10, 5],
             grid: bool = False):
        """
        Plot the lattice coordinates or lattice as an array

        Arguments:
            plot_these - will take up to 2 sets of lattice and crossover coordinates
            ax - an axes
            fout - saves the plot as a png with the name of `fout`
            poly - True will plot the initial polygon vertices
            root - directory to save file in, defaults to saving in current directory
            title - this is the title of the plot
            aspect - this is the aspect ratio of the x and y axes

        """
        assert len(
            plot_these) <= 4, "Max no. of plots on one axis reached, 4 or less"

        if not ax:
            fig, ax = plt.subplots()
        if grid:
            plt.grid(True)
        for label in ax.get_xticklabels() + ax.get_yticklabels():
            label.set_fontsize(4)
        ax.xaxis.set_major_locator(MultipleLocator(ticks[0]))
        ax.yaxis.set_major_locator(MultipleLocator(ticks[1]))
        ax.set_xlabel("No. of nucleotides")
        ax.set_ylabel("No. of strands")

        point_style = itertools.cycle(["ko", "b.", "r.", "cP"])
        point_size = itertools.cycle([0.5, 2.5])
        for points in plot_these:
            if np.shape(points)[1] not in [2, 3]:  # if array
                nodes = self.array_to_coords(points)
            else:
                nodes = points
            # Lattice sites then crossover sites
            ax.plot(nodes[:, 0],
                    nodes[:, 1],
                    next(point_style),
                    ms=next(point_size),
                    alpha=0.25)

        if poly:
            self.plotPolygon(ax, plot_these[0], coords=True)
        if title:
            ax.set_title(f"{title}")

        plt.gca().set_aspect(aspect)
        if fout:
            plt.savefig(f"{root}{fout}.png", dpi=500)
        if not ax:
            plt.show()
Exemple #5
0
def plot_f1_metrics(run_histories: List[Run],
                    experiment_name: str,
                    metric_names: Dict[str, str],
                    target_file_paths: Optional[List[str]] = None,
                    axis: plt.Axes = None,
                    y_lims: Tuple[float, float] = None,
                    mode='box',
                    add_legend=True,
                    color=None):
    standalone_mode = axis is None

    if color is None:
        # by default we use bright orange from map 'tab20c'
        color = plt.get_cmap('tab20c')(4)

    runs_data = []
    for run_history in run_histories:
        metrics_data = []
        for metric_name in metric_names.keys():
            metric_data = []
            for fold_history in run_history.fold_histories:
                # add the metric for the last epoch
                metric_data.append(fold_history.epochs[-1].metrics[metric_name])
            metric_data = np.array(metric_data)
            if mode == 'bar':
                metric_data = np.mean(metric_data)
            metrics_data.append(metric_data)
        runs_data.append(metrics_data)

    with plt.style.context('seaborn'):
        if axis is None:
            fig: plt.Figure = plt.figure()
            fig.suptitle(experiment_name, fontsize=24)
            axis = fig.add_subplot(111)
        else:
            axis.set_title(experiment_name, fontsize=22)
        axis.set_ylabel('Metric Value', fontsize=22)
        if y_lims is not None:
            print('limits', y_lims)
            axis.set_ylim(*y_lims)
        num_metrics = None
        num_runs = len(runs_data)
        for i, metrics_data in enumerate(runs_data):
            num_metrics = len(metrics_data)
            xs = range(i * len(metrics_data) + i, (i + 1) * len(metrics_data) + i)

            max_v = .9
            min_v = .6
            colors = []
            for idx in range(num_metrics):
                if num_metrics > 1:
                    norm = idx * (max_v - min_v) / (num_metrics - 1)
                else:
                    norm = 0
                fill_color = list(colorsys.rgb_to_hls(*mc.to_rgb(color)))
                fill_color[1] = min_v + norm
                colors.append((
                    color,
                    colorsys.hls_to_rgb(*fill_color)
                ))
            line_styles = ['-', '-.', ':', '--']

            if mode == 'box':
                boxplots = axis.boxplot(
                    metrics_data,
                    meanline=True,
                    showmeans=True,
                    positions=xs,
                    widths=0.6,
                    patch_artist=True
                )

                for plot_idx in range(num_metrics):
                    dark_color = colors[plot_idx][0]
                    light_color = colors[plot_idx][1]

                    plt.setp(boxplots['boxes'][plot_idx], color=dark_color)
                    plt.setp(boxplots['boxes'][plot_idx], facecolor=light_color)
                    plt.setp(boxplots['boxes'][plot_idx], linestyle=line_styles[plot_idx])

                    plt.setp(boxplots['whiskers'][plot_idx * 2], color=dark_color)
                    plt.setp(boxplots['whiskers'][plot_idx * 2 + 1], color=dark_color)
                    plt.setp(boxplots['whiskers'][plot_idx * 2], linestyle=line_styles[plot_idx])
                    plt.setp(boxplots['whiskers'][plot_idx * 2 + 1], linestyle=line_styles[plot_idx])

                    plt.setp(boxplots['caps'][plot_idx * 2], color=dark_color)
                    plt.setp(boxplots['caps'][plot_idx * 2 + 1], color=dark_color)

                    plt.setp(boxplots['fliers'][plot_idx], markeredgecolor=dark_color)
                    plt.setp(boxplots['fliers'][plot_idx], marker='x')

                    plt.setp(boxplots['medians'][plot_idx], color=dark_color)
                    plt.setp(boxplots['means'][plot_idx], color=dark_color)

                    legend_styles = [boxplots['boxes'][idx] for idx in range(num_metrics)]
            elif mode == 'bar':
                legend_styles = []
                for plot_idx in range(num_metrics):
                    ret = axis.bar(xs[plot_idx], metrics_data[plot_idx],
                                   color=colors[plot_idx][1],
                                   edgecolor=colors[plot_idx][0],
                                   width=0.6,
                                   linewidth=1.25,
                                   linestyle=line_styles[plot_idx], )
                    legend_styles.append(ret)

        tick_offset = num_metrics * 0.5 - 0.5
        ticks = np.arange(start=tick_offset, stop=num_runs * num_metrics + num_runs + tick_offset,
                          step=num_metrics + 1.0)
        axis.set_xticks(ticks)
        for yticklabel in axis.get_yticklabels():
            yticklabel.set_fontsize(20)
        axis.set_xticklabels([r.name for r in run_histories], fontsize=20, rotation=0)
        if add_legend:
            axis.legend(legend_styles, metric_names.values(),
                        loc='lower right', fontsize=16,
                        facecolor="white", frameon=True,
                        edgecolor="black")

        if standalone_mode:
            fig.show()
            if target_file_paths is not None:
                for target_file_path in target_file_paths:
                    fig.savefig(target_file_path)
        return legend_styles