Ejemplo n.º 1
0
def test_hbox_divider():
    arr1 = np.arange(20).reshape((4, 5))
    arr2 = np.arange(20).reshape((5, 4))

    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.imshow(arr1)
    ax2.imshow(arr2)

    pad = 0.5  # inches.
    divider = HBoxDivider(
        fig,
        111,  # Position of combined axes.
        horizontal=[Size.AxesX(ax1),
                    Size.Fixed(pad),
                    Size.AxesX(ax2)],
        vertical=[Size.AxesY(ax1),
                  Size.Scaled(1),
                  Size.AxesY(ax2)])
    ax1.set_axes_locator(divider.new_locator(0))
    ax2.set_axes_locator(divider.new_locator(2))

    fig.canvas.draw()
    p1 = ax1.get_position()
    p2 = ax2.get_position()
    assert p1.height == p2.height
    assert p2.width / p1.width == pytest.approx((4 / 5)**2)
Ejemplo n.º 2
0
def make_heights_equal(fig, rect, ax1, ax2, pad):
    # pad in inches

    h1, v1 = Size.AxesX(ax1), Size.AxesY(ax1)
    h2, v2 = Size.AxesX(ax2), Size.AxesY(ax2)

    pad_v = Size.Scaled(1)
    pad_h = Size.Fixed(pad)
def make_heights_equal(fig, rect, ax1, ax2, pad):
    # pad in inches
    divider = HBoxDivider(
        fig,
        rect,
        horizontal=[Size.AxesX(ax1),
                    Size.Fixed(pad),
                    Size.AxesX(ax2)],
        vertical=[Size.AxesY(ax1),
                  Size.Scaled(1),
                  Size.AxesY(ax2)])
    ax1.set_axes_locator(divider.new_locator(0))
    ax2.set_axes_locator(divider.new_locator(2))
Ejemplo n.º 4
0
def make_heights_equal(fig, rect, ax1, ax2, pad):
    # pad in inches

    h1, v1 = Size.AxesX(ax1), Size.AxesY(ax1)
    h2, v2 = Size.AxesX(ax2), Size.AxesY(ax2)

    pad_v = Size.Scaled(1)
    pad_h = Size.Fixed(pad)

    my_divider = HBoxDivider(fig, rect,
                             horizontal=[h1, pad_h, h2],
                             vertical=[v1, pad_v, v2])

    ax1.set_axes_locator(my_divider.new_locator(0))
    ax2.set_axes_locator(my_divider.new_locator(2))
Ejemplo n.º 5
0
def plot_anno_pred(anno_img, img, pred, save_path=None):
    min_det_score = 0.5
    boxes, scores = pred[:, :4].astype(np.int), pred[:, 4]
    valid_idx = scores >= min_det_score
    boxes = boxes[valid_idx]
    scores = scores[valid_idx]
    fig = plt.figure(figsize=(30, 15))
    fig.add_subplot(121)
    plt.imshow(anno_img[:, :, ::-1])
    fig.add_subplot(122)
    ax = plt.gca()
    plt.imshow(img[:, :, ::-1])
    # cmap=plt.cm.Wistia
    cmap = plt.cm.spring
    normal = plt.Normalize(0.5, max(scores))
    colors = cmap(scores)
    for box, c in zip(boxes, colors):
        rect = patches.Rectangle(box[:2],
                                 *box[2:],
                                 linewidth=2,
                                 edgecolor=c,
                                 facecolor='none')
        ax.add_patch(rect)
    #  cax, _ = cbar.make_axes(ax)
    from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
    pad_fraction = 0.5
    aspect = 20
    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1. / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax = divider.append_axes("right", size=width, pad=pad)
    cb2 = cbar.ColorbarBase(cax, cmap=cmap, norm=normal)
    if save_path is not None:
        fig.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
Ejemplo n.º 6
0
def plot_weightmap(ax1,weightmap,name,title=False):
        norm = np.size(weightmap)

        cmap = mpl.cm.seismic
        cmap.set_bad('k',1.)

        im = np.log10(weightmap.T*norm)
        pic = ax1.imshow(im,cmap=cmap, vmin=-2*np.nanmax(im),vmax=2*np.nanmax(im),
            interpolation='nearest',origin='lower')
        if title:
            plt.title(r'TV-min Weightmap %s' % name)

        # cbaraxes, kw = mpl.colorbar.make_axes(ax1,location='right',pad=0.01)
        # plt.colorbar(pic,cax=cbaraxes)

        # cbaraxes.yaxis.set_ticks_position('right')
        aspect = 20
        pad_fraction = 0.5

        divider = make_axes_locatable(ax1)
        width = axes_size.AxesY(ax1, aspect=1./aspect)
        pad = axes_size.Fraction(pad_fraction, width)
        cax = divider.append_axes("right", size=width, pad=pad)
        plt.colorbar(pic, cax=cax)

        ax1.yaxis.set_ticks_position('left')
Ejemplo n.º 7
0
    def _make_axes_grid(self):
        self.axes = self._get_main_axes()

        # Split up the current axes so there is space for start & stop buttons
        self.divider = make_axes_locatable(self.axes)
        pad = 0.01  # Padding between axes
        pad_size = Size.Fraction(pad, Size.AxesX(self.axes))
        large_pad_size = Size.Fraction(0.1, Size.AxesY(self.axes))

        # Define size of useful axes cells, 50% each in x 20% for buttons in y.
        small_x = Size.Fraction((1. - 2. * pad) / 10, Size.AxesX(self.axes))
        ysize = Size.Fraction((1. - 2. * pad) / 15., Size.AxesY(self.axes))

        button_grid = max((7, self.num_buttons))
        # Set up grid, 3x3 with cells for padding.
        if self.num_buttons > 0:
            xsize = Size.Fraction((1. - 2. * pad) / button_grid,
                                  Size.AxesX(self.axes))
            horiz = [xsize] + [pad_size, xsize] * (button_grid - 1)
            vert = [ysize, pad_size] * self.num_sliders + \
                   [large_pad_size, large_pad_size, Size.AxesY(self.axes)]
        else:
            vert = [ysize, large_pad_size] * self.num_sliders + \
                   [large_pad_size, Size.AxesY(self.axes)]
            horiz = [Size.Fraction(0.1, Size.AxesX(self.axes))] + \
                    [Size.Fraction(0.05, Size.AxesX(self.axes))] + \
                    [Size.Fraction(0.65, Size.AxesX(self.axes))] + \
                    [Size.Fraction(0.1, Size.AxesX(self.axes))] + \
                    [Size.Fraction(0.1, Size.AxesX(self.axes))]

        self.divider.set_horizontal(horiz)
        self.divider.set_vertical(vert)
        self.button_ny = len(vert) - 3

        # If we are going to add a colorbar it'll need an axis next to the plot
        if self.if_colorbar:
            nx1 = -3
            self.cax = self.fig.add_axes((0., 0., 0.141, 1.))
            locator = self.divider.new_locator(nx=-2, ny=len(vert) - 1, nx1=-1)
            self.cax.set_axes_locator(locator)
        else:
            # Main figure spans all horiz and is in the top (2) in vert.
            nx1 = -1

        self.axes.set_axes_locator(
            self.divider.new_locator(nx=0, ny=len(vert) - 1, nx1=nx1))
Ejemplo n.º 8
0
def plot_correlations(ax1, ax2, fig, traces_cc, footprints_cc):
    from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
    aspect = 20
    pad_fraction = 0.5
    #     im1 = ax1.imshow(traces_cc, cmap='BuGn', clim=[0,1], interpolation = 'nearest')
    #     im1 = ax1.imshow(traces_cc, cmap='BuGn', interpolation = 'nearest')
    im1 = ax1.imshow(traces_cc,
                     cmap='BrBG',
                     interpolation='nearest',
                     clim=[-1, 1])

    divider = make_axes_locatable(ax1)
    width = axes_size.AxesY(ax1, aspect=1. / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax1 = divider.append_axes("right", size=width, pad=pad)

    cbar = fig.colorbar(im1, ax=ax1, cax=cax1)
    cbar.set_label('zero norm. cross-correlation')
    ax1.set_xlabel('extracted traces')
    ax1.set_ylabel('ground truth traces')
    ax1.set_title('traces')
    ax1.set_xticks([])
    ax1.set_yticks([])

    #     im2 = ax2.imshow(footprints_cc, cmap='BuGn', clim=[0,1], interpolation = 'nearest')
    im2 = ax2.imshow(footprints_cc,
                     cmap='BrBG',
                     interpolation='nearest',
                     clim=[-1, 1])

    divider = make_axes_locatable(ax2)
    width = axes_size.AxesY(ax2, aspect=1. / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax2 = divider.append_axes("right", size=width, pad=pad)

    cbar = fig.colorbar(im2, ax=ax2, cax=cax2)
    cbar.set_label('zero norm. cross-correlation')
    ax2.set_xlabel('extracted footprints')
    ax2.set_ylabel('ground truth footprints')
    ax2.set_title('footprints')
    ax2.set_xticks([])
    ax2.set_yticks([])
    fig.tight_layout()
Ejemplo n.º 9
0
def plot_corrcoef(correlation_coefficient_matrix,
                  axes,
                  correlation_minimum=-1.,
                  correlation_maximum=1.,
                  colormap='bwr',
                  color_bar_aspect=20,
                  color_bar_padding_fraction=.5):
    """
    Plots the cross-correlation matrix returned by
    :py:func:`elephant.spike_train_correlation.corrcoef` function and adds a
    color bar.

    Parameters
    ----------
    correlation_coefficient_matrix : np.ndarray
        Pearson's correlation coefficient matrix
    axes : object
        Matplotlib figure Axes
    correlation_minimum : float
        minimum correlation for colour mapping. Default: -1
    correlation_maximum : float
        maximum correlation for colour mapping. Default: 1
    colormap : str
        colormap. Default: 'bwr'
    color_bar_aspect : float
        aspect ratio of the color bar. Default: 20
    color_bar_padding_fraction : float
        padding between matrix plot and color bar relative to color bar width.
        Default: .5

    Examples
    --------
    Create correlation coefficient matrix from Elephant `corrcoef` example
    and save the result to `corrcoef_matrix`.

    >>> import seaborn
    >>> seaborn.set_style('ticks')
    >>> fig, ax = plt.subplots(1, 1, subplot_kw={'aspect': 'equal'})
    ...
    >>> plot_corrcoef(correlation_coefficient_matrix, axes=ax)

    """

    image = axes.imshow(correlation_coefficient_matrix,
                        vmin=correlation_minimum,
                        vmax=correlation_maximum,
                        cmap=colormap)

    # Initialise colour bar axis
    divider = make_axes_locatable(axes)
    width = axes_size.AxesY(axes, aspect=1. / color_bar_aspect)
    pad = axes_size.Fraction(color_bar_padding_fraction, width)
    cax = divider.append_axes("right", size=width, pad=pad)

    plt.colorbar(image, cax=cax)
Ejemplo n.º 10
0
def cbar(ax, aspect=20, pad_fraction=0.5, **cbar_kwargs):
    from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
    from matplotlib.pyplot import colorbar
    from matplotlib.colorbar import ColorbarBase

    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1. / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax = divider.append_axes("right", size=width, pad=pad)
    c = ColorbarBase(cax, **cbar_kwargs)
    # c = colorbar(im, cax=cax, norm=norm, **cbar_kwargs)
    return c
Ejemplo n.º 11
0
def plot_single_gwas_scatter(cur_ds, df_all_hg_pval, df_all_hg_size):
    fig, ax = plt.subplots(figsize=(15, 15))

    x = np.arange(len(df_all_hg_pval.columns))
    y = np.arange(len(df_all_hg_pval.index))
    xx, yy = zip(*[(a, b) for a in x for b in y
                   if df_all_hg_pval.iloc[b, a] != 0])
    c = [
        df_all_hg_pval.iloc[b, a] for a in x for b in y
        if df_all_hg_pval.iloc[b, a] != 0
    ]
    s = [
        df_all_hg_size.iloc[b, a] for a in x for b in y
        if df_all_hg_pval.iloc[b, a] != 0
    ]
    s = np.array(s)
    # im = plt.imshow(np.array(c).reshape(len(c), 1), cmap='bwr')
    # im.remove()
    sc = ax.scatter(xx,
                    yy,
                    s / float(max(s)) * 1000,
                    c=c,
                    cmap='bwr',
                    vmin=np.percentile(c, 10),
                    vmax=np.percentile(c, 90))
    for s_i, a in enumerate(s):
        ax.annotate("%.2f" % a, (xx[s_i], yy[s_i]))
    ax.legend(loc='upper left')
    ax.margins(0.03, 0.03)
    # ax.locator_params(axis='x', nbins=len(df_all_hg_pval.columns))
    # ax.locator_params(axis='y', nbins=len(df_all_hg_pval.index))
    ax.set_xlabel("modules")
    plt.subplots_adjust(left=0.25, right=0.99, top=0.99, bottom=0.05)
    plt.xticks(np.arange(len(df_all_hg_pval.columns)),
               tuple(list(df_all_hg_pval.columns.values)),
               rotation='vertical')
    ax.set_ylabel("GO terms")
    plt.yticks(np.arange(len(df_all_hg_pval.index)),
               tuple(list(df_all_hg_pval.index.values)))
    ax_ = plt.gca()
    aspect = 20
    pad_fraction = 0.5
    divider = make_axes_locatable(ax_)
    width = axes_size.AxesY(ax_, aspect=1. / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax = divider.append_axes("right", size=0.3, pad=0.4)
    plt.colorbar(mappable=sc, cax=cax)
    plt.tight_layout()
    plt.savefig(
        os.path.join(constants.OUTPUT_GLOBAL_DIR,
                     "go_to_modules_{}.png".format(cur_ds)))
    plt.cla()
Ejemplo n.º 12
0
def plot_grad_img(named_parameters, save_path):
    r"""
    Plot gardient of every layer for debugging.
    Args:
        named_parameters (layers):
        save_path (string): Path to save the plots.

    """
    # format names for better visualization
    with sns.axes_style("whitegrid", {'axes.grid': False}):
        aspect = 20
        pad_fraction = 0.5
        for n, p in named_parameters:
            if (p.requires_grad) and ("bias" not in n):
                fig, axs = plt.subplots(1, 2)

                im = axs[0].imshow(p.abs().cpu().data.numpy())
                weight_label = format2latex('$' + n + '_weight$')
                axs[0].title.set_text(weight_label)
                divider = make_axes_locatable(axs[0])
                width = axes_size.AxesY(axs[0], aspect=1. / aspect)
                pad = axes_size.Fraction(pad_fraction, width)
                cax = divider.append_axes("right", size=width, pad=pad)
                fig.colorbar(im, cax=cax, use_gridspec=True)

                im = axs[1].imshow(p.grad.abs().cpu().data.numpy())
                grad_label = format2latex('$' + n + '_grad$')
                axs[1].title.set_text(grad_label)
                divider = make_axes_locatable(axs[1])
                width = axes_size.AxesY(axs[1], aspect=1. / aspect)
                pad = axes_size.Fraction(pad_fraction, width)
                cax = divider.append_axes("right", size=width, pad=pad)

                fig.colorbar(im, cax=cax, use_gridspec=True)
                fig.set_tight_layout(True)
                plt.savefig(save_path + n + '.pdf')
                plt.close()
Ejemplo n.º 13
0
def plot_anno_pred(anno_img, img, pred, save_path=None):
    #     plot_anno_pred(anno, img, pred, save_path=save_path)
    min_det_score = 0.5  # 最低阈值
    pred = np.array(pred)
    boxes = pred[:, :, :3].astype(np.int)  # 预测到的box
    scores = pred[:, 4]  # 预测到的分数
    valid_idx = (scores >= min_det_score)  # 判断有效框
    print(valid_idx)
    print(boxes)
    boxes = boxes[valid_idx]
    #scores = scores[valid_idx]  # 获得有效框的box和分数
    fig = plt.figure(figsize=(30, 15))  # 画图
    fig.add_subplot(121)  # 1,2,1 一行分成两列,第一个子图
    #plt.imshow(anno_img[:,:,::-1])  # 接收标注图像  ::-1 倒序输出
    fig.add_subplot(122)  # 第二个子图
    ax = plt.gca()  # 挪动坐标轴
    plt.imshow(img[:, :, ::-1])  # 接收原图像
    # cmap=plt.cm.Wistia
    cmap = plt.cm.spring  # 获取不同的颜色color map
    normal = plt.Normalize(min_det_score,
                           max(scores))  # 归一化[0,1],最小值为最低阈值,最大值为score最大值
    colors = cmap(scores)  # 不同分数对应不同颜色
    print(zip(boxes, colors))
    print("#################")
    for box, c in zip(boxes, colors):  # 预测值化为了四个值+color
        rect = patches.Rectangle(box[:2],
                                 *box[2:],
                                 linewidth=2,
                                 edgecolor=c,
                                 facecolor='none')
        ax.add_patch(rect)
    #  cax, _ = cbar.make_axes(ax)
    from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size
    pad_fraction = 0.5
    aspect = 20
    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1. / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax = divider.append_axes("right", size=width, pad=pad)
    cb2 = cbar.ColorbarBase(cax, cmap=cmap, norm=normal)

    if save_path is not None:
        fig.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=200)
Ejemplo n.º 14
0
def plot_fluxmap(ax1,image,name,title=False):

        cmap = mpl.cm.hot
        cmap.set_bad('k',1.)
        im = np.log10(image)
        pic = ax1.imshow(im,cmap=cmap, vmax=np.nanmax(im),
            interpolation='nearest',origin='lower')
        if title:
            plt.title(r'%s Flux Map' % name)

        aspect = 20
        pad_fraction = 0.5

        divider = make_axes_locatable(ax1)
        width = axes_size.AxesY(ax1, aspect=1./aspect)
        pad = axes_size.Fraction(pad_fraction, width)
        cax = divider.append_axes("left", size=width, pad=pad)
        plt.colorbar(pic, cax=cax)

        ax1.yaxis.set_ticks_position('right')
        
        cax.yaxis.set_ticks_position('left')
Ejemplo n.º 15
0
def add_controls(axes=None, slider=False):
    """ Adds Start/Stop controls to an axes having been given a animation
    instance. """

    #If No axes specified use current axes.
    if not axes:
        axes = plt.gca()
    fig = axes.get_figure()

    #Split up the current axes so there is space for a start and a stop button
    divider = make_axes_locatable(axes)
    pad = 0.1  # Padding between axes
    pad_size = Size.Fraction(pad, Size.AxesX(axes))

    #Define size of usefult axes cells, 50% each in x 20% for buttons in y.
    xsize = Size.Fraction((1. - 2. * pad) / 3., Size.AxesX(axes))
    ysize = Size.Fraction((1. - 2. * pad) / 15., Size.AxesY(axes))

    #Set up grid, 3x3 with cells for padding.
    divider.set_horizontal([xsize, pad_size, xsize, pad_size, xsize])
    if slider:
        divider.set_vertical(
            [ysize, pad_size, ysize, pad_size,
             Size.AxesY(axes)])
        bny = 2
    else:
        divider.set_vertical([ysize, pad_size, Size.AxesY(axes)])
        bny = 0

    #Main figure spans all horiz and is in the top (2) in vert.
    axes.set_axes_locator(
        divider.new_locator(0, len(divider.get_vertical()) - 1, nx1=-1))

    #Add two axes for buttons and make them 50/50 spilt at the bottom.
    bax1 = fig.add_axes((0., 0., 1., 1.))
    locator = divider.new_locator(nx=0, ny=bny)
    bax1.set_axes_locator(locator)
    bax2 = fig.add_axes((0., 0., 0.8, 1.))
    locator = divider.new_locator(nx=2, ny=bny)
    bax2.set_axes_locator(locator)
    bax3 = fig.add_axes((0., 0., 0.7, 1.))
    locator = divider.new_locator(nx=4, ny=bny)
    bax3.set_axes_locator(locator)

    start = widgets.Button(bax1, "Start")
    stop = widgets.Button(bax2, "Stop")
    step = widgets.Button(bax3, "Step")
    #Make dummy refernce to prevent garbage collection
    bax1._button = start
    bax2._button = stop
    bax3._button = step

    if slider:
        bax4 = fig.add_axes((0., 0., 0.6, 1.))
        locator = divider.new_locator(nx=0, ny=0, nx1=-1)
        bax4.set_axes_locator(locator)
        sframe = widgets.Slider(bax4, 'Frame', 0, 10, valinit=0, valfmt='%i')
        bax4._slider = sframe

        return axes, bax1, bax2, bax3, bax4
    return axes, bax1, bax2, bax3
Ejemplo n.º 16
0
def userInput(ks_db, graphics, df, min_sub):

    import io
    from io import StringIO
    import pandas as pd
    import scipy.stats as st
    import numpy as np

    if graphics == "yes":
        import matplotlib as mpl
        mpl.use('Agg')
        import matplotlib.pyplot as plt
        from mpl_toolkits.axes_grid1 import axes_size
        from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
        from mpl_toolkits.axes_grid1.axes_size import AxesY, Fraction
        from mpl_toolkits.axes_grid1.colorbar import colorbar
        import seaborn as sns

    # Function used to convert kinase z-scores to corresponding p-values.
    # It is assmued z-scores are normally distributed.
    def getpValue(z):
        if z < 0:
            p = st.norm.cdf(z)
            return p
        else:
            dist = st.norm.cdf(z)
            p = 1.0 - dist
            return p

    # Function to calculate the Z-score for each kinase.
    def zScore(mean_kin, mean_all, sub_num, sd):
        z = (mean_kin - mean_all) * sub_num**(1 / 2) / sd
        return z

    # Convert JSON-string datatset back into a DF.
    f = pd.read_json(df, orient="split")

    # User data is passed from the server and parsed as appropriate.
    user_file = f.values.tolist()
    header = f.columns.values.tolist()
    col_length = len(header)
    array = []
    for line in user_file:
        if line == "":
            continue
        else:
            if "" not in line:
                array.append(line)

    kinase_dic = {}
    dic = {}
    ks_links = []
    ks_info = []
    pval_map = []
    heatmap_array = []
    zscore_info = []
    z_columns = ["Kinase", "Sub.Count"]
    ks_columns = ["Kinase", "Site", "Site.Seq(+/- 7AA)", "Source"]
    heatmap_col = ["Kinase"]
    # Columns 1 and onwards represent samples (e.g. cell lines).
    # For the current column (sample) a set of operations is performed.
    for col in range(1, col_length):
        # Reset the values in each dictionary for a new column.
        for key in dic:
            dic[key] = []
        for kin in kinase_dic:
            kinase_dic[kin] = []
        # Column names for relevant dataframes are created here dynamically.
        # curr_col is current sample/column name.
        curr_col = header[col]
        z_columns.append("mnlog2(FC)." + curr_col)
        z_columns.append("zSc." + curr_col)
        z_columns.append("pVal." + curr_col)
        ks_columns.append("log2(FC)." + curr_col)
        heatmap_col.append(curr_col)

        data = []
        all_log2 = []
        # Multiple phosphosites separated by a colon are split here.
        # This ensures each phosphosite substrate and the log2(FC) value starts with a new line.
        # This is ran for each sample in turn.
        for n in range(0, len(array)):
            site = array[n][0].upper()
            fc = array[n][col]
            site = site.split(";")
            for s in site:
                if s == '':
                    continue
                else:
                    data.append([s, float(fc)])

        # Mapping of phosphosite substrate keys to their (often multiple) log2(FCs) is achieved here.
        for entry in data:
            site = entry[0]
            fc = entry[1]
            if site not in dic:
                dic[site] = [fc]
            else:
                dic[site].append(fc)

        # If the same phosphosite has been detected more than once, its mean log2(FC) is calculated.
        # Final dictionary contains unique phosphosites and individual log2(FC) values, averaged where appropriate.
        for key in dic:
            length = len(dic[key])
            mean_fc = sum(dic[key]) / length
            dic[key] = float(mean_fc)

        # For each sample, the mean log2(FC) and standard deviation of all phosphosites in the dataset are calculated here. These values will be used to obtain a z-score for each identified kinase later on.
        for key in dic:
            all_log2.append(dic[key])
        all_mean = sum(all_log2) / float(len(all_log2))
        all_std = np.std(all_log2)

        # Each phosphosite in the dictionary is scanned against the K-S db.
        # If a match is found, relevant information for that phosphosite is retained.
        # Scanning is only done for the first column.
        if col == 1:
            for x in dic:
                for y in ks_db:
                    if x == y[0]:
                        # ks_links will be used to assign the current sample's log2(FCs) to each kinase later on.
                        ks_links.append([y[1], y[0], y[2], y[3], dic[x]])
                        # ks_info will contain kinase-substrate relationship info for each sample.
                        ks_info.append([y[1], y[0], y[2], y[3], dic[x]])
        # Once the first column is passed, new log2(FCs) are removed and/or appended to the original arrays for each sample.
        elif col > 1:
            for s in ks_links:
                s.remove(s[-1])
                s.append(dic[s[1]])
            for k in ks_info:
                k.append(dic[k[1]])

        # A dictionary containing unique kinases and substrate log2(FCs) is created.
        # If the same kinase was identified for multiple substrates, multiple log2(FCs) are appended to the dictionary values.
        for match in ks_links:
            kinase = match[0]
            log2fc = match[4]
            if kinase not in kinase_dic:
                kinase_dic[kinase] = [log2fc]
            else:
                kinase_dic[kinase].append(log2fc)

        # The dictionary is used to calculate the number of substrates identified for each kinase.
        # It also calculates the mean log2(FC) across each kinase's substrates.
        # The algorithm then computes the z-score.
        # A new array stores kinase gene, no. of substrates, mean log2(FC), z-score and p-value.
        condition_ind = -1
        index = -1
        for key in kinase_dic:
            index += 1
            substrate_num = len(kinase_dic[key])
            kin_fc_mean = sum(kinase_dic[key]) / float(substrate_num)
            z_score = zScore(kin_fc_mean, all_mean, substrate_num, all_std)
            z_pval = getpValue(z_score)
            # For the heatmap, only z-scores for the kinases with the minimum substrate count specified by the user are extracted.
            if substrate_num >= min_sub:
                condition_ind += 1
                if col == 1:
                    # An array of kinases and z-scores corresponding to multiple samples. Used for the heatmap.
                    heatmap_array.append([key, z_score])
                    # An array of p-values for each z-score across all samples. Used for heatmap annotation.
                    pval_map.append([z_pval])
                else:
                    heatmap_array[condition_ind].append(z_score)
                    pval_map[condition_ind].append(z_pval)
            if col == 1:
                # KSEA results stored here.
                zscore_info.append(
                    [key, substrate_num, kin_fc_mean, z_score, z_pval])
            # If the program has gone past the first condition column, kin_fc_mean, z_score and zpval are appended in a repeating manner to the original array.
            else:
                zscore_info[index].append(kin_fc_mean)
                zscore_info[index].append(z_score)
                zscore_info[index].append(z_pval)

    # p-values for the heatmap annotation are extracted from a nested list into a flat list.
    pvalues = []
    for entry in pval_map:
        for pval in entry:
            pvalues.append(pval)

    # Score and kinase-substrate relationships DFs are generated.
    zscore_df = pd.DataFrame(zscore_info, columns=z_columns)
    ks_df = pd.DataFrame(ks_info, columns=ks_columns)
    # Heatmap DFs for the heatmap generation.
    heatmap_df = pd.DataFrame(heatmap_array, columns=heatmap_col)
    heatmap_df = heatmap_df.set_index("Kinase")

    # Heatmap only generated if the user chose to produce graphics during file upload.
    if graphics == "no":
        svg_fig = "Heatmap was not generated for this analysis."
    elif graphics == "yes":
        # Set the margins and square height for a single category.
        topmargin = 0.1  #inches
        bottommargin = 0.1  #inches
        categorysize = 0.35  # inches
        # Number of kinases identified.
        n = len(heatmap_array)
        leftmargin = 0.1
        rightmargin = 0.1
        catsize = 0.5
        # Number of conditions (e.g. cell lines).
        m = len(heatmap_col) - 1

        # Parameters for color bar.
        aspect = n
        pad_fraction = 0.7

        # Calculate a dynamic figure height.
        figheight = topmargin + bottommargin + (n + 1) * categorysize

        # Calculate a dynamic figure width.
        figwidth = leftmargin + rightmargin + (m + 1) * catsize

        fig, ax = plt.subplots(figsize=(figwidth, figheight))

        # Format the axes.
        ax.xaxis.set_ticks_position('top')
        plt.yticks(fontsize=6)
        plt.xticks(fontsize=6)

        # Plot the heatmap.
        ax = sns.heatmap(heatmap_df,
                         cmap='coolwarm',
                         annot=True,
                         fmt=".1f",
                         annot_kws={'size': 5},
                         cbar=False,
                         linewidths=0.3,
                         linecolor='white')

        # Format the colour bar dynamically.
        ax_div = make_axes_locatable(ax)
        width = axes_size.AxesY(ax, aspect=1. / aspect)
        pad = axes_size.Fraction(pad_fraction, width)
        cax = ax_div.append_axes('right', size=width, pad=pad)
        cb = plt.colorbar(ax.get_children()[0],
                          cax=cax,
                          orientation='vertical')
        cax.yaxis.set_ticks_position('right')
        cb.ax.tick_params(labelsize=6)
        cb.set_label('Z-Score', fontsize=6, labelpad=7)
        cb.outline.set_visible(False)

        #Remove y-axis label.
        ax.yaxis.set_label_text("")

        # Rotate the axis labels.
        for item in ax.get_yticklabels():
            item.set_rotation(0)

        for item in ax.get_xticklabels():
            item.set_rotation(90)

        # Annotate statistically significant scores with asterisks.
        # * for p < 0.05 and ** for p < 0.01.
        counter = -1
        for text in ax.texts:
            counter += 1
            if pvalues[counter] < 0.05 and pvalues[counter] >= 0.01:
                text.set_weight('bold')
                text.set_text(text.get_text() + "*")
            elif pvalues[counter] < 0.01:
                text.set_weight('bold')
                text.set_text(text.get_text() + "**")

        # Create a StringIO object and use it to write SVG figure data to string buffer.
        fig_file = StringIO()
        fig.savefig(fig_file, format='svg', bbox_inches="tight")
        # Seek beginning of the figure file.
        fig_file.seek(0)
        # Retrieve figure contents as a string.
        svg_fig = '<svg' + fig_file.getvalue().split('<svg')[1]
        # Free memory buffer.
        fig_file.close()

    # Convert results DFs into JSON strings.
    zscore_df = zscore_df.to_json(orient='split')
    ks_df = ks_df.to_json(orient='split')

    return zscore_df, ks_df, svg_fig
Ejemplo n.º 17
0
    def _ColorBar(self):

        try:
            if self.options.plot_type is not 0 and self.options.justvector is not 1:
                try:
                    ax = plt.gca()
                    divider = make_axes_locatable(ax)
                    try:
                        if self.options.decimalcolorbar is not None:
                            decimal = '%.' + str(
                                self.options.decimalcolorbar) + 'f'
                        else:
                            decimal = None
                    except Exception as ex:
                        logging.info(
                            ': Modulo Maps Functions : Error 012 : Failed to calculate decimacolorbar '
                            + ex)
                        logging.shutdown()
                        sys.exit()

                    try:
                        aux = np.float(self.options.scalarMax -
                                       self.options.scalarMin)
                        if self.options.LogMap == 1:
                            cba = plt.colorbar(self.pa)
                        else:
                            if self.options.colorbar == 1 or self.options.colorbar == 2:
                                try:
                                    if self.options.scalarMin >= self.options.scalarMax:
                                        logging.info(
                                            ': Modulo Maps Functions : Error 013a : Scalar ´Min is bigger then Max. '
                                            + ex)
                                        logging.shutdown()
                                        sys.exit()
                                except Exception as ex:
                                    logging.info(
                                        ': Modulo Maps Functions : Error 013a : Scalar ´Min is bigger then Max. '
                                        + ex)
                                    logging.shutdown()
                                    sys.exit()
                                v2 = np.linspace(
                                    self.options.scalarMin,
                                    self.options.scalarMax,
                                    round(
                                        (aux / self.options.colorbarSpacing) +
                                        1, 4))

                                if self.options.dynamic_limits == 1:
                                    v2 = np.linspace(self.options.scalarMin,
                                                     self.options.scalarMax,
                                                     11)

                                if len(v2) > 200:
                                    logging.info(
                                        ': Modulo Maps Functions : Error 013 : you requested more then 200 colorbar splits, please reduce'
                                    )
                                    logging.shutdown()
                                    sys.exit()

                                if self.options.orientationcolorbar == 'horizontal':
                                    width = axes_size.AxesX(ax, aspect=1 / 50)
                                    pad = axes_size.Fraction(4, width)
                                    cba = plt.colorbar(
                                        selfpa,
                                        ticks=v2,
                                        extend='max',
                                        spacing='proportional',
                                        orientation='horizontal',
                                        format=decimal)
                                elif self.options.orientationcolorbar == 'vertical':
                                    width = axes_size.AxesY(ax, aspect=1 / 35)
                                    pad = axes_size.Fraction(0.55, width)
                                    if self.options.plot_type == 2:
                                        divider = make_axes_locatable(ax)
                                        cax = divider.append_axes("right",
                                                                  size="5%",
                                                                  pad=0.05)
                                        cba = plt.colorbar(
                                            self.pa,
                                            cax=cax,
                                            ticks=v2,
                                            extend='max',
                                            spacing='proportional',
                                            orientation='vertical',
                                            format=decimal)
                                    elif self.options.plot_type == 1:
                                        width = axes_size.AxesY(ax,
                                                                aspect=1 / 35)
                                        pad = axes_size.Fraction(0.55, width)
                                        cba = plt.colorbar(
                                            self.pa,
                                            cax=divider.append_axes("right",
                                                                    size=0.085,
                                                                    pad=pad),
                                            ticks=v2,
                                            extend='max',
                                            spacing='proportional',
                                            orientation='vertical',
                                            format=decimal)
                    except Exception as ex:
                        logging.info(
                            ': Modulo Maps Functions : Error 013 : Failed to calculate width, pad and cba from colorbar. '
                            + ex)
                        logging.shutdown()
                        sys.exit()

                    try:
                        mpl.rcParams.update(
                            {'font.size': self.options.fontsize})
                        if self.options.legend is not None and self.options.colorbar != 0:
                            if self.options.plot_type == 2:
                                cba.set_label(self.options.legend,
                                              fontsize=self.options.fontsize,
                                              rotation=270,
                                              labelpad=7)
                                for t in cba.ax.get_yticklabels():
                                    t.set_horizontalalignment('right')
                                    t.set_x(2.2)

                            else:
                                cba.set_label(self.options.legend,
                                              fontsize=self.options.fontsize)
                    except Exception as ex:
                        logging.info(
                            ': Modulo Maps Functions : Error 014 : Failed to draw colorbar title. '
                            + ex)
                        logging.shutdown()
                        sys.exit()
                except Exception as ex:
                    logging.info(
                        ': Modulo Maps Functions : Error 015 : Failed inside colorbar manipulation '
                    )
                    logging.shutdown()
                    sys.exit()
                return self
        except Exception as ex:
            logging.info(
                ': Modulo Maps Functions : Error 016 : Failed inside colorbar manipulation '
            )
            logging.shutdown()
            sys.exit()
Ejemplo n.º 18
0
def plot_quant_aud_figure(weights,
                          sparse_weights,
                          rstrfs,
                          Y_mds,
                          t_pow_real,
                          t_pow_model,
                          t_pow_sparse,
                          save_path=None,
                          use_blobs=False):

    fig = plt.figure(figsize=[10, 20])
    gs = plt.GridSpec(200, 100)
    pop_ax = np.empty((2, 3), dtype=object)
    hist_ax = np.empty((2, 2), dtype=object)

    mds_ax = fig.add_subplot(gs[:45, :45])
    temp_ax = fig.add_subplot(gs[:45, 64:])

    pop_ax[0, 0] = fig.add_subplot(gs[55:85, :30])
    pop_ax[0, 1] = fig.add_subplot(gs[55:85, 35:65])
    pop_ax[0, 2] = fig.add_subplot(gs[55:85, 70:])

    hist_ax[0, 0] = fig.add_subplot(gs[95:110, :45])
    hist_ax[0, 1] = fig.add_subplot(gs[95:110, 55:])

    pop_ax[1, 0] = fig.add_subplot(gs[120:150, :30])
    pop_ax[1, 1] = fig.add_subplot(gs[120:150, 35:65])
    pop_ax[1, 2] = fig.add_subplot(gs[120:150, 70:])

    hist_ax[1, 0] = fig.add_subplot(gs[160:175, :45])
    hist_ax[1, 1] = fig.add_subplot(gs[160:175, 55:])

    # gs.update(wspace=0.005, hspace=0.05) # set the spacing between axes.
    labels = ['A1', 'predict', 'sparse', 'noise']
    azure = '#006FFF'
    clors = ['red', 'black', azure, 'y']  # markers = ['o', '_', '|']
    markers = ['8', 'o', '^']

    # mds_ax.set_axis_off()
    # mds_ax.margins = [0,0]
    mds_axlim = 20
    mds_axes = plotting_funcs.plot_2D_projection_with_hists(Y_mds,
                                                            mds_ax,
                                                            labels=None,
                                                            clors=clors,
                                                            markers=markers)
    mds_axes[0].set_xlim([-mds_axlim, mds_axlim])
    mds_axes[0].set_ylim([-mds_axlim, mds_axlim])

    for ax in mds_axes:
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_axis_off()

    ax = temp_ax
    divider = make_axes_locatable(ax)
    aspect = 4
    pad_fraction = 0.05
    width = axes_size.AxesY(ax, aspect=1 / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    dummy_ax_x = divider.append_axes("top", size=width, pad=0.1, sharex=ax)
    # dummy_ax_y = divider.append_axes("right", size=width, pad=0.1, sharey=ax)
    dummy_ax_x.set_axis_off()
    # dummy_ax_y.set_axis_off()
    if use_blobs:
        markers = ['o', 'o', 'o']
    else:
        markers = ['blob', 'blob', 'blob']

    t_plot_kwargs = {}
    t_plot_kwargs['color'] = clors[0]
    t_plot_kwargs['label'] = 'A1'
    t_plot_kwargs['linewidth'] = 3
    # ax.plot(real_measures.t_pow.mean(axis=0)/sum(real_measures.t_pow.mean(axis=0)),'o', **t_plot_kwargs )
    # t_plot_kwargs['label']=''
    ax.plot(t_pow_real, **t_plot_kwargs)

    t_plot_kwargs = {}
    t_plot_kwargs['color'] = clors[1]
    t_plot_kwargs['label'] = 'prediction'
    t_plot_kwargs['linewidth'] = 3

    # ax.plot(res_pd.iloc[0].t_pow.mean(axis=0)[2:]/sum(res_pd.iloc[0].t_pow.mean(axis=0)[2:]),'_', **t_plot_kwargs)
    # t_plot_kwargs['label']=''
    ax.plot(t_pow_model, **t_plot_kwargs)

    t_plot_kwargs['color'] = clors[2]
    t_plot_kwargs['label'] = 'sparse'
    t_plot_kwargs['linewidth'] = 3

    # ax.plot(sparse_pd.iloc[11].t_pow.mean(axis=0)[2:]/sum(sparse_pd.iloc[11].t_pow.mean(axis=0)[2:]),'|', **t_plot_kwargs)
    # t_plot_kwargs['label']=''
    ax.plot(t_pow_sparse, **t_plot_kwargs)

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles, labels)

    temp_ax = ax
    temp_ax.set_xticklabels(['-250', '-200', '-150', '-100', '-50', '0'])
    temp_ax.set_ylabel('Proportion of power')
    temp_ax.set_xlabel('Time (msec)')

    [
        mf_peak_pos, mf_peak_neg, mf_bw_pos, mf_bw_neg, mt_peak_pos,
        mt_peak_neg, mt_bw_pos, mt_bw_neg, m_pow
    ] = quant.quantify_strfs(weights)
    [
        sf_peak_pos, sf_peak_neg, sf_bw_pos, sf_bw_neg, st_peak_pos,
        st_peak_neg, st_bw_pos, st_bw_neg, s_pow
    ] = quant.quantify_strfs(sparse_weights)
    [
        rf_peak_pos, rf_peak_neg, rf_bw_pos, rf_bw_neg, rt_peak_pos,
        rt_peak_neg, rt_bw_pos, rt_bw_neg, r_pow
    ] = quant.quantify_strfs(rstrfs, n_h=38)

    mf_ix = [mf_bw_neg > 0]  #and [mf_bw_pos>0]
    rf_ix = [rf_bw_neg > 0]  #and [rf_bw_pos>0]
    mt_ix = [mt_bw_neg > 0]  #and [mt_bw_pos>0]
    rt_ix = [rt_bw_neg > 0]  #and [rt_bw_pos>0]
    sf_ix = [sf_bw_neg > 0]  #and [mf_bw_pos>0]
    st_ix = [st_bw_neg > 0]  #and [mt_bw_pos>0]

    print('Excluding units with 0 bw')
    print('model proportion included: ')
    print(np.sum(mt_ix[:]) / len(mt_ix[0]))
    print('total: %i' % len(mt_ix[0]))
    print('included: %i' % np.sum(mt_ix[:]))
    print('excluded: %i' % (len(mt_ix[0]) - np.sum(mt_ix[:])))
    print('sparse proportion included: ')
    print(np.sum(st_ix[:]) / len(st_ix[0]))
    print('total: %i' % len(st_ix[0]))
    print('included: %i' % np.sum(st_ix[:]))
    print('excluded: %i' % (len(st_ix[0]) - np.sum(st_ix[:])))
    print('real proportion included: ')
    print(np.sum(rt_ix[:]) / len(rt_ix[0]))
    print('total: %i' % len(rt_ix[0]))
    print('included: %i' % np.sum(rt_ix[:]))
    print('excluded: %i' % (len(rt_ix[0]) - np.sum(rt_ix[:])))

    mf_bw_pos = mf_bw_pos[mf_ix]
    rf_bw_pos = rf_bw_pos[rf_ix]
    sf_bw_pos = sf_bw_pos[sf_ix]
    mt_bw_pos = mt_bw_pos[mt_ix]
    rt_bw_pos = rt_bw_pos[rt_ix]
    st_bw_pos = st_bw_pos[st_ix]

    mf_bw_neg = mf_bw_neg[mf_ix]
    rf_bw_neg = rf_bw_neg[rf_ix]
    sf_bw_neg = sf_bw_neg[sf_ix]
    mt_bw_neg = mt_bw_neg[mt_ix]
    rt_bw_neg = rt_bw_neg[rt_ix]
    st_bw_neg = st_bw_neg[st_ix]

    xs = np.empty((2, 3), dtype=object)
    ys = np.empty((2, 3), dtype=object)
    xs[0, :] = [rt_bw_pos, mt_bw_pos, st_bw_pos]
    ys[0, :] = [rt_bw_neg, mt_bw_neg, st_bw_neg]
    xs[1, :] = [rf_bw_pos, mf_bw_pos, sf_bw_pos]
    ys[1, :] = [rf_bw_neg, mf_bw_neg, sf_bw_neg]

    lims = [225, 6]
    #Make the scatter plots on the population axes
    for ii in range(xs.shape[0]):
        for jj, ax in enumerate(pop_ax[ii, :]):
            x = xs[ii, jj]
            y = ys[ii, jj]
            clor = clors[jj]  # '#444444'
            if use_blobs:
                pop_ax[ii, jj], _ = plotting_funcs.blobscatter(x,
                                                               y,
                                                               ax=ax,
                                                               **{
                                                                   'facecolor':
                                                                   'none',
                                                                   'edgecolor':
                                                                   clor
                                                               })
            else:
                ax.scatter(x, y)
            plt_kwargs = {'color': 'k', 'linestyle': '--', 'alpha': 0.8}
            pop_ax[ii, jj].plot(range(lims[ii] + 1), **plt_kwargs)
            pop_ax[ii, jj].set_xlim([0, lims[ii]])
            pop_ax[ii, jj].set_ylim([0, lims[ii]])
            if ii == 0:
                pop_ax[ii, jj].set_yticks([0, 50, 100, 150, 200])
                pop_ax[ii, jj].set_xticks([0, 50, 100, 150, 200])
                print('x or y>100', sum(np.logical_or(x > 100, y > 100)),
                      'out of ', len(x))
            if ii == 1:
                pop_ax[ii, jj].set_yticks([0, 2, 4, 6])
                print('x or y>4', sum(np.logical_or(x > 4, y > 4)), 'out of ',
                      len(x))
            if jj > 0:
                #             pop_ax[ii,jj].spines['left'].set_visible(False)
                #             pop_ax[ii,jj].set_yticks([])
                pop_ax[ii, jj].set_yticklabels([])

    inset_lims = [50, 1.5]
    inset_ticks = [[25, 50], [0.75, 1.5]]
    inset_binss = [np.arange(0.1, 50, 5), np.arange(0.05, 2, 0.15)]
    #Make insets to show details on first two scatter plots
    for ii in range(xs.shape[0]):
        for jj, ax in enumerate(pop_ax[ii, :2]):
            x = xs[ii, jj]
            y = ys[ii, jj]
            clor = clors[jj]  # '#444444'
            inset_axes_kwargs = {
                'xlim': [0, inset_lims[ii]],
                'ylim': [0, inset_lims[ii]],
                'xticks': inset_ticks[ii],
                'yticks': inset_ticks[ii]
            }
            inset_ax = inset_axes(
                pop_ax[ii, jj],
                width="40%",  # width = 30% of parent_bbox
                height="40%",  # height : 1 inch
                loc=1,
                axes_kwargs=inset_axes_kwargs)

            if use_blobs:
                inset_ax, _ = plotting_funcs.blobscatter(x,
                                                         y,
                                                         ax=inset_ax,
                                                         **{
                                                             'facecolor':
                                                             'none',
                                                             'edgecolor': clor
                                                         })
            else:
                inset_ax.scatter(x, y)
#             inset_ax.hist2d(x,y, bins=inset_binss[ii])

#             plt_kwargs = {'color':'k', 'linestyle':'--', 'alpha':0.8}
#             inset_ax.plot(range(lims[ii]+1), **plt_kwargs)

    binss = [np.arange(-2.5, 200, 5), np.arange(0, 6, 0.15)]
    #Make the scatter plots on the population axes
    for ii in range(xs.shape[0]):
        bins = binss[ii]
        bincenters = 0.5 * (bins[1:] + bins[:-1])
        for jj in range(xs.shape[1]):
            clor = clors[jj]
            plt_kwargs = {'color': clor, 'linestyle': '-', 'alpha': 0.8}
            x = xs[ii, jj]
            y = ys[ii, jj]
            xx, binEdges = np.histogram(x, bins=bins)
            yy, binEdges = np.histogram(y, bins=bins)
            hist_ax[ii, 0].plot(bincenters, xx / sum(xx), **plt_kwargs)
            hist_ax[ii, 1].plot(bincenters, yy / sum(yy), **plt_kwargs)

    hist_ax[0, 0].set_yticks([0, 0.2, 0.4, 0.6, 0.8])
    hist_ax[0, 1].set_yticks([0, 0.2, 0.4, 0.6, 0.8])

    hist_ax[1, 0].set_yticks([0, 0.1, 0.2, 0.3, 0.4])
    hist_ax[1, 1].set_yticks([0, 0.1, 0.2, 0.3, 0.4])

    all_axes = fig.get_axes()
    for ii, ax in enumerate(all_axes):
        # Hide the right and top spines
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

        # Only show ticks on the left and bottom spines
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')

    pop_ax[0, 0].set_ylabel('Temporal span of inhibition (msec)')
    pop_ax[0, 0].set_xlabel('Temporal span of excitation (msec)')
    pop_ax[0, 1].set_xlabel('Temporal span of excitation (msec)')
    pop_ax[0, 2].set_xlabel('Temporal span of excitation (msec)')

    hist_ax[0, 0].set_xlabel('Temporal span of excitation (msec)')
    hist_ax[0, 1].set_xlabel('Temporal span of inhibition (msec)')
    hist_ax[0, 0].set_ylabel('Proportion of units')

    pop_ax[1, 0].set_ylabel('Frequency span of inhibition (octave)')
    pop_ax[1, 0].set_xlabel('Frequency span of excitation (octave)')
    pop_ax[1, 1].set_xlabel('Frequency span of excitation (octave)')
    pop_ax[1, 2].set_xlabel('Frequency span of excitation (octave)')

    hist_ax[1, 0].set_xlabel('Frequency span of excitation (octave)')
    hist_ax[1, 1].set_xlabel('Frequency span of inhibition (octave)')
    hist_ax[1, 0].set_ylabel('Proportion of units')

    hist_ax[0, 0].set_xticks([0, 50, 100, 150, 200])
    hist_ax[0, 1].set_xticks([0, 50, 100, 150, 200])

    if save_path is not None:
        fig.savefig(os.path.join(save_path + '.svg'))
            Green_trip_total_riverflux  + gice_total_riverflux
    #mld_avg=ahy_afil_total_riverflux*Maskout_only_Gr 
    mld_avg = np.ma.masked_where(mld_avg<=0.0,mld_avg)
    mx=mld_avg.max()
    mn=mld_avg.min()
    clim=[mn, mx]
    clim=[0.0, 2.0]
    #clim=[0.2, 2.0]
    P=bm.pcolormesh(xi,yi,mld_avg,cmap=cmap)
    bm.drawcoastlines(linewidth=0.5)
    #bm.fillcontinents(color='.8',lake_color='white')
    #
    aspect = 20
    pad_fraction = 0.5
    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1./aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax = divider.append_axes("right", size=width, pad=pad)
    cb=ax.figure.colorbar(P,cax=cax,extend='both')
    P.set_clim(clim)
    ax.set_title('Total AEHYPE clim + Greenland(icemelt) + Greenland(TRIP) river flux')
    fig.canvas.print_figure("Total_AEhype_GreenlandiceTrip_%03d.png"%(record+9))
    ax.clear()
    cb.remove()
 
 xticks = np.arange(0,12)
 xticklabels = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec') 
 #
 tripafil=1
 if tripafil :
    fig = plt.figure(figsize=(8,8))
Ejemplo n.º 20
0
def plot_2D_projection_with_hists(Y,
                                  axScatter,
                                  n_examples=100,
                                  labels=None,
                                  clors=None,
                                  markers=None):

    # create new axes on the right and on the top of the current axes.
    divider = make_axes_locatable(axScatter)

    aspect = 4
    pad_fraction = 0.05
    width = axes_size.AxesY(axScatter, aspect=1 / aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    axHistx = divider.append_axes("top", size=width, pad=0.1, sharex=axScatter)
    axHisty = divider.append_axes("right",
                                  size=width,
                                  pad=0.1,
                                  sharey=axScatter)
    lim = 20
    binwidth = 1
    bins = np.arange(-lim, lim + binwidth, binwidth)
    linewidth = 3

    bincenters = 0.5 * (bins[1:] + bins[:-1])
    n_plots = int(Y.shape[0] / n_examples)
    for ii in range(n_plots):
        if labels is not None:
            label = labels[ii]
        else:
            label = ''
        if clors is not None:
            clor = clors[ii]
        else:
            clor = np.random.rand(3, 1)
        if markers is not None:
            marker = markers[ii]
        else:
            marker = 'o'
        alpha = 1
        if marker == 'o' or marker == '^':
            axScatter.scatter(Y[ii * n_examples:(ii + 1) * n_examples, 0],
                              Y[ii * n_examples:(ii + 1) * n_examples, 1],
                              edgecolor=clor,
                              facecolors='none',
                              marker=marker,
                              alpha=alpha,
                              label=label)
        # elif marker == 'blob':
        #     print('using blob scttaer')
        #     plt_kwargs = {'edgecolor':clor, 'facecolors':'none', 'alpha':alpha, 'label':label}
        #     blobscatter(Y[ii*n_examples:(ii+1)*n_examples, 0], Y[ii*n_examples:(ii+1)*n_examples, 1], ax=axScatter, **plt_kwargs)
        else:
            axScatter.scatter(Y[ii * n_examples:(ii + 1) * n_examples, 0],
                              Y[ii * n_examples:(ii + 1) * n_examples, 1],
                              edgecolor=clor,
                              facecolors=clor,
                              marker=marker,
                              alpha=alpha,
                              label=label)

        xx, binEdges = np.histogram(Y[ii * n_examples:(ii + 1) * n_examples,
                                      0],
                                    bins=bins,
                                    density=1)
        yy, binEdges = np.histogram(Y[ii * n_examples:(ii + 1) * n_examples,
                                      1],
                                    bins=bins,
                                    density=1)

        plt_kwargs = {'color': clor, 'linewidth': linewidth, 'alpha': alpha}

        axHistx.plot(bincenters, xx, **plt_kwargs)
        axHisty.plot(yy, bincenters, **plt_kwargs)

    if labels is not None:
        handles, labels = axScatter.get_legend_handles_labels()
        axScatter.legend(handles, labels)
    return [axScatter, axHistx, axHisty]
Ejemplo n.º 21
0
def plot_scatter_hist(x,
                      y,
                      axScatter,
                      bins,
                      hist_axes=None,
                      label=None,
                      clor=None,
                      marker='o',
                      alpha=1,
                      n_examples=None):

    # create new axes on the right and on the top of the current axes.
    if hist_axes is None:
        divider = make_axes_locatable(axScatter)
        aspect = 4
        pad_fraction = 0.05
        width = axes_size.AxesY(axScatter, aspect=1 / aspect)
        pad = axes_size.Fraction(pad_fraction, width)
        axHistx = divider.append_axes("top",
                                      size=width,
                                      pad=0.1,
                                      sharex=axScatter)
        axHisty = divider.append_axes("right",
                                      size=width,
                                      pad=0.1,
                                      sharey=axScatter)
    else:
        axHistx = hist_axes[0]
        axHisty = hist_axes[1]

    linewidth = 3
    bincenters = 0.5 * (bins[1:] + bins[:-1])

    if n_examples is not None:
        ix = np.arange(len(x))
        random.shuffle(ix)
        ix = ix[:n_examples]
    else:
        ix = np.arange(len(x))

    binwidth = bins[1] - bins[0]
    x_jitter = 0.5 * binwidth * (2 * np.random.rand(len(ix)) - 1)
    y_jitter = 0.5 * binwidth * (2 * np.random.rand(len(ix)) - 1)
    #     x_jitter = 0
    #     y_jitter = 0
    if marker == 'o':
        axScatter.scatter(x[ix] + x_jitter,
                          y[ix] + y_jitter,
                          edgecolor=clor,
                          facecolors='none',
                          marker=marker,
                          alpha=alpha,
                          label=label)
    # elif marker == 'blob':
    #     if clor == 'black' or clor == 'k':
    #         facecolors = clor
    #     else:
    #         facecolors = 'none'
    #     plt_kwargs = {'edgecolor':clor, 'facecolors':facecolors, 'marker':'o', 'alpha':alpha, 'label':label, 'linewidth':2}
    #     blobscatter(x, y, ax=axScatter, minsize=10*binwidth, **plt_kwargs)
    else:
        axScatter.scatter(x[ix] + x_jitter,
                          y[ix] + y_jitter,
                          edgecolor=clor,
                          facecolors=clor,
                          marker=marker,
                          alpha=alpha,
                          label=label)

    xx, binEdges = np.histogram(x, bins=bins, density=1)
    yy, binEdges = np.histogram(y, bins=bins, density=1)
    plt_kwargs = {'color': clor, 'linewidth': linewidth, 'alpha': alpha}
    axHistx.plot(bincenters, xx, **plt_kwargs)
    axHisty.plot(yy, bincenters, **plt_kwargs)

    return [axScatter, axHistx, axHisty]
Ejemplo n.º 22
0
def plotResolutionImage(file_name):
    """
    Plots an image of the array with the energy resolution as a color for the solution
    file (file_name).
    Args:
        file_name: the wavecal solution file including the path (string)
    """
    wave_cal = tb.open_file(file_name, mode='r')
    beamImage = wave_cal.root.header.beamMap.read()
    wavelengths = wave_cal.root.header.wavelengths.read()[0]
    calsoln = wave_cal.root.wavecal.calsoln.read()
    wave_flag = calsoln["wave_flag"]
    R0 = calsoln["R"]
    R0[R0 == -1] = 0
    rows = calsoln['pixel_row']
    columns = calsoln['pixel_col']
    wave_cal.close()

    R = np.zeros((len(wavelengths) + 1, *beamImage.shape))
    for pixel_ind, flag in enumerate(wave_flag):
        # add good fits to the image
        if flag == 4 or flag == 5:
            row = rows[pixel_ind]
            col = columns[pixel_ind]
            for wave_ind, _ in enumerate(wavelengths):
                R[wave_ind, row, col] = R0[pixel_ind, wave_ind]
            R[-1, row, col] = 1
    R = np.transpose(R, (0, 2, 1))

    fig, ax = plt.subplots(figsize=(8, 8))
    image = ax.imshow(R[0])
    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1. / 20)
    pad = axes_size.Fraction(0.5, width)
    cax = divider.append_axes("right", size=width, pad=pad)
    maximum = np.max(R)
    cbar_ticks = np.linspace(0., maximum, num=11)
    cbar = fig.colorbar(image, cax=cax, ticks=cbar_ticks)
    cbar.set_clim(vmin=0, vmax=maximum)
    cbar.draw_all()

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    position = ax.get_position()
    middle = position.x0 + 3 * position.width / 4
    ax_prev = plt.axes([middle - 0.18, 0.05, 0.15, 0.03])
    ax_next = plt.axes([middle + 0.02, 0.05, 0.15, 0.03])
    ax_slider = plt.axes([position.x0, 0.05, position.width / 2, 0.03])

    class Index(object):
        def __init__(self, ax_slider, ax_prev, ax_next):
            self.ind = 0
            self.num = len(wavelengths)
            self.bnext = Button(ax_next, 'Next')
            self.bnext.on_clicked(self.next)
            self.bprev = Button(ax_prev, 'Previous')
            self.bprev.on_clicked(self.prev)
            self.slider = Slider(ax_slider,
                                 "Energy Resolution: {:.2f} nm".format(
                                     wavelengths[0]),
                                 0,
                                 self.num,
                                 valinit=0,
                                 valfmt='%d')
            self.slider.valtext.set_visible(False)
            self.slider.label.set_horizontalalignment('center')
            self.slider.on_changed(self.update)

            position = ax_slider.get_position()
            self.slider.label.set_position((0.5, -0.5))
            self.slider.valtext.set_position((0.5, -0.5))

        def next(self, event):
            i = (self.ind + 1) % (self.num + 1)
            self.slider.set_val(i)

        def prev(self, event):
            i = (self.ind - 1) % (self.num + 1)
            self.slider.set_val(i)

        def update(self, i):
            self.ind = int(i)
            image.set_data(R[self.ind])
            if self.ind != len(wavelengths):
                self.slider.label.set_text(
                    "Energy Resolution: {:.2f} nm".format(
                        wavelengths[self.ind]))
            else:
                self.slider.label.set_text("Calibrated Pixels")
            if self.ind != len(wavelengths):
                number = 11
                cbar.set_clim(vmin=0, vmax=maximum)
                cbar_ticks = np.linspace(0.,
                                         maximum,
                                         num=number,
                                         endpoint=True)
            else:
                number = 2
                cbar.set_clim(vmin=0, vmax=1)
                cbar_ticks = np.linspace(0., 1, num=number)
            cbar.set_ticks(cbar_ticks)
            cbar.draw_all()
            plt.draw()

    indexer = Index(ax_slider, ax_prev, ax_next)
    plt.show(block=True)
Ejemplo n.º 23
0
                #    fontsize=mpl.rcParams['font.size'])

        #plt.subplot(gs[12,0]).set_visible(False)
        #plt.subplot(gs[11,0]).set_visible(False)
    labeler = Labeler(xpad=.12, ypad=.01, fontsize=10)

    vmin = 1e3
    vmax = 1E7
    #panel C
    ax = plt.subplot(gs[0:12, 1])
    im = count_cells(ax,sort_name,'cells', \
        ylabel=True, vmax=vmax, vmin=vmin)
    labeler.label_subplot(ax, 'C')

    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1 / 30.)
    pad = axes_size.Fraction(0.75, width)
    cax = divider.append_axes("right", size=width, pad=pad)

    #cax = fig.add_axes([0.92, 0.5375, 0.025, 0.4125])
    cbar = fig.colorbar(im, cax=cax, orientation='vertical')
    #cbar.set_label(r'number of sorted cells')
    cbar.solids.set_rasterized(True)

    # Panel D
    ax = plt.subplot(gs[16:28, 1])
    im = count_reads(ax, rep, 'reads', ylabel=True, vmax=vmax, vmin=vmin)
    labeler.label_subplot(ax, 'D')
    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1 / 30.)
    pad = axes_size.Fraction(0.75, width)
def main(lon1,lat1,lon2,lat2,variable,files,filetype="archive",clim=None,sectionid="",
      ijspace=False,xaxis="distance",section_map=False,ncfiles="",dpi=180) :
   #TP4Grd='/cluster/work/users/aal069/TP4a0.12/mfile/'
   logger.info("Filetype is %s"% filetype)
   gfile = abf.ABFileGrid("regional.grid","r")
   plon=gfile.read_field("plon")
   plat=gfile.read_field("plat")


   # Set up section info
   if ijspace :
      sec = gridxsec.SectionIJSpace([lon1,lon2],[lat1,lat2],plon,plat)
   else  :
      sec = gridxsec.Section([lon1,lon2],[lat1,lat2],plon,plat)
   I,J=sec.grid_indexes
   dist=sec.distance
   print('dit.shae=',dist.shape)
   slon=sec.longitude
   slat=sec.latitude
   # In testing
   #J,I,slon,slat,case,dist=sec.find_intersection(qlon,qlat)
   #print I,J
   #raise NameError,"test"

   logger.info("Min max I-index (starts from 0):%d %d"%(I.min(),I.max()))
   logger.info("Min max J-index (starts from 0):%d %d"%(J.min(),J.max()))
   #
   #
   if section_map :
      ll_lon=slon.min()-10.
      ur_lon=slon.max()+10.
      ll_lat=np.maximum(-90.,slat.min()-10.)
      ur_lat=np.minimum(90. ,slat.max()+10.)

      proj=ccrs.Stereographic(central_latitude=90.0,central_longitude=-40.0)
      #pxy = proj.transform_points(ccrs.PlateCarree(), plon, plat)
      #px=pxy[:,:,0]
      #py=pxy[:,:,1]
      #x,y=np.meshgrid(np.arange(slon.shape[0]),np.arange(slat.shape[0]))
        
      figure =plt.figure(figsize=(8,8))
      ax=figure.add_subplot(111,projection=proj)
      #ax = plt.axes(projection=ccrs.PlateCarree())
      ax.set_extent([-179, 179, 53, 85],ccrs.PlateCarree())
      #ax = plt.axes(projection=ccrs.Stereographic())
      ax.add_feature(cfeature.GSHHSFeature('auto', edgecolor='grey'))
      ax.add_feature(cfeature.GSHHSFeature('auto', facecolor='grey'))
      ax.gridlines()
      #ax.coastlines(resolution='110m')
      ax.plot(slon,slat,"r-",lw=1,transform=ccrs.PlateCarree())
       
      pos = ax.get_position()
      asp=pos.height/pos.width
      w=figure.get_figwidth()
      h=asp*w
      figure.set_figheight(h)
      if sectionid :
         figure.canvas.print_figure("map_%s.png"%sectionid,dpi=dpi,bbox_inches='tight')
      else :
         figure.canvas.print_figure("map.png",dpi=dpi,bbox_inches='tight')

   # Get layer thickness variable used in hycom
   dpname = modeltools.hycom.layer_thickness_variable[filetype]
   logger.info("Filetype %s: layer thickness variable is %s"%(filetype,dpname))


   if xaxis == "distance" :
      x=dist/1000.
      xlab="Distance along section[km]"
   elif xaxis == "i" :
      x=I
      xlab="i-index"
   elif xaxis == "j" :
      x=J
      xlab="j-index"
   elif xaxis == "lon" :
      x=slon
      xlab="longitude"
   elif xaxis == "lat" :
      x=slat
      xlab="latitude"
   else :
      logger.warning("xaxis must be i,j,lo,lat or distance")
      x=dist/1000.
      xlab="Distance along section[km]"

   # get kdm from the first file:
   # Remove [ab] ending if present
   print('firstfilw', files[0])
   m=re.match("(.*)\.[ab]",files[0])
   print('m=',m.group(1))
   myf=m.group(1)
   fi_abfile = abf.ABFileArchv(myf,"r")
   kdm=max(fi_abfile.fieldlevels)

   # Loop over archive files
   figure = plt.figure()
   ax=figure.add_subplot(111)
   pos = ax.get_position()
   count_sum=0
   intfsec_sum=np.zeros((kdm+1,I.size))
   datasec_sum=np.zeros((kdm+1,I.size))
   for fcnt,myfile0 in enumerate(files) :
      count_sum=count_sum+1
      print('count_sum==', count_sum)
      print('fcnt=', fcnt)
      print('mfile0=', myfile0)
      # Remove [ab] ending if present
      m=re.match("(.*)\.[ab]",myfile0)
      if m :
         myfile=m.group(1)
      else :
         myfile=myfile0

      # Add more filetypes if needed. By def we assume archive
      if filetype == "archive" :
         i_abfile = abf.ABFileArchv(myfile,"r")
      elif filetype == "restart" :
         i_abfile = abf.ABFileRestart(myfile,"r",idm=gfile.idm,jdm=gfile.jdm)
      else :
         raise NotImplementedError("Filetype %s not implemented"%filetype)
      # kdm assumed to be max level in ab file
      kdm=max(i_abfile.fieldlevels)

      # Set up interface and daat arrays
      
      xx=np.zeros((kdm+1,I.size))
      intfsec=np.zeros((kdm+1,I.size))
      datasec=np.zeros((kdm+1,I.size))
      # Loop over layers in file. 
      logger.info("File %s"%(myfile))
      for k in range(kdm) :
         logger.debug("File %s, layer %03d/%03d"%(myfile,k,kdm))

         # Get 2D fields
         dp2d=i_abfile.read_field(dpname,k+1)
         data2d=i_abfile.read_field(variable,k+1)
         #print('---mn,mx  data=',  data2d.min(),data2d.max())
         if (k%kdm==49):
            print("---Reach bottom layer" )
         dp2d=np.ma.filled(dp2d,0.)/modeltools.hycom.onem
         data2d=np.ma.filled(data2d,1e30)
         # Place data into section arrays
         intfsec[k+1,:] = intfsec[k,:] + dp2d[J,I]
         if k==0 : datasec[k,:] = data2d[J,I]
         datasec[k+1,:] = data2d[J,I]
      

      intfsec_sum=intfsec_sum + intfsec
      datasec_sum=datasec_sum + datasec
      #print 'prs_intafce=', np.transpose(intfsec[:,15]) 
      i_abfile.close()

      # end loop over files
      
   intfsec_avg=intfsec_sum/count_sum
   datasec_avg=datasec_sum/count_sum

   if ncfiles :
      MLDGS_sum=np.zeros((1,I.size))
      count_sum=0
      for fcnt,ncfile in enumerate(ncfiles) :
         count_sum=count_sum+1
         print('ncfile count_sum==', count_sum)
         print('ncfile fcnt=', fcnt)
         print('ncfilefile=', ncfile)
         MLDGS=np.zeros((1,I.size))
         ncfile0 = netCDF4.Dataset(ncfile,'r')
         MLD_2D  = ncfile0.variables['GS_MLD'][:]
         #MLD_2D  = ncfile0.variables['mlp'][:]
         MLDGS[0,:]=MLD_2D[0,J,I]
         MLDGS_sum= MLDGS_sum + MLDGS
         ncfile0.close()
      # end loop over files
      MLDGS_avg=MLDGS_sum/count_sum
   #
   #-----------------------------------------------------------------
   # read from clim mld TP5netcdf
   if ncfiles :
      if 'TP2' in files[0]:
         fh=netCDF4.Dataset('mld_dr003_l3_modif_Interp_TP2grd.nc')
      else:
         fh=netCDF4.Dataset('mld_dr003_l3_modif_Interp_TP5grd.nc')
      fhmldintrp = fh.variables['TP5mld'][:]
      fh.close()
      #fhMLDintrp_sum=np.zeros((760,800))
      MLDclim_sum=np.zeros((1,I.size))
      cunt_sum=0
      for ii in range(12) :
          cunt_sum=cunt_sum +1
          MLDclim=np.zeros((1,I.size))
          MLDclim[0,:]=fhmldintrp[ii,J,I]

          MLDclim_sum= MLDclim_sum + MLDclim
          print('clim count_sum==', cunt_sum)
      MLDclim_avg=MLDclim_sum/cunt_sum
   #-----------------------------------------------------------------   
   i_maxd=np.argmax(np.abs(intfsec_avg[kdm,:]))
   #print i_maxd
   for k in range(kdm+1) :
      xx[k,:] = x[:]
   # Set up section plot
   #datasec = np.ma.masked_where(datasec==1e30,datasec)
   datasec_avg = np.ma.masked_where(datasec_avg>0.5*1e30,datasec_avg)
   #print datasec.min(),datasec.max()
   #P=ax.pcolormesh(dist/1000.,-intfsec,datasec)
   #print i_maxd
   for k in range(kdm+1) :
      xx[k,:] = x[:]
   
   if clim is not None : lvls = MaxNLocator(nbins=30).tick_values(clim[0], clim[1])
   #print 'levels=', lvls
   mf='sawtooth_0-1.txt'
   LinDic=mod_hyc2plot.cmap_dict(mf)
   my_cmap = matplotlib.colors.LinearSegmentedColormap('my_colormap',LinDic)
   cmap=my_cmap
   #cmap = matplotlib.pyplot.get_cmap('gist_rainbow_r')
   norm = BoundaryNorm(lvls, ncolors=cmap.N, clip=True)
   print('x.shape=' ,      x.shape)
   print('x.min,xmax=' ,  x.min(),x.max())
   print('xx.shape=' ,      xx.shape)
   print('xx.min,xxmax=' ,  xx.min(),xx.max())
   print('intfsec_avg.shape=', intfsec_avg.shape)
   print('datasec_avg.shape=', datasec_avg.shape)
   #P=ax.pcolormesh(x,-intfsec,datasec,cmap=cmap)
   P=ax.contourf(xx,-intfsec_avg,datasec_avg,extend='both',cmap=cmap,levels=lvls)
   if 'sal' in variable:
      P1=ax.contour(xx,-intfsec_avg,datasec_avg,levels=[32.0,33.0,34.0,35.0,35.5],
          colors=('k',),linestyles=('-',),linewidths=(1.5,))
   else:
      P1=ax.contour(xx,-intfsec_avg,datasec_avg,levels=[-1,0.0,2.0],
          colors=('k',),linestyles=('-',),linewidths=(1.5,))
   matplotlib.pyplot.clabel(P1, fmt = '%2.1d', colors = 'k', fontsize=10) #contour line labels
   # Plot layer interfaces
   for k in range(1,kdm+1) :
      if k%100 == 0 : 
         PL=ax.plot(x,-intfsec_avg[k,:],"-",color="k")
      elif k%5 == 0 and k <= 10: 
         PL=ax.plot(x,-intfsec_avg[k,:],"--",color="k", linewidth=0.5)
         textx = x[i_maxd]
         texty = -0.5*(intfsec_avg[k-1,i_maxd] + intfsec_avg[k,i_maxd])
         ax.text(textx,texty,str(k),verticalalignment="center",horizontalalignment="center",fontsize=6)
      elif k%2 and k > 10 : 
         PL=ax.plot(x,-intfsec_avg[k,:],"--",color="k", linewidth=0.5)
         textx = x[i_maxd]
         texty = -0.5*(intfsec_avg[k-1,i_maxd] + intfsec_avg[k,i_maxd])
         ax.text(textx,texty,str(k),verticalalignment="center",horizontalalignment="center",fontsize=6)
   if ncfiles :
      PL=ax.plot(x,-MLDGS_avg[0,:],"-",color="w", linewidth=1.50)
      PL=ax.plot(x,-MLDclim_avg[0,:],"--",color="r", linewidth=1.50)
###    else :
###       PL=ax.plot(x,-intfsec_avg[k,:],"-",color=".5")
  # Print figure and remove wite space.
   aspect = 50
   pad_fraction = 0.25
   divider = make_axes_locatable(ax)
   width = axes_size.AxesY(ax, aspect=1./aspect)
   pad = axes_size.Fraction(pad_fraction, width)
   cax = divider.append_axes("right", size=width, pad=pad)
   cb=ax.figure.colorbar(P,cax=cax,extend='both')
   #cb=ax.figure.colorbar(P,extend='both')
   if clim is not None : P.set_clim(clim)
   #cb=ax.figure.colorbar(P,extend='both')
   ax.set_title(variable+':'+myfile+'AVG-')
   ax.set_ylabel('Depth [m]')
   ax.set_xlabel(xlab)
   #ax.set_position(pos)
   #matplotlib.pyplot.tight_layout()

   # Print in different y-lims 
   suff=os.path.basename(myfile)
   if sectionid : suff=suff+"_"+sectionid
   figure.canvas.print_figure("sec_AVG_%s_full_%s.png"%(variable,suff),dpi=dpi)
   #ax.set_ylim(-1000,0)
   if 'Fram' in sectionid or 'Svin' in sectionid:
      print('sectionid=', sectionid)
      ax.set_ylim(-600,0)
      figure.canvas.print_figure("sec_AVG_%s_600m_%s.png"%(variable,suff),dpi=dpi)
   else:
      #ax.set_ylim(-2500,0)
      #figure.canvas.print_figure("sec_AVG_%s_2500m_%s.png"%(variable,suff),dpi=dpi)
      ax.set_ylim(-3000,0)
      figure.canvas.print_figure("sec_AVG_%s_3000m_%s.png"%(variable,suff),dpi=dpi)

   # Close input file
   #i_abfile.close()
   #
   ax.clear()
   cb.remove()
Ejemplo n.º 25
0
def align_colorbar(ax, aspect=20, pad_fraction=0.5):
    divider = make_axes_locatable(ax)
    width = axes_size.AxesY(ax, aspect=1. / aspect)
    pad = axes_size.Fraction(pad_fraction, width)

    return divider.append_axes('right', size=width, pad=pad)
Ejemplo n.º 26
0
import mpl_toolkits.axes_grid1.axes_size as Size
from mpl_toolkits.axes_grid1 import Divider
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(5.5, 4))

rect = (0.1, 0.1, 0.8, 0.8)
ax = [fig.add_axes(rect, label="%d" % i) for i in range(4)]

horiz = [Size.AxesX(ax[0]), Size.Fixed(.5), Size.AxesX(ax[1])]
vert = [Size.AxesY(ax[0]), Size.Fixed(.5), Size.AxesY(ax[2])]

divider = Divider(fig, rect, horiz, vert, aspect=False)

ax[0].set_axes_locator(divider.new_locator(nx=0, ny=0))
ax[1].set_axes_locator(divider.new_locator(nx=2, ny=0))
ax[2].set_axes_locator(divider.new_locator(nx=0, ny=2))
ax[3].set_axes_locator(divider.new_locator(nx=2, ny=2))

ax[0].set_xlim(0, 2)
ax[1].set_xlim(0, 1)

ax[0].set_ylim(0, 1)
ax[2].set_ylim(0, 2)

divider.set_aspect(1.)

for ax1 in ax:
    ax1.tick_params(labelbottom=False, labelleft=False)

plt.show()
Ejemplo n.º 27
0
def draw_vibfreq_scatter_plot_n_overlap_matrix(name, engine, ref_eigvals, ref_eigvecs, freqs_rearr, normal_modes_rearr):

    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size

    plt.switch_backend('agg')
    fig, axs = plt.subplots(1,2, figsize=(10,6))
    overlap_matrix = np.array([[(vib_overlap(engine, v1, v2)) for v2 in normal_modes_rearr] for v1 in ref_eigvecs])
    qm_overlap_matrix = np.array([[(vib_overlap(engine,v1, v2)) for v2 in ref_eigvecs] for v1 in ref_eigvecs])

    axs[0].scatter(ref_eigvals, freqs_rearr, label='MM vibrational frequencies(rearr.)')
    axs[0].plot(ref_eigvals,ref_eigvals, 'k-')
    axs[0].legend()
    axs[0].set_xlabel(r'QM vibrational frequency ($cm^{-1}$)')
    axs[0].set_ylabel(r'MM vibrational frequency ($cm^{-1}$)')
    mae = np.sum(np.abs(ref_eigvals - freqs_rearr))/ len(ref_eigvals)
    axs[0].set_title(f'QM vs. MM vibrational frequencies\n MAE= {mae:.2f}')
    x0,x1 = axs[0].get_xlim()
    y0,y1 = axs[0].get_ylim()
    axs[0].set_aspect((x1-x0)/(y1-y0))

    # move ax x axis to top
    axs[1].xaxis.tick_top()
    # move ax x ticks inside
    axs[1].tick_params(axis="y", direction='in')
    axs[1].tick_params(axis="x", direction='in')
    # draw matrix
    im = axs[1].imshow(overlap_matrix, cmap= 'OrRd', vmin=0,vmax=1)
    # colorbar
    aspect = 20
    pad_fraction = 0.5
    divider = make_axes_locatable(axs[1])
    width = axes_size.AxesY(axs[1], aspect=1./aspect)
    pad = axes_size.Fraction(pad_fraction, width)
    cax = divider.append_axes("right", size=width, pad=pad)
    cax.yaxis.tick_right()
    cax.xaxis.set_visible(False)
    plt.colorbar(im, cax=cax)
    corr_coef = cal_corr_coef(overlap_matrix)
    err = np.linalg.norm(qm_overlap_matrix - overlap_matrix)/np.linalg.norm(qm_overlap_matrix) # measure of error in matrix (Relative error)
    axs[1].set_title(f'QM vs. MM normal modes\n Correlation coef. ={corr_coef:.4f}, Error={err:.4f}')

    # # move ax x axis to top
    # axs[2].xaxis.tick_top()
    # # move ax x ticks inside
    # axs[2].tick_params(axis="y", direction='in')
    # axs[2].tick_params(axis="x", direction='in')
    # # draw matrix
    # im = axs[2].imshow(qm_overlap_matrix, cmap= 'OrRd', vmin=0,vmax=1)
    # # colorbar
    # aspect = 20
    # pad_fraction = 0.5
    # divider = make_axes_locatable(axs[2])
    # width = axes_size.AxesY(axs[2], aspect=1./aspect)
    # pad = axes_size.Fraction(pad_fraction, width)
    # cax = divider.append_axes("right", size=width, pad=pad)
    # cax.yaxis.tick_right()
    # cax.xaxis.set_visible(False)
    # plt.colorbar(im, cax=cax)
    # axs[2].set_title(f'(QM normal modes for reference)')

    plt.tight_layout()
    plt.subplots_adjust(top=0.85)
    fig.suptitle('Hessian: iteration %i\nSystem: %s' % (Counter(), name))
    fig.savefig('vibfreq_scatter_plot_n_overlap_matrix.pdf')
Ejemplo n.º 28
0
def plot_usa_daybyday_case_diffs(
    states_df: pd.DataFrame,
    *,
    geo_df: geopandas.GeoDataFrame = None,
    stage: Union[DiseaseStage, Literal[Select.ALL]],
    count: Union[Counting, Literal[Select.ALL]],
    dates: List[pd.Timestamp] = None,
) -> pd.DataFrame:

    Counting.verify(count, allow_select=True)
    DiseaseStage.verify(stage, allow_select=True)

    if geo_df is None:
        geo_df = get_geo_df()

    DIFF_COL = "Diff_"
    ASPECT_RATIO = 1 / 20
    PAD_FRAC = 0.5
    N_CBAR_BUCKETS = 6  # only used when bucketing colormap into discrete regions
    N_BUCKETS_BTWN_MAJOR_TICKS = 1
    N_MINOR_TICKS_BTWN_MAJOR_TICKS = 8  # major_1, minor_1, ..., minor_n, major_2
    N_CBAR_MAJOR_TICKS = N_CBAR_BUCKETS // N_BUCKETS_BTWN_MAJOR_TICKS + 1
    CMAP = cmocean.cm.matter
    # CMAP = ListedColormap(cmocean.cm.matter(np.linspace(0, 1, N_CBAR_BUCKETS)))
    DPI = 300
    NOW_STR = datetime.now(timezone.utc).strftime(r"%b %-d, %Y at %H:%M UTC")

    ID_COLS = [
        Columns.TWO_LETTER_STATE_CODE,
        Columns.DATE,
        Columns.STAGE,
        Columns.COUNT_TYPE,
    ]

    save_fig_kwargs = {
        "dpi": "figure",
        "bbox_inches": "tight",
        "facecolor": "w"
    }

    if count is Select.ALL:
        count_list = list(Counting)
    else:
        count_list = [count]

    if stage is Select.ALL:
        stage_list = list(DiseaseStage)
    else:
        stage_list = [stage]

    count_list: List[Counting]
    stage_list: List[DiseaseStage]

    if dates is None:
        dates: List[pd.Timestamp] = states_df[Columns.DATE].unique()

    dates = sorted(pd.Timestamp(date) for date in dates)

    # Get day-by-day case diffs per location, date, stage, count-type
    case_diffs_df = states_df[
        (states_df[Columns.TWO_LETTER_STATE_CODE].isin(USA_STATE_CODES))
        &
        (~states_df[Columns.TWO_LETTER_STATE_CODE].isin(["AK", "HI"]))].copy()

    # Make sure data exists for every date for every state so that the entire country is
    # plotted each day; fill missing data with 0 (missing really *is* as good as 0)
    state_date_stage_combos = pd.MultiIndex.from_product(
        [
            case_diffs_df[Columns.TWO_LETTER_STATE_CODE].unique(),
            dates,
            [s.name for s in DiseaseStage],
            [c.name for c in Counting],
        ],
        names=ID_COLS,
    )

    case_diffs_df = (state_date_stage_combos.to_frame(index=False).merge(
        case_diffs_df,
        how="left",
        on=ID_COLS,
    ).sort_values(ID_COLS))

    case_diffs_df[Columns.CASE_COUNT] = case_diffs_df[
        Columns.CASE_COUNT].fillna(0)

    case_diffs_df[DIFF_COL] = case_diffs_df.groupby(
        [Columns.TWO_LETTER_STATE_CODE, Columns.STAGE,
         Columns.COUNT_TYPE])[Columns.CASE_COUNT].diff()

    case_diffs_df = case_diffs_df[case_diffs_df[DIFF_COL].notna()]

    dates = case_diffs_df[Columns.DATE].unique()

    vmins = {
        Counting.TOTAL_CASES:
        1,
        Counting.PER_CAPITA:
        case_diffs_df.loc[case_diffs_df[DIFF_COL] > 0, DIFF_COL].min(),
    }
    vmaxs = case_diffs_df.groupby([Columns.STAGE,
                                   Columns.COUNT_TYPE])[DIFF_COL].max()

    fig: plt.Figure = plt.figure(facecolor="white", dpi=DPI)

    # Don't put too much stock in these, we tweak them later to make sure they're even
    fig_width_px = len(count_list) * 1800
    fig_height_px = len(stage_list) * 1000 + 200

    max_date = max(dates)

    # The order doesn't matter, but doing later dates first lets us see interesting
    # output in Finder earlier, which is good for debugging
    for date in reversed(dates):
        date: pd.Timestamp = pd.Timestamp(date)
        # Data is associated with the right endpoint of the data collection period,
        # e.g., data collected *on* March 20 is labeled March 21 -- this is done so that
        # data collected today (on the day the code is run) has a meaningful date
        # associated with it (today's current time)
        # Anyway, here we undo that and display data on the date it was collected
        # in order to show a meaningful title on the graph
        if date == date.normalize():
            collection_date = date - pd.Timedelta(days=1)
        else:
            collection_date = date.normalize()

        fig.suptitle(collection_date.strftime(r"%b %-d, %Y"))

        for subplot_index, (stage, count) in enumerate(itertools.product(
                stage_list, count_list),
                                                       start=1):
            ax: plt.Axes = fig.add_subplot(len(stage_list), len(count_list),
                                           subplot_index)

            # Add timestamp to top right axis
            if subplot_index == 2:
                ax.text(
                    1.25,  # Coords are arbitrary magic numbers
                    1.23,
                    f"Last updated {NOW_STR}",
                    horizontalalignment="right",
                    fontsize="small",
                    transform=ax.transAxes,
                )

            # Filter to just this axes: this stage, this count-type, this date
            stage_date_df = case_diffs_df[
                (case_diffs_df[Columns.STAGE] == stage.name)
                & (case_diffs_df[Columns.COUNT_TYPE] == count.name)
                & (case_diffs_df[Columns.DATE] == date)]

            # Should have length 49 (50 + DC - AK - HI)
            stage_geo_df: geopandas.GeoDataFrame = geo_df.merge(
                stage_date_df,
                how="inner",
                left_on="STUSPS",
                right_on=Columns.TWO_LETTER_STATE_CODE,
            )
            assert len(stage_geo_df) == 49

            vmin = vmins[count]
            vmax = vmaxs.loc[(stage.name, count.name)]

            # Create log-scaled color mapping
            # https://stackoverflow.com/a/43807666
            norm = LogNorm(vmin, vmax)
            scm = plt.cm.ScalarMappable(norm=norm, cmap=CMAP)

            # Actually plot the data. Omit legend, since we'll want to customize it and
            # it's easier to create a new one than customize the existing one.
            stage_geo_df.plot(
                column=DIFF_COL,
                ax=ax,
                legend=False,
                vmin=vmin,
                vmax=vmax,
                cmap=CMAP,
                norm=norm,
            )

            # Plot state boundaries
            stage_geo_df.boundary.plot(ax=ax, linewidth=0.06, edgecolor="k")

            # Add colorbar axes to right side of graph
            # https://stackoverflow.com/a/33505522
            divider = make_axes_locatable(ax)
            width = axes_size.AxesY(ax, aspect=ASPECT_RATIO)
            pad = axes_size.Fraction(PAD_FRAC, width)
            cax = divider.append_axes("right", size=width, pad=pad)

            # Add colorbar itself
            cbar = fig.colorbar(scm, cax=cax)

            # Add evenly spaced ticks and their labels
            # First major, then minor
            # Adapted from https://stackoverflow.com/a/50314773
            bucket_size = (vmax / vmin)**(1 / N_CBAR_BUCKETS)
            tick_dist = bucket_size**N_BUCKETS_BTWN_MAJOR_TICKS

            # Simple log scale math
            major_tick_locs = (
                vmin * (tick_dist**np.arange(0, N_CBAR_MAJOR_TICKS))
                # * (bucket_size ** 0.5) # Use this if centering ticks on buckets
            )

            cbar.set_ticks(major_tick_locs)

            # Get minor locs by linearly interpolating between major ticks
            minor_tick_locs = []
            for major_tick_index, this_major_tick in enumerate(
                    major_tick_locs[:-1]):
                next_major_tick = major_tick_locs[major_tick_index + 1]

                # Get minor ticks as numbers in range [this_major_tick, next_major_tick]
                # and exclude the major ticks themselves (once we've used them to
                # compute the minor tick locs)
                minor_tick_locs.extend(
                    np.linspace(
                        this_major_tick,
                        next_major_tick,
                        N_MINOR_TICKS_BTWN_MAJOR_TICKS + 2,
                    )[1:-1])

            cbar.ax.yaxis.set_ticks(minor_tick_locs, minor=True)
            cbar.ax.yaxis.set_minor_formatter(NullFormatter())

            # Add major tick labels
            if count is Counting.PER_CAPITA:
                fmt_func = "{:.2e}".format
            else:
                fmt_func = functools.partial(format_float,
                                             max_digits=5,
                                             decimal_penalty=2)

            cbar.set_ticklabels(
                [fmt_func(x) if x != 0 else "0" for x in major_tick_locs])

            # Set axes titles
            ax_stage_name: str = {
                DiseaseStage.CONFIRMED: "Cases",
                DiseaseStage.DEATH: "Deaths",
            }[stage]
            ax_title_components: List[str] = ["New Daily", ax_stage_name]
            if count is Counting.PER_CAPITA:
                ax_title_components.append("Per Capita")

            ax.set_title(" ".join(ax_title_components))

            # Remove axis ticks (I think they're lat/long but we don't need them)
            for spine in [ax.xaxis, ax.yaxis]:
                spine.set_major_locator(NullLocator())
                spine.set_minor_locator(NullLocator())

        # Save figure, and then deal with matplotlib weirdness that doesn't exactly
        # respect the dimensions we set due to bbox_inches='tight'
        save_path: Path = DOD_DIFF_DIR / f"dod_diff_{date.strftime(r'%Y%m%d')}.png"
        fig.set_size_inches(fig_width_px / DPI, fig_height_px / DPI)
        fig.savefig(save_path, **save_fig_kwargs)

        # x264 video encoder requires frames have even width and height
        resize_to_even_dims(save_path)

        # Save poster (preview frame for video on web)
        if date == max_date:
            (GEO_FIG_DIR / "dod_diff_poster.png").write_bytes(
                save_path.read_bytes())

        fig.clf()

        print(f"Saved '{save_path}'")

        # if date < pd.Timestamp("2020-4-16"):
        #     break

    return case_diffs_df
def main(lon1,
         lat1,
         lon2,
         lat2,
         variable,
         files,
         filetype="archive",
         clim=None,
         sectionid="",
         ijspace=False,
         xaxis="distance",
         section_map=False,
         dens=False,
         dpi=180):

    logger.info("Filetype is %s" % filetype)
    gfile = abf.ABFileGrid("regional.grid", "r")
    plon = gfile.read_field("plon")
    plat = gfile.read_field("plat")

    # Set up section info
    if ijspace:
        sec = gridxsec.SectionIJSpace([lon1, lon2], [lat1, lat2], plon, plat)
    else:
        sec = gridxsec.Section([lon1, lon2], [lat1, lat2], plon, plat)
    I, J = sec.grid_indexes
    dist = sec.distance
    print('dit.shae=', dist.shape)
    slon = sec.longitude
    slat = sec.latitude

    logger.info("Min max I-index (starts from 0):%d %d" % (I.min(), I.max()))
    logger.info("Min max J-index (starts from 0):%d %d" % (J.min(), J.max()))
    #
    #
    if section_map:
        ll_lon = slon.min() - 10.
        ur_lon = slon.max() + 10.
        ll_lat = np.maximum(-90., slat.min() - 10.)
        ur_lat = np.minimum(90., slat.max() + 10.)

        proj = ccrs.Stereographic(central_latitude=90.0,
                                  central_longitude=-40.0)
        pxy = proj.transform_points(ccrs.PlateCarree(), plon, plat)
        px = pxy[:, :, 0]
        py = pxy[:, :, 1]
        x, y = np.meshgrid(np.arange(slon.shape[0]), np.arange(slat.shape[0]))

        figure = plt.figure(figsize=(10, 8))
        ax = figure.add_subplot(111)

        ax = plt.axes(projection=ccrs.PlateCarree())
        ax.set_extent([-179, 179, 53, 85], ccrs.PlateCarree())
        ax.add_feature(cfeature.GSHHSFeature('auto', edgecolor='grey'))
        ax.add_feature(cfeature.GSHHSFeature('auto', facecolor='grey'))
        ax.gridlines()
        ax.plot(slon, slat, "r-", lw=1)

        pos = ax.get_position()
        asp = pos.height / pos.width
        w = figure.get_figwidth()
        h = asp * w
        figure.set_figheight(h)
        if sectionid:
            figure.canvas.print_figure("map_%s.png" % sectionid, dpi=dpi)
        else:
            figure.canvas.print_figure("map.png", dpi=dpi)

    # Get layer thickness variable used in hycom
    dpname = modeltools.hycom.layer_thickness_variable[filetype]
    logger.info("Filetype %s: layer thickness variable is %s" %
                (filetype, dpname))

    if xaxis == "distance":
        x = dist / 1000.
        xlab = "Distance along section[km]"
    elif xaxis == "i":
        x = I
        xlab = "i-index"
    elif xaxis == "j":
        x = J
        xlab = "j-index"
    elif xaxis == "lon":
        x = slon
        xlab = "longitude"
    elif xaxis == "lat":
        x = slat
        xlab = "latitude"
    else:
        logger.warning("xaxis must be i,j,lo,lat or distance")
        x = dist / 1000.
        xlab = "Distance along section[km]"

    # Loop over archive files
    figure = plt.figure()
    ax = figure.add_subplot(111)
    pos = ax.get_position()
    for fcnt, myfile0 in enumerate(files):

        # Remove [ab] ending if present
        m = re.match("(.*)\.[ab]", myfile0)
        if m:
            myfile = m.group(1)
        else:
            myfile = myfile0

        # Add more filetypes if needed. By def we assume archive
        if filetype == "archive":
            i_abfile = abf.ABFileArchv(myfile, "r")
        elif filetype == "restart":
            i_abfile = abf.ABFileRestart(myfile,
                                         "r",
                                         idm=gfile.idm,
                                         jdm=gfile.jdm)
        else:
            raise NotImplementedError("Filetype %s not implemented" % filetype)

        # kdm assumed to be max level in ab file
        kdm = max(i_abfile.fieldlevels)

        # Set up interface and daat arrays
        xx = np.zeros((kdm + 1, I.size))
        intfsec = np.zeros((kdm + 1, I.size))
        datasec = np.zeros((kdm + 1, I.size))
        if dens:
            datasec_sal = np.zeros((kdm + 1, I.size))
            sigma_sec = np.zeros((kdm + 1, I.size))

        # Loop over layers in file.
        logger.info("File %s" % (myfile))
        for k in range(kdm):
            logger.debug("File %s, layer %03d/%03d" % (myfile, k, kdm))

            # Get 2D fields
            dp2d = i_abfile.read_field(dpname, k + 1)
            data2d = i_abfile.read_field(variable, k + 1)
            dp2d = np.ma.filled(dp2d, 0.) / modeltools.hycom.onem
            data2d = np.ma.filled(data2d, 1e30)

            # Place data into section arrays
            intfsec[k + 1, :] = intfsec[k, :] + dp2d[J, I]
            if k == 0: datasec[k, :] = data2d[J, I]
            datasec[k + 1, :] = data2d[J, I]

            if dens:
                data2d_sal = i_abfile.read_field('salin', k + 1)
                data2d_sal = np.ma.filled(data2d_sal, 1e30)
                datasec_sal[k + 1, :] = data2d_sal[J, I]

        i_maxd = np.argmax(np.abs(intfsec[kdm, :]))
        for k in range(kdm + 1):
            xx[k, :] = x[:]

        datasec = np.ma.masked_where(datasec > 0.5 * 1e30, datasec)
        print("datasec min, max=", datasec.min(), datasec.max())
        if dens:
            datasec_sal = np.ma.masked_where(datasec_sal > 0.5 * 1e30,
                                             datasec_sal)
            print("datasec_sal min, max=", datasec_sal.min(),
                  datasec_sal.max())
            sigma_sec = mod_hyc2plot.sig(datasec, datasec_sal)
            sigma_sec = np.ma.masked_where(sigma_sec < 0.0, sigma_sec)
            datasec = sigma_sec
        # Set up section plot
        datasec = np.ma.masked_where(datasec > 0.5 * 1e30, datasec)
        print("min, max=", datasec.min(), datasec.max())
        if clim is None:
            clim = [datasec.min(), datasec.max()]
            #clim=[0.0,13]
        print("clim=", clim[0], clim[1])
        if clim is not None:
            lvls = MaxNLocator(nbins=70).tick_values(clim[0], clim[1])
        mf = 'sawtooth_fc100.txt'
        LinDic = mod_hyc2plot.cmap_dict(mf)
        my_cmap = matplotlib.colors.LinearSegmentedColormap(
            'my_colormap', LinDic)
        cmap = my_cmap
        norm = BoundaryNorm(lvls, ncolors=cmap.N, clip=True)
        P = ax.contourf(xx, -intfsec, datasec, cmap=cmap, levels=lvls)

        # Plot layer interfaces
        for k in range(1, kdm + 1):
            if k % 100 == 0:
                PL = ax.plot(x, -intfsec[k, :], "-", color="k")
                textx = x[i_maxd]
                texty = -0.5 * (intfsec[k - 1, i_maxd] + intfsec[k, i_maxd])
                ax.text(textx,
                        texty,
                        str(k),
                        verticalalignment="center",
                        horizontalalignment="center",
                        fontsize=6)
            elif k % 5 == 0:
                PL = ax.plot(x, -intfsec[k, :], "--", color="k", linewidth=0.5)
                textx = x[i_maxd]
                texty = -0.5 * (intfsec[k - 1, i_maxd] + intfsec[k, i_maxd])
                ax.text(textx,
                        texty,
                        str(k),
                        verticalalignment="center",
                        horizontalalignment="center",
                        fontsize=6)
            else:
                if k > 2 and k % 2 == 0:
                    PL = ax.plot(x,
                                 -intfsec[k, :],
                                 "-",
                                 color=".5",
                                 linewidth=0.5)
                    textx = x[i_maxd]
                    texty = -0.5 * (intfsec[k - 1, i_maxd] +
                                    intfsec[k, i_maxd])
                    ax.text(textx,
                            texty,
                            str(k),
                            verticalalignment="center",
                            horizontalalignment="center",
                            fontsize=6)
                else:
                    continue
        # Print figure
        ax.set_facecolor('xkcd:gray')
        aspect = 90
        pad_fraction = 0.25
        divider = make_axes_locatable(ax)
        width = axes_size.AxesY(ax, aspect=1. / aspect)
        pad = axes_size.Fraction(pad_fraction, width)
        cax = divider.append_axes("right", size=width, pad=pad)
        cb = ax.figure.colorbar(P, cax=cax)
        if clim is not None: P.set_clim(clim)
        if dens:
            ax.set_title('[P. density ]: ' + myfile)
        else:
            ax.set_title('[' + variable + ']: ' + myfile)

        ax.set_ylabel('Depth [m]')
        ax.set_xlabel(xlab)

        # Print in different y-lims
        suff = os.path.basename(myfile)
        if sectionid: suff = suff + "_" + sectionid
        if dens: variable = "dens"
        figure.canvas.print_figure("sec_%s_full_%s.png" % (variable, suff),
                                   dpi=dpi)
        ax.set_ylim(-1000, 0)
        figure.canvas.print_figure("sec_%s_1000m_%s.png" % (variable, suff),
                                   dpi=dpi)

        # Close input file
        i_abfile.close()

        #
        ax.clear()
        cb.remove()
Ejemplo n.º 30
0
def main(args):

    # pt_criterion = ChamfersDistance3()
    img_criterion = nn.L1Loss(reduction="sum")

    tsne = TSNE(n_components=2, init='pca', random_state=0)

    modelist = ['5by20', '2by10']

    #plt.style.use('ggplot')
    '''
    fig, ax = plt.subplots(figsize=(20, 20))
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22']
    plt.rcParams['axes.titlesize'] = 25
    plt.rcParams['axes.titleweight'] = 'bold'
    plt.rcParams['axes.labelsize'] = 24
    for idx, mode in enumerate(modelist):
        train_pt_mod = []
        pred_pt_mod = []
        train_img_mod = []
        args.first_plot = False
        if mode == '5by20':
            sparselist = [1, 2, 3, 4, 5, 6, 7]
            partition = 20*[0] + 20 *[1] + 20*[2] + 20*[3] + 20 *[4]
        elif mode == '2by10':
            sparselist = [1, 5, 10, 15, 20, 30, 35]
            partition = 10*[0] + 10 *[1] 

        print(mode, partition)

        test_json = 'test_interp_1000.json'
        json_file = os.path.join(args.data_dir, test_json)
        data_path = open(json_file, 'r')
        testdata = json.load(data_path)
        interp_list = []

        for i in range(len(testdata)):
            index_key_pair = (i, testdata[i]['shapekey_value'])
            interp_list.append(index_key_pair)
        interp_list.sort(key=lambda pair: pair[1])

        for sparse in sparselist:
            if mode == '5by20':
                train_json = 'cluster_interp_5by20_%d.json'%sparse	
                results_dir = "../experiment/cubesphere_5by20_resnet/results_{}/final_vis".format(sparse)
                similarity_dir = "../experiment/cubesphere_5by20_resnet/results_{}".format(sparse)
            elif mode == '2by10':
                train_json = 'cluster_samept_2by10_%d.json'%sparse
                results_dir = "../experiment/cubesphere_2by10_resnet/results_{}/final_vis".format(sparse)
                similarity_dir = "../experiment/cubesphere_2by10_resnet/results_{}".format(sparse)

            results_name = 'fineptcloud_'
            print('results_dir',results_dir)
            print('similarity_dir',similarity_dir)
            ## If true, then load test results, compute distance matrix 
            ## If false, then load distance matrix computed before  
            if args.first_plot:
                gt_ptcloud, gt_image = load_gt(test_json, args.data_dir)
                
                ########pred_ptcloud, pred_image, _ , pred_codeword = load_pred(args.data_dir, results_dir, args.test_json)
                
                #train_ptcloud, train_image = load_gt(train_json, args.data_dir)
                #pred_ptcloud = load_predpts(results_path = results_dir, results_name = results_name, total_num = 200, origin_test_batch = 100)
                #pred_ptcloud = torch.from_numpy(pred_ptcloud).float()
                sorted_predptcloud = torch.zeros(len(testdata), 1024, 3)
                sorted_gtptcloud = torch.zeros(len(testdata), 1024, 3)
                for i in range(len(testdata)):
                    #sorted_predptcloud[i] = pred_ptcloud[interp_list[i][0]]
                    sorted_gtptcloud[i] = gt_ptcloud[interp_list[i][0]]

                #train_ptcloud = train_ptcloud.cuda()
                #sorted_predptcloud = sorted_predptcloud.cuda()
                #train_image = train_image.cuda()
                sorted_gtptcloud = sorted_gtptcloud.cuda()
                pt_criterion = pt_criterion.cuda()
                img_criterion = img_criterion.cuda()
                train_pt_matrix = compute_squared_EDM_method(train_ptcloud, pt_criterion, True, 
                                                '%s_pt_similarity_matrix.npy'%'train', similarity_dir)
                pred_pt_matrix = compute_squared_EDM_method(sorted_predptcloud, pt_criterion, True, 
                                                '%s_pt_similarity_matrix.npy'%'results', similarity_dir)
                train_img_matrix = compute_squared_EDM_method(train_image, img_criterion, True, 
                                                '%s_img_similarity_matrix.npy'%'train', similarity_dir)
                gt_pt_matrix = compute_squared_EDM_method(sorted_gtptcloud, pt_criterion, True, 		
                                        '%s_pt_similarity_matrix.npy'%'gt', similarity_dir)

            else:
                #pred_ptcloud, pred_image, _ , pred_codeword = load_pred(args.data_dir, results_dir, args.test_json)
                train_pt_matrix = np.load(os.path.join(similarity_dir, '%s_pt_similarity_matrix.npy'%'train'))
                pred_pt_matrix = np.load(os.path.join(similarity_dir, '%s_pt_similarity_matrix.npy'%'results'))
                train_img_matrix = np.load(os.path.join(similarity_dir, '%s_img_similarity_matrix.npy'%'train'))
                gt_pt_matrix = np.load(os.path.join(similarity_dir, '%s_pt_similarity_matrix.npy'%'gt'))
            ## compute modularity

            #gt_part = gt_partition(gt_pt_matrix)
            pred_part = gt_partition(pred_pt_matrix)
            train_pt_mod += [silhouette(train_pt_matrix, partition)]
            train_img_mod += [silhouette(train_img_matrix, partition)]
            #pred_pt_mod += [silhouette(pred_pt_matrix, gt_part)]
            pred_pt_mod += [silhouette(pred_pt_matrix, pred_part)]


        print(idx)

#############################################################################
################################################################################
################################################################################
########Visualtion of Sil... Score##############################		
        plt.subplot(2,2,1+idx*2)
        for i in range(len(train_img_mod)):
            plt.scatter(train_img_mod[i], pred_pt_mod[i], c = colors[i],  s = 600,label = 'dataset index = {}'.format(i+1))
        plt.plot(train_img_mod, pred_pt_mod, linewidth = 7, c= 'k')
        if idx == 1:
            plt.xlabel('Training Image Dataset Silhouette Score',labelpad = 25)
        plt.ylabel('Predicted Point Cloud Results Silhouette Score',labelpad = 22)
        plt.yticks(np.arange(0.25, 0.9, step=0.05))
        #plt.xticks(np.arange(0.5, 1, step=0.1))
        plt.legend(loc = 3, fontsize = 23)
        if mode =='5by20':
            plt.title('Subsampled dataset 1.1')
        if mode =='2by10':
            plt.title('Subsampled dataset 1.2')
        plt.tick_params(labelsize=16)

    
        for x, y in zip(train_img_mod, pred_pt_mod):
            label = "index = {}".format()
            plt.annotate(label, # this is the text
                 (x,y), # this is the point to label
                 textcoords="offset points", # how to position the text
                 xytext=(0,12), # distance from text to points (x,y)
                 ha='center',
                 fontsize=14,
                 fontname='sans')
    
        plt.subplot(2,2,2+idx*2)
        for i in range(len(train_pt_mod)):
            plt.scatter(train_pt_mod[i], pred_pt_mod[i], c = colors[i], s = 600, label = 'dataset index = {}'.format(i+1))
        plt.plot(train_pt_mod, pred_pt_mod, linewidth = 7, c= 'k')
        if idx == 1:
            plt.xlabel('Training Point Cloud Dataset Silhouette Score', labelpad = 25)

        plt.yticks(np.arange(0.25, 0.9, step=0.05))

        if mode =='5by20':
            plt.title('Subsampled dataset 1.1')
        if mode =='2by10':
            plt.title('Subsampled dataset 1.2')
    
        plt.legend(loc = 3, fontsize = 23)
        plt.tick_params(labelsize=16)
    plt.tight_layout()
    plt.savefig('../img/cluster/sparse/cluster_score_method1.png')
    '''

    #############Visualization of Point cloud#####################################################
    ##############################################################################################
    ##############################################################################################

    mode = '5by20'
    # mode = '2by10'

    if mode == '2by10':
        sparselist = [1, 5, 10, 15, 20, 30, 35]
    elif mode == '5by20':
        sparselist = [1, 2, 3, 4, 5, 6, 7]

    #fig.suptitle('Distance Matrices',fontsize=30,fontname="Times New Roman")
    for row, sparse in enumerate(sparselist):
        titles = [
            "train_img_similarity_matrix", "train_pt_similarity_matrix",
            "results_pt_similarity_matrix"
        ]
        for col, title in enumerate(titles):
            if mode == '5by20':
                d = np.load(
                    "../experiment/cubesphere_5by20_resnet/results_{0}/{1}.npy"
                    .format(sparse, title))
            elif mode == '2by10':
                d = np.load(
                    "../experiment/cubesphere_2by10_resnet/results_{0}/{1}.npy"
                    .format(sparse, title))
            print(np.max(np.max(d, axis=1)))
            #plt.subplot(len(sparselist), 3, row  * 3 + col + 1)
            fig, ax = plt.subplots(figsize=(12, 12))
            im = ax.imshow(d,
                           interpolation='nearest',
                           cmap=plt.cm.get_cmap("jet"))
            # plt.imshow(d, interpolation='nearest', cmap=plt.cm.get_cmap("jet"))
            # plt.xticks([])
            # plt.yticks([])

            ax.set_xticks([])
            ax.set_yticks([])

            aspect = 20
            pad_fraction = 0.5

            #plt.tight_layout()
            if col == 0:
                v = np.linspace(0, 8000, 10, endpoint=True)
                # plt.clim(0, 8000)
                # cb = plt.colorbar(ticks=v)
                # cb = plt.colorbar(im,fraction=0.046, pad=0.04)
                # cb.set_ticks([v])
                # #np.arange(0, 8000, 2)
                # plt.colorbar().ax.set_yticklabels(v,
                #     fontsize=16, weight='bold')
                #plt.colorbar().ax.tick_params(labelsize =14)
                # cb.ax.tick_params(labelsize=16)
                divider = make_axes_locatable(ax)
                width = axes_size.AxesY(ax, aspect=1. / aspect)
                pad = axes_size.Fraction(pad_fraction, width)
                cax = divider.append_axes("right", size=width, pad=pad)
                cbar = plt.colorbar(im, cax=cax, ticks=v)
                cbar.remove()
                # cbar.ax.locator_params(nbins=4)
                # # cbar.set_clim(0, 9000)
                # # cbar.set_ticks([0, 3000, 6000])
                # cbar.ax.tick_params(labelsize=16)
                # cbar.set_ticklabels([0, 4000, 8000])
                # plt.ylabel("dataset index = {0}".format(row+1), fontsize=45, labelpad = 30,fontweight ='bold',fontname="Times New Roman")
            elif col == 1:
                v = np.linspace(0, 0.45, 10, endpoint=True)
                divider = make_axes_locatable(ax)
                width = axes_size.AxesY(ax, aspect=1. / aspect)
                pad = axes_size.Fraction(pad_fraction, width)
                cax = divider.append_axes("right", size=width, pad=pad)
                cbar = plt.colorbar(im, cax=cax, ticks=v)
                cbar.remove()
                # # cbar.set_ticks([0, 0.18, 0.36])
                # cbar.ax.tick_params(labelsize=16)
                # cbar.ax.locator_params(nbins=4)
                # #plt.clim(0, 0.6) np.arange(0, 0.6, 2)
                # # plt.colorbar().ax.set_yticklabels(v,
                # #     fontsize=16, weight='bold')
                # #plt.colorbar().ax.tick_params(labelsize =14)
            elif col == 2:
                v = np.linspace(0, 0.45, 10, endpoint=True)
                divider = make_axes_locatable(ax)
                width = axes_size.AxesY(ax, aspect=1. / aspect)
                pad = axes_size.Fraction(pad_fraction, width)
                cax = divider.append_axes("right", size=width, pad=pad)
                cbar = plt.colorbar(im, cax=cax, ticks=v)  #, ticks=v)
                cbar.remove()
                # cbar.ax.locator_params(nbins=4)
                # cbar.ax.tick_params(labelsize=16)
                # #plt.clim(0, 0.6)np.arange(0, 0.6, 2)
                # # plt.colorbar().ax.set_yticklabels(v,
                # #     fontsize=16, weight='bold')
                # #plt.colorbar().ax.tick_params(labelsize =14)
            '''
            if row == 0 and col ==0:
                plt.title('Training Image',fontsize=48, fontweight = 'bold', fontname="Times New Roman",pad=30)
            if row == 0 and col ==1:
                plt.title('Training Point Cloud',fontsize=48,fontweight = 'bold',fontname="Times New Roman",pad=30)
            if row == 0 and col ==2:
                plt.title('Predicted Point Cloud',fontsize=48,fontweight = 'bold',fontname="Times New Roman",pad=30)
            '''
            plt.savefig('cluster_{}_{}_{}.png'.format(mode, row + 1, col + 1),
                        bbox_inches='tight')

    print("cluster-level higher means more separate clusters")
    '''