Пример #1
0
def setup_axes(myg, num):
    """ create a grid of axes whose layout depends on the aspect ratio of the
    domain """

    L_x = myg.xmax - myg.xmin
    L_y = myg.ymax - myg.ymin

    f = plt.figure(1)

    cbar_title = False

    if L_x > 2 * L_y:
        # we want num rows:
        axes = AxesGrid(f,
                        111,
                        nrows_ncols=(num, 1),
                        share_all=True,
                        cbar_mode="each",
                        cbar_location="top",
                        cbar_pad="10%",
                        cbar_size="25%",
                        axes_pad=(0.25, 0.65),
                        add_all=True,
                        label_mode="L")
        cbar_title = True

    elif L_y > 2 * L_x:
        # we want num columns:  rho  |U|  p  e
        axes = AxesGrid(f,
                        111,
                        nrows_ncols=(1, num),
                        share_all=True,
                        cbar_mode="each",
                        cbar_location="right",
                        cbar_pad="10%",
                        cbar_size="25%",
                        axes_pad=(0.65, 0.25),
                        add_all=True,
                        label_mode="L")

    else:
        # 2-d grid of plots
        ny = int(math.sqrt(num))
        nx = num // ny

        axes = AxesGrid(f,
                        111,
                        nrows_ncols=(nx, ny),
                        share_all=True,
                        cbar_mode="each",
                        cbar_location="right",
                        cbar_pad="2%",
                        axes_pad=(0.65, 0.25),
                        add_all=True,
                        label_mode="L")

    return f, axes, cbar_title
Пример #2
0
def demo_grid_with_single_cbar_log(fig):
    """
    A grid of 2x2 images with a single colorbar and log scaling
    """
    grid = AxesGrid(
        fig,
        111,  # modified to be only subplot
        nrows_ncols=(2, 2),
        axes_pad=0.0,
        share_all=True,
        label_mode="L",
        cbar_location="top",
        cbar_mode="single",
    )

    Z, extent = get_demo_image()
    Z -= np.min(Z)  # modified to make data positive
    for i in range(4):
        im = grid[i].imshow(Z,
                            extent=extent,
                            interpolation="nearest",
                            norm=LogNorm())  # modified to log-scale display
    #plt.colorbar(im, cax = grid.cbar_axes[0])
    grid.cbar_axes[0].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(False)

    # This affects all axes as share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])
    plt.show()
Пример #3
0
def img_pro(path):
    F = plt.figure(1, (15,20))
    grid = AxesGrid(F, 111, nrows_ncols=(4,4), axes_pad=0, label_mode='1')


    for i in range(16):
        char = map_characters[i]
        list = [path+k for k in os.listdir(path) if char in k]
        image = cv2.imread(np.random.choice(list))


        img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        pic = cv2.resize(image, (pic_size,pic_size)).astype('float32')/255
        a = model.predict(pic.reshape(1, pic_size, pic_size, 3))[0]
        actual = char.split('_')[0].title()
        text = sorted(['{:s} : {:.1f}%'.format(map_characters[i].split('_')[0].title(), 100*v) for k,v in enumerate(a)],
                      key=lambda x:float(x.split(':')[1].split('%')[0]), reverse=True)[:3]
        img = cv2.resize(img, (352, 352))
        cv2.rectangle(img, (0,260), (215,352), (255,255,255), -1)
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(img, 'Actual : %s' % actual, (10,280), font, 0.7, (0,0,0), 2, cv2.LINE_AA)
        for k,t in enumerate(text):
            cv2.putText(img, t, (10, 300+k*18), font, 0.65, (0,0,0), 2, cv2.LINE_AA)
        grid[i].imshow(img)
    plt.show()
def demo_right_cbar(fig):
    """
    A grid of 2x2 images. Each row has its own colorbar.
    """
    grid = AxesGrid(
        fig,
        122,  # similar to subplot(122)
        nrows_ncols=(2, 2),
        axes_pad=0.10,
        label_mode="1",
        share_all=True,
        cbar_location="right",
        cbar_mode="edge",
        cbar_size="7%",
        cbar_pad="2%",
    )
    Z, extent = get_demo_image()
    cmaps = [plt.get_cmap("spring"), plt.get_cmap("winter")]
    for i in range(4):
        im = grid[i].imshow(Z,
                            extent=extent,
                            interpolation="nearest",
                            cmap=cmaps[i // 2])
        if i % 2:
            grid.cbar_axes[i // 2].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(True)
        cax.axis[cax.orientation].set_label('Foo')

    # This affects all axes because we set share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])
Пример #5
0
def plot_filter(image_path, layer_name, output_dir):
    base_model = VGG16(weights='imagenet')
    x = load_images([image_path])
    model = Model(input=base_model.input,
                  output=base_model.get_layer(layer_name).output)
    layer_output = model.predict(x)
    side = int(layer_output.shape[-1]**0.5)

    fig = plt.figure()

    grid = AxesGrid(fig,
                    111,
                    nrows_ncols=(side, side),
                    axes_pad=0.0,
                    share_all=True)

    for i in range(side**2):
        im = grid[i].imshow(layer_output[0, :, :, i], interpolation="nearest")
    grid.cbar_axes[0].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(False)

    for ax in grid.axes_all:
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

    output_dir = Path(output_dir)
    fig_file = '{}-{}.pdf'.format(Path(image_path).stem, layer_name)
    plt.savefig(str(output_dir / fig_file))
Пример #6
0
def draw_grid(fig, data_array, x_min, x_max, y_min, y_max):
    """
    A grid of 2x2 images with a single colorbar
    """
    grid = AxesGrid(fig, 111,  # similar to subplot(142)
                    nrows_ncols=(1, 1),
                    axes_pad=0.2,
                    share_all=True,
                    label_mode="L",
                    cbar_location="right",
                    cbar_mode="single",
                    )
    Z = data_array
    extent = (x_min, x_max, y_min, y_max)

    im = grid[0].imshow(Z, extent=extent, interpolation="nearest")
    grid[0].set_aspect(2.0, adjustable='box')

    grid.cbar_axes[0].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(True)

    # This affects all axes as share_all = True.
    grid.axes_llc.set_xticks(range(x_min, x_max, 10))  # Set x-axes sequence
    grid.axes_llc.set_yticks(range(y_min, y_max, 10))
Пример #7
0
def plot(D, i, flux_vol, xlabel, ylabel, name, sigma):

    flux = flux_vol.sum(axis=i)
    flux = nd.gaussian_filter(flux, sigma=(sigma, sigma), order=0)

    fig = plt.figure(figsize=(6.5, 6.0))
    ax = AxesGrid(
        fig,
        111,  # similar to subplot(122)
        nrows_ncols=(1, 1),
        axes_pad=0.0,
        label_mode="1",
        share_all=True,
        cbar_location="right",
        cbar_mode="edge",
        cbar_size="2%",
        cbar_pad="0%",
    )

    neg = ax[0].imshow(flux.T,
                       origin='lower',
                       extent=[D.x2[0], D.x2[-1], D.x3[0], D.x3[-1]])

    ax.cbar_axes[0].colorbar(neg)

    #    fig.subplots_adjust(top=0.99,bottom=0.11,left=0.11,right=0.99)

    ax.axes_llc.set_xlabel(xlabel, fontsize=20)
    ax.axes_llc.set_ylabel(ylabel, fontsize=20)
    fig.savefig(name)
Пример #8
0
def triple_plot(ds,fields,lfpps={},**kwargs):
    fig = plt.figure(figsize=(15,5))
    
    grid = AxesGrid(fig, (0.075,0.075,0.85,0.85),
                    nrows_ncols = (1, 3),
                    axes_pad = 0.9,
                    #label_mode = "1",
                    share_all = True,
                    cbar_location="right",
                    cbar_mode="each",
                    cbar_size="3%",
                    cbar_pad="0%")
    
    #fields = [("flash",'soundspeed'), ("flash",'velocityz'), ("flash",'temperature'), ("flash",'density')]
    
    ps = []
    for field in fields:
        fpp = lfpps.get(field,{})
        if not(fpp):
            fpp = fpps.get(field,{})
        ps.append(slice_plot(ds,field,plotprops=fpp,**kwargs))
    
    for i, field in enumerate(fields):
        plot = ps[i].plots[field]
        plot.figure = fig
        plot.axes = grid[i].axes
        plot.cax = grid.cbar_axes[i]
        ps[i]._setup_plots()
    return fig
Пример #9
0
def plot_filters_single_channel(t):
    nrows=25
    ncols=25
    fig,ax = plt.subplots(nrows=25, ncols=25, figsize = (38,38))
    count = 0
    #looping through all the kernels in each channel
    grid = AxesGrid(fig, 111,
                nrows_ncols=(25, 25),
                axes_pad=0.05,
                cbar_mode='single',
                cbar_location='right',
                cbar_pad=0.1
                )
    count = 0
    vmin = np.amin(t)
    vmax = np.amax(t)
    ax = []
    for x in grid:
        ax.append(x)
    count = 0
    #looping through all the kernels in each channel
    for i in range(1, 26):
        for j in range(1, 26):
            pcm = ax[count].imshow(t[i-1, j-1],vmin = vmin, vmax = vmax, cmap = "plasma")
            ax[count].axis('off')
            cbar = ax[count].cax.colorbar(pcm)
            count+=1
    cbar = grid.cbar_axes[0].colorbar(pcm)
    plt.savefig("w_gate.png")
    return
Пример #10
0
 def _create_figure(self, figsize=None, **kwargs):
     # creates a figure and then an AxesGrid in the 111 position
     fig = Figure(figsize=figsize)
     grid = AxesGrid(fig, 111, **kwargs)
     self.grid = grid
     self.next_grid = {grid: 0}
     return fig, grid
Пример #11
0
def plotting_weights(save_folder, filename, mat, true_states=None, estimated_states=None, soz_ch_ids=None, sel_win_num=None):
    title_fontSize = 40
#     ictal_indices = np.argwhere(true_states!=0).astype(np.int)
#     ictal_indices = ictal_indices.reshape((ictal_indices.size,))
    num_wind = len(mat)
    plot_num_rows = int(np.ceil(num_wind**0.5))
    plot_num_cols = int(np.ceil(num_wind/plot_num_rows)) 
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    fig = plt.figure(num=None, figsize=(60, 40), dpi=120)
    grid = AxesGrid(fig, 111,
                nrows_ncols=(plot_num_rows, plot_num_cols),
                axes_pad=1,
                cbar_mode='single',
                cbar_location='right',
                cbar_pad=0.1)
    i = 0
    for ax in grid: 
        if(i>=len(mat)):
            break
        N = int(mat[i].size**0.5)
        im = ax.imshow(np.reshape(mat[i], (N, N)))
        if(true_states is not None and true_states[i] != 0 and estimated_states is not None):
            ax.set_title('--'+ str(estimated_states[i])+'--', fontsize=title_fontSize+5)
        elif(true_states is not None and true_states[i] == 0 and estimated_states is not None):
            ax.set_title(str(estimated_states[i]), fontsize=title_fontSize)
        i += 1
    cbar = ax.cax.colorbar(im)
    cbar.ax.tick_params(labelsize=30) 
    cbar = grid.cbar_axes[0].colorbar(im)   
    plt.savefig(save_folder + 'W_' + filename + '.png')       
Пример #12
0
def demo_grid_with_each_cbar(fig):
    """
    A grid of 2x2 images. Each image has its own colorbar.
    """

    grid = AxesGrid(
        F,
        143,  # similar to subplot(143)
        nrows_ncols=(2, 2),
        axes_pad=0.1,
        label_mode="1",
        share_all=True,
        cbar_location="top",
        cbar_mode="each",
        cbar_size="7%",
        cbar_pad="2%",
    )
    Z, extent = get_demo_image()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")
        grid.cbar_axes[i].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(False)

    # This affects all axes because we set share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])
Пример #13
0
def double_plot(settings, ds):

    fig = plt.figure()
    grid = AxesGrid(fig, (0.09,0.09,0.8,0.8),
                    nrows_ncols = (1, 2),
                    axes_pad = 0.05,
                    label_mode = "L",
                    share_all = True,
                    cbar_location="right",
                    cbar_mode="single",
                    cbar_size="5%",
                    cbar_pad="1.2%")

    for i, fn in enumerate(fns):

        if settings["option"] == "top_down":
            slc = yt.SlicePlot(ds, 
                               'z' , 
                               settings["field"], 
                               center=settings["center"], 
                               width=settings["width"],
                               fontsize=settings["font"])
            if settings["streamlines"]:
                slc.annotate_streamlines('velocity_x', 
                                         'velocity_y', 
                                         density=1.5,
                                         factor=16, 
                                         plot_args={
                                             'color': 'black', 
                                             'linewidth': 0.25})

        if settings["option"] == "side_on":
            slc = yt.OffAxisSlicePlot(ds, 
                                      settings["L"],
                                      settings["field"], 
                                      center=settings["center"], 
                                      north_vector=settings["north_vector"], 
                                      width=settings["width"],
                                      fontsize=settings["font"])
            slc.set_xlabel('x $\ (\mathrm{R}_{\odot})$')
            slc.set_ylabel('z $\ (\mathrm{R}_{\odot})$')
            if settings["streamlines"]:
                slc.annotate_streamlines('magnetic_field_x', 
                                         'magnetic_field_z', 
                                         density=1.5, 
                                         factor=16, 
                                         plot_args={
                                             'color': 'white', 
                                             'linewidth': 0.75})
            
        slc.set_cmap(field=settings["field"], cmap='jet')
        slc.set_zlim(settings["field"], settings["lim"][0], settings["lim"][1])

        plot = slc.plots[settings["field"]]
        plot.figure = fig
        plot.axes = grid[i].axes
        plot.cax = grid.cbar_axes[i]
        slc._setup_plots()

    slc.save("plots/"+settings["save_name"]+".pdf")
Пример #14
0
def AX_perturbations(fig, perturbations):
    """
	A grid of 1x10 images representing the perturbations of the 
	generated AXs - with a single coolwarm colorbar at the right
	"""
    grid = AxesGrid(
        fig,
        211,  # similar to subplot(211)
        nrows_ncols=(1, 10),
        axes_pad=0.0,
        share_all=True,
        label_mode="1",
        cbar_location="right",
        cbar_mode="single",
        cbar_size="7%",
        cbar_pad="3%",
    )

    for i in range(nb_classes):
        img = perturbations[i].reshape([28, 28])
        im = grid[i].imshow(img,
                            interpolation="nearest",
                            cmap=cm.coolwarm,
                            vmin=-1.,
                            vmax=1.)
        grid[i].tick_params(which='both',
                            bottom='off',
                            left='off',
                            labelbottom='off',
                            labelleft='off')
    import matplotlib as mpl
    norm_ = mpl.colors.Normalize(vmin=-1., vmax=1.)
    grid.cbar_axes[0].colorbar(im, norm=norm_)
    #grid.cbar_axes[0].set_yticklabels(['-1', '0', '1'])
    grid.cbar_axes[0].set_yticks((-1, 0, 1))
Пример #15
0
def AX_actual(fig, adv_x, top_1, confidence, ylabel):
    """
	A grid of 1x10 images displaying the actual AXs along with 
	their predicted labels and confidences"""
    grid = AxesGrid(
        fig,
        111,  # similar to subplot(212)
        nrows_ncols=(1, 10),
        axes_pad=0.0,
        share_all=True,
        label_mode="all")
    for i in range(nb_classes):
        img = adv_x[i].reshape([28, 28])
        im = grid[i].imshow(img, cmap='gray')
        grid[i].tick_params(which='both',
                            bottom='off',
                            left='off',
                            labelbottom='off',
                            labelleft='off')
        conf = str(
            Decimal(str(confidence[i])).quantize(Decimal('0.01'),
                                                 rounding=ROUND_DOWN))
        #xlabel = str(top_1[i]) + " " + "(" + '{0:.2f}'.format(confidence[i]) + ')'
        xlabel = str(top_1[i]) + " " + "(" + conf + ")"
        grid[i].set_xlabel(xlabel, labelpad=2.0, fontsize=12)
    grid[9].yaxis.set_label_position("right")
    # ylabel should be a string
    grid[9].set_ylabel(ylabel, labelpad=14.0, fontsize=15, rotation=270)
Пример #16
0
def plot_multiple_model_weights(weights_to_plot):
    """
    Plot the weights of different models side by side as a heat maps.
    """
    models_weights = np.array(
        [rectangularfy(weights) for weights in weights_to_plot])
    vmax = models_weights.max()
    vmin = models_weights.min()

    fig = plt.figure()
    grid = AxesGrid(
        fig,
        111,
        nrows_ncols=(1, len(models_weights)),
        axes_pad=0.05,
        share_all=True,
        label_mode="L",
        cbar_location="right",
        cbar_mode="single",
    )

    for axis, model_weights in zip(grid, models_weights):
        heatmap = axis.imshow(model_weights,
                              cmap="coolwarm",
                              vmax=vmax,
                              vmin=vmin)
        axis.set_xticks([])
        axis.set_yticks([])

    grid.cbar_axes[0].colorbar(heatmap)

    plt.show()
    def displayMaps(self, map_list, figFilename, n_rows=2):
        print 'map_list: ', (np.array(map_list)).shape
        fig = plt.figure(figsize=(200, 200))
        grid = AxesGrid(
            fig,
            111,
            nrows_ncols=(n_rows, int(np.round(len(map_list) / n_rows))),
            axes_pad=0.02,
            share_all=True,
            label_mode="L",
            cbar_location="right",
            cbar_mode="each",
        )
        vmin = float('inf')
        vmax = -float('inf')
        for cur_map in map_list:
            cur_map_min = np.min(cur_map)
            cur_map_max = np.max(cur_map)
            vmin = min(vmin, cur_map_min)
            vmax = max(vmax, cur_map_max)
            print 'DEBUG:vmax, vmin: ', vmax, vmin

        for cur_map, ax in zip(map_list, grid):
            im = ax.imshow(cur_map, vmin=vmin, vmax=vmax)
            #im = ax.imshow(cur_map, vmin=self.noise_floor_dB, vmax=0.0)
        grid.cbar_axes[0].colorbar(im)
        #plt.show(block = False)
        #plt.colorbar()
        plt.savefig(figFilename)
Пример #18
0
def demo_grid_with_each_cbar_labelled(fig):
    """
    A grid of 2x2 images. Each image has its own colorbar.
    """

    grid = AxesGrid(
        F,
        144,  # similar to subplot(144)
        nrows_ncols=(2, 2),
        axes_pad=(0.45, 0.15),
        label_mode="1",
        share_all=True,
        cbar_location="right",
        cbar_mode="each",
        cbar_size="7%",
        cbar_pad="2%",
    )
    Z, extent = get_demo_image()

    # Use a different colorbar range every time
    limits = ((0, 1), (-2, 2), (-1.7, 1.4), (-1.5, 1))
    for i in range(4):
        im = grid[i].imshow(Z,
                            extent=extent,
                            interpolation="nearest",
                            vmin=limits[i][0],
                            vmax=limits[i][1])
        grid.cbar_axes[i].colorbar(im)

    for i, cax in enumerate(grid.cbar_axes):
        cax.set_yticks((limits[i][0], limits[i][1]))

    # This affects all axes because we set share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])
Пример #19
0
def quad_plot(ds,fields,width=(5000,"km")):
    fig = plt.figure()
    
    grid = AxesGrid(fig, (0.075,0.075,0.85,0.85),
                    nrows_ncols = (2, 2),
                    axes_pad = 0.9,
                    #label_mode = "1",
                    share_all = True,
                    cbar_location="right",
                    cbar_mode="each",
                    cbar_size="3%",
                    cbar_pad="0%")
    
    #fields = [("flash",'soundspeed'), ("flash",'velocityz'), ("flash",'temperature'), ("flash",'density')]
    
    ps = []
    for field in fields:
        ps.append(slice_plot(ds,field,width=width))
    
    for i, field in enumerate(fields):
        plot = ps[i].plots[field]
        plot.figure = fig
        plot.axes = grid[i].axes
        plot.cax = grid.cbar_axes[i]
        ps[i]._setup_plots()
    return fig
Пример #20
0
def multiplot(): 

    field = 'density'  
    view = 'projection' 
    field = 'O_p5_number_density'

    track = Table.read('complete_track', format='ascii') 
    track.sort('col1') 

    outs = [x+55 for x in range(400)]

    for n in outs: 

        fig = plt.figure() 
        grid = AxesGrid(fig, (0.5,0.5,1.5,1.5),
                nrows_ncols = (1, 5),
                axes_pad = 0.1,
                label_mode = "1",
                share_all = True,
                cbar_location="right",
                cbar_mode="edge",
                cbar_size="5%",
                cbar_pad="0%")

        strset = 'DD00'+str(n) 
        if (n > 99): strset = 'DD0'+str(n) 
        fields = [field, field, field, field, field] 
        snaps = ['nref10_track_2/'+strset+'/'+strset, 'nref10_track_lowfdbk_1/'+strset+'/'+strset, 
             'nref10_track_lowfdbk_2/'+strset+'/'+strset, 'nref10_track_lowfdbk_3/'+strset+'/'+strset,
             'nref10_track_lowfdbk_4/'+strset+'/'+strset]

        for i, (field, snap) in enumerate(zip(fields, snaps)):
    
            ds = yt.load(snap) 
            zsnap = ds.get_parameter('CosmologyCurrentRedshift')
            trident.add_ion_fields(ds, ions=['C IV', 'O VI','H I','Si III'])

            centerx = np.interp(zsnap, track['col1'], 0.5*(track['col2']+track['col5']) ) 
            centery = np.interp(zsnap, track['col1'], track['col3']+30./143886.) 
            centerz = np.interp(zsnap, track['col1'], 0.5*(track['col4']+track['col7']) ) 
            center = [centerx, centery, centerz] 
    
            box = ds.r[ center[0]-200./143886:center[0]+200./143886, center[1]-250./143886.:center[1]+250./143886., center[2]-40./143886.:center[2]+40./143886.]
    
            # projection 
            p = yt.ProjectionPlot(ds, 'z', field, center=center, width=((120,'kpc'),(240,'kpc')), data_source=box)
            if (field == 'density'): 
                p.set_unit('density', 'Msun / pc**2')
                p.set_zlim('density', 0.01, 1000) 
            if ('O_p5' in field): 
                p.set_zlim("O_p5_number_density",1e11,1e15)
            if (i < 1): p.annotate_timestamp(corner='upper_left', redshift=True, draw_inset_box=True, text_args={'color':'white', 'size':'small'} )
    
            # This forces the ProjectionPlot to redraw itself on the AxesGrid axes.
            plot = p.plots[field]
            plot.figure = fig
            plot.axes = grid[i].axes
            p._setup_plots()   # Finally, redraw the plot.
    
        plt.savefig(strset+'_multiplot_'+field+'_projection.png', bbox_inches='tight')
Пример #21
0
def demo_grid_with_single_cbar(fig):
    """
    A grid of 2x2 images with a single colorbar
    """
    grid = AxesGrid(
        fig,
        142,  # similar to subplot(142)
        nrows_ncols=(2, 2),
        axes_pad=0.0,
        share_all=True,
        label_mode="L",
        cbar_location="top",
        cbar_mode="single",
    )

    Z, extent = get_demo_image()
    for i in range(4):
        im = grid[i].imshow(Z, extent=extent, interpolation="nearest")
    #plt.colorbar(im, cax = grid.cbar_axes[0])
    grid.cbar_axes[0].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(False)

    # This affects all axes as share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])
Пример #22
0
    def displayMaps(self, map_list, figFilename=None, n_rows=1, title=''):
        fig = plt.figure()
        grid = AxesGrid(
            fig,
            111,
            nrows_ncols=(n_rows, int(np.round(len(map_list) / n_rows))),
            axes_pad=0.01,
            share_all=True,
            label_mode="L",
            cbar_location="right",
            cbar_mode="single",
        )
        vmin = float('inf')
        vmax = -float('inf')
        for cur_map in map_list:
            cur_map_min = np.min(cur_map)
            cur_map_max = np.max(cur_map)
            vmin = min(vmin, cur_map_min)
            vmax = max(vmax, cur_map_max)
            #print 'DEBUG:vmax, vmin: ',vmax, vmin

        for cur_map, ax in zip(map_list, grid):
            im = ax.imshow(cur_map, vmin=vmin, vmax=vmax)
        grid.cbar_axes[0].colorbar(im)

        if figFilename is not None:
            plt.savefig(figFilename)
        else:
            plt.show(block=False)
Пример #23
0
def heatmap_in_one_figure(vals, pars, cmap=None):
    """
    vals: lists of dataframes
    pars: nrow and ncol of plots in the figure
    """

    from mpl_toolkits.axes_grid1 import AxesGrid
    fig = plt.figure()
    grid = AxesGrid(
        fig,
        111,
        nrows_ncols=pars,
        axes_pad=0.05,
        share_all=True,
        label_mode="L",
        cbar_location="right",
        cbar_mode="single",
    )

    for val, ax in zip(vals, grid):
        myplot3d = plot3D(val)
        hm = ax.pcolor(myplot3d.x,
                       myplot3d.y,
                       myplot3d.z,
                       vmin=0,
                       vmax=1,
                       cmap=cmap)

    grid.cbar_axes[0].colorbar(hm)

    return fig, ax
def demo_bottom_cbar(fig):
    """
    A grid of 2x2 images with a colorbar for each column.
    """
    grid = AxesGrid(
        fig,
        121,  # similar to subplot(121)
        nrows_ncols=(2, 2),
        axes_pad=0.10,
        share_all=True,
        label_mode="1",
        cbar_location="bottom",
        cbar_mode="edge",
        cbar_pad=0.25,
        cbar_size="15%",
        direction="column")

    Z, extent = get_demo_image()
    cmaps = [plt.get_cmap("autumn"), plt.get_cmap("summer")]
    for i in range(4):
        im = grid[i].imshow(Z,
                            extent=extent,
                            interpolation="nearest",
                            cmap=cmaps[i // 2])
        if i % 2:
            grid.cbar_axes[i // 2].colorbar(im)

    for cax in grid.cbar_axes:
        cax.toggle_label(True)
        cax.axis[cax.orientation].set_label("Bar")

    # This affects all axes as share_all = True.
    grid.axes_llc.set_xticks([-2, 0, 2])
    grid.axes_llc.set_yticks([-2, 0, 2])
Пример #25
0
 def displayMaps(self, map_list, figFilename, n_rows = 1):
     fig = plt.figure()
     fig.suptitle("Maximum %-diff in input map-pairs", fontsize=16)
     nrows_ncols = (n_rows, int(  np.round( len(map_list)/n_rows) ))
     grid = AxesGrid(fig, 111,
                     nrows_ncols= nrows_ncols,
                     axes_pad=0.1,
                     share_all=True,
                     label_mode="L",
                     cbar_location="right",
                     cbar_mode="single",
                     )
     vmin =  float('inf')
     vmax = -float('inf')
     for cur_map in map_list:
         cur_map_min = np.min(cur_map)
         cur_map_max = np.max(cur_map)
         vmin = min(vmin, cur_map_min)
         vmax = max(vmax, cur_map_max)
     if (vmax == vmin):
         vmax += 1
     for cur_map, ax in zip(map_list, grid):
         im = ax.imshow(cur_map, vmin=vmin, vmax=vmax)
     grid.cbar_axes[0].colorbar(im)
     #plt.show(block = False)
     plt.savefig(figFilename)
Пример #26
0
    def create_maps(self):
        g = self.chan_map.geometry
        # this is an over-estimate of inches per pixel, maybe
        # find another rule
        inch_per_pixel = 0.039 * self.mm_per_pixel
        # size of array map: width, height
        img_size = g[1] * inch_per_pixel, g[0] * inch_per_pixel
        figsize = self.map_col * img_size[0], self.map_row * img_size[1]
        figwin = SimpleFigure(figsize=figsize)
        fig = figwin.figure

        text_size_inch = 11 / fig.dpi

        grid = AxesGrid(
            fig, 111, nrows_ncols=(self.map_row, self.map_col),
            axes_pad=1.5 * text_size_inch,
            cbar_mode='single', cbar_location='right',
            cbar_pad='2%', cbar_size='4%'
        )

        for ax in grid:
            ax.axis('off')
        self._grid = grid
        self._g_idx = 0
        self._cbar = None
        # hold onto this or it disappears
        figwin.show()
        self._figwin = figwin
Пример #27
0
def plotImgMosaic(imgOrig, extent, origin='upper', T=0):
    fig = plt.figure()
    grid = AxesGrid(
        fig,
        111,  # similar to subplot(111)
        nrows_ncols=(3, 2),
        axes_pad=0.0,
        share_all=True,
        label_mode="all",
        cbar_location="top",
        cbar_mode="single",
    )

    img = dict(imgOrig)

    for im in img:
        img[im] = np.squeeze(img[im])
        if T:
            img[im] = img[im].T

    ij = "21"
    toPlot = 20 * np.log10(np.abs(img[ij]))
    im = grid[0].imshow(toPlot, cmap='jet', extent=extent, origin=origin)
    grid[0].set_ylabel(ij)

    ij = "41"
    toPlot = 20 * np.log10(np.abs(img[ij]))
    im = grid[1].imshow(toPlot, cmap='jet', extent=extent, origin=origin)
    grid[1].set_ylabel(ij)
    grid[1].yaxis.set_ticks_position('right')
    grid[1].yaxis.set_label_position('right')

    ij = "31"
    toPlot = 20 * np.log10(np.abs(img[ij]))
    im = grid[2].imshow(toPlot, cmap='jet', extent=extent, origin=origin)
    grid[2].set_ylabel(ij)

    ij = "42"
    toPlot = 20 * np.log10(np.abs(img[ij]))
    im = grid[3].imshow(toPlot, cmap='jet', extent=extent, origin=origin)
    grid[3].set_ylabel(ij)
    grid[3].yaxis.set_ticks_position('right')
    grid[3].yaxis.set_label_position('right')

    ij = "32"
    toPlot = 20 * np.log10(np.abs(img[ij]))
    im = grid[4].imshow(toPlot, cmap='jet', extent=extent, origin=origin)
    grid[4].set_ylabel(ij)

    ij = "43"
    toPlot = 20 * np.log10(np.abs(img[ij]))
    im = grid[5].imshow(toPlot, cmap='jet', extent=extent, origin=origin)
    grid[5].set_ylabel(ij)
    grid[5].yaxis.set_ticks_position('right')
    grid[5].yaxis.set_label_position('right')

    grid.cbar_axes[0].colorbar(im)

    return fig, grid
Пример #28
0
def plot_activation_gradients(all_timesteps_activations: np.array,
                              neuron_heatmap_size: tuple,
                              show_title=True,
                              absolute=True,
                              save=None):
    """
    Plot the changes in activation values between time steps as heat maps for one single sample.
    """
    num_timesteps = len(all_timesteps_activations)

    assert all([type(out) in (torch.Tensor, np.ndarray) for out in all_timesteps_activations]), \
        "This function only takes all the activations for all the time steps of a single sample."

    fig = plt.figure()
    last_activations = all_timesteps_activations[0]

    grid = AxesGrid(
        fig,
        111,
        nrows_ncols=(1, num_timesteps - 1),
        axes_pad=0.05,
        share_all=True,
        label_mode="L",
        cbar_location="right",
        cbar_mode="single",
    )

    for t, (axis, current_activations) in enumerate(
            zip(grid, all_timesteps_activations[1:])):
        activation_gradients = current_activations - last_activations
        vmin, vmax = -2, 2
        colormap = 'coolwarm'

        if absolute:
            vmin = 0
            colormap = "Reds"

        heatmap = axis.imshow(
            activation_gradients.reshape(*neuron_heatmap_size),
            cmap=colormap,
            vmin=vmin,
            vmax=vmax)
        axis.set_xlabel("t={} -> t={}".format(t, t + 1))
        axis.set_xticks([])
        axis.set_yticks([])

        last_activations = current_activations

    grid.cbar_axes[0].colorbar(heatmap)

    if show_title:
        fig.suptitle("Activation value gradients over {} time steps".format(
            num_timesteps))

    if save is None:
        plt.show()
    else:
        plt.savefig(save, bbox_inches="tight")
        plt.close()
    def accumulate_patches_into_heatmaps(self,
                                         all_test_output,
                                         outpath_prefix=''):
        outpath = "plots/%s_%s.png" % (
            outpath_prefix, path.splitext(path.basename(
                self.test_imagepath))[0])
        # http://matplotlib.org/examples/axes_grid/demo_axes_grid.html
        fig = plt.figure()
        grid = AxesGrid(
            fig,
            143,  # similar to subplot(143)
            nrows_ncols=(1, 1))
        orig_img = imread(self.test_imagepath + '.png')
        grid[0].imshow(orig_img)
        grid = AxesGrid(
            fig,
            144,  # similar to subplot(144)
            nrows_ncols=(2, 2),
            axes_pad=0.15,
            label_mode="1",
            share_all=True,
            cbar_location="right",
            cbar_mode="each",
            cbar_size="7%",
            cbar_pad="2%",
        )

        for klass in xrange(all_test_output.shape[1]):
            accumulator = numpy.zeros(self.ds.image_shape[:2])
            normalizer = numpy.zeros(self.ds.image_shape[:2])
            for n in xrange(self.num_patch_centers):
                i_start, i_end, j_start, j_end = self.nth_patch(n)

                accumulator[i_start:i_end,
                            j_start:j_end] += all_test_output[n, klass]
                normalizer[i_start:i_end, j_start:j_end] += 1
            normalized_img = accumulator / normalizer
            im = grid[klass].imshow(normalized_img,
                                    interpolation="nearest",
                                    vmin=0,
                                    vmax=1)
            grid.cbar_axes[klass].colorbar(im)
        grid.axes_llc.set_xticks([])
        grid.axes_llc.set_yticks([])
        print("Saving figure as: %s" % outpath)
        plt.savefig(outpath, dpi=600, bbox_inches='tight')
Пример #30
0
def plot_pathological_imgs():
    fig = plt.figure()
    grid = AxesGrid(fig, 111, nrows_ncols = (1, 4))
    names = ['23050_right.png', '2468_left.png', '15450_left.png', '406_left.png']
    imgs = [imread(n) for n in names]
    [grid[i].imshow(imgs[i]) for i in range(len(imgs))]
    plt.axis('off')
    plt.savefig('out.png', dpi=300)
Пример #31
0
def main(args) :
    import os.path

    if args.trackTitles is None :
        args.trackTitles = [os.path.dirname(filename) for
                            filename in args.inputDataFiles]

    if len(args.inputDataFiles) == 0 :
         print "WARNING: No corner control files given!"

    if len(args.trackTitles) != len(args.inputDataFiles) :
        raise ValueError("The number of TITLEs does not match the number"
                         " of INPUTFILEs.")

    if args.statName is not None and args.statLonLat is None :
        statData = ByName(args.statName)[0]
        args.statLonLat = (statData['LON'], statData['LAT'])

    if args.layout is None :
        args.layout = (1, len(args.inputDataFiles))

    if args.figsize is None :
        args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1])

    if args.simTagFiles is None :
        args.simTagFiles = []

    polyfiles = args.polys

    cornerVolumes = [ReadCorners(inFileName,
                                 os.path.dirname(inFileName))['volume_data']
                     for inFileName in args.inputDataFiles]

    polyData = [_load_verts(f, list(vol['stormCells'] for vol in vols)) for
                f, vols in zip(polyfiles, cornerVolumes)]

    multiTags = [ReadSimTagFile(fname) for fname in args.simTagFiles]

    if len(multiTags) == 0 :
        multiTags = [None]

    if len(multiTags) < len(cornerVolumes) :
        # Rudimentary broadcasting
        tagMult = max(int(len(cornerVolumes) // len(multiTags)), 1)
        multiTags = multiTags * tagMult

    if args.statLonLat is not None :
        for vols in cornerVolumes :
            for vol in vols :
                CoordinateTransform(vol['stormCells'],
                                    args.statLonLat[0],
                                    args.statLonLat[1])
        for verts in polyData:
            CoordinateTransform(verts,
                                args.statLonLat[0],
                                args.statLonLat[1])

    showMap = (args.statLonLat is not None and args.displayMap)
    showRadar = (args.statLonLat is not None and args.radarFile is not None)

    theFig = plt.figure(figsize=args.figsize)
    grid = AxesGrid(theFig, 111, nrows_ncols=args.layout,
                            share_all=True, axes_pad=0.32)

    theAnim, radAnim = MakeCornerPlots(theFig, grid, cornerVolumes,
                                       args.trackTitles, showMap, showRadar,
                                       tail=args.tail,
                                       startFrame=args.startFrame,
                                       endFrame=args.endFrame,
                                       radarFiles=args.radarFile,
                                       fade=args.fade,
                                       multiTags=multiTags,
                                       tag_filters=args.filters)

    polyAnims = []
    for ax, verts in zip(grid, polyData):
        from matplotlib.animation import ArtistAnimation
        polyAnim = ArtistAnimation(theFig,
                        _to_polygons(polys[startFrame:endFrame + 1], ax),
                        event_source=theTimer)
        polyAnims.append(polyAnim)

    if args.xlims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_xlim(args.xlims)

    if args.ylims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_ylim(args.ylims)

    if args.saveImgFile is not None :
        if radAnim is not None :
            radAnim = [radAnim]
        theAnim.save(args.saveImgFile, extra_anim=radAnim + polyAnims)

    if args.doShow :
        plt.show()
Пример #32
0
def main(args) :
    import os.path			# for os.path
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import AxesGrid
    
    inputDataFiles = []
    titles = []
    simTagFiles = []

    if args.simName is not None :
        dirName = os.path.join(args.directory, args.simName)
        simParams = ParamUtils.ReadSimulationParams(os.path.join(dirName,
                                                    "simParams.conf"))
        inputDataFiles.append(os.path.join(dirName, simParams['inputDataFile']))
        titles.append(args.simName)
        simTagFiles.append(os.path.join(dirName, simParams['simTagFile']))

    # Add on any files specified at the command-line
    inputDataFiles += args.inputDataFiles
    titles += args.inputDataFiles
    if args.simTagFiles is not None :
        simTagFiles += args.simTagFiles

    if len(inputDataFiles) == 0 :
        print "WARNING: No inputDataFiles given or found!"

    if len(titles) != len(inputDataFiles) :
        raise ValueError("The number of TITLEs does not match the"
                         " number of INPUTFILEs.")

    if len(simTagFiles) < len(inputDataFiles) :
        # Not an error, just simply append None
        simTagFiles.append([None] * (len(inputDataFiles) - len(simTagFiles)))

    if args.statName is not None and args.statLonLat is None :
        statData = ByName(args.statName)[0]
        args.statLonLat = (statData['LON'], statData['LAT'])

    if args.layout is None :
        args.layout = (1, len(inputDataFiles))

    if args.figsize is None :
        args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1])

    cornerVolumes = [ReadCorners(inFileName,
                                 os.path.dirname(inFileName))['volume_data']
                     for inFileName in inputDataFiles]

    multiTags = [(ReadSimTagFile(fname) if fname is not None else None) for
                 fname in simTagFiles]

    for vols, simTags in zip(cornerVolumes, multiTags) :
        keeperIDs = process_tag_filters(simTags, args.filters)
        if keeperIDs is None :
            continue

        for vol in vols :
            vol['stormCells'] = FilterTrack(vol['stormCells'],
                                            cornerIDs=keeperIDs)

    if args.statLonLat is not None :
        for vols in cornerVolumes :
            for vol in vols :
                CoordinateTransform(vol['stormCells'],
                                    args.statLonLat[0],
                                    args.statLonLat[1])

    theFig = plt.figure(figsize=args.figsize)
    grid = AxesGrid(theFig, 111, nrows_ncols=args.layout,
                            share_all=True, axes_pad=0.32)

    # A list to hold the CircleCollection arrays, it will have length 
    # of max(tLims) - min(tLims) + 1
    allCorners = None

    if args.trackFile is not None :
        (tracks, falarms) = FilterMHTTracks(*ReadTracks(args.trackFile))

        if args.statLonLat is not None :
            CoordinateTransform(tracks + falarms,
                                args.statLonLat[0],
                                args.statLonLat[1])

        (xLims, yLims, frameLims) = DomainFromTracks(tracks + falarms)
    else :
        volumes = []
        for aVol in cornerVolumes :
            volumes.extend(aVol)
        (xLims, yLims, tLims, frameLims) = DomainFromVolumes(volumes)

    showMap = (args.statLonLat is not None and args.displayMap)

    if showMap :
        bmap = Basemap(projection='cyl', resolution='l',
                       suppress_ticks=False,
                       llcrnrlat=yLims[0], llcrnrlon=xLims[0],
                       urcrnrlat=yLims[1], urcrnrlon=xLims[1])

    startFrame = args.startFrame
    endFrame = args.endFrame
    tail = args.tail

    if startFrame is None :
        startFrame = frameLims[0]

    if endFrame is None :
        endFrame = frameLims[1]

    if tail is None :
        tail = 0

    # A common event_source for synchronizing all the animations
    theTimer = None

    # Make the corners big
    big = False

    if args.radarFile is not None and args.statLonLat is not None :
        if endFrame - frameLims[0] >= len(args.radarFile) :
            # Not enough radar files, so truncate the tracks.
            endFrame = (len(args.radarFile) + frameLims[0]) - 1
        files = args.radarFile[startFrame - frameLims[0]:(endFrame + 1) -
                                                         frameLims[0]]
        radAnim = RadarAnim(theFig, files)
        theTimer = radAnim.event_source
        for ax in grid :
            radAnim.add_axes(ax, alpha=0.6, zorder=0)

        # Radar images make it difficult to see corners, so make 'em big
        big = True
    else :
        radAnim = None

    theAnim = CornerAnimation(theFig, endFrame - startFrame + 1,
                              tail=tail, interval=250, blit=False,
                              event_source=theTimer, fade=args.fade)

    for (index, volData) in enumerate(cornerVolumes) :
        curAxis = grid[index]

        if showMap :
            PlotMapLayers(bmap, mapLayers, curAxis, zorder=0.1)

        volFrames = [frameVol['frameNum'] for frameVol in volData]
        startIdx = volFrames.index(startFrame)
        endIdx = volFrames.index(endFrame)
        volTimes = [frameVol['volTime'] for frameVol in volData]
        startT = volTimes[startIdx]
        endT = volTimes[endIdx]

        corners = PlotCorners(volData, (startT, endT), axis=curAxis,
                              big=big)

        #curAxis.set_aspect("equal", 'datalim')
        #curAxis.set_aspect("equal")
        curAxis.set_title(titles[index])
        if not showMap :
            curAxis.set_xlabel("X")
            curAxis.set_ylabel("Y")
        else :
            curAxis.set_xlabel("Longitude")
            curAxis.set_ylabel("Latitude")

        theAnim.AddCornerVolume(corners)

    if args.xlims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_xlim(args.xlims)

    if args.ylims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_ylim(args.ylims)

    if args.saveImgFile is not None :
        if radAnim is not None :
            radAnim = [radAnim]
        theAnim.save(args.saveImgFile, extra_anim=radAnim)

    if args.doShow :
        plt.show()
Пример #33
0
def main(args) :
    import os.path
    import glob				# for globbing
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import AxesGrid

    if args.bw_mode :
        BW_mode()       # from TrackPlot module

    # FIXME: Currently, the code allows for trackFiles to be listed as well
    #        as providing a simulation (which trackfiles are automatically
    #        grabbed). Both situations can not be handled right now, though.
    trackFiles = []
    trackTitles = []

    if args.simName is not None :
        dirName = os.path.join(args.directory, args.simName)
        simParams = ParamUtils.ReadSimulationParams(os.path.join(dirName,
                                                    "simParams.conf"))

        if args.trackRuns is not None :
            simParams['trackers'] = ExpandTrackRuns(simParams['trackers'],
                                                    args.trackRuns)

        trackFiles = [os.path.join(dirName, simParams['result_file'] +
                                            '_' + aTracker)
                      for aTracker in simParams['trackers']]
        if args.trackTitles is None :
            trackTitles = simParams['trackers']
        else :
            trackTitles = args.trackTitles

        if args.truthTrackFile is None :
            args.truthTrackFile = os.path.join(dirName,
                                               simParams['noisyTrackFile'])

        if args.simTagFile is None :
            args.simTagFile = os.path.join(dirName, simParams['simTagFile'])

    trackFiles += args.trackFiles

    if args.trackTitles is None :
        trackTitles += args.trackFiles
    else :
        trackTitles += args.trackTitles

    if len(trackFiles) == 0 : print "WARNING: No trackFiles given or found!"

    if len(trackTitles) != len(trackFiles) :
        raise ValueError("The number of TITLEs do not match the"
                         " number of TRACKFILEs.")

    if args.statName is not None and args.statLonLat is None :
        statData = ByName(args.statName)[0]
        args.statLonLat = (statData['LON'], statData['LAT'])

    if args.layout is None :
        args.layout = (1, len(trackFiles))

    if args.figsize is None :
        args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1])

    trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for
                   trackFile in trackFiles]

    keeperIDs = None

    if args.simTagFile is not None :
        simTags = ParamUtils.ReadSimTagFile(args.simTagFile)
        keeperIDs = ParamUtils.process_tag_filters(simTags, args.filters)

    if args.statLonLat is not None :
        for aTracker in trackerData :
            CoordinateTransform(aTracker[0] + aTracker[1],
                                args.statLonLat[0],
                                args.statLonLat[1])


    theFig = plt.figure(figsize=args.figsize)
    grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, aspect=False,
                            share_all=True, axes_pad=0.35)


    if args.truthTrackFile is not None :
        (true_tracks,
         true_falarms) = FilterMHTTracks(*ReadTracks(args.truthTrackFile))

        if args.statLonLat is not None :
            CoordinateTransform(true_tracks + true_falarms,
                                args.statLonLat[0],
                                args.statLonLat[1])

        true_AssocSegs = CreateSegments(true_tracks)
        true_FAlarmSegs = CreateSegments(true_falarms)

        if keeperIDs is not None :
            true_AssocSegs = FilterSegments(keeperIDs, true_AssocSegs)
            true_FAlarmSegs = FilterSegments(keeperIDs, true_FAlarmSegs)


        (xLims, yLims, frameLims) = DomainFromTracks(true_tracks + true_falarms)
    else :
        true_AssocSegs = None
        true_FAlarmSegs = None

        stackedTracks = []
        for aTracker in trackerData :
            stackedTracks += aTracker[0] + aTracker[1]
        (xLims, yLims, frameLims) = DomainFromTracks(stackedTracks)

    endFrame = args.endFrame
    tail = args.tail

    if endFrame is None :
        endFrame = frameLims[1]

    if tail is None :
        tail = endFrame - frameLims[0]

    startFrame = endFrame - tail

    showMap = (args.statLonLat is not None and args.displayMap)

    if args.radarFile is not None and args.statLonLat is not None :
        if len(args.radarFile) > 1 and args.endFrame is not None :
            args.radarFile = args.radarFile[args.endFrame]
        else :
            args.radarFile = args.radarFile[-1]

        raddata = LoadRastRadar(args.radarFile)
    else :
        raddata = None

    if showMap :
        bmap = Basemap(projection='cyl', resolution='i',
                       suppress_ticks=False,
                       llcrnrlat=yLims[0], llcrnrlon=xLims[0],
                       urcrnrlat=yLims[1], urcrnrlon=xLims[1])


    for index, (tracks, falarms) in enumerate(trackerData) :
        curAxis = grid[index]

        if raddata is not None :
            MakeReflectPPI(raddata['vals'][0], raddata['lats'], raddata['lons'],
                           meth='pcmesh', ax=curAxis, colorbar=False,
                           axis_labels=False, zorder=0, alpha=0.6)

        if showMap :
            PlotMapLayers(bmap, mapLayers, curAxis)

        if true_AssocSegs is not None and true_FAlarmSegs is not None :
            trackAssocSegs = CreateSegments(tracks)
            trackFAlarmSegs = CreateSegments(falarms)

            if keeperIDs is not None :
                trackAssocSegs = FilterSegments(keeperIDs, trackAssocSegs)
                trackFAlarmSegs = FilterSegments(keeperIDs, trackFAlarmSegs)

            truthtable = CompareSegments(true_AssocSegs, true_FAlarmSegs,
                                         trackAssocSegs, trackFAlarmSegs)
            PlotSegments(truthtable, (startFrame, endFrame), axis=curAxis,
                         fade=args.fade)
        else :
            if keeperIDs is not None :
                filtFunc = lambda trk: FilterTrack(trk, cornerIDs=keeperIDs)
                tracks = map(filtFunc, tracks)
                falarms = map(filtFunc, falarms)
                CleanupTracks(tracks, falarms)

            PlotPlainTracks(tracks, falarms,
                            startFrame, endFrame, axis=curAxis,
                            fade=args.fade)

        #curAxis.set_xlim(xLims)
        #curAxis.set_ylim(yLims)
        #curAxis.set_aspect("equal", 'datalim')
        #curAxis.set_aspect("equal")
        curAxis.set_title(trackTitles[index])
        if not showMap :
            curAxis.set_xlabel("X")
            curAxis.set_ylabel("Y")
        else :
            curAxis.set_xlabel("Longitude")
            curAxis.set_ylabel("Latitude")

    if args.xlims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_xlim(args.xlims)

    if args.ylims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_ylim(args.ylims)

    if args.saveImgFile is not None :
        theFig.savefig(args.saveImgFile, bbox_inches='tight')

    if args.doShow :
        plt.show()
Пример #34
0
def main(args) :
    if args.bw_mode :
        BW_mode()

    if len(args.trackFiles) == 0 :
         print "WARNING: No trackFiles given!"
    if len(args.truthTrackFile) == 0 :
         print "WARNING: No truth trackFiles given!"

    if args.trackTitles is None :
        args.trackTitles = args.trackFiles
    else :
        if len(args.trackTitles) != len(args.trackFiles) :
            raise ValueError("The number of TITLEs does not match the number"
                             " of TRACKFILEs")

    if args.statName is not None and args.statLonLat is None :
        statData = ByName(args.statName)[0]
        args.statLonLat = (statData['LON'], statData['LAT'])

    if args.layout is None :
        args.layout = (1, max(len(args.trackFiles),
                              len(args.truthTrackFile)))

    if args.figsize is None :
        args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1])

    if args.simTagFiles is None :
        args.simTagFiles = []

    trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for
                   trackFile in args.trackFiles]
    truthData = [FilterMHTTracks(*ReadTracks(trackFile)) for
                 trackFile in args.truthTrackFile]
    multiTags = [ReadSimTagFile(fname) for fname in args.simTagFiles]

    if len(multiTags) == 0 :
        multiTags = [None]

    if args.statLonLat is not None :
        for aTracker in trackerData + truthData :
            CoordinateTransform(aTracker[0] + aTracker[1],
                                args.statLonLat[0],
                                args.statLonLat[1])

    if len(trackerData) != len(truthData) :
        # Basic broadcasting needed!

        if len(truthData) > len(trackerData) :
            # Need to extend track data to match with the number of truth sets
            if len(truthData) % len(trackerData) != 0 :
                raise ValueError("Can't extend TRACKFILE list to match with"
                                 " the TRUTHFILE list!")
        else :
            # Need to extend truth sets to match with the number of track data
            if len(trackerData) % len(truthData) != 0 :
                raise ValueError("Can't extend TRUTHFILE list to match with"
                                 " the TRACKFILE list!")

        trkMult = max(int(len(truthData) // len(trackerData)), 1)
        trthMult = max(int(len(trackerData) // len(truthData)), 1)


        trackerData = trackerData * trkMult
        truthData = truthData * trthMult

        tagMult = max(int(len(truthData) // len(multiTags)), 1)
        multiTags = multiTags * tagMult

        args.trackTitles = args.trackTitles * trkMult


    theFig = plt.figure(figsize=args.figsize)
    grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, aspect=False,
                            share_all=True, axes_pad=0.45)

    showMap = (args.statLonLat is not None and args.displayMap)

    if args.radarFile is not None and args.statLonLat is not None :
        if len(args.radarFile) > 1 and args.endFrame is not None :
            args.radarFile = args.radarFile[args.endFrame]
        else :
            args.radarFile = args.radarFile[-1]

        data = LoadRastRadar(args.radarFile)
        for ax in grid :
            MakeReflectPPI(data['vals'][0], data['lats'], data['lons'],
                           meth='pcmesh', ax=ax, colorbar=False,
                           axis_labels=False, zorder=0, alpha=0.6)

    MakeComparePlots(grid, trackerData, truthData, args.trackTitles, showMap,
                     endFrame=args.endFrame, tail=args.tail, fade=args.fade,
                     multiTags=multiTags, tag_filters=args.filters)

    if args.xlims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_xlim(args.xlims)

    if args.ylims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_ylim(args.ylims)

    if args.saveImgFile is not None :
        theFig.savefig(args.saveImgFile)

    if args.doShow :
        plt.show()
Пример #35
0
def main(args) :
    import os.path			# for os.path.join()
    import glob				# for globbing

    if args.bw_mode :
        BW_mode()       # from TrackPlot module

    # FIXME: Currently, the code allows for trackFiles to be listed as well
    #        as providing a simulation (which trackfiles are automatically
    #        grabbed). Both situations can not be handled right now, though.
    trackFiles = []
    trackTitles = []
    polyfiles = args.polys

    if args.statName is not None and args.statLonLat is None :
        statData = ByName(args.statName)[0]
        args.statLonLat = (statData['LON'], statData['LAT'])

    if args.simName is not None :
        dirName = os.path.join(args.directory, args.simName)
        simParams = ParamUtils.ReadSimulationParams(os.path.join(dirName,
                                                    "simParams.conf"))

        if args.trackRuns is not None :
            simParams['trackers'] = ExpandTrackRuns(simParams['trackers'],
                                                    args.trackRuns)

        trackFiles = [os.path.join(dirName, simParams['result_file'] +
                                            '_' + aTracker)
                      for aTracker in simParams['trackers']]
        trackTitles = simParams['trackers']

        if args.truthTrackFile is None :
            args.truthTrackFile = os.path.join(dirName,
                                               simParams['noisyTrackFile'])

        if args.simTagFile is None :
            args.simTagFile = os.path.join(dirName,
                                           simParams['simTagFile'])

    trackFiles += args.trackFiles
    trackTitles += args.trackFiles

    if args.trackTitles is not None :
        trackTitles = args.trackTitles


    if len(trackFiles) == 0 : print "WARNING: No trackFiles given or found!"

    if args.layout is None :
        args.layout = (1, len(trackFiles))

    if args.figsize is None :
        args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1])

    if len(trackFiles) < len(polyfiles):
        raise ValueError("Can not have more polygon files than trackfiles!")

    trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for
                   trackFile in trackFiles]
    polyData = [_load_verts(f, tracks + falarms) for f, (tracks, falarms) in
                zip(polyfiles, trackerData)]

    keeperIDs = None

    if args.simTagFile is not None :
        simTags = ParamUtils.ReadSimTagFile(args.simTagFile)
        keeperIDs = ParamUtils.process_tag_filters(simTags, args.filters)

    if args.statLonLat is not None :
        for aTracker in trackerData :
            CoordinateTransform(aTracker[0] + aTracker[1],
                                args.statLonLat[0],
                                args.statLonLat[1])
        for polys in polyData:
            CoordinateTrans_lists(polys,
                                  args.statLonLat[0], args.statLonLat[1])

    theFig = plt.figure(figsize=args.figsize)
    grid = AxesGrid(theFig, 111, nrows_ncols=args.layout,# aspect=False,
                            share_all=True, axes_pad=0.45)

    if args.truthTrackFile is not None :
        (true_tracks,
         true_falarms) = FilterMHTTracks(*ReadTracks(args.truthTrackFile))

        if args.statLonLat is not None :
            CoordinateTransform(true_tracks + true_falarms,
                                args.statLonLat[0],
                                args.statLonLat[1])

        true_AssocSegs = CreateSegments(true_tracks)
        true_FAlarmSegs = CreateSegments(true_falarms)

        if keeperIDs is not None :
            true_AssocSegs = FilterSegments(keeperIDs, true_AssocSegs)
            true_FAlarmSegs = FilterSegments(keeperIDs, true_FAlarmSegs)


        (xLims, yLims, frameLims) = DomainFromTracks(true_tracks + true_falarms)
    else :
        true_AssocSegs = None
        true_FAlarmSegs = None

        stackedTracks = []
        for aTracker in trackerData :
            stackedTracks += aTracker[0] + aTracker[1]
        (xLims, yLims, frameLims) = DomainFromTracks(stackedTracks)

    startFrame = args.startFrame
    endFrame = args.endFrame
    tail = args.tail

    if startFrame is None :
        startFrame = 0

    if endFrame is None :
        endFrame = frameLims[1]

    if tail is None :
        tail = endFrame - startFrame

    # A common timer for all animations for syncing purposes.
    theTimer = None

    if args.radarFile is not None and args.statLonLat is not None :
        if endFrame >= len(args.radarFile) :
            # Not enough radar files, so truncate the tracks.
            endFrame = len(args.radarFile) - 1
        files = args.radarFile[startFrame:(endFrame + 1)]
        radAnim = RadarAnim(theFig, files)
        theTimer = radAnim.event_source
        for ax in grid :
            radAnim.add_axes(ax, alpha=0.6, zorder=0)
    else :
        radAnim = None

    showMap = (args.statLonLat is not None and args.displayMap)

    if showMap :
        bmap = Basemap(projection='cyl', resolution='i',
                       suppress_ticks=False,
                       llcrnrlat=yLims[0], llcrnrlon=xLims[0],
                       urcrnrlat=yLims[1], urcrnrlon=xLims[1])


    animator = SegAnimator(theFig, startFrame, endFrame, tail,
                           event_source=theTimer, fade=args.fade)

    for index, (tracks, falarms) in enumerate(trackerData):
        curAxis = grid[index]

        if showMap :
            PlotMapLayers(bmap, mapLayers, curAxis, zorder=0.1)

        if true_AssocSegs is not None and true_FAlarmSegs is not None :
            trackAssocSegs = CreateSegments(tracks)
            trackFAlarmSegs = CreateSegments(falarms)

            if keeperIDs is not None :
                trackAssocSegs = FilterSegments(keeperIDs, trackAssocSegs)
                trackFAlarmSegs = FilterSegments(keeperIDs, trackFAlarmSegs)

            truthtable = CompareSegments(true_AssocSegs, true_FAlarmSegs,
                                         trackAssocSegs, trackFAlarmSegs)
            l, d = Animate_Segments(truthtable, (startFrame, endFrame),
                                    axis=curAxis)
        else :
            if keeperIDs is not None :
                filtFunc = lambda trk : FilterTrack(trk, cornerIDs=keeperIDs)
                tracks = map(filtFunc, tracks)
                falarms = map(filtFunc, falarms)
                CleanupTracks(tracks, falarms)

            l, d = Animate_PlainTracks(tracks, falarms,
                                       (startFrame, endFrame), axis=curAxis)

        animator._lines.extend(l)
        animator._lineData.extend(d)

        #curAxis.set_aspect("equal", 'datalim')
        #curAxis.set_aspect("equal")
        curAxis.set_title(trackTitles[index])
        if not showMap :
            curAxis.set_xlabel("X")
            curAxis.set_ylabel("Y")
        else :
            curAxis.set_xlabel("Longitude (degrees)")
            curAxis.set_ylabel("Latitude (degrees)")

    polyAnims = []
    for ax, verts in zip(grid, polyData):
        from matplotlib.animation import ArtistAnimation
        polyAnim = ArtistAnimation(theFig,
                        _to_polygons(polys[startFrame:endFrame + 1], ax),
                        event_source=theTimer)
        polyAnims.append(polyAnim)

    if args.xlims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_xlim(args.xlims)

    if args.ylims is not None and np.prod(grid.get_geometry()) > 0 :
        grid[0].set_ylim(args.ylims)

    if args.saveImgFile is not None :
        if radAnim is not None :
            radAnim = [radAnim]
        else:
            radAnim = []
        animator.save(args.saveImgFile, extra_anim=radAnim + polyAnims)

    if args.doShow :
        plt.show()
Пример #36
0
def main(args):
    if args.bw_mode:
        BW_mode()

    if len(args.trackFiles) == 0:
        print "WARNING: No trackFiles given!"

    if args.trackTitles is None:
        args.trackTitles = args.trackFiles
    else:
        if len(args.trackTitles) != len(args.trackFiles):
            raise ValueError("The number of TITLEs do not match the" " number of TRACKFILEs.")

    if args.statName is not None and args.statLonLat is None:
        statData = ByName(args.statName)[0]
        args.statLonLat = (statData["LON"], statData["LAT"])

    if args.layout is None:
        args.layout = (1, len(args.trackFiles))

    if args.figsize is None:
        args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1])

    trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for trackFile in args.trackFiles]

    if args.statLonLat is not None:
        for aTracker in trackerData:
            CoordinateTransform(aTracker[0] + aTracker[1], args.statLonLat[0], args.statLonLat[1])

    if args.simTagFiles is None:
        args.simTagFiles = [None]

    multiTags = [ReadSimTagFile(fname) if fname is not None else None for fname in args.simTagFiles]

    if len(trackerData) > len(multiTags):
        # Very rudimentary broadcasting of multiTags to match trackerData
        tagMult = max(int(len(trackerData) // len(multiTags)), 1)
        multiTags = multiTags * tagMult

    theFig = plt.figure(figsize=args.figsize)
    grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, aspect=False, share_all=True, axes_pad=0.45)

    showMap = args.statLonLat is not None and args.displayMap

    # Can only do this if all other data being displayed will be in
    # lon/lat coordinates
    if args.radarFile is not None and args.statLonLat is not None:
        if len(args.radarFile) > 1 and args.endFrame is not None:
            args.radarFile = args.radarFile[args.endFrame]
        else:
            args.radarFile = args.radarFile[-1]

        data = LoadRastRadar(args.radarFile)
        for ax in grid:
            MakeReflectPPI(
                data["vals"][0],
                data["lats"],
                data["lons"],
                meth="pcmesh",
                ax=ax,
                colorbar=False,
                axis_labels=False,
                zorder=0,
                alpha=0.6,
            )

    MakeTrackPlots(
        grid,
        trackerData,
        args.trackTitles,
        showMap,
        endFrame=args.endFrame,
        tail=args.tail,
        fade=args.fade,
        multiTags=multiTags,
        tag_filters=args.filters,
    )

    if args.xlims is not None and np.prod(grid.get_geometry()) > 0:
        grid[0].set_xlim(args.xlims)

    if args.ylims is not None and np.prod(grid.get_geometry()) > 0:
        grid[0].set_ylim(args.ylims)

    if args.saveImgFile is not None:
        theFig.savefig(args.saveImgFile)

    if args.doShow:
        plt.show()