Exemplo n.º 1
0
def export(data, F, k):
    '''Write data to a png image
    
    Arguments
    ---------
    data : numpy.ndarray
        array containing the data to be written as png image
    F : float
        feed rate of the current configuration
    k : float
        rate constant of the current configuration
    '''
        
    figsize = tuple(s / 72.0 for s in data.shape)
    fig = plt.figure(figsize=figsize, dpi=72.0, facecolor='white')
    fig.add_axes([0, 0, 1, 1], frameon=False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(data, cmap=plt.cm.RdBu_r, interpolation='bicubic')
    plt.gci().set_clim(0, 1)

    filename = './study/F{:03d}-k{:03d}.png'.format(int(1000*F), int(1000*k))
    plt.savefig(filename, dpi=72.0)
    plt.close()
Exemplo n.º 2
0
def plot(f, n=100):
    '''Plot function over the standard quadrilateral.
    '''
    import matplotlib.tri
    import matplotlib.pyplot as plt

    x = numpy.linspace(-1, +1, n)
    y = numpy.linspace(-1, +1, n)
    X, Y = numpy.meshgrid(x, y)
    XY = numpy.stack([X, Y])

    z = numpy.array(f(XY), dtype=float)

    triang = matplotlib.tri.Triangulation(X.flatten(), Y.flatten())
    plt.tripcolor(triang, z.flatten(), shading='flat')
    plt.colorbar()

    # Choose a diverging colormap such that the zeros are clearly
    # distinguishable.
    plt.set_cmap('coolwarm')
    # Make sure the color map limits are symmetric around 0.
    clim = plt.gci().get_clim()
    mx = max(abs(clim[0]), abs(clim[1]))
    plt.clim(-mx, mx)

    # quad outlines
    X = numpy.array([[-1, 1, 1, -1, -1], [-1, -1, 1, 1, -1]])
    plt.plot(X[0], X[1], '-k')

    plt.gca().set_aspect('equal')
    plt.axis('off')
    return
Exemplo n.º 3
0
    def show_matched_boxes(self, image, anchor_polygons, gt_polygons):
        """

        Args:
            anchor_polygons (np.array): in format (x1, y1, x2, y2, x3, y3, x4, y4)
            gt_polygons (np.array): in format (x1, y1, x2, y2, x3, y3, x4, y4)

        Returns:

        """
        h, w, c = image.shape

        for anchor_polygon, gt_polygon in zip(anchor_polygons, gt_polygons):
            plt.imshow(image)
            ax = plt.gca()
            im = plt.gci()
            anchor_polygon = (anchor_polygon * 512).reshape((-1, 2))
            gt_polygon = (gt_polygon * 512).reshape((-1, 2))

            ax.add_patch(
                plt.Polygon(anchor_polygon,
                            fill=False,
                            color=[1.0, 0.0, 0.0, 1.0],
                            linewidth=1))

            ax.add_patch(
                plt.Polygon(gt_polygon,
                            fill=False,
                            color=[0.0, 1.0, 0.0, 1.0],
                            linewidth=1))

            plt.show()
            plt.close()
Exemplo n.º 4
0
 def plot_results(self, results=None, classes=None, show_labels=True, gt_data=None, confidence_threshold=None):
     if results is None:
         results = self.results
     if confidence_threshold is not None:
         mask = results[:, 4] > confidence_threshold
         results = results[mask]
     if classes is not None:
         colors = plt.cm.hsv(np.linspace(0, 1, len(classes)+1)).tolist()
     ax = plt.gca()
     im = plt.gci()
     image_size = im.get_size()
     
     # draw ground truth
     if gt_data is not None:
         for box in gt_data:
             label = np.nonzero(box[4:])[0][0]+1
             color = 'g' if classes == None else colors[label]
             xy_rec = to_rec(box[:4], image_size)
             ax.add_patch(plt.Polygon(xy_rec, fill=True, color=color, linewidth=1, alpha=0.3))
     
     # draw prediction
     for r in results:
         label = int(r[5])
         confidence = r[4]
         color = 'r' if classes == None else colors[label]
         xy_rec = to_rec(r[:4], image_size)
         ax.add_patch(plt.Polygon(xy_rec, fill=False, edgecolor=color, linewidth=2))
         if show_labels:
             label_name = label if classes == None else classes[label]
             xmin, ymin = xy_rec[0]
             display_txt = '%0.2f, %s' % (confidence, label_name)        
             ax.text(xmin, ymin, display_txt, bbox={'facecolor':color, 'alpha':0.5})
Exemplo n.º 5
0
def locate_spot(event):
    """
    Callback function for key_press_event from matplotlib.
    Stores the location of a spot the graph, for later use.
    Press "m" to store.
    """

    global fignum_one
    global fignum_two
    if event.key == 'm':
        canv = event.canvas
        # Have to manually reset focus to the figure or else it breaks after first use
        plt.figure(canv.figure.number)
        ax = event.inaxes
        x = np.rint(event.xdata).astype(np.int)
        y = np.rint(event.ydata).astype(np.int)

        # append in opposite order because y = row, x = column
        adj_locs = fc(plt.gci().get_array().data, [[y, x]])
        if canv.figure.number == fignum_one:
            spot_locs_one.append(adj_locs[0])
        elif canv.figure.number == fignum_two:
            spot_locs_two.append(adj_locs[0])
        else:
            print(
                'Key press was detected in figure number {}, but only figure numbers {} and {} are valid.  Close the program and try again?'
                .format(canv.figure.number, fignum_one, fignum_two))
            raise ValueError

        ax.plot(adj_locs[-1][1], adj_locs[-1][0], 'ro')
        canv.draw()
        print('Point: ({:d}, {:d}) Figure: {}'.format(adj_locs[-1][1],
                                                      adj_locs[-1][0],
                                                      canv.figure.number))
Exemplo n.º 6
0
def sync_axes(axes_list, axis, lim=()):
    assert axis in ['x', 'y', 'c']
    if len(lim) == 0:
        axis_min = np.inf
        axis_max = -np.inf
        for ax in axes_list:
            plt.sca(ax)
            if axis == 'x':
                axis_lims = plt.gca().get_xlim()
            elif axis == 'y':
                axis_lims = plt.gca().get_ylim()
            elif axis == 'c':
                axis_lims = plt.gci().get_clim()
            else:
                raise Exception
            axis_min = min(axis_lims[0], axis_min)
            axis_max = max(axis_lims[1], axis_max)

        if axis_min == axis_max:
            axis_max = axis_min + 1e-3
    else:
        axis_min = lim[0]
        axis_max = lim[1]
    for ax in axes_list:
        plt.sca(ax)
        if axis == 'x':
            plt.xlim(axis_min, axis_max)
        elif axis == 'y':
            plt.ylim(axis_min, axis_max)
        elif axis == 'c':
            plt.clim(axis_min, axis_max)
Exemplo n.º 7
0
def plot_box(box,
             box_format='xywh',
             color='r',
             linewidth=1,
             normalized=False,
             vertices=False):
    if box_format == 'xywh':  # opencv
        xmin, ymin, w, h = box
        xmax, ymax = xmin + w, ymin + h
    elif box_format == 'xyxy':
        xmin, ymin, xmax, ymax = box
    if box_format == 'polygon':
        xy_rec = np.reshape(box, (-1, 2))
    else:
        xy_rec = np.array([[xmin, ymin], [xmax, ymin], [xmax, ymax],
                           [xmin, ymax]])
    if normalized:
        im = plt.gci()
        xy_rec = xy_rec * np.tile(im.get_size(), (4, 1))
    ax = plt.gca()
    ax.add_patch(
        plt.Polygon(xy_rec, fill=False, edgecolor=color, linewidth=linewidth))
    if vertices:
        c = 'rgby'
        for i in range(4):
            plt.plot(xy_rec[i, 0],
                     xy_rec[i, 1],
                     c[i],
                     marker='o',
                     markersize=4)
Exemplo n.º 8
0
    def plot_assignment(self, map_idx):
        ax = plt.gca()
        im = plt.gci()
        image_h, image_w = image_size = im.get_size()

        # ground truth
        boxes = self.gt_boxes
        boxes_x = (boxes[:, 0] + boxes[:, 2]) / 2. * image_w
        boxes_y = (boxes[:, 1] + boxes[:, 3]) / 2. * image_h
        for box in boxes:
            xy_rec = to_rec(box[:4], image_size)
            ax.add_patch(
                plt.Polygon(xy_rec, fill=False, edgecolor='b', linewidth=2))
        plt.plot(boxes_x, boxes_y, 'bo', markersize=6)

        # prior boxes
        for idx, box_idx in self.match_indices.items():
            if idx >= self.map_offsets[map_idx] and idx < self.map_offsets[
                    map_idx + 1]:
                x, y = self.priors_xy[idx]
                w, h = self.priors_wh[idx]
                plt.plot(x, y, 'ro', markersize=4)
                plt.plot([x, boxes_x[box_idx]], [y, boxes_y[box_idx]],
                         '-r',
                         linewidth=1)
                ax.add_patch(
                    plt.Rectangle((x - w / 2, y - h / 2),
                                  w + 1,
                                  h + 1,
                                  fill=False,
                                  edgecolor='y',
                                  linewidth=2))
Exemplo n.º 9
0
def plot(f, n=100, d=1.0):
    import matplotlib.tri
    import matplotlib.pyplot as plt

    x = numpy.linspace(-d, +d, n)
    y = numpy.linspace(-d, +d, n)
    X, Y = numpy.meshgrid(x, y)
    XY = numpy.stack([X, Y])

    z = numpy.array(f(XY), dtype=float)

    triang = matplotlib.tri.Triangulation(X.flatten(), Y.flatten())
    plt.tripcolor(triang, z.flatten(), shading='flat')
    plt.colorbar()

    # Choose a diverging colormap such that the zeros are clearly
    # distinguishable.
    plt.set_cmap('coolwarm')
    # Make sure the color map limits are symmetric around 0.
    clim = plt.gci().get_clim()
    mx = max(abs(clim[0]), abs(clim[1]))
    plt.clim(-mx, mx)

    plt.gca().set_aspect('equal')
    plt.axis('off')
    return
Exemplo n.º 10
0
 def plot_gt(self, boxes, show_labels=True):
     # if parameter is sample index
     if type(boxes) in [int]:
         boxes = self.data[boxes]
     
     ax = plt.gca()
     im = plt.gci()
     w, h = im.get_size()
     
     for box in boxes:
         class_idx = int(box[-1])
         color = self.colors[class_idx]
         is_polygon = len(box)-1 > 4
         if is_polygon:
             xy = box[:-1].reshape((-1,2))
         else:
             xmin, ymin, xmax, ymax = box[:4]
             xy = np.array([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]])
         xy = xy * [h, w]
         ax.add_patch(plt.Polygon(xy, fill=False, edgecolor=color, linewidth=1))
         if show_labels:
             label_name = self.classes[class_idx]
             if is_polygon:
                 angle = np.arctan((xy[1,0]-xy[0,0])/(xy[1,1]-xy[0,1]+eps))
                 if angle < 0:
                     angle += np.pi
                 angle = angle/np.pi*180-90
             else:
                 angle = 0                
             ax.text(xy[0,0], xy[0,1], label_name, bbox={'facecolor':color, 'alpha':0.5}, rotation=angle)
Exemplo n.º 11
0
def makeGuessAboutCmap(clim=None, colormap=None):
    if clim:
        vmin, vmax = clim[0], clim[1]
    else:
        vmin, vmax = plt.gci().get_clim()
        cutOffFrac = 0.5
        if -vmin < vmax and -vmin / vmax > cutOffFrac: vmin = -vmax
        elif -vmin > vmax and -vmax / vmin > cutOffFrac: vmax = -vmin
    if vmin == vmax:
        if debug: print('vmin,vmax=', vmin, vmax)
        vmin = vmin - 1
        vmax = vmax + 1
    plt.clim(vmin, vmax)
    if colormap: plt.set_cmap(colormap)
    else:
        if vmin * vmax >= 0 and vmax > 0 and 3 * vmin < vmax:
            plt.set_cmap('plasma')  # Single signed +ve data
        elif vmin * vmax >= 0 and vmin < 0 and 3 * vmax > vmin:
            plt.set_cmap('plasma_r')  # Single signed -ve data
        elif abs((vmax + vmin) / (vmax - vmin)) < .01:
            plt.set_cmap('seismic')  # Multi-signed symmetric data
        else:
            plt.set_cmap('viridis')
    landColor = [.5, .5, .5]
    plt.gca().set_facecolor(landColor)
    return (vmin, vmax)
Exemplo n.º 12
0
def animateImages(images: list,
                  titles: list = None,
                  dpi: int = None,
                  save_name: str = "_temp",
                  frame_interval: int = 60,
                  repeat_delay: int = 1000,
                  cmap: Union[str, Colormap] = 'viridis',
                  verbose: bool = False):
    assert type(images) is list, "Images should be list or dict."
    assert len(images) != 0, "Images is empty."

    from matplotlib import animation
    from sys import stdout as stdout

    _str_stdout = ""
    if verbose:
        _str_stdout = "Compiling Figures: 0 % / 100 %"
        stdout.write(_str_stdout)

    _frames = []
    n_img = len(images)

    _fig = plt.figure(dpi=dpi)
    for i in range(n_img):
        plt.imshow(images[i], animated=True)
        plt.xticks([]), plt.yticks([])
        plt.set_cmap(cmap)
        title = ""
        try:
            title = plt.text(0.5, 1.01, titles[i],
                             horizontalalignment='center',
                             verticalalignment='bottom',
                             transform=plt.gca().transAxes)
        except IndexError:
            pass
        _frames.append([plt.gci(), title])
        if verbose:
            for _ in _str_stdout:
                stdout.write("\b")
            _str_stdout = f"Compiling figures: " \
                          f"{int(np.round(i * 100 / n_img))} % / 100 %"
            stdout.write(_str_stdout)

    stdout.write("\nCreating animation object, this may take a while.")
    ani = animation.ArtistAnimation(_fig,
                                    _frames,
                                    interval=frame_interval,
                                    blit=True,
                                    repeat_delay=repeat_delay)

    stdout.write(f"\nSaving to figures {save_name}, this may take a while.")
    try:
        ani.save(f'{save_name}.mp4')
    except ValueError:
        ani.save(f'{save_name}.gif')

    return
Exemplo n.º 13
0
def visualize_deviation_matrix(S,
                               fs_act,
                               list_iois,
                               xlim=None,
                               clim=None,
                               colorbar=True):
    from matplotlib.colors import LinearSegmentedColormap

    tolerance = int((S.shape[0] - 1) / 2)
    left = -1 / 2
    right = S.shape[1] - 1 / 2
    lower = -tolerance / fs_act
    upper = tolerance / fs_act

    matrix_to_show = S.copy()
    matrix_to_show += 10**-1
    matrix_to_show[matrix_to_show <= 0] = np.nan
    matrix_to_show -= 10**-1

    # get the colormap right
    log_comp = 100
    log_series = np.log((np.linspace(start=1, stop=log_comp, num=256)))
    min_val = 0.07
    scaled_series = log_series / np.max(log_series) * (1 - min_val) + min_val
    gray_values = 1 - scaled_series
    gray_values_rgb = np.repeat(gray_values.reshape(256, 1), 3, axis=1)
    color_wb = LinearSegmentedColormap.from_list('color_wb',
                                                 gray_values_rgb,
                                                 N=256)

    im = plt.imshow(matrix_to_show,
                    aspect='auto',
                    extent=[left, right, lower, upper],
                    origin='lower',
                    cmap=color_wb)

    if colorbar:
        plt.colorbar()

    plt.plot(list_iois / 2, 'k')
    plt.plot(-list_iois / 2, 'k')

    plt.ylim([lower, upper])

    if xlim is not None:
        plt.xlim(xlim)

    if clim is not None:
        plt.clim(clim[0], clim[1])
    else:
        clim = plt.gci().get_clim()

    plt.xlabel('Tap index')
    plt.ylabel('Deviation (sec)')

    return im, clim
Exemplo n.º 14
0
def clim_std(stdfac=1,ih=None):
  if not ih:
    ih = plt.gci()
  #ihandles = h.findobj(matplotlib.image.AxesImage)
  #ih = ihandles[-1]
  im = ih.get_array()
  imean = np.median(im.ravel()[~np.isnan(im.ravel())])
  istd = mad(im.ravel()[~np.isnan(im.ravel())])
  ih.set_clim([imean-stdfac*istd,imean+stdfac*istd])
  plt.draw_if_interactive()
Exemplo n.º 15
0
def save_image(path,
               data,
               colorbar=True,
               background=None,
               clim=None,
               title=None,
               size=(3, 3),
               dpi=500,
               xlim=None,
               ylim=None,
               sym=True):

    from numpy import abs, max

    with mpl.rc_context(rc=rc):
        plt.ioff()
        fig = plt.figure()

        extent = None
        if xlim and ylim:
            extent = [*xlim, *ylim]

        if background is None:
            if sym:
                plt.imshow(data, extent=extent, cmap='coolwarm')
                clim = plt.gci().get_clim()
                clim = max(abs(clim))
                plt.gci().set_clim([-clim, clim])
            else:
                plt.imshow(data, clim=clim, extent=extent)
        else:
            plt.imshow(background, cmap='Greys', extent=extent)
            plt.imshow(colormap(data, clim=clim), extent=extent)

        if title:
            plt.title(title, fontdict={'fontsize': 'xx-small'})

        if colorbar:
            plt.colorbar()

        fig.set_size_inches(size)
        fig.savefig(path, dpi=dpi, facecolor='white')
        plt.close()
Exemplo n.º 16
0
Arquivo: plot.py Projeto: w-klijn/csa
def inverseGray():
    '''
    set the default colormap to gray and apply to current image if any.
    See help(colormaps) for more information
    '''
    _plt.rc('image', cmap='gray_r')
    im = _plt.gci()

    if im is not None:
        im.set_cmap(_plt.cm.gray_r)
    _plt.draw_if_interactive()
Exemplo n.º 17
0
def plotspectrogram(data: np.ndarray, fs: int, nmode: str):
    fg = figure()
    ax = fg.gca()

    Pxx, freqs, bins, im = specgram(data, NFFT=NFFT, Fs=fs, noverlap=500)

    fg.colorbar(gci(), ax=ax)

    ax.set_title(f"{nmode} noise")
    ax.set_xlabel("time (sec.)")
    ax.set_ylabel("frequency (Hz)")
Exemplo n.º 18
0
def linecut(width=1,plot=True,pts=None,data=None,\
            avg_profiles=True,mode='nearest',**kwargs):

    import numpy as np
    from scipy import ndimage

    from_image = False
    if hasattr(data, 'get_array'):
        im = data
        from_image = True
    elif data is None:
        im = plt.gci()
        assert im is not None, 'No current image, so provide explicit `data`!'
        from_image = True
    else:
        data = np.asarray(data)
        from_image = False

    if from_image:
        data = im.get_array().T  #Transpose because axes are swtiched in image

    if pts is not None: pt1, pt2 = pts
    else:
        PP = PointPicker(max_pts=2, verbose=True, mousebutton=3)
        pt1, pt2 = PP.get_points()
    (x1, y1), (x2, y2) = pt1, pt2
    N = int(numpy.sqrt((x2 - x1)**2 + (y2 - y1)**2))

    angle = numpy.arctan2(y2 - y1, x2 - x1)
    dx, dy = numpy.sin(angle), -numpy.cos(angle)

    profiles = []
    for lineno in range(int(width)):
        xoffset = (width / 2 - .5 - lineno) * dx
        yoffset = (width / 2 - .5 - lineno) * dy
        xi, yi = x1 + xoffset, y1 + yoffset
        xf, yf = x2 + xoffset, y2 + yoffset

        xs = numpy.linspace(xi, xf, N)
        ys = numpy.linspace(yi, yf, N)
        if from_image and plot:
            X1, X2, Y1, Y2 = im.get_extent()
            dX = (X2 - X1) / data.shape[0]
            Xs = X1 + xs * dX
            dY = (Y2 - Y1) / data.shape[1]
            Ys = Y1 + ys * dY
            im.axes.plot(Xs, Ys, color='k', alpha=.5)

        profile = ndimage.map_coordinates(data, np.vstack((xs,ys)),\
                                          mode=mode,**kwargs)
        profiles.append(profile)

    if avg_profiles: return numpy.mean(profiles, axis=0)
    else: return numpy.array(profiles)
def onclick(event):
    global cid2
    # event from matplotlib
    x = event.xdata; y = event.ydata
    xndx = np.where(mass_vec_edges < x)[0][-1]
    yndx = np.where(age_vec_edges < y)[0][-1]
    mass = np.int(mass_vec[xndx])
    age  = np.int(age_vec[yndx])
    mass_str = str(mass).zfill(3)
    age_str = str(age).zfill(4)
    name_str = path + "spec_" + ATM_TYPE + "_mass_" + mass_str + "_age_" + age_str + ".txt"
    print name_str
    data = np.loadtxt(name_str)
    this_Si_vec = data[1:,0]
    this_Si_vec_edges = get_cell_edges(this_Si_vec)
    yval_edges = np.array([-1.,1.])
    xval_edges = this_Si_vec_edges
    xvalA,yvalA = np.meshgrid(xval_edges,yval_edges)
    zvalA = np.zeros_like(xvalA)
    for ii in range(len(this_Si_vec)):
        zvalA[0,ii] = this_Si_vec[ii]
    plt.figure(2,figsize=(8,3))
    plt.pcolor(xvalA,yvalA,zvalA,edgecolors='k')
    hh = plt.gci()
    ax2 = plt.gca()
    loc = plticker.MultipleLocator(base=0.5)
    ax.xaxis.set_major_locator(loc)
    ax2.set_xlim([min(xval_edges),max(xval_edges)])
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    title_str = 'Mass: ' + str(mass) + ' Mjup; Age: ' + str(age) + ' Myr; ' + \
                ' Click to select the initial entropy.'
    plt.title(title_str)
    plt.xlabel('Initial Entropy (kB/baryon)')
    def onclick2(event2):
        x2 = event2.xdata; y2 = event2.ydata
        print x2, this_Si_vec_edges
        rgb = hh.cmap(hh.norm(x2))[:-1] # last element is alpha (transparency)
        xndx2 = np.where(this_Si_vec_edges < x2)[0][-1]
        print len(this_Si_vec_edges),'****',name_str
        this_spec_row_ndx = 1+xndx2
        print this_spec_row_ndx
        this_spec = data[this_spec_row_ndx,1:]
        plt.figure(3,figsize=(6,6))
        plt.plot(lambdas,this_spec,color=rgb)
        plt.xlabel('Wavelength (microns)')
        plt.ylabel('Flux (mJy)')
        print "***********", x2, rgb, this_spec_row_ndx, this_spec[:5]
        plt.show()
    plt.ion()
    if not cid2 is None:
        plt.disconnect(cid2)
    cid2 = plt.connect('button_press_event',onclick2)
    plt.show()
Exemplo n.º 20
0
def divergentMesh(X,
                  Y,
                  Z,
                  xlabel='x',
                  ylabel='y',
                  vmax=None,
                  generateCBAR=True,
                  zOrder=None,
                  plotOptions={}):
    vmin0, vmax0 = None, None

    absmax = np.abs(Z).max()
    if vmax is None:
        vmax = absmax
    if not generateCBAR:
        vmin0, vmax0 = plt.gci().get_clim()
    if vmin0 is None:
        vmin0, vmax0 = -vmax, vmax
    vmin, vmax = vmin0, vmax0
    print(vmax)
    #     plt.set_cmap('bwr')
    # plt.pcolormesh(X,Y,Z,cmap='inferno',vmin=vmin, vmax=vmax,linewidth=0,rasterized=True)
    # plt.pcolormesh(X,Y,Z,cmap=BKR_cmap,vmin=-vmax, vmax=vmax,linewidth=0,rasterized=True)
    if zOrder is not None:
        #         plt.pcolormesh(X,Y,Z,vmin=-vmax, vmax=vmax,linewidth=0,rasterized=True, zorder=zOrder, **plotOptions)
        plt.pcolormesh(X,
                       Y,
                       Z,
                       cmap=fireIce_cmap,
                       vmin=-vmax,
                       vmax=vmax,
                       linewidth=0,
                       rasterized=True,
                       zorder=zOrder,
                       **plotOptions)
    else:
        #         plt.pcolormesh(X,Y,Z,vmin=-vmax, vmax=vmax,linewidth=0,rasterized=True, **plotOptions)
        plt.pcolormesh(X,
                       Y,
                       Z,
                       cmap=fireIce_cmap,
                       vmin=-vmax,
                       vmax=vmax,
                       linewidth=0,
                       rasterized=True,
                       **plotOptions)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if generateCBAR:
        cbar = plt.colorbar()
        return cbar
    return None
Exemplo n.º 21
0
 def draw(self):
     """Draw the reaction."""
     plt.axis([0, self.rows, 0, self.cols])
     plt.xticks([])
     plt.yticks([])
     shared_options = dict(interpolation='bicubic',
                           vmin=None,
                           vmax=None,
                           alpha=0.3,
                           origin='upper',
                           extent=[0, self.rows, 0, self.cols])
     plt.imshow(self.array_a, cmap='twilight', **shared_options)
     plt.imshow(self.array_b, cmap='cool', **shared_options)
     self.image = plt.gci()
Exemplo n.º 22
0
def defaultMesh(X,
                Y,
                Z,
                xlabel='x',
                ylabel='y',
                vmin=None,
                vmax=None,
                generateCBAR=True,
                zOrder=None,
                plotOptions={}):
    vmin0, vmax0 = vmin, vmax

    if (not generateCBAR):
        vmin0, vmax0 = plt.gci().get_clim()

    if vmin0 is None:
        vmin0, vmax0 = vmin, vmax
    vmin, vmax = vmin0, vmax0
    print(vmin, vmax)
    # plt.pcolormesh(X,Y,Z,cmap='inferno',vmin=vmin, vmax=vmax,linewidth=0,rasterized=True)
    # plt.pcolormesh(X,Y,Z,cmap=KR_cmap,vmin=vmin, vmax=vmax,linewidth=0,rasterized=True)
    if zOrder is not None:
        plt.pcolormesh(X,
                       Y,
                       Z,
                       cmap=fire_cmap,
                       vmin=vmin,
                       vmax=vmax,
                       linewidth=0,
                       rasterized=True,
                       zorder=zOrder,
                       **plotOptions)
    else:
        plt.pcolormesh(X,
                       Y,
                       Z,
                       cmap=fire_cmap,
                       vmin=vmin,
                       vmax=vmax,
                       linewidth=0,
                       rasterized=True,
                       **plotOptions)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if generateCBAR:
        cbar = plt.colorbar()
        return cbar
    return None
Exemplo n.º 23
0
def plot(corners, f, n=100):
    '''Plot function over a triangle.
    '''
    import matplotlib.tri
    import matplotlib.pyplot as plt

    # discretization points
    def partition(boxes, balls):
        # <https://stackoverflow.com/a/36748940/353337>
        def rec(boxes, balls, parent=tuple()):
            if boxes > 1:
                for i in range(balls + 1):
                    for x in rec(boxes - 1, i, parent + (balls - i, )):
                        yield x
            else:
                yield parent + (balls, )

        return list(rec(boxes, balls))

    bary = numpy.array(partition(3, n)).T / n
    X = numpy.sum([numpy.outer(bary[k], corners[:, k]) for k in range(3)],
                  axis=0).T

    # plot the points
    # plt.plot(X[0], X[1], 'xk')

    x = numpy.array(X[0])
    y = numpy.array(X[1])
    z = numpy.array(f(bary), dtype=float)

    triang = matplotlib.tri.Triangulation(x, y)
    plt.tripcolor(triang, z, shading='flat')
    plt.colorbar()

    # Choose a diverging colormap such that the zeros are clearly
    # distinguishable.
    plt.set_cmap('coolwarm')
    # Make sure the color map limits are symmetric around 0.
    clim = plt.gci().get_clim()
    mx = max(abs(clim[0]), abs(clim[1]))
    plt.clim(-mx, mx)

    # triangle outlines
    X = numpy.column_stack([corners, corners[:, 0]])
    plt.plot(X[0], X[1], '-k')

    plt.gca().set_aspect('equal')
    plt.axis('off')
    return
Exemplo n.º 24
0
def spatial_plot_format(var_name, scale):
    """ Formats TPV spatial plots consistently """
    title_str = "TPV effect on " + var_name
    plt.title(title_str)
    
    cbar_str = var_name
    if scale:
        cbar_str = cbar_str + " (stdv)"
    plt.colorbar(fraction=0.023, pad=0.04, label = cbar_str)
    
    if var_name != 'grid_count':
        A, B = plt.gci().get_clim()
        cmax=np.abs(np.array([A, B])).max()
        plt.clim(-cmax, cmax)
    else:
        A, B = plt.gci().get_clim()
        cmax=np.abs(np.array([A, B])).max()
        plt.clim(0, cmax/3)
        plt.title('Grid box count')
    
    plt.axhline(y=0, color='k', linestyle='-')
    plt.axvline(x=0, color='k', linestyle='-')
    plt.xlabel('x (km)')
    plt.ylabel('y (km)')
Exemplo n.º 25
0
def sync_clim_axes(axes_list, clim=()):
    if len(clim) == 0:
        cmin = np.inf
        cmax = -np.inf
        for ax in axes_list:
            plt.sca(ax)
            clims = plt.gci().get_clim()
            cmin = min(clims[0], cmin)
            cmax = max(clims[1], cmax)
        if cmin == cmax:
            cmax = cmin + 1e-3
    else:
        cmin = clim[0]
        cmax = clim[1]
    for ax in axes_list:
        plt.sca(ax)
        plt.clim(cmin, cmax)
Exemplo n.º 26
0
def plot_single(degrees,
                res=100,
                scaling="normal",
                colorbar=True,
                cmap="RdBu_r",
                corners=None):
    import meshzoo
    from matplotlib import pyplot as plt

    n = sum(degrees)
    r = degrees[0]

    def f(bary):
        for k, level in enumerate(Eval(bary, scaling)):
            if k == n:
                return level[r]

    if corners is None:
        alpha = numpy.pi * numpy.array([7.0 / 6.0, 11.0 / 6.0, 3.0 / 6.0])
        corners = numpy.array([numpy.cos(alpha), numpy.sin(alpha)])

    bary, cells = meshzoo.triangle(res)
    x, y = numpy.dot(corners, bary)
    z = numpy.array(f(bary), dtype=float)

    plt.tripcolor(x, y, cells, z, shading="flat")

    if colorbar:
        plt.colorbar()
    # Choose a diverging colormap such that the zeros are clearly distinguishable.
    plt.set_cmap(cmap)
    # Make sure the color map limits are symmetric around 0.
    clim = plt.gci().get_clim()
    mx = max(abs(clim[0]), abs(clim[1]))
    plt.clim(-mx, mx)

    # triangle outlines
    X = numpy.column_stack([corners, corners[:, 0]])
    plt.plot(X[0], X[1], "-k")

    plt.gca().set_aspect("equal")
    plt.axis("off")
    plt.title(
        f"Orthogonal polynomial on triangle ([{degrees[0]}, {degrees[1]}], {scaling})"
    )
Exemplo n.º 27
0
 def plot(self,
          showvalues=False,
          normed=None,
          cmap=None,
          log=None,
          cut=None):
     '''
     Plot the joint statistics histogram generated with :func:`joint`
     '''
     if normed is None:
         normed = self.normed
     X, Y = np.meshgrid(self.bins[0], self.bins[1])
     if normed:
         data = self.hist / self.N
     else:
         data = self.hist
     q = -1
     if cut:
         data = data[:3, :3]
         X = X[:3, :3]
         Y = Y[:3, :3]
         q = 2
     if log is None or log == False:
         norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
     else:
         vmin = np.min(data[np.nonzero(data)])
         norm = mpl.colors.LogNorm(vmin=vmin, vmax=data.max())
     cplt = plt.pcolormesh(X, Y, data, cmap=cmap, norm=norm)
     ax = plt.gca()
     ax.set_xticks(self.bins[0][:q] + 0.5)
     ax.set_xticklabels(np.array(self.bins[0][:q], dtype=np.uint16))
     ax.set_yticks(self.bins[1][:q] + 0.5)
     ax.set_yticklabels(np.array(self.bins[1][:q], dtype=np.uint16))
     vmin, vmax = plt.gci().get_clim()
     cm = cplt.get_cmap()
     sm = cmx.ScalarMappable(norm=norm, cmap=cm)
     plt.xlabel('Region 1')
     plt.ylabel('Region 2')
     if showvalues:
         for i, v in np.ndenumerate(data):
             rgba = opcolor(sm.to_rgba(v))
             if normed:
                 plt.text(i[1] + 0.4, i[0] + 0.4, "%.3e" % v, color=rgba)
             else:
                 plt.text(i[1] + 0.4, i[0] + 0.4, "%d" % v, color=rgba)
Exemplo n.º 28
0
def spherical_plot_format(var_name, scale):
    """ Formats TPV spherical plots consistently """
    title_str = "TPV effect on " + var_name
    plt.title(title_str)

    cbar_str = var_name
    if scale:
        cbar_str = cbar_str + " (stdv)"
    plt.colorbar(fraction=0.023, pad=0.04, label=cbar_str)

    A, B = plt.gci().get_clim()
    cmax = np.abs(np.array([A, B])).max()

    if var_name != 'grid_count':
        plt.clim(-cmax, cmax)
    else:
        plt.clim(0, cmax)
        plt.title('Grid box count')
Exemplo n.º 29
0
def locate_spot(event):
    """
    Callback function for key_press_event from matplotlib.
    Stores the location of a spot the graph, for later use.
    Press "m" to store.
    """
    if event.key == 'm':
        canv = event.canvas
        # Have to manually reset focus to the figure or else it breaks after first use
        plt.figure(canv.figure.number)
        ax = event.inaxes
        x = np.rint(event.xdata).astype(np.int)
        y = np.rint(event.ydata).astype(np.int)

        # append in opposite order because y = row, x = column
        adj_locs = fc(plt.gci().get_array().data, [[y, x]])
        spot_locs.append(adj_locs[0])
        ax.plot(spot_locs[-1][1], spot_locs[-1][0], 'ro')
        canv.draw()
        print('Point: ({:d}, {:d})'.format(spot_locs[-1][1], spot_locs[-1][0]))
Exemplo n.º 30
0
def plot_reconstruct_stimulus(response, rf, xlims=[18, 27], ylims=[3, 12]):
    """Reconstruct stimulus using responses and receptive fields """

    ss = 0 * rf[:, :, 0]
    for icell in range(rf.shape[-1] - 1):
        ss += rf[:, :, icell] * response[icell]
    ss += rf[:, :, -1]

    plt.imshow(ss, interpolation='nearest', cmap='gray', clim=(-0.01, 0.01))
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.xlim(xlims)
    plt.ylim(ylims)

    print(plt.gci().get_clim())
    #plt.xlim([18*2, 27*2])
    #plt.ylim([3*2, 12*2])

    return ss
Exemplo n.º 31
0
def makeGuessAboutCmap(clim=None, colormap=None):
  if clim:
    vmin, vmax = clim[0], clim[1]
  else:
    vmin, vmax = plt.gci().get_clim()
    cutOffFrac = 0.5
    if -vmin<vmax and -vmin/vmax>cutOffFrac: vmin=-vmax
    elif -vmin>vmax and -vmax/vmin>cutOffFrac: vmax=-vmin
  if vmin==vmax:
    if debug: print('vmin,vmax=',vmin,vmax)
    vmin = vmin - 1; vmax = vmax + 1
  plt.clim(vmin,vmax)
  if colormap: plt.set_cmap(colormap)
  else:
    if vmin*vmax>=0 and vmax>0 and 3*vmin<vmax: plt.set_cmap('plasma') # Single signed +ve data
    elif vmin*vmax>=0 and vmin<0 and 3*vmax>vmin: plt.set_cmap('plasma_r') # Single signed -ve data
    elif abs((vmax+vmin)/(vmax-vmin))<.01: plt.set_cmap('seismic') # Multi-signed symmetric data
    else: plt.set_cmap('viridis')
  landColor=[.5,.5,.5]
  plt.gca().set_facecolor(landColor)
  return (vmin, vmax)
Exemplo n.º 32
0
    def update(self, data):
        # check if image exists
        axesImage = plt.gci()
        if axesImage is None:
            # create image
            axesImage = plt.imshow(data, cmap=self.settings['colormap'], \
                                   interpolation=self.settings['interpolation'])
            
        # set new data
        axesImage.set_data(data)
        axesImage.set_clim(data.min(), data.max())

        # update plot in interactive mode
        if self.show == True:
            plt.draw()
            
        # write out new image
        if self.export != None:
            filename = '{}temp-{:05d}.png'.format(self.export, self.images_written)            
            self.images_written += 1
            plt.savefig(filename, dpi=self.settings['dpi']) 
Exemplo n.º 33
0
    def Colorbar(self):
        '''
		return current colorbar

		colorbar of healpix:
			use plt.gca()

		colorbar of plt.pcolor:
			use plt.gci() == plt.gcf()._gci()
	
		(cbar.ax) can be as the input ax in
			jp.Plt.Axes...
		'''
        import matplotlib.pyplot as plt
        ax = plt.gca()
        if ('healpy' in str(ax.__class__)):
            im = ax.get_images()[-1]
        else:
            im = plt.gci()
        cbar = im.colorbar
        return cbar
Exemplo n.º 34
0
def plot(f, lcar=1.0e-1):
    '''Plot function over a disk.
    '''
    import matplotlib
    import matplotlib.pyplot as plt
    import pygmsh

    geom = pygmsh.built_in.Geometry()
    geom.add_circle(
        [0.0, 0.0, 0.0],
        1.0,
        lcar,
        num_sections=4,
        compound=True,
    )
    points, cells, _, _, _ = pygmsh.generate_mesh(geom, verbose=True)

    x = points[:, 0]
    y = points[:, 1]
    triang = matplotlib.tri.Triangulation(x, y, cells['triangle'])

    plt.tripcolor(triang, f(points.T), shading='flat')
    plt.colorbar()

    # Choose a diverging colormap such that the zeros are clearly
    # distinguishable.
    plt.set_cmap('coolwarm')
    # Make sure the color map limits are symmetric around 0.
    clim = plt.gci().get_clim()
    mx = max(abs(clim[0]), abs(clim[1]))
    plt.clim(-mx, mx)

    # circle outline
    circle = plt.Circle((0, 0), 1.0, edgecolor='k', fill=False)
    plt.gca().add_artist(circle)

    plt.gca().set_aspect('equal')
    plt.axis('off')
    return
Exemplo n.º 35
0
def set_clim_to_convex_hull(*all_axes):
    """Set the color scale limits of the given axes to the convex hull of all those specified."""
    min_limit = None
    max_limit = None

    # Save the current axes so we can restore the context after this function call.
    current_axes = pyplot.gca()

    for axes in all_axes:
        # This is the only public-API method for obtaining 'images' (e.g. a pcolormesh) from each axis.
        pyplot.sca(axes)
        image = pyplot.gci()
        this_min, this_max = image.get_clim()
        min_limit = this_min if min_limit is None else min(this_min, min_limit)
        max_limit = this_max if max_limit is None else max(this_max, max_limit)

    # We now have the new limits, so set everywhere.
    for axes in all_axes:
        pyplot.sca(axes)
        pyplot.clim(min_limit, max_limit)

    pyplot.sca(current_axes)
Exemplo n.º 36
0
def main():
    ###############################################################################
    ##### PARAMETERS ##############################################################
    ###############################################################################
    
    # Some general help descriptions
    ######### Some general plotting arguments descriptions ###############
    helpinput = 'The file name of the input Experimental Matrix file. Recommended to add more columns for more information for ploting. For example, cell type or factors.'
    helpoutput = 'The directory name for the output files. For example, project name.'
    helptitle = 'The title shown on the top of the plot and also the folder name.'
    helpgroup = "Group the data by reads(needs 'factor' column), regions(needs 'factor' column), another name of column (for example, 'cell')in the header of experimental matrix, or None."
    helpgroupbb = "Group the data by any optional column (for example, 'cell') of experimental matrix, or None."
    helpsort = "Sort the data by reads(needs 'factor' column), regions(needs 'factor' column), another name of column (for example, 'cell')in the header of experimental matrix, or None."
    helpcolor = "Color the data by reads(needs 'factor' column), regions(needs 'factor' column), another name of column (for example, 'cell')in the header of experimental matrix, or None."
    helpcolorbb = "Color the data by any optional column (for example, 'cell') of experimental matrix, or None."
    helpDefinedColot = 'Define the specific colors with the given column "color" in experimental matrix. The color should be in the format of matplotlib.colors. For example, "r" for red, "b" for blue, or "(100, 35, 138)" for RGB.'
    helpreference = 'The file name of the reference Experimental Matrix. Multiple references are acceptable.'
    helpquery = 'The file name of the query Experimental Matrix. Multiple queries are acceptable.'
    
    parser = argparse.ArgumentParser(description='Provides various Statistical analysis methods and plotting tools for ExperimentalMatrix.\
    \nAuthor: Joseph Kuo, Ivan Gesteira Costa Filho', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    subparsers = parser.add_subparsers(help='sub-command help',dest='mode')
    
    ################### Projection test ##########################################
    parser_projection = subparsers.add_parser('projection',help='Projection test evaluates the association level by comparing to the random binomial model.')
    parser_projection.add_argument('-r', metavar='  ', help=helpreference)
    parser_projection.add_argument('-q', metavar='  ', help=helpquery)
    parser_projection.add_argument('-o', metavar='  ', help=helpoutput) 
    parser_projection.add_argument('-t', metavar='  ', default='projection_test', help=helptitle)
    parser_projection.add_argument('-g', metavar='  ', default=None, help=helpgroupbb +" (Default:None)")
    parser_projection.add_argument('-c', metavar='  ', default="regions", help=helpcolorbb +' (Default: regions)')
    parser_projection.add_argument('-bg', metavar='  ', default=None, help="Define a BED file as background. If not defined, the background is whole genome according to the given organism.")
    parser_projection.add_argument('-union', action="store_true", help='Take the union of references as background for binominal test.')
    parser_projection.add_argument('-organism', metavar='  ', default='hg19', help='Define the organism. (Default: hg19)')
    parser_projection.add_argument('-log', action="store_true", help='Set y axis of the plot in log scale.')
    parser_projection.add_argument('-color', action="store_true", help=helpDefinedColot)
    parser_projection.add_argument('-show', action="store_true", help='Show the figure in the screen.')
    parser_projection.add_argument('-table', action="store_true", help='Store the tables of the figure in text format.')
    
    ################### Intersect Test ##########################################
    parser_intersect = subparsers.add_parser('intersect',help='Intersection test provides various modes of intersection to test the association between references and queries.')
    parser_intersect.add_argument('-r', metavar='  ', help=helpreference)
    parser_intersect.add_argument('-q', metavar='  ', help=helpquery)
    parser_intersect.add_argument('-o', help=helpoutput)
    parser_intersect.add_argument('-t', metavar='  ', default='intersection_test', help=helptitle)
    parser_intersect.add_argument('-g', metavar='  ', default=None, help=helpgroupbb +" (Default:None)")
    parser_intersect.add_argument('-c', metavar='  ', default="regions", help=helpcolorbb +' (Default: regions)')
    parser_intersect.add_argument('-organism', metavar='  ', default='hg19', help='Define the organism. (Default: hg19)')
    parser_intersect.add_argument('-bg', metavar='  ', help="Define a BED file as background. If not defined, the background is whole genome according to the given organism.")
    parser_intersect.add_argument('-m', metavar='  ', default="count", choices=['count','bp'],
                                  help="Define the mode of calculating intersection. \
                                  'count' outputs the number of overlapped regions.\
                                  'bp' outputs the coverage(basepair) of intersection.")
    parser_intersect.add_argument('-tc', metavar='  ', type=int, default=False, help="Define the threshold(in percentage) of reference length for intersection counting. For example, '20' means that the query which overlaps more than 20%% of reference is counted as intersection.")
    parser_intersect.add_argument('-ex', metavar='  ', type=int, default=0, help="Define the extension(in percentage) of reference length for intersection counting. For example, '20' means that each region of reference is extended by 20%% in order to include proximal queries.")
    parser_intersect.add_argument('-log', action="store_true", help='Set y axis of the plot in log scale.')
    parser_intersect.add_argument('-color', action="store_true", help=helpDefinedColot)
    parser_intersect.add_argument('-show', action="store_true", help='Show the figure in the screen.')
    parser_intersect.add_argument('-stest', metavar='  ', type=int, default= 0, help='Define the repetition time of random subregion test between reference and query.')
    
    ################### Jaccard test ##########################################
    
    parser_jaccard = subparsers.add_parser('jaccard',help='Jaccard test evaluates the association level by comparing with jaccard index from repeating randomization.')
    
    parser_jaccard.add_argument('-o', help=helpoutput) 
    parser_jaccard.add_argument('-r', '--reference',help=helpreference)
    parser_jaccard.add_argument('-q', '--query', help=helpquery)
    parser_jaccard.add_argument('-t','--title', default='jaccard_test', help=helptitle)
    parser_jaccard.add_argument('-rt','--runtime', type=int, default=500, help='Define how many times to run the randomization. (Default:500)')
    parser_jaccard.add_argument('-g', default=None, help=helpgroupbb +" (Default:None)")
    parser_jaccard.add_argument('-c', default="regions", help=helpcolorbb +' (Default: regions)')
    parser_jaccard.add_argument('-organism',default='hg19', help='Define the organism. (Default: hg19)')
    parser_jaccard.add_argument('-nlog', action="store_false", help='Set y axis of the plot not in log scale.')
    parser_jaccard.add_argument('-color', action="store_true", help=helpDefinedColot)
    parser_jaccard.add_argument('-show', action="store_true", help='Show the figure in the screen.')
    parser_jaccard.add_argument('-table', action="store_true", help='Store the tables of the figure in text format.')

    ################### Combinatorial Test ##########################################
    parser_combinatorial = subparsers.add_parser('combinatorial',help='Combinatorial test compare all combinatorial possibilities from reference to test the association between references and queries.')
    
    parser_combinatorial.add_argument('-o', help=helpoutput)
    parser_combinatorial.add_argument('-r', '--reference',help=helpreference)
    parser_combinatorial.add_argument('-q', '--query', help=helpquery)
    parser_combinatorial.add_argument('-t','--title', default='combinatorial_test', help=helptitle)
    parser_combinatorial.add_argument('-g', default=None, help=helpgroupbb +" (Default:None)")
    parser_combinatorial.add_argument('-c', default="regions", help=helpcolorbb +' (Default: regions)')
    parser_combinatorial.add_argument('-organism',default='hg19', help='Define the organism. (Default: hg19)')
    parser_combinatorial.add_argument('-bg', help="Define a BED file as background. If not defined, the background is whole genome according to the given organism.")
    parser_combinatorial.add_argument('-m', default="count", choices=['count','bp'],
                                      help="Define the mode of calculating intersection. \
                                      'count' outputs the number of overlapped regions.\
                                      'bp' outputs the coverage(basepair) of intersection.")
    parser_combinatorial.add_argument('-tc', type=int, default=False, help="Define the threshold(in percentage) of reference length for intersection counting. For example, '20' means that the query which overlaps more than 20%% of reference is counted as intersection.")
    parser_combinatorial.add_argument('-ex', type=int, default=0, help="Define the extension(in percentage) of reference length for intersection counting. For example, '20' means that each region of reference is extended by 20%% in order to include proximal queries.")
    parser_combinatorial.add_argument('-log', action="store_true", help='Set y axis of the plot in log scale.')
    parser_combinatorial.add_argument('-color', action="store_true", help=helpDefinedColot)
    parser_combinatorial.add_argument('-show', action="store_true", help='Show the figure in the screen.')
    parser_combinatorial.add_argument('-stest', type=int, default= 0, help='Define the repetition time of random subregion test between reference and query.')
    
    ################### Boxplot ##########################################
    
    parser_boxplot = subparsers.add_parser('boxplot',help='Boxplot based on the BAM and BED files for gene association analysis.')
    parser_boxplot.add_argument('input',help=helpinput)
    parser_boxplot.add_argument('-o', metavar='  ', help=helpoutput)
    parser_boxplot.add_argument('-t', metavar='  ', default='boxplot', help=helptitle)
    parser_boxplot.add_argument('-g', metavar='  ', default='reads', help=helpgroup + " (Default:reads)")
    parser_boxplot.add_argument('-c', metavar='  ', default='regions', help=helpcolor + " (Default:regions)")
    parser_boxplot.add_argument('-s', metavar='  ', default='None', help=helpsort + " (Default:None)")
    parser_boxplot.add_argument('-sy', action="store_true", help="Share y axis for convenience of comparison.")
    parser_boxplot.add_argument('-nlog', action="store_false", help='Set y axis of the plot not in log scale.')
    parser_boxplot.add_argument('-color', action="store_true", help=helpDefinedColot)
    parser_boxplot.add_argument('-nqn', action="store_true", help='No quantile normalization in calculation.')
    parser_boxplot.add_argument('-df', action="store_true", help="Show the difference of the two signals which share the same labels.The result is the subtraction of the first to the second.")
    parser_boxplot.add_argument('-ylim', metavar='  ', type=int, default=None, help="Define the limit of y axis.")
    parser_boxplot.add_argument('-p', metavar='  ', type=float, default=0.05, help='Define the significance level for multiple test. Default: 0.01')
    parser_boxplot.add_argument('-show', action="store_true", help='Show the figure in the screen.')
    parser_boxplot.add_argument('-table', action="store_true", help='Store the tables of the figure in text format.')
    
    ################### Lineplot ##########################################
    parser_lineplot = subparsers.add_parser('lineplot', help='Generate lineplot with various modes.')
    
    choice_center = ['midpoint','leftend','rightend','bothends'] 
    # Be consist as the arguments of GenomicRegionSet.relocate_regions
    
    parser_lineplot.add_argument('input', help=helpinput)
    parser_lineplot.add_argument('-o', help=helpoutput)
    parser_lineplot.add_argument('-ga', action="store_true", help="Use genetic annotation data as input regions (e.g. TSS, TTS, exons and introns) instead of the BED files in the input matrix.")
    parser_lineplot.add_argument('-t', metavar='  ', default='lineplot', help=helptitle)
    parser_lineplot.add_argument('-center', metavar='  ', choices=choice_center, default='midpoint', 
                                 help='Define the center to calculate coverage on the regions. Options are: '+', '.join(choice_center) + 
                                 '.(Default:midpoint) The bothend mode will flap the right end region for calculation.')
    parser_lineplot.add_argument('-g', metavar='  ', default='reads', help=helpgroup + " (Default:reads)")
    parser_lineplot.add_argument('-c', metavar='  ', default='regions', help=helpcolor + " (Default:regions)")
    parser_lineplot.add_argument('-s', metavar='  ', default='None', help=helpsort + " (Default:None)")
    parser_lineplot.add_argument('-e', metavar='  ', type=int, default=2000, help='Define the extend length of interested region for plotting.(Default:2000)')
    parser_lineplot.add_argument('-rs', metavar='  ', type=int, default=200, help='Define the readsize for calculating coverage.(Default:200)')
    parser_lineplot.add_argument('-ss', metavar='  ', type=int, default=50, help='Define the stepsize for calculating coverage.(Default:50)')
    parser_lineplot.add_argument('-bs', metavar='  ', type=int, default=100, help='Define the binsize for calculating coverage.(Default:100)')
    parser_lineplot.add_argument('-sy', action="store_true", help="Share y axis for convenience of comparison.")
    parser_lineplot.add_argument('-sx', action="store_true", help="Share x axis for convenience of comparison.")
    parser_lineplot.add_argument('-organism', metavar='  ', default='hg19', help='Define the organism. (Default: hg19)')
    parser_lineplot.add_argument('-color', action="store_true", help=helpDefinedColot)
    parser_lineplot.add_argument('-mp', action="store_true", help="Perform multiprocessing for faster computation.")
    parser_lineplot.add_argument('-df', action="store_true", help="Show the difference of the two signals which share the same labels.The result is the subtraction of the first to the second.")
    parser_lineplot.add_argument('-show', action="store_true", help='Show the figure in the screen.')
    parser_lineplot.add_argument('-table', action="store_true", help='Store the tables of the figure in text format.')
    
    ################### Heatmap ##########################################
    parser_heatmap = subparsers.add_parser('heatmap', help='Generate heatmap with various modes.')
    
    choice_center = ['midpoint','leftend','rightend','bothends'] 
    # Be consist as the arguments of GenomicRegionSet.relocate_regions
    
    parser_heatmap.add_argument('input', help=helpinput)
    parser_heatmap.add_argument('-o', metavar='  ', help=helpoutput)
    parser_heatmap.add_argument('-ga', action="store_true", help="Use genetic annotation data as input regions (e.g. TSS, TTS, exons and introns) instead of the BED files in the input matrix.")
    parser_heatmap.add_argument('-t', metavar='  ', default='heatmap', help=helptitle)
    parser_heatmap.add_argument('-center', metavar='  ', choices=choice_center, default='midpoint', 
                                 help='Define the center to calculate coverage on the regions. Options are: '+', '.join(choice_center) + 
                                 '.(Default:midpoint) The bothend mode will flap the right end region for calculation.')
    parser_heatmap.add_argument('-sort', metavar='  ', type=int, default=None, help='Define the way to sort the signals.'+
                                'Default is no sorting at all, the signals arrange in the order of their position; '+
                                '"0" is sorting by the average ranking of all signals; '+
                                '"1" is sorting by the ranking of 1st column; "2" is 2nd and so on... ')
    parser_heatmap.add_argument('-s', metavar='  ', default='None', help=helpsort + " (Default:None)")
    parser_heatmap.add_argument('-g', metavar='  ', default='regions', help=helpgroup + " (Default:regions)")
    parser_heatmap.add_argument('-c', metavar='  ', default='reads', help=helpcolor + " (Default:reads)")
    parser_heatmap.add_argument('-e', metavar='  ', type=int, default=2000, help='Define the extend length of interested region for plotting.(Default:2000)')
    parser_heatmap.add_argument('-rs', metavar='  ', type=int, default=200, help='Define the readsize for calculating coverage.(Default:200)')
    parser_heatmap.add_argument('-ss', metavar='  ', type=int, default=50, help='Define the stepsize for calculating coverage.(Default:50)')
    parser_heatmap.add_argument('-bs', metavar='  ', type=int, default=100, help='Define the binsize for calculating coverage.(Default:100)')
    parser_heatmap.add_argument('-organism', metavar='  ', default='hg19', help='Define the organism. (Default: hg19)')
    parser_heatmap.add_argument('-color', action="store_true", help=helpDefinedColot)
    parser_heatmap.add_argument('-log', action="store_true", help='Set colorbar in log scale.')
    parser_heatmap.add_argument('-mp', action="store_true", help="Perform multiprocessing for faster computation.")
    parser_heatmap.add_argument('-show', action="store_true", help='Show the figure in the screen.')
    parser_heatmap.add_argument('-table', action="store_true", help='Store the tables of the figure in text format.')
    
    ################### Integration ##########################################
    parser_integration = subparsers.add_parser('integration', help='Provides some tools to deal with experimental matrix or other purposes.')
    parser_integration.add_argument('-ihtml', action="store_true", help='Integrate all the html files within the given directory and generate index.html for all plots.')
    parser_integration.add_argument('-l2m', help='Convert a given file list in txt format into a experimental matrix.')
    parser_integration.add_argument('-o', help='Define the folder of the output file.') 
    ################### Parsing the arguments ################################
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    elif len(sys.argv) == 2: 
        # retrieve subparsers from parser
        subparsers_actions = [action for action in parser._actions if isinstance(action, argparse._SubParsersAction)]
        # there will probably only be one subparser_action,but better save than sorry
        for subparsers_action in subparsers_actions:
            # get all subparsers and print help
            for choice, subparser in subparsers_action.choices.items():
                if choice == sys.argv[1]:
                    print("\nYou need more arguments.")
                    print("\nSubparser '{}'".format(choice))        
                    subparser.print_help()
        sys.exit(1)
    else:
        args = parser.parse_args()
        if not args.o:
            print("** Error: Please define the output directory (-o).")
            sys.exit(1)
        
        t0 = time.time()
        # Normalised output path
        args.o = os.path.normpath(os.path.join(dir,args.o))
        check_dir(args.o)
        check_dir(os.path.join(args.o, args.t))
        
        # Input parameters dictionary
        parameter = []
        parameter.append("Time: " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
        parameter.append("User: "******"\nCommand:\n\t$ " + " ".join(sys.argv))

        #################################################################################################
        ##### Main #####################################################################################
        #################################################################################################

        ################### Projection test ##########################################
        if args.mode == 'projection':
            # Fetching reference and query EM
            print2(parameter, "\n############# Projection Test #############")
            print2(parameter, "\tReference:        "+args.r)
            print2(parameter, "\tQuery:            "+args.q)
            print2(parameter, "\tOutput directory: "+os.path.basename(args.o))
            print2(parameter, "\tExperiment title: "+args.t)

            projection = Projection( args.r, args.q )
            projection.group_refque(args.g)
            projection.colors( args.c, args.color )
            if args.bg: projection.background(args.bg)
            if args.union: 
                projection.ref_union()
                projection.projection_test(organism = args.organism)
                print2(parameter, "\tTaking intersect of references as the background. ")
            else:
                projection.projection_test(organism = args.organism)
            
            # generate pdf
            projection.plot(args.log)
            output(f=projection.fig, directory = args.o, folder = args.t, filename="projection_test",
                   extra=plt.gci(),pdf=True,show=args.show)
            
            # generate html 
            projection.gen_html(args.o, args.t, args=args)
            
            if args.table:
                projection.table(directory = args.o, folder = args.t)
                
            print("\nAll related files are saved in:  "+ os.path.join(os.path.basename(args.o),args.t))
            t1 = time.time()
            print2(parameter,"\nTotal running time is : " + str(datetime.timedelta(seconds=round(t1-t0))))
            output_parameters(parameter, directory = args.o, folder = args.t, filename="parameters.txt")
            copy_em(em=args.r, directory=args.o, folder=args.t, filename="reference_experimental_matrix.txt")
            copy_em(em=args.q, directory=args.o, folder=args.t, filename="query_experimental_matrix.txt")
            list_all_index(path=args.o)
            
        ################### Intersect Test ##########################################
        if args.mode == 'intersect':
            print2(parameter, "\n############ Intersection Test ############")
            print2(parameter, "\tReference:        "+args.r)
            print2(parameter, "\tQuery:            "+args.q)
            print2(parameter, "\tOutput directory: "+os.path.basename(args.o))
            print2(parameter, "\tExperiment title: "+args.t)
            # Fetching reference and query EM
            inter = Intersect(args.r,args.q, mode_count=args.m, organism=args.organism)
            # Setting background
            inter.background(args.bg)
            # Grouping
            inter.group_refque(args.g)
            
            # Extension
            if args.ex == 0: pass
            elif args.ex > 0: inter.extend_ref(args.ex)
            elif args.ex < 0: 
                print("\n**** extension percentage(-ex) should be positive value, not negative.\n")
                sys.exit(1)
            
            inter.colors(args.c, args.color)
            inter.count_intersect(threshold=args.tc)
            
            # generate pdf
            inter.barplot(logt=args.log)
            output(f=inter.bar, directory = args.o, folder = args.t, filename="intersection_bar",
                   extra=plt.gci(), pdf=True,show=args.show)
            inter.stackedbar()
            output(f=inter.sbar, directory = args.o, folder = args.t, filename="intersection_stackedbar",
                   extra=plt.gci(), pdf=True,show=args.show)
            inter.barplot(logt=args.log, percentage=True)
            output(f=inter.bar, directory = args.o, folder = args.t, filename="intersection_barp",
                   extra=plt.gci(), pdf=True,show=args.show)
            
            if args.stest > 0:
                inter.stest(repeat=args.stest,threshold=args.tc)
            
            # generate html
            inter.gen_html(args.o, args.t, align=50, args=args)
            
            t1 = time.time()
            print2(parameter, "\nAll related files are saved in:  "+ os.path.join(os.path.basename(args.o),args.t))
            print2(parameter,"\nTotal running time is : " + str(datetime.timedelta(seconds=round(t1-t0))))
            output_parameters(parameter, directory = args.o, folder = args.t, filename="parameters.txt")
            copy_em(em=args.r, directory=args.o, folder=args.t, filename="reference_experimental_matrix.txt")
            copy_em(em=args.q, directory=args.o, folder=args.t, filename="query_experimental_matrix.txt")
            list_all_index(path=args.o)
            ################### Jaccard test ##########################################
        if args.mode == "jaccard":
            """Return the jaccard test of every possible comparisons between two ExperimentalMatrix. 
            
            Method:
            The distribution of random jaccard index is calculated by randomizing query for given times. 
            Then, we compare the real jaccard index to the distribution and formulate p-value as 
            p-value = (# random jaccard > real jaccard)/(# random jaccard)
            
            """
            print("\n############## Jaccard Test ###############")
            jaccard = Jaccard(args.reference,args.query)
            jaccard.group_refque(args.g)
            jaccard.colors(args.c, args.color)
            
            # jaccard test
            jaccard.jaccard_test(args.runtime, args.organism)
            parameter = parameter + jaccard.parameter
            t1 = time.time()
            # ploting and generate pdf
            jaccard.plot(logT=args.nlog)
            for i,f in enumerate(jaccard.fig):
                output(f=f, directory = args.o, folder = args.title, filename="jaccard_test"+str(i+1),extra=plt.gci(),pdf=True,show=args.show)
            # generate html
            jaccard.gen_html(args.o, args.title)
            
            if args.table:
                jaccard.table(directory = args.o, folder = args.title)
            
            print("\nAll related files are saved in:  "+ os.path.join(dir,args.o,args.title))
            print2(parameter,"\nTotal running time is : " + str(datetime.timedelta(seconds=round(t1-t0))))
            output_parameters(parameter, directory = args.o, folder = args.title, filename="parameters.txt")
            copy_em(em=args.reference, directory=args.o, folder=args.title, filename="Reference_experimental_matrix.txt")
            copy_em(em=args.query, directory=args.o, folder=args.title, filename="Query_experimental_matrix.txt")
            list_all_index(path=args.o)

        ################### Combinatorial Test ##########################################
        if args.mode == 'combinatorial':
            print("\n############ Combinatorial Test ############")
            # Fetching reference and query EM
            comb = Combinatorial(args.reference,args.query, mode_count=args.m, organism=args.organism)
            # Setting background
            inter.background(args.bg)
            # Grouping
            inter.group_refque(args.g)
            
            # Extension
            if args.ex == 0: pass
            elif args.ex > 0: inter.extend_ref(args.ex)
            elif args.ex < 0: 
                print("\n**** extension percentage(-ex) should be positive value, not negative.\n")
                sys.exit(1)
            # Combinatorial 
            print2(parameter, "Generating all combinatorial regions for further analysis...")
            inter.combinatorial()
            inter.count_intersect(threshold=args.tc, frequency=True)
            
            # generate pdf
            inter.colors_comb()
            #inter.barplot(args.log)
            #output(f=inter.bar, directory = args.output, folder = args.title, filename="intersection_bar",extra=plt.gci(),pdf=True,show=args.show)
            #if args.stackedbar:
            #inter.colors(args.c, args.color,ref_que = "ref")
            inter.comb_stacked_plot()
            output(f=inter.sbar, directory = args.o, folder = args.title, filename="intersection_stackedbar",extra=plt.gci(),pdf=True,show=args.show)
            #if args.lineplot:
            #    inter.comb_lineplot()
            if args.stest > 0:
                inter.stest(repeat=args.stest,threshold=args.tc)
            # generate html
            inter.gen_html_comb(args.o, args.title, align=50)
            
            parameter = parameter + inter.parameter
            t1 = time.time()
            print("\nAll related files are saved in:  "+ os.path.join(dir,args.o,args.title))
            print2(parameter,"\nTotal running time is : " + str(datetime.timedelta(seconds=round(t1-t0))))
            output_parameters(parameter, directory = args.o, folder = args.title, filename="parameters.txt")
            copy_em(em=args.reference, directory=args.o, folder=args.title, filename="Reference_experimental_matrix.txt")
            copy_em(em=args.query, directory=args.o, folder=args.title, filename="Query_experimental_matrix.txt")
            list_all_index(path=args.o)


        ################### Boxplot ##########################################
        if args.mode == 'boxplot':
            print("\n################# Boxplot #################")
            boxplot = Boxplot(args.input, title=args.t, df=args.df)
            
            print2(parameter,"\nStep 1/5: Combining all regions")
            boxplot.combine_allregions()
            print2(parameter,"    " + str(len(boxplot.all_bed)) + " regions from all bed files are combined.")
            t1 = time.time()
            print2(parameter,"    --- finished in {0} secs\n".format(round(t1-t0)))
            
            # Coverage of reads on all_bed
            print2(parameter,"Step 2/5: Calculating coverage of each bam file on all regions")
            boxplot.bedCoverage() 
            t2 = time.time()
            print2(parameter,"    --- finished in {0} (H:M:S)\n".format(datetime.timedelta(seconds=round(t2-t1))))
            
            # Quantile normalization
            print2(parameter,"Step 3/5: Quantile normalization of all coverage table")
            if args.nqn:
                print2(parameter,"    No quantile normalization.")
                boxplot.norm_table = boxplot.all_table
            else: boxplot.quantile_normalization()
            t3 = time.time()
            print2(parameter,"    --- finished in {0} secs\n".format(round(t3-t2)))
            
            # Generate individual table for each bed
            print2(parameter,"Step 4/5: Constructing different tables for box plot")
            boxplot.tables_for_plot()
            if args.table: boxplot.print_plot_table(directory = args.o, folder = args.t)
            t4 = time.time()
            print2(parameter,"    --- finished in {0} secs\n".format(round(t4-t3)))
            
            # Plotting
            print2(parameter,"Step 5/5: Plotting")
            boxplot.group_tags(groupby=args.g, sortby=args.s, colorby=args.c)
            
            boxplot.group_data(directory = args.o, folder = args.t, log=args.nlog)
            boxplot.color_map(colorby=args.c, definedinEM=args.color)
            boxplot.plot(title=args.t, logT=args.nlog, sy=args.sy, ylim=args.ylim)
            #if args.table: boxplot.print_table(directory=args.output, folder=args.title)
            output(f=boxplot.fig, directory = args.o, folder = args.t, filename="boxplot",extra=plt.gci(),pdf=True,show=args.show)
            # HTML
            boxplot.gen_html(args.o, args.t, align = 50)
            t5 = time.time()
            print2(parameter,"    --- finished in {0} secs\n".format(round(t5-t4)))
            print2(parameter,"Total running time is: " + str(datetime.timedelta(seconds=round(t5-t0))) + " (H:M:S)\n")
            print("\nAll related files are saved in:  "+ os.path.join(dir,args.o,args.t))
            output_parameters(parameter, directory = args.o, folder = args.t, filename="parameters.txt")
            copy_em(em=args.input, directory=args.o, folder=args.t)
            list_all_index(path=args.o)

        ################### Lineplot #########################################
        if args.mode == 'lineplot':
            if args.sy and args.sx:
                print("** Err: -sy and -sx cannot be used simutaneously.")
                sys.exit(1)

            print("\n################ Lineplot #################")
            # Read experimental matrix
            t0 = time.time()
            if "reads" not in (args.g, args.c, args.s):
                print("Please add 'reads' tag as one of grouping, sorting, or coloring argument.")
                sys.exit(1)
            if "regions" not in (args.g, args.c, args.s):
                print("Please add 'regions' tag as one of grouping, sorting, or coloring argument.")
                sys.exit(1)
            print2(parameter, "Parameters:\tExtend length:\t"+str(args.e))
            print2(parameter, "\t\tRead size:\t"+str(args.rs))
            print2(parameter, "\t\tBin size:\t"+str(args.bs))
            print2(parameter, "\t\tStep size:\t"+str(args.ss))
            print2(parameter, "\t\tCenter mode:\t"+str(args.center+"\n"))
            
            lineplot = Lineplot(EMpath=args.input, title=args.t, annotation=args.ga, 
                                organism=args.organism, center=args.center, extend=args.e, rs=args.rs, 
                                bs=args.bs, ss=args.ss, df=args.df)
            # Processing the regions by given parameters
            print2(parameter, "Step 1/3: Processing regions by given parameters")
            lineplot.relocate_bed()
            t1 = time.time()
            print2(parameter, "    --- finished in {0} secs".format(str(round(t1-t0))))
            
            if args.mp: print2(parameter, "\nStep 2/3: Calculating the coverage to all reads and averaging with multiprocessing ")
            else: print2(parameter, "\nStep 2/3: Calculating the coverage to all reads and averaging")
            lineplot.group_tags(groupby=args.g, sortby=args.s, colorby=args.c)
            lineplot.gen_cues()
            lineplot.coverage(sortby=args.s, mp=args.mp)
            t2 = time.time()
            print2(parameter, "    --- finished in {0} (H:M:S)".format(str(datetime.timedelta(seconds=round(t2-t1)))))
            
            # Plotting
            print2(parameter, "\nStep 3/3: Plotting the lineplots")
            lineplot.colormap(colorby = args.c, definedinEM = args.color)
            lineplot.plot(groupby=args.g, colorby=args.c, output=args.o, printtable=args.table, sy=args.sy, sx=args.sx)
            output(f=lineplot.fig, directory = args.o, folder = args.t, filename="lineplot",extra=plt.gci(),pdf=True,show=args.show)
            lineplot.gen_html(args.o, args.t)
            t3 = time.time()
            print2(parameter, "    --- finished in {0} secs".format(str(round(t3-t2))))
            print2(parameter, "\nTotal running time is : " + str(datetime.timedelta(seconds=round(t3-t0))) + "(H:M:S)\n")
            print("\nAll related files are saved in:  "+ os.path.join(dir,args.o,args.t))
            output_parameters(parameter, directory = args.o, folder = args.t, filename="parameters.txt")
            copy_em(em=args.input, directory=args.o, folder=args.t)
            list_all_index(path=args.o)

        ################### Heatmap ##########################################
        if args.mode=='heatmap':
            print("\n################# Heatmap #################")
            # Most part of heat map are the same as lineplot, so it share the same class as lineplot
            # Read experimental matrix
            t0 = time.time()
            if "reads" not in (args.g, args.c, args.s):
                print("Please add 'reads' tag as one of grouping, sorting, or coloring argument.")
                sys.exit(1)
            if "regions" not in (args.g, args.c, args.s):
                print("Please add 'regions' tag as one of grouping, sorting, or coloring argument.")
                sys.exit(1)
            print2(parameter, "Parameters:\tExtend length:\t"+str(args.e))
            print2(parameter, "\t\tRead size:\t"+str(args.rs))
            print2(parameter, "\t\tBin size:\t"+str(args.bs))
            print2(parameter, "\t\tStep size:\t"+str(args.ss))
            print2(parameter, "\t\tCenter mode:\t"+str(args.center+"\n"))
        
            lineplot = Lineplot(EMpath=args.input, title=args.t, annotation=args.ga, 
                                organism=args.organism, center=args.center, extend=args.e, rs=args.rs, 
                                bs=args.bs, ss=args.ss, df=False)
            # Processing the regions by given parameters
            print2(parameter, "Step 1/4: Processing regions by given parameters")
            lineplot.relocate_bed()
            t1 = time.time()
            print2(parameter, "    --- finished in {0} secs".format(str(round(t1-t0))))
            
            if args.mp: print2(parameter, "\nStep 2/4: Calculating the coverage to all reads and averaging with multiprocessing ")
            else: print2(parameter, "\nStep 2/4: Calculating the coverage to all reads and averaging")
            lineplot.group_tags(groupby=args.g, sortby=args.s, colorby=args.c)
            lineplot.gen_cues()
            lineplot.coverage(sortby=args.s, heatmap=True, logt=args.log, mp=args.mp)
            t2 = time.time()
            print2(parameter, "    --- finished in {0} (h:m:s)".format(str(datetime.timedelta(seconds=round(t2-t1)))))
            
            # Sorting 
            print2(parameter, "\nStep 3/4: Sorting the data for heatmap")
            lineplot.hmsort(sort=args.sort)
            t3 = time.time()
            print2(parameter, "    --- finished in {0} (h:m:s)".format(str(datetime.timedelta(seconds=round(t3-t2)))))
            
            # Plotting
            print2(parameter, "\nStep 4/4: Plotting the heatmap")
            lineplot.hmcmlist(colorby = args.c, definedinEM = args.color)
            lineplot.heatmap(args.log)
            for i, name in enumerate(lineplot.hmfiles):
                output(f=lineplot.figs[i], directory = args.o, folder = args.t, filename=name,pdf=True,show=args.show)
            lineplot.gen_htmlhm(args.o, args.t)
            t4 = time.time()
            print2(parameter, "    --- finished in {0} secs".format(str(round(t4-t3))))
            print2(parameter, "\nTotal running time is : " + str(datetime.timedelta(seconds=round(t4-t0))) + "(H:M:S)\n")
            print("\nAll related files are saved in:  "+ os.path.join(dir,args.o,args.t))
            output_parameters(parameter, directory = args.o, folder = args.t, filename="parameters.txt")
            copy_em(em=args.input, directory=args.o, folder=args.t)
            list_all_index(path=args.o)
Exemplo n.º 37
0
    def boxplot(self, dir, matrix, sig_region, truecounts, sig_boolean, ylabel, filename):
        """Generate the visualized plot"""
        tick_size = 8
        label_size = 9

        f, ax = plt.subplots(1, 1, dpi=300, figsize=(6, 4))
        max_y = int(max([matrix.max()] + truecounts) * 1.1) + 1
        min_y = max(int(matrix.min() * 0.9) - 1, 0)

        # Significant DBD
        rect = patches.Rectangle(xy=(1, 0), width=0.8, height=max_y, facecolor=sig_color,
                                 edgecolor="none", alpha=0.5, lw=None, label="Significant DBD")
        for i, r in enumerate(sig_boolean):
            if r:
                rect = patches.Rectangle(xy=(i + 0.6, min_y), width=0.8, height=max_y, facecolor=sig_color,
                                         edgecolor="none", alpha=0.5, lw=None, label="Significant DBD")
                ax.add_patch(rect)

        # Plotting

        bp = ax.boxplot(matrix.transpose(), notch=False, sym='o', vert=True,
                        whis=1.5, positions=None, widths=None,
                        patch_artist=True, bootstrap=None)
        z = 10
        plt.setp(bp['boxes'], color=nontarget_color, alpha=1, edgecolor="none")
        plt.setp(bp['whiskers'], color='black', linestyle='-', linewidth=1, zorder=z, alpha=1)
        plt.setp(bp['fliers'], markerfacecolor='gray', color='white', alpha=0.3, markersize=1.8, zorder=z)
        plt.setp(bp['caps'], color='white', zorder=-1)
        plt.setp(bp['medians'], color='black', linewidth=1.5, zorder=z + 1)

        # Plot target regions
        plt.plot(range(1, len(self.rbss) + 1), truecounts, markerfacecolor=target_color,
                 marker='o', markersize=5, linestyle='None', markeredgecolor="white", zorder=z + 5)

        ax.set_xlabel(self.rna_name + " DNA Binding Domains", fontsize=label_size)
        ax.set_ylabel(ylabel, fontsize=label_size, rotation=90)

        ax.set_ylim([min_y, max_y])
        ax.yaxis.set_major_locator(MaxNLocator(integer=True))

        ax.set_xticklabels([dbd.str_rna(pa=False) for dbd in self.rbss], rotation=35,
                           ha="right", fontsize=tick_size)
        for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(tick_size)

        for spine in ['top', 'right']:
            ax.spines[spine].set_visible(False)
        ax.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='on')
        ax.tick_params(axis='y', which='both', left='on', right='off', labelbottom='off')

        # Legend
        dot_legend, = plt.plot([1, 1], color=target_color, marker='o', markersize=5, markeredgecolor="white",
                               linestyle='None')
        bp_legend, = plt.plot([1, 1], color=nontarget_color, linewidth=6, alpha=1)

        ax.legend([dot_legend, bp_legend, rect], ["Target Regions", "Non-target regions", "Significant DBD"],
                  bbox_to_anchor=(0., 1.02, 1., .102), loc=2, mode="expand", borderaxespad=0.,
                  prop={'size': 9}, ncol=3, numpoints=1)
        bp_legend.set_visible(False)
        dot_legend.set_visible(False)

        # f.tight_layout(pad=1.08, h_pad=None, w_pad=None)
        f.savefig(os.path.join(dir, filename + ".png"), facecolor='w', edgecolor='w',
                  bbox_extra_artists=(plt.gci()), bbox_inches='tight', dpi=300)
        # PDF
        for tick in ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(12)
        for tick in ax.yaxis.get_major_ticks():
            tick.label.set_fontsize(12)
        ax.xaxis.label.set_size(14)
        ax.yaxis.label.set_size(14)

        pp = PdfPages(os.path.join(dir, filename + '.pdf'))
        pp.savefig(f, bbox_extra_artists=(plt.gci()), bbox_inches='tight')
        pp.close()
# make the spectrogram plot
ax1 = plt.subplot(311)
plt.pcolor(t_spec, freqs, 10*np.log10(spec_PSDperBin))  # dB re: 1 uV
plt.clim(25-5+np.array([-40, 0]))
plt.xlim(t_sec[0], t_sec[-1])
if (t_lim_sec[2-1] != 0):
    plt.xlim(t_lim_sec)
plt.ylim(f_lim_Hz)
plt.xlabel('Time (sec)')
plt.ylabel('Frequency (Hz)')
plt.title(fname[12:])


# add annotation for FFT Parameters
cl=plt.gci().get_clim();
ax1.text(0.025, 0.95,
        "NFFT = " + str(NFFT) + "\nfs = " + str(int(fs_Hz)) + " Hz\nClim = [" + str(cl[0]) + ", " + str(cl[1]) + "]",
        transform=ax1.transAxes,
        verticalalignment='top',
        horizontalalignment='left',
        backgroundcolor='w',
        size='smaller')


# find spectra that are in our time span
foo_spec = spec_PSDperBin
bool_ind = np.zeros(t_spec.shape, dtype='bool')
for lim_sec in alpha_lim_sec:
    bool_ind = bool_ind | ((t_spec >= lim_sec[0]) & (t_spec <= lim_sec[1]))
foo_spec = foo_spec[:, bool_ind]
Exemplo n.º 39
0
def lineplot(txp, rnalen, rnaname, dirp, sig_region, cut_off, log, ylabel, linelabel, 
             filename, ac=None, showpa=False, exons=None):
    # Plotting
    f, ax = plt.subplots(1, 1, dpi=300, figsize=(6,4))
    
    # Extract data points
    x = range(rnalen)
    #print(rnalen)
    if log:
        all_y = [1] * rnalen
        p_y = [1] * rnalen
        a_y = [1] * rnalen
    else:
        all_y = [0] * rnalen
        p_y = [0] * rnalen
        a_y = [0] * rnalen

    txp.remove_duplicates_by_dbs()
    for rd in txp:
        #print(str(rd.rna.initial), str(rd.rna.final))
        if rd.rna.orientation == "P":
            for i in range(rd.rna.initial, rd.rna.final):
                p_y[i] += 1
                all_y[i] += 1
        if rd.rna.orientation == "A":
            for i in range(rd.rna.initial, rd.rna.final):
                a_y[i] += 1
                all_y[i] += 1
    # Log
    if log:
        all_y = numpy.log(all_y)
        p_y = numpy.log(p_y)
        a_y = numpy.log(a_y)
        max_y = max(all_y)+0.5
        min_y = 1
        ylabel += "(log10)"
    else:
        max_y = float(max(all_y) * 1.1)
        min_y = 0

    if ac:
        min_y = float(max_y*(-0.09))
    
    
    # Plotting
    for rbs in sig_region:
        rect = patches.Rectangle(xy=(rbs.initial,0), width=len(rbs), height=max_y, facecolor=sig_color, 
                                 edgecolor="none", alpha=0.5, lw=None, label="Significant DBD")
        ax.add_patch(rect)
    
    lw = 1.5
    if showpa:
        ax.plot(x, all_y, color=target_color, alpha=1, lw=lw, label="Parallel + Anti-parallel")
        ax.plot(x, p_y, color="purple", alpha=1, lw=lw, label="Parallel")
        ax.plot(x, a_y, color="dimgrey", alpha=.8, lw=lw, label="Anti-parallel")
    else:
        ax.plot(x, all_y, color="mediumblue", alpha=1, lw=lw, label=linelabel)

    # RNA accessbility
    if ac:
        n_value = read_ac(ac, cut_off, rnalen=rnalen)
        drawing = False
        for i in x:
            if n_value[i] > 0:
                if drawing:
                    continue
                else:
                    last_i = i
                    drawing = True
            elif drawing:
                pac = ax.add_patch(patches.Rectangle((last_i, min_y), i-last_i, -min_y,
                                   hatch='///', fill=False, snap=False, linewidth=0, label="RNA accessibility"))
                drawing = False
            else:
                continue

    # Legend
    handles, labels = ax.get_legend_handles_labels()
    legend_h = []
    legend_l = []
    for uniqlabel in uniq(labels):
        legend_h.append(handles[labels.index(uniqlabel)])
        legend_l.append(uniqlabel)
    ax.legend(legend_h, legend_l, 
              bbox_to_anchor=(0., 1.02, 1., .102), loc=2, mode="expand", borderaxespad=0., 
              prop={'size':9}, ncol=3)

    # XY axis
    ax.set_xlim(left=0, right=rnalen )
    ax.set_ylim( [min_y, max_y] ) 
    for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(9) 
    for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(9) 
    ax.set_xlabel(rnaname+" sequence (bp)", fontsize=9)
    
    ax.set_ylabel(ylabel,fontsize=9, rotation=90)
    
    if None:
        if exons and len(exons) > 1:
            w = 0
            i = 0
            h = (max_y - min_y)*0.02

            for exon in exons:
                l = abs(exon[2] - exon[1])
                
                #print([i,l,w])
                #ax.axvline(x=w, color="gray", alpha=0.5, zorder=100)
                if i % 2 == 0:
                    rect = matplotlib.patches.Rectangle((w,max_y-h),l,h, color="moccasin")
                else:
                    rect = matplotlib.patches.Rectangle((w,max_y-h),l,h, color="gold")
                ax.add_patch(rect)
                i += 1
                w += l
            ax.text(rnalen*0.01, max_y-2*h, "exon boundaries", fontsize=5, color='black')

    f.tight_layout(pad=1.08, h_pad=None, w_pad=None)

    f.savefig(os.path.join(dirp, filename), facecolor='w', edgecolor='w',  
              bbox_extra_artists=(plt.gci()), bbox_inches='tight', dpi=300)
    # PDF
    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(12) 
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(12)
    ax.xaxis.label.set_size(14)
    ax.yaxis.label.set_size(14) 
    ax.legend(legend_h, legend_l, 
              bbox_to_anchor=(0., 1.02, 1., .102), loc=2, mode="expand", borderaxespad=0., 
              prop={'size':12}, ncol=3)
    pp = PdfPages(os.path.splitext(os.path.join(dirp,filename))[0] +'.pdf')
    pp.savefig(f,  bbox_inches='tight') # bbox_extra_artists=(plt.gci()),
    pp.close()
Exemplo n.º 40
0
	def handle(self, *args, **options):
		try:
			#DEG_STEP = 0.0005
			# 0.1 deg longtitude = 9 km 
			# 0.1 deg latitude = 11 km 
			# wave glider speed : 0.9 m/s 
			# advection speed 0.3 m/s
			# last coord 37.0598171042004	-122.460206988621 
			LAT_MAX = 37.15981
			LON_MIN = -122.56020
			LAT_MIN = 36.69
			NUM_TICKS_PER_MIN = 6
			plotInTicks = 6
			WG_LAT_DEG_STEP_PER_MIN = 0.0299/NUM_TICKS_PER_MIN
			WG_LON_DEG_STEP_PER_MIN = 0.034/NUM_TICKS_PER_MIN
			collided = 0
			STEP_DIST_MILE = 1.00662;
			  
			#DEG_STEP = 0.005 
			pollWindowInSecs = 10;
			endDate = dt.now()
			startDate = endDate - timedelta(0,pollWindowInSecs)
			lastFew = VoteWP.objects.filter(date__gt=startDate, date__lt=endDate)

			latAvg = 0.0
			lonAvg = 0.0
			newLat = 0.0
			newLon = 0.0
			arr = numpy.ndarray(shape=(len(lastFew),2), dtype=float, order='F')
			ctr = 0;
			for vote_ in lastFew:
				latitude_ = vote_.latitude
				longitude_ = vote_.longitude
				if not (latitude_ < 1):
					
					arr[ctr,1] = latitude_
					arr[ctr,0] = longitude_
					ctr = ctr+1
				else:
					print 'zero lat detected'
			
			winners = [];
			scoreMax = -1;
			centerMax = []
			indMax = -1;
			for i in range(2,5):
				res, idx = kmeans2(arr,i)
				[c,d] = vq(arr, res)
				ua,uind=numpy.unique(idx,return_inverse=True)
				count=numpy.bincount(uind)
				maxCountInd = numpy.argmax(count)
				maxCount = max(count)
				center= [res[ua[maxCountInd],0],res[ua[maxCountInd],1]];
				score = pow(maxCount,2)*(1/numpy.sum(d))*(1/sum(scipy.spatial.distance.pdist(res,'euclidean')))
				if score > scoreMax:
					scoreMax = score
					centerMax = center
					indMax = i
				
			#res, idx = kmeans2(arr,indMax)

			
			# convert groups to rbg 3-tuples.
			#colors = ([([0,0,0],[0,1,1])[i] for i in idx])

			# show sizes and colors. each color belongs in diff cluster.
			#pylab.scatter(arr[:,1],arr[:,0],s=20, c=colors)
			#plt.plot(res[1,1],res[1,0])
			#pylab.savefig('/var/www/cinaps/jd/clust.png')
			lonAvg = centerMax[0]
			latAvg = centerMax[1]
			latestState = WaveGliderState.objects.latest('time')
			
			
			
			
			bearing_ = self.bearing(math.radians(latestState.latitude), math.radians(latestState.longitude),math.radians(latAvg),math.radians(lonAvg))
			print 'bearing=' + str(bearing_)
			destPoint = VincentyDistance(STEP_DIST_MILE).destination(Point(latestState.latitude,latestState.longitude),  bearing_)
			hour = 0 
			print destPoint
			print destPoint.latitude, destPoint.longitude
			newLat = destPoint.latitude
			newLon = destPoint.longitude
				
			
			
			#if not (self.testIfPointInSafeRegion(newLat,newLon)):
			#	newLat = latestState.latitude
			#	newLon = latestState.longitude
			#	print 'here0'
			#	collided = 0.1
			#else:
			collided = 0.0
			
			lon = 360.0+newLon;
			
			lonInd = int(math.floor((lon-237.2)/(1.3/131)))
			latInd = int(math.floor((newLat-35.6)/(1.7/171)))
				
			dataset = open_url("http://ourocean.jpl.nasa.gov:8080/thredds/dodsC/MBNowcast/mb_das_20120520_mean.nc")
			sst = dataset['temp']
			tempMap = sst.array[hour,0,latInd,lonInd]
			tempVal = tempMap[0,0,0,0]
		
			salt = dataset['salt']
                        salMap = salt.array[hour,0,latInd,lonInd]
                        salVal = salMap[0,0,0,0]

			vmin_,vmax_ = 0,800
		  	trackWindowInSecs = 60*24;
                        endDate = dt.now()
                        startDate = endDate - timedelta(0,trackWindowInSecs)
                        lastFewUpdates = WaveGliderState.objects.filter(time__gt=startDate, time__lt=endDate)
                        lonVec = []
                        latVec = []
                        tempVec = []
                        salVec = []
                        chlVec = []
                        for wgUpdate in lastFewUpdates:
                                lonVec = numpy.append(lonVec,wgUpdate.longitude)
                                latVec = numpy.append(latVec, wgUpdate.latitude)
                                tempVec = numpy.append(tempVec, wgUpdate.temp)
                                salVec = numpy.append(salVec, wgUpdate.sal)
                                chlVec = numpy.append(chlVec, wgUpdate.chl)


                        plt.scatter(lonVec,latVec,10,chlVec,marker='s', edgecolors='none')
                        plt.xlim([LON_MIN,-121.79])
                        plt.ylim([LAT_MIN,LAT_MAX])
			fig = plt.gcf()
                        a=fig.gca()
                        a.set_frame_on(False)
                        a.set_xticks([]); a.set_yticks([])
                        plt.axis('off')
                        plt.clim(vmin_, vmax_)
                        plt.savefig('science_data_layer_chl.png',transparent=True,pad_inches=0,bbox_inches='tight',frameon=False)
                        os.system('cp /home/jd/science_data_layer_chl.png /var/www/cinaps/')

			plt.scatter(lonVec,latVec,10,salVec,marker='s', edgecolors='none')
                        plt.xlim([LON_MIN,-121.79])
                        plt.ylim([LAT_MIN,LAT_MAX])
			fig = plt.gcf()
                        a=fig.gca()
                        a.set_frame_on(False)
                        a.set_xticks([]); a.set_yticks([])
                        plt.axis('off')
			salvmin_, salvmax_ = plt.gci().get_clim()
                        plt.clim(salvmin_, salvmax_)
                        plt.savefig('science_data_layer_sal.png',transparent=True,pad_inches=0,bbox_inches='tight',frameon=False)
                        os.system('cp /home/jd/science_data_layer_sal.png /var/www/cinaps/')
			
			fig.clf()
                        ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
                        cmap = mpl.cm.jet
                        norm = mpl.colors.Normalize(vmin=salvmin_, vmax=salvmax_)

                        cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap,norm=norm, orientation='horizontal')
                        cbytick_obj = plt.getp(cb1.ax.axes, 'xticklabels')
                        plt.setp(cbytick_obj, color='w')
                        plt.setp(cbytick_obj, fontsize=18)
			cb1.set_label('Score')
                        plt.savefig('/var/www/cinaps/jd/simoverlaycolorbarsal.png',transparent=True,pad_inches=0,bbox_inches='tight',frameon=False)
                        fig.clf()


			#chlVal = -38.3894999141604 -0.700239885543149*tempVal + 1.62938943619475*salVal	
			#doc = parser.parse('/home/jd/BatchGeo.kml').getroot()
                        #data = doc.Document.Placemark.LineString.coordinates.text.split(' ')
			f = open('/home/jd/plumePoints.txt')
			data = [line.rstrip('\n') for line in f]
			numPoints = len(data)
			print 'start science field generation...'
			fig = plt.gcf()
			timestamp=calendar.timegm(dt.now().utctimetuple())
			tick =  (timestamp/10)%numPoints
			print tick
			ind1 = tick
                        coord = data[ind1].split('\t')
                        coord0 = data[0].split('\t')
			fieldLon = float(coord[2])
                        fieldLat = float(coord[1])
                        #fieldLon2 = -121.9462130707369
                        #fieldLat2 = 36.87829987329702
			fieldLon2 = float(coord0[2]) 
                        fieldLat2 = float(coord0[1])

                        fig.clf()
                        pi = math.pi
                        widthScaler1 = 0.01/numPoints
                        widthScaler2 = 0.05/numPoints
                        sig = 0.01
                        sig1 = sig + ind1*widthScaler1
                        sig2 = sig + (numPoints-ind1)*widthScaler2
                        #sig2 = sig - (math.sin((((dt.now().second/60.0)*4*pi)-2*pi))*sig/2)+sig
                        deltaLat = 0.00909
			deltaLon = 0.0111
                        if ind1 % plotInTicks:
                        	x = numpy.arange(LON_MIN,-121.79,deltaLon)
                        	y = numpy.arange(LAT_MIN,LAT_MAX,deltaLat)
                        	X, Y = numpy.meshgrid(x, y)
				Z1 = mlab.bivariate_normal(X, Y, sig1,sig1 ,fieldLon,fieldLat)
				Z2 = mlab.bivariate_normal(X, Y, sig2, sig2,fieldLon2,fieldLat2)
				#CS1 = plt.contour(X, Y, Z1+Z2)
				CS2 = plt.pcolor(X, Y, Z1+Z2,alpha=0.5,shading='interp')
				#plt.clabel(CS1,inline=True,fontsize=10)
				plt.xlim([LON_MIN,-121.79])
				plt.ylim([LAT_MIN,LAT_MAX])
				#fig.patch.set_alpha(0.1)
				fig = plt.gcf()
				a=fig.gca()
				a.set_frame_on(False)
				a.set_xticks([]); a.set_yticks([])
				plt.axis('off')
				plt.savefig('/var/www/cinaps/jd/simoverlay.png',transparent=True,pad_inches=0,bbox_inches='tight',frameon=False)
				#vmin_, vmax_ = plt.gci().get_clim()

				fig.clf()
				ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
				cmap = mpl.cm.jet
				norm = mpl.colors.Normalize(vmin=vmin_, vmax=vmax_)

				cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap,norm=norm, orientation='horizontal')
				cbytick_obj = plt.getp(cb1.ax.axes, 'xticklabels')   
				plt.setp(cbytick_obj, color='w')
	                        plt.setp(cbytick_obj, fontsize=18)
				cb1.set_label('Score')		
				plt.savefig('/var/www/cinaps/jd/simoverlaycolorbarchl.png',transparent=True,pad_inches=0,bbox_inches='tight',frameon=False)
				fig.clf()
			chlVal = mlab.bivariate_normal(newLon, newLat, sig1,sig1 ,fieldLon,fieldLat) + mlab.bivariate_normal(newLon, newLat, sig2, sig2,fieldLon2,fieldLat2)
			print 'end science field generation...' 
			print dt.now()
			now_ = dt.now()
			newWgState = WaveGliderState(time = now_, latitude = newLat,longitude = newLon, speed = 0, direction = bearing_,temp = tick+collided,sal = numpy.sum(chlVec)+chlVal,chl = chlVal)
			newWgState.save()
			ballot = BallotWP(time=endDate,latitude = latAvg,longitude=lonAvg		)
        		ballot.save()
			plt.scatter(newWgState.longitude,newWgState.latitude,1,0,edgecolors='none')
                        plt.xlim([-122.51918826855467,-121.75083194531248])
                        plt.ylim([36.61883405115106,36.97622643672235])
                        fig = plt.gcf()
                        a=fig.gca()
                        a.set_frame_on(False)
                        a.set_xticks([]); a.set_yticks([])
                        plt.axis('off')
                        plt.savefig('science_data_layer_none.png',transparent=True,pad_inches=0,bbox_inches='tight',frameon=False)
                        os.system('cp /home/jd/science_data_layer_none.png /var/www/cinaps/')
	

			#fieldLat = 36.839438
			#fieldLon = -122.096538
			
			#fieldLat = newLat
                        #fieldLon = newLon


		except VoteWP.DoesNotExist:
			raise CommandError('Error updating Wave Glider state')
			
		self.stdout.write('Succesfully updated Wave Glider state')
Exemplo n.º 41
0
# Load inversion data
gData = loadG('G_r512_k128_l4.h5', 1)

# Apply the pBASEX algorithm
out = pbasex(folded, gData, make_images=True, alpha=4.1e-5)

# Plot some results
plt.figure(figsize=(12,9))
for i, sample in enumerate(samples):
	plt.subplot(4,5,5*i+1)
	plt.imshow(raw[:,:,i])
	plt.xticks([])
	plt.yticks([])
	plt.ylabel(sample)
	clim = plt.gci().get_clim()
	plt.clim(0,clim[1])
	if i==0:
		plt.title('Raw Image')
	plt.subplot(4,5,5*i+2)
	plt.plot(out['E'], out['IE'][:,i], 'k')
	plt.gca().ticklabel_format(axis='y', style='sci', scilimits=(-2,2))
	if i==3:
		plt.xlabel('Energy (eV)')
	plt.gca().twinx()
	plt.plot(out['E'], out['betas'][:,:,i], '.', markersize=5, alpha=0.6)
	if i==0:
		plt.text(-3, 3.5, 'counts per eV', size='small')
		plt.text(12, 3.5, 'beta', size='small')
		plt.text(3.5, 3.25, 'I(E), ', color='black', size='large')
		plt.text(6, 3.25, 'B2', color=u'#1f77b4', size='large')
        coh_ave_plot = coh_ave_shfl_in_imag - coh_ave_shfl_in_cdtn
        psi_ave_plot = psi_ave_shfl_in_imag - psi_ave_shfl_in_cdtn
    elif plot_type == 'evok':
        coh_ave_plot = coh_ave_shfl_in_cdtn
        psi_ave_plot = psi_ave_shfl_in_cdtn

    h_fig, h_ax = plt.subplots(2,3, figsize=[8,6])
    h_ax = h_ax.ravel()
    for i in range(6):
        plt.axes(h_ax[i])
        pnp.SpectrogramPlot( coh_ave_plot[i,:,:], t_grid, freq, c_lim_style='diverge', t_lim=t_show)
        if i==0:
            plt.xlabel('t')
            plt.ylabel('freq')
    pnp.share_clim(h_ax)
    h_fig.colorbar(plt.gci(), ax=h_ax.tolist())
    plt.suptitle('coherence {}, {}'.format(region_pair_focus, plot_type))
    plt.savefig('./temp_figs/coherence_{}_all_session_{}_{}.png'.format(region_pair_focus, plot_type, block_type))
    plt.savefig('./temp_figs/coherence_{}_all_session_{}_{}.pdf'.format(region_pair_focus,plot_type, block_type))

    _, h_ax = plt.subplots(2,3, figsize=[8,6])
    h_ax = h_ax.ravel()
    for i in range(6):
        plt.axes(h_ax[i])
        pnp.SpectrogramPlot(psi_ave_plot[i,:,:], t_grid, freq, c_lim_style='diverge', t_lim=t_show)
        if i==0:
            plt.xlabel('t')
            plt.ylabel('freq')
    # plt.xlim([-0.1, 0.4])
    pnp.share_clim(h_ax)
    h_fig.colorbar(plt.gci(), ax=h_ax.tolist())
Exemplo n.º 43
0
def plot_mixture(mixture, i=0, j=1, center_style=dict(s=0.15),
                 cmap='nipy_spectral', cutoff=0.0, ellipse_style=dict(alpha=0.3),
                 solid_edge=True, visualize_weights=False):
    '''Plot the (Gaussian) components of the ``mixture`` density as
    one-sigma ellipses in the ``(i,j)`` plane.

    :param center_style:
        If a non-empty ``dict``, plot mean value with the style passed to ``scatter``.

    :param cmap:

        The color map to which components are mapped in order to
        choose their face color. It is unaffected by the
        ``cutoff``. The meaning depends on ``visualize_weights``.

    :param cutoff:
        Ignore components whose weight is below the ``cut off``.

    :param ellipse_style:
        Passed on to define the properties of the ``Ellipse``.

    :param solid_edge:
        Draw the edge of the ellipse as solid opaque line.

    :param visualize_weights:
        Colorize the components according to their weights if ``True``.
        One can do `plt.colorbar()` after this function and the bar allows to read off the weights.
        If ``False``, coloring is based on the component index and the total number of components.
        This option makes it easier to track components by assigning them the same color in subsequent calls to this function.

    '''
    # imports inside the function because then "ImportError" is raised on
    # systems without 'matplotlib' only when 'plot_mixture' is called
    import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib.patches import Ellipse
    from matplotlib.cm import get_cmap

    assert i >= 0 and j >= 0, 'Invalid submatrix specification (%d, %d)' % (i, j)
    assert i != j, 'Identical dimension given: i=j=%d' % i
    assert mixture.dim >= 2, '1D plot not supported'

    cmap = get_cmap(name=cmap)

    if visualize_weights:
        # colors according to weight
        renormalized_component_weights  = np.array(mixture.weights)
        colors = [cmap(k) for k in renormalized_component_weights]
    else:
        # colors according to index
        colors = [cmap(k) for k in np.linspace(0, _max_color, len(mixture.components))]

    mask = mixture.weights >= cutoff

    # plot component means
    means = np.array([c.mu for c in mixture.components])
    x_values = means.T[i]
    y_values = means.T[j]

    for k, w in enumerate(mixture.weights):
        # skip components by hand to retain consistent coloring
        if w < cutoff:
            continue

        cov = mixture.components[k].sigma
        submatrix = np.array([[cov[i,i], cov[i,j]], \
                              [cov[j,i], cov[j,j]]])

        # for idea, check
        # 'Combining error ellipses' by John E. Davis
        correlation = np.array([[1.0, cov[i,j] / np.sqrt(cov[i,i] * cov[j,j])], [0.0, 1.0]])
        correlation[1,0] = correlation[0,1]

        assert abs(correlation[0,1]) <= 1, 'Invalid component %d with correlation %g' % (k, correlation[0, 1])

        ew, ev = np.linalg.eigh(submatrix)
        assert ew.min() > 0, 'Nonpositive eigenvalue in component %d: %s' % (k, ew)

        # rotation angle of major axis with x-axis
        if submatrix[0,0] == submatrix[1,1]:
            theta = np.sign(submatrix[0,1]) * np.pi / 4.
        else:
            theta = 0.5 * np.arctan( 2 * submatrix[0,1] / (submatrix[1,1] - submatrix[0,0]))

        # put larger eigen value on y'-axis
        height = np.sqrt(ew.max())
        width = np.sqrt(ew.min())

        # but change orientation of coordinates if the other is larger
        if submatrix[0,0] > submatrix[1,1]:
            height = np.sqrt(ew.min())
            width = np.sqrt(ew.max())

        # change sign to rotate in right direction
        angle = -theta * 180 / np.pi

        # copy keywords but override some
        ellipse_style_clone = dict(ellipse_style)

        # overwrite facecolor
        ellipse_style_clone['facecolor'] = colors[k]

        ax = plt.gca()

        # need full width/height
        e = Ellipse(xy=(x_values[k], y_values[k]),
                                   width=2*width, height=2*height, angle=angle,
                                   **ellipse_style_clone)
        ax.add_patch(e)

        if solid_edge:
            ellipse_style_clone['facecolor'] = 'none'
            ellipse_style_clone['edgecolor'] = colors[k]
            ellipse_style_clone['alpha'] = 1
            ax.add_patch(Ellipse(xy=(x_values[k], y_values[k]),
                                       width=2*width, height=2*height, angle=angle,
                                       **ellipse_style_clone))

    if center_style:
        plt.scatter(x_values[mask], y_values[mask], **center_style)

    if visualize_weights:
        # to enable plt.colorbar()
        mappable = plt.gci()
        mappable.set_array(mixture.weights)
        mappable.set_cmap(cmap)