Ejemplo n.º 1
0
def fgsm_attack(request):
    #标准化数据
    input = (
        (255 - np.array(eval(request.POST.get('inputs')), dtype=np.float32)) /
        255.0).reshape(1, 28, 28, 1)
    X_adv = make_fgsm(sess, env, input, epochs=12, eps=0.02).tolist()
    # print(X_adv[0])

    X_tmp = np.empty((10, 28, 28))
    X_tmp[0] = np.squeeze(X_adv[0])
    # print(X_tmp[0])

    fig = plt.figure(figsize=(1, 1.2))
    gs = gridspec.GridSpec(1, 1, wspace=0.05, hspace=0.05)
    ax = fig.add_subplot(gs[0, 0])  # 分别画子图
    ax.imshow(X_tmp[0], cmap='gray', interpolation='none')  # 展示预测错的对抗样本

    ax.set_xticks([])
    ax.set_yticks([])

    plugins.connect(fig, plugins.MousePosition(fontsize=14))
    # mpld3.show()

    os.makedirs('img', exist_ok=True)
    plt.savefig('img/fgsm_mnist.png')
    # return HttpResponse(json.dumps(X_adv))
    # print(mpld3.fig_to_html(fig))
    return HttpResponse(mpld3.fig_to_html(fig))
Ejemplo n.º 2
0
def mpld3_enable_notebook():
    """Change the default plugins, enable ipython notebook mode and return mpld3 module."""
    import mpld3
    from mpld3 import plugins as plugs
    plugs.DEFAULT_PLUGINS = [plugs.Reset(), plugs.Zoom(), plugs.BoxZoom(), plugs.MousePosition()]
    mpld3.enable_notebook()
    return mpld3
Ejemplo n.º 3
0
def plot(request, title):
    fig = plt.figure(figsize=(4,4))
    ax = plt.gca()
    if title == None:
        x = np.linspace(-2, 2, 20)
        y = x[:, None]
        X = np.zeros((20, 20, 4))

        X[:, :, 0] = np.exp(- (x - 1) ** 2 - (y) ** 2)
        X[:, :, 1] = np.exp(- (x + 0.71) ** 2 - (y - 0.71) ** 2)
        X[:, :, 2] = np.exp(- (x + 0.71) ** 2 - (y + 0.71) ** 2)
        X[:, :, 3] = np.exp(-0.25 * (x ** 2 + y ** 2))

        im = ax.imshow(X, extent=(10, 20, 10, 20),
                    origin='lower', zorder=1, interpolation='nearest')
        fig.colorbar(im, ax=ax)

        ax.set_title('An Image', size=20)

        plugins.connect(fig, plugins.MousePosition(fontsize=14))

        plot_string = mpld3.fig_to_html(fig, d3_url=None, mpld3_url=None, no_extras=False, template_type='general', figid=None, use_http=False)

        return HttpResponse(plot_string, status=200)
    else:
        plot_string = r.get('user_' + str(0) + '_plot_' + title)

        return HttpResponse(plot_string, status=200)
Ejemplo n.º 4
0
def draw_spectrum(spectrum, dummy=False):

    rmf = pf.open('rmf_62bands.fits')
    src_spectrum = spectrum[8].data
    E_min = rmf[3].data['E_min']
    E_max = rmf[3].data['E_max']

    msk = src_spectrum['RATE'] > 0.

    y = src_spectrum['RATE'] / (E_max - E_min)

    x = (E_max + E_min)
    dx = np.log10(np.e) * (E_max - E_min) / x
    x = np.log10(x)

    dy = src_spectrum['STAT_ERR'] / (E_max - E_min)
    dy = np.log10(np.e) * dy / y
    y = np.log10(y)

    fig, ax = plt.subplots(figsize=(4, 2.8))

    ax.set_xlabel('log(E)  keV')
    ax.set_ylabel('log(counts/s/keV)')

    ax.errorbar(x[msk],
                y[msk],
                yerr=dy[msk] * 0.5,
                xerr=dx[msk] * 0.5,
                fmt='o')
    #print (x,y,dy)
    plugins.connect(fig, plugins.MousePosition(fontsize=14))

    return mpld3.fig_to_dict(fig)
Ejemplo n.º 5
0
def graphic(request, fun, iterations, lowerbound, upperbound):
    
    import matplotlib.pyplot as plt, mpld3
    import matplotlib.figure as fg
    import matplotlib
    import mpld3
    from mpld3 import plugins, utils
    from matplotlib.patches import Polygon
    
    # Ver atributos de un objeto
    # from pprint import pprint
    # pprint(vars(fig))
    
    #limite inferior
    a=float(lowerbound)
    #limite superior
    b=float(upperbound)
    
    fig, ax = plt.subplots(figsize=(11,6.5))
    x = np.linspace(a, b, 1000)
    #fig.subplots_adjust(right=2,top=1.1)
    
    #paso de la funcion a graficar
    y=eval(fun)
    #iteraciones
    n=int(iterations)
    
    

    # Propiedades de la linea que es definida por la funcion
    ax.plot(x,y, lw=5, alpha=0.7)
    # Add grid to figure
    ax.grid(True, alpha=0.3)
    
    
    # Dibujar area bajo la curva
    ix = np.linspace(a, b)
    iy = ix
    verts = [(a, 0)] + list(zip(x, y)) + [(b, 0)]
    poly = Polygon(verts, facecolor='0.8', edgecolor='0.5')
    ax.add_patch(poly)
    
    #muestra la posicion del puntero en la grafica
    plugins.connect(fig, plugins.MousePosition(fontsize=14))


    # For transform it to HTML
    g = mpld3.fig_to_html(fig)
    return HttpResponse(g)
Ejemplo n.º 6
0
def create_plot():
    fig, ax = plt.subplots(2, 2, sharex='col', sharey='row')
    fig.subplots_adjust(hspace=0.3)

    for i in range(2):
        for j in range(2):
            txt = '({i}, {j})'.format(i=i, j=j)
            ax[i, j].set_title(txt, size=14)
            ax[i, j].text(0.5, 0.5, txt, size=40, ha='center')
            ax[i, j].grid(True, color='lightgray')
            ax[i, j].set_xlabel('xlabel')
            ax[i, j].set_ylabel('ylabel')

    plugins.connect(fig, plugins.MousePosition())
    return fig
Ejemplo n.º 7
0
    def produceHeroDistribution(self, hero_vector, hero_count_in_features):
        sum_vector = np.sum(hero_vector, axis=0)
        dire_side = sum_vector[:hero_count_in_features]
        radiant_side = sum_vector[hero_count_in_features:]

        hero_sum = np.add(dire_side, radiant_side)
        x = list(range(1, len(hero_sum) + 1))
        self.hero_distribution = hero_sum

        plt.bar(x, hero_sum, align="center")
        plt.xlabel("HeroID")
        plt.ylabel("Frequency")
        plt.title("Hero pick distribution")
        fig = plt.gcf()
        tooltip = plugins.MousePosition(fontsize=14)
        plugins.connect(fig, tooltip)
        self.chart1 = mpld3.fig_to_dict(fig)
Ejemplo n.º 8
0
def make_html_plot_for_file(filename, my_title):
    my_fontsize = 20

    # read in csv file
    if os.stat(filename).st_size:
        data = np.genfromtxt(filename, delimiter=',')
    else:
        # empty data file
        data = np.zeros([2, 2])

    if len(data.shape) == 1:
        return "<h2> " + my_title + " has no data </h2>"

    x = data[:, 0]
    fig, ax = plt.subplots(figsize=(10, 5))

    for k in range(1, data.shape[1]):
        y = data[:, k]

        # sort data
        x_sorted = [x1 for (x1, y1) in sorted(zip(x, y))]
        y_sorted = [y1 for (x1, y1) in sorted(zip(x, y))]

        lines = ax.plot(x_sorted,
                        y_sorted,
                        '.-',
                        markersize=15,
                        label='ion ' + str(k - 1))

    plt.legend(fontsize=my_fontsize, loc='best')
    plt.title(my_title, fontsize=my_fontsize)
    ax.grid()
    plt.xticks(fontsize=my_fontsize)
    plt.yticks(fontsize=my_fontsize)
    plt.xlim([np.min(x), np.max(x)])

    #90% of the time that this function takes is taken up after this line
    plugins.clear(fig)
    plugins.connect(fig, plugins.Reset(), plugins.BoxZoom(),
                    plugins.Zoom(enabled=True),
                    plugins.MousePosition(fontsize=my_fontsize))
    java_txt = mpld3.fig_to_html(fig)
    plt.close()

    return java_txt
def prepare_image_data(image_data):
	im = np.array(image_data, dtype=np.uint8)

	# Create figure and axes
	fig,ax = plt.subplots(1)
	
  # Adjust the 
	width = 15
	height = width * image_data.size[1] / image_data.size[0]
	fig.set_size_inches(width, int(height), forward=True)
	
	plugins.connect(fig, plugins.MousePosition(fontsize=14))
	mpld3.enable_notebook()
  
	# Display the image
	ax.imshow(im)

	return ax
Ejemplo n.º 10
0
def draw_fig(image_array, image_header, catalog=None, plot=False):
    from astropy import wcs
    from astropy.wcs import WCS

    from astropy import units as u
    import astropy.coordinates as coord

    fig, (ax) = plt.subplots(1,
                             1,
                             figsize=(4, 3),
                             subplot_kw={'projection': WCS(image_header)})
    im = ax.imshow(image_array,
                   origin='lower',
                   zorder=1,
                   interpolation='none',
                   aspect='equal')

    if catalog is not None:

        lon = coord.Angle(catalog['RA_FIN'] * u.deg)
        lat = coord.Angle(catalog['DEC_FIN'] * u.deg)

        w = wcs.WCS(image_header)
        pixcrd = w.wcs_world2pix(np.column_stack((lon, lat)), 1)

        ax.plot(pixcrd[:, 0], pixcrd[:, 1], 'o', mfc='none')
        for ID in xrange(catalog.size):
            ax.annotate('%s' % catalog[ID]['NAME'],
                        xy=(pixcrd[:, 0][ID], pixcrd[:, 1][ID]),
                        color='white')

        ax.set_xlabel('RA')
        ax.set_ylabel('DEC')

    fig.colorbar(im, ax=ax)
    if plot == True:
        plt.show()

    plugins.connect(fig, plugins.MousePosition(fontsize=14))

    return mpld3.fig_to_dict(fig)
Ejemplo n.º 11
0
def about():
    fig, ax = plt.subplots()

    x = np.linspace(-2, 2, 20)
    y = x[:, None]
    X = np.zeros((20, 20, 4))

    X[:, :, 0] = np.exp(-(x - 1)**2 - (y)**2)
    X[:, :, 1] = np.exp(-(x + 0.71)**2 - (y - 0.71)**2)
    X[:, :, 2] = np.exp(-(x + 0.71)**2 - (y + 0.71)**2)
    X[:, :, 3] = np.exp(-0.25 * (x**2 + y**2))

    im = ax.imshow(X,
                   extent=(10, 20, 10, 20),
                   origin='lower',
                   zorder=1,
                   interpolation='nearest')
    fig.colorbar(im, ax=ax)

    ax.set_title('An Image', size=20)
    plugins.connect(fig, plugins.MousePosition(fontsize=14))

    return render_template('about.html', plot=mpld3.fig_to_html(fig))
Ejemplo n.º 12
0
def draw_dummy(dummy=True):

    fig, ax = plt.subplots(figsize=(4, 3))

    if dummy == True:
        x = np.linspace(-2, 2, 200)
        y = x[:, None]
        image = np.zeros((200, 200, 4))

        image[:, :, 0] = np.exp(-(x - 1)**2 - (y)**2)
        image[:, :, 1] = np.exp(-(x + 0.71)**2 - (y - 0.71)**2)
        image[:, :, 2] = np.exp(-(x + 0.71)**2 - (y + 0.71)**2)
        image[:, :, 3] = np.exp(-0.25 * (x**2 + y**2))

    im = ax.imshow(image,
                   origin='lower',
                   zorder=1,
                   interpolation='none',
                   aspect='equal')
    fig.colorbar(im, ax=ax)

    plugins.connect(fig, plugins.MousePosition(fontsize=14))

    return mpld3.fig_to_dict(fig)
Ejemplo n.º 13
0
def plot_signal(sample_rate, samples, size):
    [width, height] = size
    N = len(samples)
    fig = plt.figure()
    fig.set_figwidth(width)
    fig.set_figheight(height)

    samples = signal.decimate(samples, 300, ftype='fir')
    T = N / sample_rate
    dT = T / len(samples)
    time = np.arange(0, len(samples), dtype=float)
    time = np.multiply(time, dT)
    #labels = ["Point {0}".format(i) for i in range(len(samples))]
    #tooltip = plugins.PointLabelTooltip(samples)

    plt.plot(time, samples)
    plt.title('Sygnał')
    plt.xlabel('Czas [s]')
    plt.ylabel('Ciśnienie [Pa]')
    plt.grid()
    plt.tight_layout()
    #plugins.connect(fig, plugins.PointHTMLTooltip(samples))
    plugins.connect(fig, plugins.MousePosition())
    return mpld3.fig_to_html(fig, template_type="general")
Ejemplo n.º 14
0
    def render(self, idf, filename='output.html'):
        progress_inst = helpers.Progress(idf.opt)
        self.progress = progress_inst.progress

        if idf.opt['outputFilename']:
            filename = idf.opt['outputFilename']

        if idf.opt['outputAs'] == 'html':
            # write matplotlib/d3 plots to html file
            import matplotlib
            import matplotlib.pyplot as plt, mpld3
            import matplotlib.axes

            from mpld3 import plugins
            from jinja2 import Environment, FileSystemLoader
        elif idf.opt['outputAs'] in ['pdf', 'interactive', 'tikz']:
            # show plots in separate matplotlib windows
            import matplotlib
            if idf.opt['outputAs'] == 'pdf':
                from matplotlib.backends.backend_pdf import PdfPages
                pp = PdfPages(filename)
            import matplotlib.pyplot as plt
            import matplotlib.axes
        else:
            print("No proper output method given. Not plotting.")
            return

        font_size = 10
        if idf.opt['outputAs'] in ['pdf', 'tikz']:
            if idf.opt['plotPerJoint']:
                font_size = 30
            else:
                font_size = 12
            matplotlib.rcParams.update({'font.size': font_size})
            matplotlib.rcParams.update({'axes.labelsize': font_size - 5})
            matplotlib.rcParams.update({'axes.linewidth': font_size / 15.})
            matplotlib.rcParams.update({'axes.titlesize': font_size - 2})
            matplotlib.rcParams.update({'legend.fontsize': font_size - 2})
            matplotlib.rcParams.update({'xtick.labelsize': font_size - 5})
            matplotlib.rcParams.update({'ytick.labelsize': font_size - 5})
            matplotlib.rcParams.update({'lines.linewidth': font_size / 15.})
            matplotlib.rcParams.update({'patch.linewidth': font_size / 15.})
            matplotlib.rcParams.update({'grid.linewidth': font_size / 20.})

        # skip some samples so graphs don't get too large/detailed TODO: change skip so that some
        # maximum number of points is plotted (determined by screen etc.)
        skip = 5

        #create figures and plots
        figures = list()
        for ds in self.progress(range(len(self.datasets))):
            group = self.datasets[ds]
            fig, axes = plt.subplots(len(group['dataset']),
                                     sharex=True,
                                     sharey=True)
            # scale unified scaling figures to same ranges and add some margin
            if group['unified_scaling']:
                ymin = 0
                ymax = 0
                for i in range(len(group['dataset'])):
                    ymin = np.min(
                        (np.min(group['dataset'][i]['data']), ymin)) * 1.05
                    ymax = np.max(
                        (np.max(group['dataset'][i]['data']), ymax)) * 1.05

            #plot each group of data
            for d_i in range(len(group['dataset'])):
                d = group['dataset'][d_i]
                if not issubclass(type(axes), matplotlib.axes.SubplotBase):
                    ax = axes[d_i]
                else:
                    ax = axes
                    axes = [axes]
                if idf.opt['outputAs'] != 'tikz':
                    ax.set_title(d['title'])
                if group['unified_scaling']:
                    ax.set_ylim([ymin, ymax])
                for data_i in range(0, len(d['data'])):
                    if len(d['data'][data_i].shape) > 1:
                        #data matrices
                        for i in range(0, d['data'][data_i].shape[1]):
                            l = group['labels'][i] if data_i == 0 else ''
                            if i < 6 and 'contains_base' in group and group[
                                    'contains_base']:
                                ls = 'dashed'
                            else:
                                ls = '-'
                            dashes = ()  # type: Tuple
                            if idf.opt['plotErrors']:
                                if idf.opt['plotPrioriTorques']:
                                    n = 3
                                else:
                                    n = 2
                                if i == n:
                                    ls = 'dashed'
                                    dashes = (3, 0.5)
                            ax.plot(d['time'][::skip],
                                    d['data'][data_i][::skip, i],
                                    label=l,
                                    color=colors[i],
                                    alpha=1 - (data_i / 2.0),
                                    linestyle=ls,
                                    dashes=dashes)
                    else:
                        #data vector
                        ax.plot(d['time'][::skip],
                                d['data'][data_i][::skip],
                                label=group['labels'][d_i],
                                color=colors[0],
                                alpha=1 - (data_i / 2.0))

                ax.grid(which='both', linestyle="dotted", alpha=0.8)
                if 'y_label' in group:
                    ax.set_ylabel(group['y_label'])

            if idf.opt['outputAs'] != 'tikz':
                ax.set_xlabel("Time (s)")

            plt.setp([a.get_xticklabels() for a in axes[:-1]], visible=False)
            #plt.setp([a.get_yticklabels() for a in axes], fontsize=8)

            if idf.opt['plotLegend']:
                handles, labels = ax.get_legend_handles_labels()
                if idf.opt['outputAs'] == 'html':
                    #TODO: show legend properly (see mpld3 bug #274)
                    #leg = fig.legend(handles, labels, loc='upper right', fancybox=True, fontsize=10, title='')
                    leg = axes[0].legend(handles,
                                         labels,
                                         loc='upper right',
                                         fancybox=True,
                                         fontsize=10,
                                         title='',
                                         prop={'size': 8})
                else:
                    leg = plt.figlegend(handles,
                                        labels,
                                        loc='upper right',
                                        fancybox=True,
                                        fontsize=font_size,
                                        title='',
                                        prop={'size': font_size - 3})
                    leg.draggable()

            fig.subplots_adjust(hspace=2)
            fig.set_tight_layout(True)

            if idf.opt['outputAs'] == 'html':
                plugins.clear(fig)
                plugins.connect(fig, plugins.Reset(), plugins.BoxZoom(),
                                plugins.Zoom(enabled=False),
                                plugins.MousePosition(fontsize=14, fmt=".5g"))
                figures.append(mpld3.fig_to_html(fig))
            elif idf.opt['outputAs'] == 'interactive':
                plt.show(block=False)
            elif idf.opt['outputAs'] == 'pdf':
                pp.savefig(plt.gcf())
            elif idf.opt['outputAs'] == 'tikz':
                from matplotlib2tikz import save as tikz_save
                tikz_save('{}_{}_{}.tex'.format(
                    filename, group['dataset'][0]['title'].replace('_', '-'),
                    ds // idf.model.num_dofs),
                          figureheight='\\figureheight',
                          figurewidth='\\figurewidth',
                          show_info=False)

        if idf.opt['outputAs'] == 'html':
            path = os.path.dirname(os.path.abspath(__file__))
            template_environment = Environment(
                autoescape=False,
                loader=FileSystemLoader(os.path.join(path, '../output')),
                trim_blocks=False)

            context = {'figures': figures, 'text': self.text}
            outfile = os.path.join(path, '..', 'output', filename)
            import codecs
            with codecs.open(outfile, 'w', 'utf-8') as f:
                html = template_environment.get_template(
                    "templates/index.html").render(context)
                f.write(html)

            print("Saved output at file://{}".format(outfile))
        elif idf.opt['outputAs'] == 'interactive':
            #keep non-blocking plot windows open
            plt.show()
        elif idf.opt['outputAs'] == 'pdf':
            pp.close()
Ejemplo n.º 15
0
    def makePlot(self,
                 noTightLayout=False,
                 figHeight='100%',
                 figWidth='100%',
                 centeredStyle=None):

        current_figure = plt.gcf()

        if self.usesMPLD3():
            plugins.connect(current_figure, plugins.MousePosition(fontsize=14))

        if self.outputType == PlotSaveTYPE.HTML_STRING:

            if centeredStyle == None:
                centeredStyle = self.centeredStyle

            outString = mpld3.fig_to_html(current_figure,
                                          template_type='notebook',
                                          d3_url=self.d3js,
                                          mpld3_url=self.mpld3js,
                                          figHeight=figHeight,
                                          figWidth=figWidth,
                                          styles=centeredStyle)
            self.createdPlots.append(outString)
            plt.close(current_figure)
            return

        if self.outputType == PlotSaveTYPE.D3:
            mpld3.show(current_figure)

        if self.save_to_file:

            exactFilename = self.save_file + ".%02d." + PlotSaveTYPE.getFileExtension(
                self.outputType)
            exactFilename = exactFilename % self.saved_plot
            self.saved_plot += 1

            if self.outputType == PlotSaveTYPE.HTML:
                mpld3.save_html(current_figure,
                                exactFilename,
                                template_type='simple',
                                d3_url=self.d3js,
                                mpld3_url=self.mpld3js)
            elif self.outputType == PlotSaveTYPE.JSON:
                mpld3.save_json(current_figure,
                                exactFilename,
                                d3_url=self.d3js,
                                mpld3_url=self.mpld3js)
            elif self.outputType == PlotSaveTYPE.PNG:
                plt.savefig(exactFilename,
                            transparent=self.transparent_bg,
                            bbox_inches='tight')

            self.createdPlots.append(exactFilename)

        else:  # if self.outputType == PlotSaveTYPE.MPL

            if not noTightLayout:

                legends = current_figure.legends

                makeTightLayout = True
                for lgd in legends + [
                        ax.legend_ for ax in current_figure.axes
                ]:

                    if lgd == None:
                        continue

                    lgd.set_draggable(True)

                    makeTightLayout = makeTightLayout and lgd._bbox_to_anchor == None

                if makeTightLayout:
                    plt.tight_layout()

            plt.show()

        plt.close(current_figure)
Ejemplo n.º 16
0
def Plot(filename,
         fixed_axis,
         axis_value,
         levels=30,
         amp_min=0,
         amp_max=2500,
         color="magma",
         save=False,
         file_prefix="",
         show=True,
         xmin=None,
         xmax=None,
         ymin=None,
         ymax=None,
         figsize=(6.4, 4.8)):
    # Check fixed_axis has a valid value
    if fixed_axis not in ["x", "y", "z", "X", "Y", "Z"]:
        print("Error: Fixed axis must be x, y or z.")
        exit(0)
    _x, _y, _z, _amplitude = np.loadtxt(os.path.dirname(__file__) + '/' +
                                        filename,
                                        delimiter=",").T
    # Count the number of measurements we have for the given fixed_axis value
    if fixed_axis in ["x", "X"]:
        n = np.count_nonzero(_x == axis_value)
    elif fixed_axis in ["y", "Y"]:
        n = np.count_nonzero(_y == axis_value)
    else:
        n = np.count_nonzero(_z == axis_value)

    # Check if we have any measurements
    if n <= 0:
        exit(0)

    # Take the other values for the given fixed axis value
    x_axis = []
    y_axis = []
    amplitude = []

    # Update our internal representation of the data
    for i in range(len(_amplitude)):
        if fixed_axis in ["x", "X"]:
            if _x[i] == axis_value:
                x_axis += [_z[i]]
                y_axis += [_y[i]]
                amplitude += [_amplitude[i]]

        elif fixed_axis in ["y", "Y"]:
            if _y[i] == axis_value:
                x_axis += [_x[i]]
                y_axis += [_z[i]]
                amplitude += [_amplitude[i]]

        else:
            if _z[i] == axis_value:
                x_axis += [_x[i]]
                y_axis += [_y[i]]
                amplitude += [_amplitude[i]]

    # Convert from Python arrays to Numpy arrays
    x = np.array(x_axis)
    y = np.array(y_axis)
    amplitude = np.array(amplitude)

    # Create a linear space for interpolating between x and z axes values
    xi = np.linspace(x.min() if xmin == None else xmin,
                     x.max() if xmax == None else xmax, 1000)
    yi = np.linspace(y.min() if ymin == None else ymin,
                     y.max() if ymax == None else ymax, 1000)

    # Interpolate the x and z axes to fill in the gaps between measurements
    amplitudei = griddata((x, y),
                          amplitude, (xi[None, :], yi[:, None]),
                          method='cubic')

    # Plot the interpolated data as a contour map
    if fixed_axis in ["x", "X"]:
        xlab = "z (m)"
        ylab = "y (m)"
    elif fixed_axis in ["y", "Y"]:
        xlab = "x (m)"
        ylab = "z (m)"
    else:
        xlab = "x (m)"
        ylab = "y (m)"

    title = filename.split("/")[-1] + "-" + fixed_axis + "=" + str(
        axis_value) + "m"
    # Create new plot figure
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)

    # Select colour
    if color == "magma":
        colormap = plt.cm.magma
    if color == "inferno":
        colormap = plt.cm.inferno
    if color == "plasma":
        colormap = plt.cm.plasma
    if color == "viridis":
        colormap = plt.cm.viridis

    #  Create a contour plot
    cs = ax.contourf(xi,
                     yi,
                     amplitudei,
                     levels=levels,
                     cmap=colormap,
                     vmin=amp_min,
                     vmax=amp_max)

    ax.invert_xaxis()
    ax.axis('scaled')
    ax.set_facecolor((0.0, 0.0, 0.0))
    ax.set_title(title, pad=20)
    ax.set_xlabel(xlab)
    ax.set_ylabel(ylab)

    #  Connect to MousePosition Plugins
    plugins.connect(fig, plugins.MousePosition(fontsize=12))

    #  Create colour bar scale for the colour map
    cbar = fig.colorbar(cs)
    cbar.ax.set_ylabel("Amplitude (Pa)")

    if save:
        # Save to local computer
        plt.savefig(file_prefix + title + ".png",
                    bbox_inches='tight',
                    dpi=300,
                    quality=100)
    if show:
        # Return JSON for visualisation
        g1 = json.dumps(fig_to_dict(fig), cls=NumpyEncoder)
        return g1
Ejemplo n.º 17
0
    def setup_design_mode(self, grid_on=True, tick_spacing=50):

        self.ax.grid(b=grid_on)
        self.ax.xaxis.set_major_locator(mticker.MultipleLocator(tick_spacing))
        self.ax.yaxis.set_major_locator(mticker.MultipleLocator(tick_spacing))
        mdplugins.connect(self.fig, mdplugins.MousePosition())
Ejemplo n.º 18
0
 def pickpoints(fig='', radius=4, color="white", x='x', y='y'):
     if not fig:
         fig = plt.gcf()
     plugins.connect(fig, Annotate(radius, color, x,
                                   y))  # color='htmlcolorname', radius=int
     plugins.connect(fig, plugins.MousePosition())
Ejemplo n.º 19
0
# limitations under the License.

import matplotlib.pyplot as plt
import mpld3
import numpy as np
from mpld3 import plugins

fig, ax = plt.subplots()

x = np.linspace(-2, 2, 20)
y = x[:, None]
X = np.zeros((20, 20, 4))

X[:, :, 0] = np.exp(-(x - 1)**2 - (y)**2)
X[:, :, 1] = np.exp(-(x + 0.71)**2 - (y - 0.71)**2)
X[:, :, 2] = np.exp(-(x + 0.71)**2 - (y + 0.71)**2)
X[:, :, 3] = np.exp(-0.25 * (x**2 + y**2))

im = ax.imshow(X,
               extent=(10, 20, 10, 20),
               origin='lower',
               zorder=1,
               interpolation='nearest')
fig.colorbar(im, ax=ax)

ax.set_title('An Image', size=20)

plugins.connect(fig, plugins.MousePosition(fontsize=14))

mpld3.show()
Ejemplo n.º 20
0
def plotWavefront(
    wf, title, slice_numbers=False, cuts=False, interactive=False, phase=False
):
    # draw wavefront with common functions
    # if phase = True, plot phase only

    print("Displaying wavefront...")

    if slice_numbers is None:
        slice_numbers = range(wf_intensity.shape[-1])

    if isinstance(slice_numbers, int):
        slice_numbers = [
            slice_numbers,
        ]

    [nx, ny, xmin, xmax, ymin, ymax] = get_mesh(wf)
    dx = (xmax - xmin) / (nx - 1)
    dy = (ymax - ymin) / (ny - 1)
    print("stepX, stepY [um]:", dx * 1e6, dy * 1e6, "\n")

    if phase == True:
        A = wf.get_phase(slice_number=0, polarization="horizontal")
        label = "Phase (rad.)"
    else:
        ii = wf.get_intensity(slice_number=0, polarization="horizontal")
        ii = ii * wf.params.photonEnergy / J2EV  # *1e3
        imax = numpy.max(ii)

        if wf.params.wEFieldUnit != "arbitrary":
            print(
                "Total power (integrated over full range): %g [GW]"
                % (ii.sum(axis=0).sum(axis=0) * dx * dy * 1e6 * 1e-9)
            )
            print(
                "Peak power calculated using FWHM:         %g [GW]"
                % (
                    imax
                    * 1e-9
                    * 1e6
                    * 2
                    * numpy.pi
                    * (calculate_fwhm_x(wf) / 2.35)
                    * (calculate_fwhm_y(wf) / 2.35)
                )
            )
            print("Max irradiance: %g [GW/mm^2]" % (imax * 1e-9))
            label = "Irradiance (W/$mm^2$)"
        else:
            ii = ii / imax
            label = "Irradiance (a.u.)"

        A = ii

    [x1, x2, y1, y2] = wf.get_limits()

    if interactive == True:
        import mpld3
        from mpld3 import plugins
        import matplotlib.pyplot as plt

        fig, ax = plt.subplots()
        im = ax.imshow(ii, extent=[x1 * 1e3, x2 * 1e3, y1 * 1e3, y2 * 1e3])
        fig.colorbar(im, ax=ax)

        fig = plt.figure()
        ax.set_title(title, size=20)
        plugins.connect(fig, plugins.MousePosition(fontsize=14))
        mpld3.show()
        # mpld3.display(fig)

    else:
        pylab.figure(figsize=(21, 6))
        pylab.imshow(A, extent=[x1 * 1e3, x2 * 1e3, y1 * 1e3, y2 * 1e3])
        pylab.set_cmap("hot")
        pylab.axis("tight")
        pylab.colorbar(orientation="horizontal")
        pylab.xlabel("x (mm)")
        pylab.ylabel("y (mm)")
        pylab.axes().set_aspect(0.5)

        pylab.title(title)
        pylab.show()