Exemple #1
0
def createPlot(df, boundaryDF):
    p = Figure(plot_height=900, plot_width=PLOT_WIDTH, title="", y_range=[], title_text_font_size=TITLE_FONT_SIZE)
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None

    quad = p.quad(
        top="top",
        bottom="bottom",
        left="left",
        right="right",
        source=blockSource,
        fill_color="grey",
        hover_fill_color="firebrick",
        fill_alpha=0.05,
        hover_alpha=0.3,
        line_color=None,
        hover_line_color="white",
    )
    p.multi_line(xs="xs", ys="ys", source=blockSource, color="black", line_width=2, line_alpha=0.4, line_dash="dotted")
    p.multi_line(xs="xs", ys="ys", source=source, color="color", line_width="width", line_alpha="line_alpha")

    p.add_tools(
        HoverTool(
            tooltips=[("chromosome", "@chromosome"), ("exon", "@exon"), ("start", "@start"), ("end", "@end")],
            renderers=[quad],
        )
    )
    return p
Exemple #2
0
def _add_labels(G: nx.Graph, plot: Figure):
    """Add labels from G to the plot.

    Args:
        G (nx.Graph): Networkx graph.
        plot (Figure): Plot to add the labels to.
    """
    text = [round(w, 2) for w in nx.get_edge_attributes(G, 'weight').values()]
    xs = np.array(list(nx.get_edge_attributes(G, 'xs').values()))
    ys = np.array(list(nx.get_edge_attributes(G, 'ys').values()))

    x_range = (plot.x_range.end - plot.x_range.start)
    y_range = (plot.y_range.end - plot.y_range.start)
    margin_shift_x = (x_range / (2 * PLOT_MARGIN + 1)) * PLOT_MARGIN
    margin_shift_y = (y_range / (2 * PLOT_MARGIN + 1)) * PLOT_MARGIN
    x_scale = plot.plot_width / x_range
    y_scale = plot.plot_height / y_range

    xs = (xs + margin_shift_x - plot.x_range.start) * x_scale
    ys = (ys + margin_shift_y - plot.y_range.start) * y_scale

    x = np.mean(xs, axis=1)
    y = np.mean(ys, axis=1)
    centers = np.vstack((x, y)).T
    v = np.vstack((xs[:, 1] - xs[:, 0], ys[:, 1] - ys[:, 0])).T
    n = np.divide(v.T, np.linalg.norm(v, axis=1)).T

    lbl_size = np.array([len(str(w)) for w in text]) * 3
    tmp = np.multiply(n, lbl_size[:, np.newaxis])
    blank_start = centers - (n * 3 + tmp)
    blank_end = centers + (n * 3 + tmp)
    blank_xs = np.vstack((blank_start[:, 0], blank_end[:, 0])).T
    blank_ys = np.vstack((blank_start[:, 1], blank_end[:, 1])).T

    x = (x / x_scale) - margin_shift_x + plot.x_range.start
    y = (y / y_scale) - margin_shift_y + plot.y_range.start
    blank_xs = (blank_xs / x_scale) - margin_shift_x + plot.x_range.start
    blank_ys = (blank_ys / y_scale) - margin_shift_y + plot.y_range.start

    plot.multi_line(xs=blank_xs.tolist(),
                    ys=blank_ys.tolist(),
                    line_color='white',
                    line_width=LINE_WIDTH + 1,
                    nonselection_line_alpha=1,
                    level=LABEL_BACKGROUND_LEVEL)

    labels_src = ColumnDataSource(data={'x': x, 'y': y, 'text': text})
    labels = LabelSet(x='x',
                      y='y',
                      text='text',
                      text_align='center',
                      text_baseline='middle',
                      text_font_size='13px',
                      text_color='black',
                      level=LABEL_LEVEL,
                      source=labels_src)
    plot.add_layout(labels)
Exemple #3
0
def render_dendrogram(dend: Dict["str", Any], plot_width: int,
                      plot_height: int) -> Figure:
    """
    Render a missing dendrogram.
    """
    # list of lists of dcoords and icoords from scipy.dendrogram
    xs, ys, cols = dend["icoord"], dend["dcoord"], dend["ivl"]

    # if the number of columns is greater than 20, make the plot wider
    if len(cols) > 20:
        plot_width = 28 * len(cols)

    fig = Figure(
        plot_width=plot_width,
        plot_height=plot_height,
        toolbar_location=None,
        tools="",
        title=" ",
    )

    # round the coordinates to integers, and plot the dendrogram
    xs = [[round(coord) for coord in coords] for coords in xs]
    ys = [[round(coord, 2) for coord in coords] for coords in ys]
    fig.multi_line(xs=xs, ys=ys, line_color="#8073ac")

    # extract the horizontal lines for the hover tooltip
    h_lns_x = [coords[1:3] for coords in xs]
    h_lns_y = [coords[1:3] for coords in ys]
    null_mismatch_vals = [coord[0] for coord in h_lns_y]
    source = ColumnDataSource(dict(x=h_lns_x, y=h_lns_y, n=null_mismatch_vals))
    h_lns = fig.multi_line(xs="x", ys="y", source=source, line_color="#8073ac")
    hover_pts = HoverTool(
        renderers=[h_lns],
        tooltips=[("Average distance", "@n{0.1f}")],
        line_policy="interp",
    )
    fig.add_tools(hover_pts)

    # shorten column labels if necessary, and override coordinates with column names
    cols = [f"{col[:16]}..." if len(col) > 18 else col for col in cols]
    axis_coords = list(range(5, 10 * len(cols) + 1, 10))
    axis_overrides = dict(zip(axis_coords, cols))
    fig.xaxis.ticker = axis_coords
    fig.xaxis.major_label_overrides = axis_overrides
    fig.xaxis.major_label_orientation = np.pi / 3
    fig.yaxis.axis_label = "Average Distance Between Clusters"
    fig.grid.visible = False
    fig.frame_width = plot_width
    return fig
Exemple #4
0
def gen_plot(ticker, features, end_date_year, end_date_month, end_date_day,
             days_track):
    quandl.ApiConfig.api_key = os.environ['QUANDL_KEY']

    stock_indices = {
        'Open': ['Opening Price', '#440154'],
        'Close': ['Closing Price', '#30678D'],
        'High': ['Daily High', '#35B778'],
        'Low': ['Daily Low', '#0A333D']
    }

    end_date_list = [end_date_year, end_date_month, end_date_day]
    end_date = '-'.join(end_date_list)
    start_date = (pd.to_datetime(end_date) -
                  pd.DateOffset(days=int(days_track))).strftime('%Y-%m-%d')

    mydata = pd.DataFrame(
        quandl.get('WIKI/' + ticker, start_date=start_date,
                   end_date=end_date)).reset_index()
    mydata['Date'] = pd.to_datetime(mydata.get('Date', None))
    mydata = mydata.get(['Date', *features], None)

    data = {
        'xs': [mydata.get('Date', None) for col in features],
        'ys': [mydata.get(col, None) for col in features],
        'labels': [stock_indices.get(col, None)[0] for col in features],
        'colors': [stock_indices.get(col, None)[1] for col in features]
    }

    source = ColumnDataSource(data)

    p = Figure(plot_width=800,
               plot_height=500,
               x_axis_type="datetime",
               x_axis_label='Initial',
               title='Initial')
    p.multi_line(xs='xs',
                 ys='ys',
                 legend='labels',
                 color='colors',
                 source=source)
    p.yaxis.axis_label = ticker
    p.xaxis.axis_label = 'Time'
    p.title.text = 'Different indices of ' + ticker + ': '\
                   + start_date + ' to ' + end_date

    return p
Exemple #5
0
    def make_plot(source, title):
        plot = Figure(plot_width=800,
                      plot_height=600,
                      tools="",
                      toolbar_location=None)
        plot.title.text = title
        colors = Blues4[0:3]

        plot.scatter(x=x, y=y, source=source)
        plot.multi_line('ci_x', 'ci', source=source)

        # fixed attributes
        plot.xaxis.axis_label = xlabel
        plot.yaxis.axis_label = ylabel
        plot.axis.major_label_text_font_size = "8pt"
        plot.axis.axis_label_text_font_size = "8pt"
        plot.axis.axis_label_text_font_style = "bold"

        return plot
Exemple #6
0
def _add_edges(
    G: nx.Graph,
    plot: Figure,
    show_labels: bool = True,
    hover_line_color: str = TERTIARY_DARK_COLOR
) -> Union[ColumnDataSource, GlyphRenderer]:
    """Add edges from G to the plot.

    Args:
        G (nx.Graph): Networkx graph.
        plot (figure): Plot to add the edges to.
        show_labels (bool): True iff each edge should be labeled.
        hover_line_color (str): Color of the edges when hovering over them.

    Returns:
        Union[ColumnDataSource, GlyphRenderer]: edge source and glyphs.
    """
    edges_df = pd.DataFrame([G[u][v] for u, v in G.edges()])
    u, v = zip(*[(u, v) for u, v in G.edges])
    edges_df['u'] = u
    edges_df['v'] = v
    edges_src = ColumnDataSource(data=edges_df.to_dict(orient='list'))

    edges_glyph = plot.multi_line(xs='xs',
                                  ys='ys',
                                  line_color='line_color',
                                  line_cap='round',
                                  hover_line_color=hover_line_color,
                                  line_width=LINE_WIDTH,
                                  nonselection_line_alpha=1,
                                  level=EDGE_LEVEL,
                                  source=edges_src)

    if show_labels:
        _add_labels(G, plot)

    return edges_src, edges_glyph
Exemple #7
0
def createPlot(height=600, width=1200):
    """
    Create and return a plot for visualizing transcripts.
    """
    TOOLS = "pan, wheel_zoom, save, reset, tap"
    p = Figure(title="",
               y_range=[],
               webgl=True,
               tools=TOOLS,
               toolbar_location="above",
               plot_height=height,
               plot_width=width)
    # This causes title to overlap plot substantially:
    #p.title.text_font_size = TITLE_FONT_SIZE
    p.xgrid.grid_line_color = None  # get rid of the grid in bokeh
    p.ygrid.grid_line_color = None
    # the block of exons, there's mouse hover effect on that
    quad = p.quad(top="top",
                  bottom="bottom",
                  left="left",
                  right="right",
                  source=blockSource,
                  fill_alpha=0,
                  line_dash="dotted",
                  line_alpha=0.4,
                  line_color='black',
                  hover_fill_color="red",
                  hover_alpha=0.3,
                  hover_line_color="white",
                  nonselection_fill_alpha=0,
                  nonselection_line_alpha=0.4,
                  nonselection_line_color='black')
    # the block of each vertical transcript, each one can be selected
    p.quad(top="top",
           bottom="bottom",
           right="right",
           left="left",
           source=tranSource,
           fill_alpha=0,
           line_alpha=0,
           nonselection_fill_alpha=0,
           nonselection_line_alpha=0)
    # what exons really is
    # Cannot use line_width="height" because it is broken.
    p.multi_line(xs="xs",
                 ys="ys",
                 line_width=opt.height,
                 color="color",
                 line_alpha="line_alpha",
                 source=source)
    # the start/stop codon
    p.inverted_triangle(x="x",
                        y="y",
                        color="color",
                        source=codonSource,
                        size='size',
                        alpha=0.5)
    # mouse hover on the block
    p.add_tools(
        HoverTool(tooltips=[("chromosome", "@chromosome"), ("exon", "@exon"),
                            ("start", "@start"), ("end", "@end")],
                  renderers=[quad]))
    return p
Exemple #8
0
def plotHistogram(fileName,
                  initData,
                  stations,
                  dateRange,
                  bokehPlaceholderId='bokehContent'):
    data = {
        'xs': [initData['bins']],
        'ys': [initData['values']],
        'ss': [1, 2],
        'es': [3, 4]
    }  #ss and es are for test purposes we'll add  other values of the controlles e.g. age, usertype, Gender coming fetshed from initdata

    source = ColumnDataSource(data=data)
    stations.insert(0, "All")
    selectSS = Select(title="Start Station:", value="All", options=stations)
    selectES = Select(title="End Station:", value="All", options=stations)

    selectUT = Select(title="User Type:",
                      value="All",
                      options=["All", "Subscriber", "Customer"])
    selectGender = Select(title="Gender:",
                          value="All",
                          options=["All", "Male", "Female"])
    sliderAge = Slider(start=8, end=100, value=30, step=5, title="Age")

    startDP = DatePicker(title="Start Date:",
                         min_date=dateRange[0],
                         max_date=dateRange[1],
                         value=dateRange[0])
    endDP = DatePicker(title="End Date:",
                       min_date=dateRange[0],
                       max_date=dateRange[1],
                       value=dateRange[1])
    binSize = TextInput(value="15", title="Bin Size (Days):")
    AddButton = Toggle(label="Add", type="success")
    DeleteButton = Toggle(label="delete", type="success")

    columns = [
        TableColumn(field="ss", title="Start Station"),
        TableColumn(field="es", title="End Station")
    ]  # add other columns contains values of other controllers
    data_table = DataTable(source=source,
                           columns=columns,
                           width=650,
                           height=300)

    model = dict(source=source,
                 selectSS=selectSS,
                 selectES=selectES,
                 startDP=startDP,
                 endDP=endDP,
                 binSize=binSize,
                 selectUT=selectUT,
                 selectGender=selectGender,
                 sliderAge=sliderAge)
    plot = Figure(plot_width=650, plot_height=400, x_axis_type="datetime")
    plot.multi_line('xs',
                    'ys',
                    source=source,
                    line_width='width',
                    line_alpha=0.6,
                    line_color='color')

    callback = CustomJS(args=model,
                        code="""
            //alert("callback");
            var startStation = selectSS.get('value');
            var endStation = selectES.get('value');
            var startDate = startDP.get('value');
            
            if ( typeof(startDate) !== "number")
                startDate = startDate.getTime();
                
            var endDate = endDP.get('value');
            
            if ( typeof(endDate) !== "number")
                endDate = endDate.getTime();            
            
            var binSize = binSize.get('value');
            //alert(startStation + " " + endStation + " " + startDate + " " + endDate + " " + binSize);
            var xmlhttp;
            xmlhttp = new XMLHttpRequest();
            
            xmlhttp.onreadystatechange = function() {
                if (xmlhttp.readyState == XMLHttpRequest.DONE ) {
                    if(xmlhttp.status == 200){
                        var data = source.get('data');
                        var result = JSON.parse(xmlhttp.responseText);
                        var temp=[];
                        
                        for(var date in result.x) {
                            temp.push(new Date(result.x[date]));
                        }
                        
                        data['xs'].push(temp);
                        data['ys'].push(result.y);
                        source.trigger('change');
                    }
                    else if(xmlhttp.status == 400) {
                        alert(400);
                    }
                    else {
                        alert(xmlhttp.status);
                    }
                }
            };
        var params = {ss:startStation, es:endStation, sd:startDate, ed:endDate, bs: binSize};
        url = "/histogram?" + jQuery.param( params );
        xmlhttp.open("GET", url, true);
        xmlhttp.send();
        """)

    AddButton.callback = callback
    #DeleteButton.on_click(callback1)
    layout1 = vform(startDP, endDP, binSize)
    layout2 = vform(plot, DeleteButton, data_table)
    layout3 = vform(selectSS, selectES, selectUT, selectGender, sliderAge,
                    AddButton)
    layout = hplot(layout1, layout2, layout3)
    script, div = components(layout)
    html = readHtmlFile(fileName)
    html = insertScriptIntoHeader(html, script)
    html = appendElementContent(html, div, "div", "bokehContent")

    return html
Exemple #9
0
def jwst_1d_spec(result_dict, model=True, title='Model + Data + Error Bars', output_file = 'data.html',legend = False, 
        R=False,  num_tran = False, plot_width=800, plot_height=400,x_range=[1,10]):
    """Plots 1d simulated spectrum and rebin or rescale for more transits
    
    Plots 1d data points with model in the background (if wanted). Designed to read in exact 
    output of run_pandexo. 
    
    Parameters 
    ----------
    result_dict : dict or list of dict
        Dictionary from pandexo output. If parameter space was run in run_pandexo 
        make sure to restructure the input as a list of dictionaries without they key words 
        that run_pandexo assigns. 
    model : bool 
        (Optional) True is default. True plots model, False does not plot model     
    title : str
        (Optional) Title of plot. Default is "Model + Data + Error Bars".  
    output_file : str 
        (Optional) name of html file for you bokeh plot. After bokeh plot is rendered you will 
        have the option to save as png. 
    legend : bool 
        (Optional) Default is False. True, plots legend. 
    R : float 
        (Optional) Rebin data from native instrument resolution to specified resolution. Dafult is False, 
        no binning. 
    num_tran : float
        (Optional) Scales data by number of transits to improve error by sqrt(`num_trans`)
    plot_width : int 
        (Optional) Sets the width of the plot. Default = 800
    plot_height : int 
        (Optional) Sets the height of the plot. Default = 400 
    x_range : list of int
        (Optional) Sets x range of plot. Default = [1,10]

    Returns
    -------
    x,y,e : list of arrays 
        Returns wave axis, spectrum and associated error in list format. x[0] will be correspond 
        to the first dictionary input, x[1] to the second, etc. 
        
    Examples
    --------
    
    >>> jwst_1d_spec(result_dict, num_tran = 3, R = 35) #for a single plot 
    
    If you wanted to save each of the axis that were being plotted: 
    
    >>> x,y,e = jwst_1d_data([result_dict1, result_dict2], model=False, num_tran = 5, R = 100) #for multiple 
    
    See Also
    --------
    jwst_noise, jwst_1d_bkg, jwst_1d_flux, jwst_1d_snr, jwst_2d_det, jwst_2d_sat

    """
    outx=[]
    outy=[]
    oute=[]
    TOOLS = "pan,wheel_zoom,box_zoom,resize,reset,save"
    outputfile(output_file)
    colors = ['black','blue','red','orange','yellow','purple','pink','cyan','grey','brown']
    #make sure its iterable
    if type(result_dict) != list: 
        result_dict = [result_dict]
        
    if type(legend)!=bool:
        legend_keys = legend
        legend = True
        if type(legend_keys) != list:
            legend_keys = [legend_keys]
      
    i = 0     
    for dict in result_dict: 
        ntran_old = dict['timing']['Number of Transits']
        to = dict['timing']["Num Integrations Out of Transit"]
        ti = dict['timing']["Num Integrations In Transit"]
        #remove any nans 
        y = dict['FinalSpectrum']['spectrum_w_rand']
        x = dict['FinalSpectrum']['wave'][~np.isnan(y)]
        err = dict['FinalSpectrum']['error_w_floor'][~np.isnan(y)]
        y = y[~np.isnan(y)]

        
        if (R == False) & (num_tran == False): 
            x=x 
            y=y 
        elif (R != False) & (num_tran != False):     
            new_wave = bin_wave_to_R(x, R)
            out = uniform_tophat_sum(new_wave,x, dict['RawData']['electrons_out']*num_tran/ntran_old)
            inn = uniform_tophat_sum(new_wave,x, dict['RawData']['electrons_in']*num_tran/ntran_old)
            vout = uniform_tophat_sum(new_wave,x, dict['RawData']['var_out']*num_tran/ntran_old)
            vin = uniform_tophat_sum(new_wave,x, dict['RawData']['var_in']*num_tran/ntran_old)
            var_tot = (to/ti/out)**2.0 * vin + (inn*to/ti/out**2.0)**2.0 * vout
            if dict['input']['Primary/Secondary']=='fp/f*':
                fac = -1.0
            else:
                fac = 1.0
            rand_noise = np.sqrt((var_tot))*(np.random.randn(len(new_wave)))
            raw_spec = (out/to-inn/ti)/(out/to)       
            sim_spec = fac*(raw_spec + rand_noise )
            x = new_wave
            y = sim_spec
            err = np.sqrt(var_tot)
        elif (R == False) & (num_tran != False):     
            out = dict['RawData']['electrons_out']*num_tran/ntran_old
            inn = dict['RawData']['electrons_in']*num_tran/ntran_old
            vout = dict['RawData']['var_out']*num_tran/ntran_old
            vin = dict['RawData']['var_in']*num_tran/ntran_old
            var_tot = (to/ti/out)**2.0 * vin + (inn*to/ti/out**2.0)**2.0 * vout
            if dict['input']['Primary/Secondary']=='fp/f*':
                fac = -1.0
            else:
                fac = 1.0
            rand_noise = np.sqrt((var_tot))*(np.random.randn(len(x)))
            raw_spec = (out/to-inn/ti)/(out/to)       
            sim_spec = fac*(raw_spec + rand_noise ) 
            x = x
            y = sim_spec
            err = np.sqrt(var_tot)
        elif (R != False) & (num_tran == False):     
            new_wave = bin_wave_to_R(x, R)
            out = uniform_tophat_sum(new_wave,x, dict['RawData']['electrons_out'])
            inn = uniform_tophat_sum(new_wave,x, dict['RawData']['electrons_in'])
            vout = uniform_tophat_sum(new_wave,x, dict['RawData']['var_out'])
            vin = uniform_tophat_sum(new_wave,x, dict['RawData']['var_in'])
            var_tot = (to/ti/out)**2.0 * vin + (inn*to/ti/out**2.0)**2.0 * vout
            if dict['input']['Primary/Secondary']=='fp/f*':
                fac = -1.0
            else:
                fac = 1.0
            rand_noise = np.sqrt((var_tot))*(np.random.randn(len(new_wave)))
            raw_spec = (out/to-inn/ti)/(out/to)       
            sim_spec = fac*(raw_spec + rand_noise ) 
            x = new_wave
            y = sim_spec
            err = np.sqrt(var_tot)
        else: 
            print("Something went wrong. Cannot enter both resolution and ask to bin to new wave")
            return
            
        #create error bars for Bokeh's multi_line
        y_err = []
        x_err = []
        for px, py, yerr in zip(x, y, err):
            np.array(x_err.append((px, px)))
            np.array(y_err.append((py - yerr, py + yerr)))
        #initialize Figure
        if i == 0: 
            #Define units for x and y axis
            y_axis_label = dict['input']['Primary/Secondary']

            if y_axis_label == 'fp/f*': p = -1.0
            else: y_axis_label = '('+y_axis_label+')^2'

            if dict['input']['Calculation Type'] =='phase_spec':
                x_axis_label='Time (secs)'
                x_range = [min(x), max(x)]
            else:
                x_axis_label='Wavelength [microns]'
            
            ylims = [min(dict['OriginalInput']['model_spec'])- 0.1*min(dict['OriginalInput']['model_spec']),
                 0.1*max(dict['OriginalInput']['model_spec'])+max(dict['OriginalInput']['model_spec'])]
            xlims = [min(x), max(x)]
         
            fig1d = Figure(x_range=x_range, y_range = ylims, 
               plot_width = plot_width, plot_height =plot_height,title=title,x_axis_label=x_axis_label,
              y_axis_label = y_axis_label, tools=TOOLS, background_fill_color = 'white')
        
              
        #plot model, data, and errors 
        if model:
            mxx = dict['OriginalInput']['model_wave']
            myy = dict['OriginalInput']['model_spec']
            
            my = uniform_tophat_mean(x, mxx,myy)
            fig1d.line(x,my, color='black',alpha=0.2, line_width = 4)
        if legend: 
            fig1d.circle(x, y, color=colors[i], legend = legend_keys[i])
        else: 
            fig1d.circle(x, y, color=colors[i])
        outx += [x]
        outy += [y]
        oute += [err]
        fig1d.multi_line(x_err, y_err,color=colors[i])
        i += 1 
    show(fig1d)
    return outx,outy,oute
Exemple #10
0
             legend='(x0,y0)')
# Plot streamline
plot.line('x',
          'y',
          source=source_streamline,
          color='black',
          legend='streamline')
# Plot critical points and lines
plot.scatter('x',
             'y',
             source=source_critical_pts,
             color='red',
             legend='critical pts')
plot.multi_line('x_ls',
                'y_ls',
                source=source_critical_lines,
                color='red',
                legend='critical lines')

# initialize controls
# text input for input of the ode system [u,v] = [x',y']
u_input = TextInput(value=odesystem_settings.sample_system_functions[
    odesystem_settings.init_fun_key][0],
                    title="u(x,y):")
v_input = TextInput(value=odesystem_settings.sample_system_functions[
    odesystem_settings.init_fun_key][1],
                    title="v(x,y):")

# dropdown menu for selecting one of the sample functions
sample_fun_input = Dropdown(
    label="choose a sample function pair or enter one below",
Exemple #11
0
def hst_spec(result_dict, plot=True, output_file='hstspec.html', model=True):
    """Plot 1d spec with error bars for hst

    Parameters
    ----------
    result_dict : dict
        Dictionary from pandexo output.

    plot : bool
        (Optional) True renders plot, False does not. Default=True
    model : bool
        (Optional) Plot model under data. Default=True
    output_file : str
        (Optional) Default = 'hstspec.html'

    Return
    ------
    x : numpy array
        micron
    y : numpy array
        1D spec fp/f* or rp^2/r*^2
    e : numpy array
        1D rms noise
    modelx : numpy array
        micron
    modely : numpy array
        1D spec fp/f* or rp^2/r*^2
    See Also
    --------
    hst_time
    """
    TOOLS = "pan,wheel_zoom,box_zoom,reset,save"
    #plot planet spectrum
    mwave = result_dict['planet_spec']['model_wave']
    mspec = result_dict['planet_spec']['model_spec']

    binwave = result_dict['planet_spec']['binwave']
    binspec = result_dict['planet_spec']['binspec']

    error = result_dict['planet_spec']['error']
    error = np.zeros(len(binspec)) + error
    xlims = [
        result_dict['planet_spec']['wmin'], result_dict['planet_spec']['wmax']
    ]
    ylims = [
        np.min(binspec) - 2.0 * error[0],
        np.max(binspec) + 2.0 * error[0]
    ]

    plot_spectrum = Figure(
        plot_width=800,
        plot_height=300,
        x_range=xlims,
        y_range=ylims,
        tools=TOOLS,  #responsive=True,
        x_axis_label='Wavelength [microns]',
        y_axis_label='Ratio',
        title="Original Model with Observation")

    y_err = []
    x_err = []
    for px, py, yerr in zip(binwave, binspec, error):
        np.array(x_err.append((px, px)))
        np.array(y_err.append((py - yerr, py + yerr)))
    if model:
        plot_spectrum.line(mwave,
                           mspec,
                           color="black",
                           alpha=0.5,
                           line_width=4)
    plot_spectrum.circle(binwave, binspec, line_width=3, line_alpha=0.6)
    plot_spectrum.multi_line(x_err, y_err)

    if plot:
        outputfile(output_file)
        show(plot_spectrum)

    return binwave, binspec, error, mwave, mspec
Exemple #12
0
def hst_simulated_lightcurve(result_dict,
                             plot=True,
                             output_file='hsttime.html',
                             model=True):
    """Plot simulated HST light curves (in fluece) for earliest and latest start times

    Parameters
    ----------
    result_dict : dict
        Dictionary from pandexo output.

    plot : bool
        (Optional) True renders plot, False does not. Default=True
    model : bool
        (Optional) Plot model under data. Default=True
    output_file : str
        (Optional) Default = 'hsttime.html'

    Return
    ------
    obsphase1 : numpy array
        earliest start time
    counts1 : numpy array
        white light curve in fluence (e/pixel)
    obsphase2 : numpy array
        latest start time
    counts2 : numpy array
        white light curve in fluence (e/pixel)
    rms : numpy array
        1D rms noise

    See Also
    --------
    hst_spec
    """
    TOOLS = "pan,wheel_zoom,box_zoom,reset,save"
    # earliest and latest start times
    obsphase1 = result_dict['light_curve']['obsphase1']
    rms = result_dict['light_curve']['light_curve_rms']
    obsphase2 = result_dict['light_curve']['obsphase2']
    phase1 = result_dict['light_curve']['phase1']
    phase2 = result_dict['light_curve']['phase2']
    counts1 = result_dict['light_curve']['counts1']
    counts2 = result_dict['light_curve']['counts2']
    count_noise = result_dict['light_curve']['count_noise']
    ramp_included = result_dict['light_curve']['ramp_included']
    model_counts1 = result_dict['light_curve']['model_counts1']
    model_counts2 = result_dict['light_curve']['model_counts2']

    if isinstance(count_noise, float):
        rms = np.zeros(len(counts1)) + count_noise
    y_err1 = []
    x_err1 = []
    for px, py, yerr in zip(obsphase1, counts1, rms):
        np.array(x_err1.append((px, px)))
        np.array(y_err1.append((py - yerr, py + yerr)))

    y_err2 = []
    x_err2 = []
    for px, py, yerr in zip(obsphase2, counts2, rms):
        np.array(x_err2.append((px, px)))
        np.array(y_err2.append((py - yerr, py + yerr)))

    if ramp_included:
        title_description = " (Ramp Included)"
    else:
        title_description = " (Ramp Removed)"

    early = Figure(
        plot_width=400,
        plot_height=300,
        tools=TOOLS,  #responsive=True,
        x_axis_label='Orbital Phase',
        y_axis_label='Flux [electrons/pixel]',
        title="Earliest Start Time" + title_description)

    if model:
        early.line(phase1,
                   model_counts1,
                   color='black',
                   alpha=0.5,
                   line_width=4)
    early.circle(obsphase1, counts1, line_width=3, line_alpha=0.6)
    early.multi_line(x_err1, y_err1)

    late = Figure(
        plot_width=400,
        plot_height=300,
        tools=TOOLS,  # responsive=True,
        x_axis_label='Orbital Phase',
        y_axis_label='Flux [electrons/pixel]',
        title="Latest Start Time" + title_description)
    if model:
        late.line(phase2,
                  model_counts2,
                  color='black',
                  alpha=0.5,
                  line_width=3)
    late.circle(obsphase2, counts2, line_width=3, line_alpha=0.6)
    late.multi_line(x_err2, y_err2)

    start_time = row(early, late)

    if plot:
        outputfile(output_file)
        show(start_time)

    return obsphase1, counts1, obsphase2, counts2, rms
Exemple #13
0
def plotHistogram(fileName, initData, stations, dateRange, bokehPlaceholderId='bokehContent'):
    data = {'xs':[initData['bins']], 'ys':[initData['values']],'ss':[1,2], 'es':[3,4] }#ss and es are for test purposes we'll add  other values of the controlles e.g. age, usertype, Gender coming fetshed from initdata 

    source = ColumnDataSource(data=data)
    stations.insert(0, "All")
    selectSS = Select(title="Start Station:", value="All", options=stations)
    selectES = Select(title="End Station:", value="All", options=stations)
    
    selectUT = Select(title="User Type:", value="All", options=["All", "Subscriber", "Customer"])
    selectGender = Select(title="Gender:", value="All", options=["All", "Male", "Female"])
    sliderAge = Slider(start=8, end=100, value=30, step=5, title="Age")    
    
    startDP = DatePicker(title="Start Date:", min_date=dateRange[0] ,max_date=dateRange[1], value=dateRange[0])
    endDP = DatePicker(title="End Date:", min_date=dateRange[0] ,max_date=dateRange[1], value=dateRange[1])
    binSize = TextInput(value="15", title="Bin Size (Days):")
    AddButton = Toggle(label="Add", type="success")
    DeleteButton = Toggle(label="delete", type="success")
    
    
    columns = [TableColumn(field="ss", title="Start Station"),TableColumn(field="es", title="End Station")]# add other columns contains values of other controllers
    data_table = DataTable(source=source, columns=columns, width=650, height=300)
    
    model = dict(source=source, selectSS = selectSS, selectES = selectES, startDP = startDP, endDP = endDP, binSize = binSize,selectUT=selectUT,selectGender=selectGender,sliderAge=sliderAge)
    plot = Figure(plot_width=650, plot_height=400, x_axis_type="datetime")
    plot.multi_line('xs', 'ys', source=source, line_width='width', line_alpha=0.6, line_color='color')
    
    callback = CustomJS(args=model, code="""
            //alert("callback");
            var startStation = selectSS.get('value');
            var endStation = selectES.get('value');
            var startDate = startDP.get('value');
            
            if ( typeof(startDate) !== "number")
                startDate = startDate.getTime();
                
            var endDate = endDP.get('value');
            
            if ( typeof(endDate) !== "number")
                endDate = endDate.getTime();            
            
            var binSize = binSize.get('value');
            //alert(startStation + " " + endStation + " " + startDate + " " + endDate + " " + binSize);
            var xmlhttp;
            xmlhttp = new XMLHttpRequest();
            
            xmlhttp.onreadystatechange = function() {
                if (xmlhttp.readyState == XMLHttpRequest.DONE ) {
                    if(xmlhttp.status == 200){
                        var data = source.get('data');
                        var result = JSON.parse(xmlhttp.responseText);
                        var temp=[];
                        
                        for(var date in result.x) {
                            temp.push(new Date(result.x[date]));
                        }
                        
                        data['xs'].push(temp);
                        data['ys'].push(result.y);
                        source.trigger('change');
                    }
                    else if(xmlhttp.status == 400) {
                        alert(400);
                    }
                    else {
                        alert(xmlhttp.status);
                    }
                }
            };
        var params = {ss:startStation, es:endStation, sd:startDate, ed:endDate, bs: binSize};
        url = "/histogram?" + jQuery.param( params );
        xmlhttp.open("GET", url, true);
        xmlhttp.send();
        """)
        
    
    AddButton.callback = callback
    #DeleteButton.on_click(callback1)
    layout1 = vform (startDP,endDP,binSize)
    layout2 = vform(plot,DeleteButton,data_table)
    layout3 = vform(selectSS, selectES,selectUT,selectGender,sliderAge,AddButton)
    layout = hplot(layout1,layout2,layout3)
    script, div = components(layout)
    html = readHtmlFile(fileName)
    html = insertScriptIntoHeader(html, script)
    html = appendElementContent(html, div, "div", "bokehContent")

    return html  
def plot_dispersion_bokeh(filename, period_array, curve_data_array, boundary_data, style_parameter):
    '''
    Plot dispersion maps and curves using bokeh
    
    Input:
        filename is the filename of the resulting html file
        period_array is a list of period
        curve_data_array is a list of dispersion curves
        boundary_data is a list of boundaries
        style_parameter contains plotting parameters 
    
    Output:
        None
        
    '''
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    # ==============================
    # prepare data
    map_data_all_slices_velocity = []
    map_data_all_slices_period = []
    map_data_all_slices_color = []
    colorbar_data_all_left = []
    colorbar_data_all_right = []
    nperiod = len(period_array)
    ncurve = len(curve_data_array)
    ncolor = len(palette)
    palette_r = palette[::-1]
    colorbar_top = [0.1 for i in range(ncolor)]
    colorbar_bottom = [0 for i in range(ncolor)]
    for iperiod in range(nperiod):
        one_slice_lat_list = []
        one_slice_lon_list = []
        one_slice_vel_list = []
        
        map_period = period_array[iperiod]
        for icurve in range(ncurve):
            acurve = curve_data_array[icurve]
            curve_lat = acurve['latitude']
            curve_lon = acurve['longitude']
            curve_vel = acurve['velocity']
            curve_period = acurve['period']
            one_slice_lat_list.append(curve_lat)
            one_slice_lon_list.append(curve_lon)
            if map_period in curve_period:
                curve_period_index = curve_period.index(map_period)
                one_slice_vel_list.append(curve_vel[curve_period_index])
            else:
                one_slice_vel_list.append(style_parameter['nan_value'])
        # get color for dispersion values
        one_slice_vel_mean = np.nanmean(one_slice_vel_list)
        one_slice_vel_std = np.nanstd(one_slice_vel_list)
        
        color_min = one_slice_vel_mean - one_slice_vel_std * style_parameter['spread_factor']
        color_max = one_slice_vel_mean + one_slice_vel_std * style_parameter['spread_factor']
        color_step = (color_max - color_min)*1./ncolor
        one_slice_color_list = get_color_list(one_slice_vel_list,color_min,color_max,palette_r,\
                                             style_parameter['nan_value'],style_parameter['nan_color'])
        colorbar_left = np.linspace(color_min,color_max-color_step,ncolor)
        colorbar_right = np.linspace(color_min+color_step,color_max,ncolor)
        if one_slice_lat_list:
            map_data_all_slices_velocity.append(one_slice_vel_list)
            map_data_all_slices_period.append('Period: {0:6.1f} s'.format(map_period))
            map_data_all_slices_color.append(one_slice_color_list)
            colorbar_data_all_left.append(colorbar_left)
            colorbar_data_all_right.append(colorbar_right)
    # get location for all points
    map_lat_list, map_lon_list = [], []
    map_lat_label_list, map_lon_label_list = [], []
    for i in range(ncurve):
        acurve = curve_data_array[i]
        map_lat_list.append(acurve['latitude'])
        map_lon_list.append(acurve['longitude'])
        map_lat_label_list.append('Lat: {0:12.3f}'.format(acurve['latitude']))
        map_lon_label_list.append('Lon: {0:12.3f}'.format(acurve['longitude']))
    # data for the map view plot
    map_view_label_lon = style_parameter['map_view_period_label_lon']
    map_view_label_lat = style_parameter['map_view_period_label_lat']

    map_data_one_slice = map_data_all_slices_color[style_parameter['map_view_default_index']]
    map_data_one_slice_period = map_data_all_slices_period[style_parameter['map_view_default_index']]
    map_data_one_slice_bokeh = ColumnDataSource(data=dict(map_lat_list=map_lat_list,\
                                                          map_lon_list=map_lon_list,\
                                                          map_data_one_slice=map_data_one_slice))
    map_data_one_slice_period_bokeh = ColumnDataSource(data=dict(lat=[map_view_label_lat], lon=[map_view_label_lon],
                                                       map_period=[map_data_one_slice_period]))
    map_data_all_slices_bokeh = ColumnDataSource(data=dict(map_data_all_slices_color=map_data_all_slices_color,\
                                                          map_data_all_slices_period=map_data_all_slices_period))

    # data for the colorbar
    colorbar_data_one_slice = {}
    colorbar_data_one_slice['colorbar_left'] = colorbar_data_all_left[style_parameter['map_view_default_index']]
    colorbar_data_one_slice['colorbar_right'] = colorbar_data_all_right[style_parameter['map_view_default_index']]
    colorbar_data_one_slice_bokeh = ColumnDataSource(data=dict(colorbar_top=colorbar_top,colorbar_bottom=colorbar_bottom,
                                                               colorbar_left=colorbar_data_one_slice['colorbar_left'],\
                                                               colorbar_right=colorbar_data_one_slice['colorbar_right'],\
                                                               palette_r=palette_r))
    colorbar_data_all_slices_bokeh = ColumnDataSource(data=dict(colorbar_data_all_left=colorbar_data_all_left,\
                                                                colorbar_data_all_right=colorbar_data_all_right))
    # data for dispersion curves
    curve_default_index = style_parameter['curve_default_index']
    selected_dot_on_map_bokeh = ColumnDataSource(data=dict(lat=[map_lat_list[curve_default_index]],\
                                                     lon=[map_lon_list[curve_default_index]],\
                                                     color=[map_data_one_slice[curve_default_index]],\
                                                     index=[curve_default_index]))
    selected_curve_data = curve_data_array[curve_default_index]
    selected_curve_data_bokeh = ColumnDataSource(data=dict(curve_period=selected_curve_data['period'],\
                                                          curve_velocity=selected_curve_data['velocity']))

    period_all = []
    velocity_all = []
    for acurve in curve_data_array:
        period_all.append(acurve['period'])
        velocity_all.append(acurve['velocity'])
    curve_data_all_bokeh = ColumnDataSource(data=dict(period_all=period_all, velocity_all=velocity_all))
    
    selected_curve_lat_label_bokeh = ColumnDataSource(data=dict(x=[style_parameter['curve_lat_label_x']], \
                                                                y=[style_parameter['curve_lat_label_y']],\
                                                                lat_label=[map_lat_label_list[curve_default_index]]))
    selected_curve_lon_label_bokeh = ColumnDataSource(data=dict(x=[style_parameter['curve_lon_label_x']], \
                                                                y=[style_parameter['curve_lon_label_y']],\
                                                                lon_label=[map_lon_label_list[curve_default_index]]))
    all_curve_lat_label_bokeh = ColumnDataSource(data=dict(map_lat_label_list=map_lat_label_list))
    all_curve_lon_label_bokeh = ColumnDataSource(data=dict(map_lon_label_list=map_lon_label_list))
    # ==============================
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['map_view_plot_height'], \
                      y_range=[style_parameter['map_view_lat_min'],\
                    style_parameter['map_view_lat_max']], x_range=[style_parameter['map_view_lon_min'],\
                    style_parameter['map_view_lon_max']], tools=style_parameter['map_view_tools'],\
                    title=style_parameter['map_view_title'])
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # ------------------------------
    # add period label
    map_view.rect(style_parameter['map_view_period_box_lon'], style_parameter['map_view_period_box_lat'], \
                  width=style_parameter['map_view_period_box_width'], height=style_parameter['map_view_period_box_height'], \
                  width_units='screen',height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    map_view.text('lon', 'lat', 'map_period', source=map_data_one_slice_period_bokeh,\
                  text_font_size=style_parameter['annotating_text_font_size'],text_align='left',level='underlay')
    # ------------------------------
    # plot dots
    map_view.circle('map_lon_list', 'map_lat_list', color='map_data_one_slice', \
                    source=map_data_one_slice_bokeh, size=style_parameter['marker_size'],\
                    line_width=0.2, line_color='black', alpha=1.0,\
                    selection_color='map_data_one_slice', selection_line_color='black',\
                    selection_fill_alpha=1.0,\
                    nonselection_fill_alpha=1.0, nonselection_fill_color='map_data_one_slice',\
                    nonselection_line_color='black', nonselection_line_alpha=1.0)
    map_view.circle('lon', 'lat', color='color', source=selected_dot_on_map_bokeh, \
                    line_color='#00ff00', line_width=4.0, alpha=1.0, \
                    size=style_parameter['selected_marker_size'])
    # ------------------------------
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # ==============================
    # plot colorbar
    colorbar_fig = Figure(tools=[], y_range=(0,0.1),plot_width=style_parameter['map_view_plot_width'], \
                          plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig.toolbar_location=None
    colorbar_fig.quad(top='colorbar_top',bottom='colorbar_bottom',left='colorbar_left',right='colorbar_right',\
                      fill_color='palette_r',source=colorbar_data_one_slice_bokeh)
    colorbar_fig.yaxis[0].ticker=FixedTicker(ticks=[])
    colorbar_fig.xgrid.grid_line_color = None
    colorbar_fig.ygrid.grid_line_color = None
    colorbar_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig.title.text_font_size = xlabel_fontsize
    colorbar_fig.title.align = 'center'
    colorbar_fig.title.text_font_style = 'normal'
    # ==============================
    curve_fig = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(style_parameter['curve_y_min'],style_parameter['curve_y_max']), \
                       x_range=(style_parameter['curve_x_min'],style_parameter['curve_x_max']),x_axis_type='log',\
                        tools=['save','box_zoom','wheel_zoom','reset','crosshair','pan'],
                        title=style_parameter['curve_title'])
    # ------------------------------
    curve_fig.rect([style_parameter['curve_label_box_x']], [style_parameter['curve_label_box_y']], \
                   width=style_parameter['curve_label_box_width'], height=style_parameter['curve_label_box_height'], \
                   width_units='screen', height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    curve_fig.text('x', 'y', \
                   'lat_label', source=selected_curve_lat_label_bokeh)
    curve_fig.text('x', 'y', \
                   'lon_label', source=selected_curve_lon_label_bokeh)
    # ------------------------------
    curve_fig.line('curve_period', 'curve_velocity', source=selected_curve_data_bokeh, color='black')
    curve_fig.circle('curve_period', 'curve_velocity', source=selected_curve_data_bokeh, size=5, color='black')
    # ------------------------------
    curve_fig.title.text_font_size = style_parameter['title_font_size']
    curve_fig.title.align = 'center'
    curve_fig.title.text_font_style = 'normal'
    curve_fig.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig.xaxis.axis_label_text_font_style = 'normal'
    curve_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig.yaxis.axis_label_text_font_style = 'normal'
    curve_fig.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig.xgrid.grid_line_dash = [4, 2]
    curve_fig.ygrid.grid_line_dash = [4, 2]
    curve_fig.xaxis[0].formatter = PrintfTickFormatter(format="%4.0f")
    curve_fig.toolbar.logo = None
    curve_fig.toolbar_location = 'above'
    curve_fig.toolbar_sticky = False
    # ==============================
    map_data_one_slice_bokeh.callback = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                          map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                                          selected_curve_data_bokeh=selected_curve_data_bokeh,\
                                                          curve_data_all_bokeh=curve_data_all_bokeh,\
                                                          selected_curve_lat_label_bokeh=selected_curve_lat_label_bokeh,\
                                                          selected_curve_lon_label_bokeh=selected_curve_lon_label_bokeh,\
                                                          all_curve_lat_label_bokeh=all_curve_lat_label_bokeh,\
                                                          all_curve_lon_label_bokeh=all_curve_lon_label_bokeh), code="""
    
    var inds = Math.round(cb_obj.selected['1d'].indices)
    
    selected_dot_on_map_bokeh.data['index'] = [inds]
    
    var new_slice = map_data_one_slice_bokeh.data
    
    selected_dot_on_map_bokeh.data['lat'] = [new_slice['map_lat_list'][inds]]
    selected_dot_on_map_bokeh.data['lon'] = [new_slice['map_lon_list'][inds]]
    selected_dot_on_map_bokeh.data['color'] = [new_slice['map_data_one_slice'][inds]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh.data['curve_period'] = curve_data_all_bokeh.data['period_all'][inds]
    selected_curve_data_bokeh.data['curve_velocity'] = curve_data_all_bokeh.data['velocity_all'][inds]
    
    selected_curve_data_bokeh.change.emit()
    
    var all_lat_labels = all_curve_lat_label_bokeh.data['map_lat_label_list']
    var all_lon_labels = all_curve_lon_label_bokeh.data['map_lon_label_list']
    
    selected_curve_lat_label_bokeh.data['lat_label'] = [all_lat_labels[inds]]
    selected_curve_lon_label_bokeh.data['lon_label'] = [all_lon_labels[inds]]
    
    selected_curve_lat_label_bokeh.change.emit()
    selected_curve_lon_label_bokeh.change.emit()
    """)
    # ==============================
    period_slider_callback = CustomJS(args=dict(map_data_all_slices_bokeh=map_data_all_slices_bokeh,\
                                  map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                  colorbar_data_all_slices_bokeh=colorbar_data_all_slices_bokeh, \
                                  colorbar_data_one_slice_bokeh=colorbar_data_one_slice_bokeh,\
                                  selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                  map_data_one_slice_period_bokeh=map_data_one_slice_period_bokeh),\
                       code="""
    var p_index = Math.round(cb_obj.value)
    var map_data_all_slices = map_data_all_slices_bokeh.data
    
    
    var map_data_new_slice = map_data_all_slices['map_data_all_slices_color'][p_index]
    map_data_one_slice_bokeh.data['map_data_one_slice'] = map_data_new_slice
    map_data_one_slice_bokeh.change.emit()
    
    var color_data_all_slices = colorbar_data_all_slices_bokeh.data
    colorbar_data_one_slice_bokeh.data['colorbar_left'] = color_data_all_slices['colorbar_data_all_left'][p_index]
    colorbar_data_one_slice_bokeh.data['colorbar_right'] = color_data_all_slices['colorbar_data_all_right'][p_index]
    colorbar_data_one_slice_bokeh.change.emit()
    
    var selected_index = selected_dot_on_map_bokeh.data['index']
    selected_dot_on_map_bokeh.data['color'] = [map_data_new_slice[selected_index]]
    selected_dot_on_map_bokeh.change.emit()
    
    map_data_one_slice_period_bokeh.data['map_period'] = [map_data_all_slices['map_data_all_slices_period'][p_index]]
    map_data_one_slice_period_bokeh.change.emit()
    """)
    period_slider = Slider(start=0, end=nperiod-1, value=style_parameter['map_view_default_index'], \
                           step=1, title=style_parameter['period_slider_title'], \
                           width=style_parameter['period_slider_plot_width'],\
                           height=50, callback=period_slider_callback)
    
    # ==============================
    curve_slider_callback = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                              map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                              selected_curve_data_bokeh=selected_curve_data_bokeh,\
                                              curve_data_all_bokeh=curve_data_all_bokeh,\
                                              selected_curve_lat_label_bokeh=selected_curve_lat_label_bokeh,\
                                              selected_curve_lon_label_bokeh=selected_curve_lon_label_bokeh,\
                                              all_curve_lat_label_bokeh=all_curve_lat_label_bokeh,\
                                              all_curve_lon_label_bokeh=all_curve_lon_label_bokeh),\
                                    code="""
    var c_index = Math.round(cb_obj.value)
    
    var one_slice = map_data_one_slice_bokeh.data
    
    selected_dot_on_map_bokeh.data['index'] = [c_index]
    selected_dot_on_map_bokeh.data['lat'] = [one_slice['map_lat_list'][c_index]]
    selected_dot_on_map_bokeh.data['lon'] = [one_slice['map_lon_list'][c_index]]
    selected_dot_on_map_bokeh.data['color'] = [one_slice['map_data_one_slice'][c_index]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh.data['curve_period'] = curve_data_all_bokeh.data['period_all'][c_index]
    selected_curve_data_bokeh.data['curve_velocity'] = curve_data_all_bokeh.data['velocity_all'][c_index]
    
    selected_curve_data_bokeh.change.emit()
    
    var all_lat_labels = all_curve_lat_label_bokeh.data['map_lat_label_list']
    var all_lon_labels = all_curve_lon_label_bokeh.data['map_lon_label_list']
    
    selected_curve_lat_label_bokeh.data['lat_label'] = [all_lat_labels[c_index]]
    selected_curve_lon_label_bokeh.data['lon_label'] = [all_lon_labels[c_index]]
    
    selected_curve_lat_label_bokeh.change.emit()
    selected_curve_lon_label_bokeh.change.emit()
    """)
    curve_slider = Slider(start=0, end=ncurve-1, value=style_parameter['curve_default_index'], \
                          step=1, title=style_parameter['curve_slider_title'], width=style_parameter['curve_plot_width'],\
                          height=50, callback=curve_slider_callback)
    
    # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text=style_parameter['annotating_html02'],\
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    output_file(filename,title=style_parameter['html_title'],mode=style_parameter['library_source'])
    left_fig = Column(period_slider, map_view, colorbar_fig, annotating_fig01,\
                    width=style_parameter['left_column_width'] )
    right_fig = Column(curve_slider, curve_fig, annotating_fig02, \
                    width=style_parameter['right_column_width'] )
    layout = Row(left_fig, right_fig)
    save(layout)
def plot_3DModel_bokeh(filename, map_data_all_slices, map_depth_all_slices, \
                       color_range_all_slices, profile_data_all, boundary_data, \
                       style_parameter):
    '''
    Plot shear velocity maps and velocity profiles using bokeh

    Input:
        filename is the filename of the resulting html file
        map_data_all_slices contains the velocity model parameters saved for map view plots
        map_depth_all_slices is a list of depths
        color_range_all_slices is a list of color ranges
        profile_data_all constains the velocity model parameters saved for profile plots
        boundary_data is a list of boundaries
        style_parameter contains plotting parameters

    Output:
        None
    
    '''
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    #
    colorbar_data_all_left = []
    colorbar_data_all_right = []
    map_view_ndepth = style_parameter['map_view_ndepth']
    ncolor = len(palette)
    colorbar_top = [0.1 for i in range(ncolor)]
    colorbar_bottom = [0 for i in range(ncolor)]
    map_data_all_slices_depth = []
    for idepth in range(map_view_ndepth):
        color_min = color_range_all_slices[idepth][0]
        color_max = color_range_all_slices[idepth][1]
        color_step = (color_max - color_min) * 1. / ncolor
        colorbar_left = np.linspace(color_min, color_max - color_step, ncolor)
        colorbar_right = np.linspace(color_min + color_step, color_max, ncolor)
        colorbar_data_all_left.append(colorbar_left)
        colorbar_data_all_right.append(colorbar_right)
        map_depth = map_depth_all_slices[idepth]
        map_data_all_slices_depth.append(
            'Depth: {0:8.1f} km'.format(map_depth))
    #
    palette_r = palette[::-1]
    # data for the colorbar
    colorbar_data_one_slice = {}
    colorbar_data_one_slice['colorbar_left'] = colorbar_data_all_left[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice['colorbar_right'] = colorbar_data_all_right[
        style_parameter['map_view_default_index']]
    colorbar_data_one_slice_bokeh = ColumnDataSource(data=dict(colorbar_top=colorbar_top,colorbar_bottom=colorbar_bottom,\
                                                               colorbar_left=colorbar_data_one_slice['colorbar_left'],\
                                                               colorbar_right=colorbar_data_one_slice['colorbar_right'],\
                                                               palette_r=palette_r))
    colorbar_data_all_slices_bokeh = ColumnDataSource(data=dict(colorbar_data_all_left=colorbar_data_all_left,\
                                                                colorbar_data_all_right=colorbar_data_all_right))
    #
    map_view_label_lon = style_parameter['map_view_depth_label_lon']
    map_view_label_lat = style_parameter['map_view_depth_label_lat']
    map_data_one_slice_depth = map_data_all_slices_depth[
        style_parameter['map_view_default_index']]
    map_data_one_slice_depth_bokeh = ColumnDataSource(data=dict(lat=[map_view_label_lat], lon=[map_view_label_lon],
                                                           map_depth=[map_data_one_slice_depth],
                                                           left=[style_parameter['profile_plot_xmin']], \
                                                           right=[style_parameter['profile_plot_xmax']]))

    #
    map_view_default_index = style_parameter['map_view_default_index']
    #map_data_one_slice = map_data_all_slices[map_view_default_index]
    #
    map_color_all_slices = []
    for i in range(len(map_data_all_slices)):
        vmin, vmax = color_range_all_slices[i]
        map_color = val_to_rgb(map_data_all_slices[i], palette_r, vmin, vmax)
        map_color_2d = map_color.view('uint32').reshape(map_color.shape[:2])
        map_color_all_slices.append(map_color_2d)
    map_color_one_slice = map_color_all_slices[map_view_default_index]
    #
    map_data_one_slice_bokeh = ColumnDataSource(data=dict(x=[style_parameter['map_view_image_lon_min']],\
                   y=[style_parameter['map_view_image_lat_min']],dw=[style_parameter['nlon']*style_parameter['dlon']],\
                   dh=[style_parameter['nlat']*style_parameter['dlat']],map_data_one_slice=[map_color_one_slice]))
    map_data_all_slices_bokeh = ColumnDataSource(data=dict(map_data_all_slices=map_color_all_slices,\
                                                           map_data_all_slices_depth=map_data_all_slices_depth))
    # ------------------------------
    nprofile = len(profile_data_all)
    grid_lat_list = []
    grid_lon_list = []
    width_list = []
    height_list = []
    for iprofile in range(nprofile):
        aprofile = profile_data_all[iprofile]
        grid_lat_list.append(aprofile['lat'])
        grid_lon_list.append(aprofile['lon'])
        width_list.append(style_parameter['map_view_grid_width'])
        height_list.append(style_parameter['map_view_grid_height'])
    grid_data_bokeh = ColumnDataSource(data=dict(lon=grid_lon_list,lat=grid_lat_list,\
                                            width=width_list, height=height_list))
    profile_default_index = style_parameter['profile_default_index']
    selected_dot_on_map_bokeh = ColumnDataSource(data=dict(lat=[grid_lat_list[profile_default_index]], \
                                                           lon=[grid_lon_list[profile_default_index]], \
                                                           width=[style_parameter['map_view_grid_width']],\
                                                           height=[style_parameter['map_view_grid_height']],\
                                                           index=[profile_default_index]))
    # ------------------------------
    profile_vs_all = []
    profile_depth_all = []
    profile_ndepth = style_parameter['profile_ndepth']
    profile_lat_label_list = []
    profile_lon_label_list = []
    for iprofile in range(nprofile):
        aprofile = profile_data_all[iprofile]
        vs_raw = aprofile['vs']
        top_raw = aprofile['top']
        profile_lat_label_list.append('Lat: {0:12.1f}'.format(aprofile['lat']))
        profile_lon_label_list.append('Lon: {0:12.1f}'.format(aprofile['lon']))
        vs_plot = []
        depth_plot = []
        for idepth in range(profile_ndepth):
            vs_plot.append(vs_raw[idepth])
            depth_plot.append(top_raw[idepth])
            vs_plot.append(vs_raw[idepth])
            depth_plot.append(top_raw[idepth + 1])
        profile_vs_all.append(vs_plot)
        profile_depth_all.append(depth_plot)
    profile_data_all_bokeh = ColumnDataSource(data=dict(profile_vs_all=profile_vs_all, \
                                                        profile_depth_all=profile_depth_all))
    selected_profile_data_bokeh = ColumnDataSource(data=dict(vs=profile_vs_all[profile_default_index],\
                                                             depth=profile_depth_all[profile_default_index]))
    selected_profile_lat_label_bokeh = ColumnDataSource(data=\
                                dict(x=[style_parameter['profile_lat_label_x']], y=[style_parameter['profile_lat_label_y']],\
                                    lat_label=[profile_lat_label_list[profile_default_index]]))
    selected_profile_lon_label_bokeh = ColumnDataSource(data=\
                                dict(x=[style_parameter['profile_lon_label_x']], y=[style_parameter['profile_lon_label_y']],\
                                    lon_label=[profile_lon_label_list[profile_default_index]]))
    all_profile_lat_label_bokeh = ColumnDataSource(data=dict(
        profile_lat_label_list=profile_lat_label_list))
    all_profile_lon_label_bokeh = ColumnDataSource(data=dict(
        profile_lon_label_list=profile_lon_label_list))
    #
    button_ndepth = style_parameter['button_ndepth']
    button_data_all_vs = []
    button_data_all_vp = []
    button_data_all_rho = []
    button_data_all_top = []
    for iprofile in range(nprofile):
        aprofile = profile_data_all[iprofile]
        button_data_all_vs.append(aprofile['vs'][:button_ndepth])
        button_data_all_vp.append(aprofile['vp'][:button_ndepth])
        button_data_all_rho.append(aprofile['rho'][:button_ndepth])
        button_data_all_top.append(aprofile['top'][:button_ndepth])
    button_data_all_bokeh = ColumnDataSource(data=dict(button_data_all_vs=button_data_all_vs,\
                                                       button_data_all_vp=button_data_all_vp,\
                                                       button_data_all_rho=button_data_all_rho,\
                                                       button_data_all_top=button_data_all_top))
    # ==============================
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], plot_height=style_parameter['map_view_plot_height'], \
                      tools=style_parameter['map_view_tools'], title=style_parameter['map_view_title'], \
                      y_range=[style_parameter['map_view_figure_lat_min'], style_parameter['map_view_figure_lat_max']],\
                      x_range=[style_parameter['map_view_figure_lon_min'], style_parameter['map_view_figure_lon_max']])
    #
    map_view.image_rgba('map_data_one_slice',x='x',\
                   y='y',dw='dw',dh='dh',
                   source=map_data_one_slice_bokeh, level='image')

    depth_slider_callback = CustomJS(args=dict(map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                               map_data_all_slices_bokeh=map_data_all_slices_bokeh,\
                                               colorbar_data_all_slices_bokeh=colorbar_data_all_slices_bokeh,\
                                               colorbar_data_one_slice_bokeh=colorbar_data_one_slice_bokeh,\
                                               map_data_one_slice_depth_bokeh=map_data_one_slice_depth_bokeh), code="""

        var d_index = Math.round(cb_obj.value)
        
        var map_data_all_slices = map_data_all_slices_bokeh.data
        
        map_data_one_slice_bokeh.data['map_data_one_slice'] = [map_data_all_slices['map_data_all_slices'][d_index]]
        map_data_one_slice_bokeh.change.emit()
        
        var color_data_all_slices = colorbar_data_all_slices_bokeh.data
        colorbar_data_one_slice_bokeh.data['colorbar_left'] = color_data_all_slices['colorbar_data_all_left'][d_index]
        colorbar_data_one_slice_bokeh.data['colorbar_right'] = color_data_all_slices['colorbar_data_all_right'][d_index]
        colorbar_data_one_slice_bokeh.change.emit()
        
        map_data_one_slice_depth_bokeh.data['map_depth'] = [map_data_all_slices['map_data_all_slices_depth'][d_index]]
        map_data_one_slice_depth_bokeh.change.emit()
        
    """)
    depth_slider = Slider(start=0, end=style_parameter['map_view_ndepth']-1, \
                          value=map_view_default_index, step=1, \
                          width=style_parameter['map_view_plot_width'],\
                          title=style_parameter['depth_slider_title'], height=50)
    depth_slider.js_on_change('value', depth_slider_callback)
    depth_slider_callback.args["depth_index"] = depth_slider
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # ------------------------------
    # add depth label
    map_view.rect(style_parameter['map_view_depth_box_lon'], style_parameter['map_view_depth_box_lat'], \
                  width=style_parameter['map_view_depth_box_width'], height=style_parameter['map_view_depth_box_height'], \
                  width_units='screen',height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    map_view.text('lon', 'lat', 'map_depth', source=map_data_one_slice_depth_bokeh,\
                  text_font_size=style_parameter['annotating_text_font_size'],text_align='left',level='underlay')
    # ------------------------------
    map_view.rect('lon', 'lat', width='width', \
                  width_units='screen', height='height', \
                  height_units='screen', line_color='gray', line_alpha=0.5, \
                  selection_line_color='gray', selection_line_alpha=0.5, selection_fill_color=None,\
                  nonselection_line_color='gray',nonselection_line_alpha=0.5, nonselection_fill_color=None,\
                  source=grid_data_bokeh, color=None, line_width=1, level='glyph')
    map_view.rect('lon', 'lat',width='width', \
                  width_units='screen', height='height', \
                  height_units='screen', line_color='#00ff00', line_alpha=1.0, \
                  source=selected_dot_on_map_bokeh, fill_color=None, line_width=3.,level='glyph')
    # ------------------------------
    grid_data_js = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh, \
                                                  grid_data_bokeh=grid_data_bokeh,\
                                                  profile_data_all_bokeh=profile_data_all_bokeh,\
                                                  selected_profile_data_bokeh=selected_profile_data_bokeh,\
                                                  selected_profile_lat_label_bokeh=selected_profile_lat_label_bokeh,\
                                                  selected_profile_lon_label_bokeh=selected_profile_lon_label_bokeh, \
                                                  all_profile_lat_label_bokeh=all_profile_lat_label_bokeh, \
                                                  all_profile_lon_label_bokeh=all_profile_lon_label_bokeh, \
                                                 ), code="""
        
        var inds = cb_obj.indices
        
        var grid_data = grid_data_bokeh.data
        selected_dot_on_map_bokeh.data['lat'] = [grid_data['lat'][inds]]
        selected_dot_on_map_bokeh.data['lon'] = [grid_data['lon'][inds]]
        selected_dot_on_map_bokeh.data['index'] = [inds]
        selected_dot_on_map_bokeh.change.emit()
        
        var profile_data_all = profile_data_all_bokeh.data
        selected_profile_data_bokeh.data['vs'] = profile_data_all['profile_vs_all'][inds]
        selected_profile_data_bokeh.data['depth'] = profile_data_all['profile_depth_all'][inds]
        selected_profile_data_bokeh.change.emit()
        
        var all_profile_lat_label = all_profile_lat_label_bokeh.data['profile_lat_label_list']
        var all_profile_lon_label = all_profile_lon_label_bokeh.data['profile_lon_label_list']
        selected_profile_lat_label_bokeh.data['lat_label'] = [all_profile_lat_label[inds]]
        selected_profile_lon_label_bokeh.data['lon_label'] = [all_profile_lon_label[inds]]
        selected_profile_lat_label_bokeh.change.emit()
        selected_profile_lon_label_bokeh.change.emit()
    """)
    grid_data_bokeh.selected.js_on_change('indices', grid_data_js)
    # ------------------------------
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # ==============================
    # plot colorbar
    colorbar_fig = Figure(tools=[], y_range=(0,0.1),plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig.toolbar_location = None
    colorbar_fig.quad(top='colorbar_top',bottom='colorbar_bottom',left='colorbar_left',right='colorbar_right',\
                  color='palette_r',source=colorbar_data_one_slice_bokeh)
    colorbar_fig.yaxis[0].ticker = FixedTicker(ticks=[])
    colorbar_fig.xgrid.grid_line_color = None
    colorbar_fig.ygrid.grid_line_color = None
    colorbar_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig.title.text_font_size = xlabel_fontsize
    colorbar_fig.title.align = 'center'
    colorbar_fig.title.text_font_style = 'normal'
    # ==============================
    profile_xrange = Range1d(start=style_parameter['profile_plot_xmin'],
                             end=style_parameter['profile_plot_xmax'])
    profile_yrange = Range1d(start=style_parameter['profile_plot_ymax'],
                             end=style_parameter['profile_plot_ymin'])
    profile_fig = Figure(plot_width=style_parameter['profile_plot_width'], plot_height=style_parameter['profile_plot_height'],\
                         x_range=profile_xrange, y_range=profile_yrange, tools=style_parameter['profile_tools'],\
                         title=style_parameter['profile_title'])
    profile_fig.line('vs',
                     'depth',
                     source=selected_profile_data_bokeh,
                     line_width=2,
                     line_color='black')
    # ------------------------------
    # add lat, lon
    profile_fig.rect([style_parameter['profile_label_box_x']], [style_parameter['profile_label_box_y']],\
                     width=style_parameter['profile_label_box_width'], height=style_parameter['profile_label_box_height'],\
                     width_units='screen', height_units='screen', color='#FFFFFF', line_width=1., line_color='black',\
                     level='underlay')
    profile_fig.text('x',
                     'y',
                     'lat_label',
                     source=selected_profile_lat_label_bokeh)
    profile_fig.text('x',
                     'y',
                     'lon_label',
                     source=selected_profile_lon_label_bokeh)
    # ------------------------------
    # change style
    profile_fig.xaxis.axis_label = style_parameter['profile_xlabel']
    profile_fig.xaxis.axis_label_text_font_style = 'normal'
    profile_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    profile_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    profile_fig.yaxis.axis_label = style_parameter['profile_ylabel']
    profile_fig.yaxis.axis_label_text_font_style = 'normal'
    profile_fig.yaxis.axis_label_text_font_size = xlabel_fontsize
    profile_fig.yaxis.major_label_text_font_size = xlabel_fontsize
    profile_fig.xgrid.grid_line_dash = [4, 2]
    profile_fig.ygrid.grid_line_dash = [4, 2]
    profile_fig.title.text_font_size = style_parameter['title_font_size']
    profile_fig.title.align = 'center'
    profile_fig.title.text_font_style = 'normal'
    profile_fig.toolbar_location = 'above'
    profile_fig.toolbar_sticky = False
    profile_fig.toolbar.logo = None
    # ==============================
    profile_slider_callback = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                 grid_data_bokeh=grid_data_bokeh, \
                                                 profile_data_all_bokeh=profile_data_all_bokeh, \
                                                 selected_profile_data_bokeh=selected_profile_data_bokeh,\
                                                 selected_profile_lat_label_bokeh=selected_profile_lat_label_bokeh,\
                                                 selected_profile_lon_label_bokeh=selected_profile_lon_label_bokeh, \
                                                 all_profile_lat_label_bokeh=all_profile_lat_label_bokeh, \
                                                 all_profile_lon_label_bokeh=all_profile_lon_label_bokeh), code="""
        var p_index = Math.round(cb_obj.value)
        
        var grid_data = grid_data_bokeh.data
        selected_dot_on_map_bokeh.data['lat'] = [grid_data['lat'][p_index]]
        selected_dot_on_map_bokeh.data['lon'] = [grid_data['lon'][p_index]]
        selected_dot_on_map_bokeh.data['index'] = [p_index]
        selected_dot_on_map_bokeh.change.emit()
        
        var profile_data_all = profile_data_all_bokeh.data
        selected_profile_data_bokeh.data['vs'] = profile_data_all['profile_vs_all'][p_index]
        selected_profile_data_bokeh.data['depth'] = profile_data_all['profile_depth_all'][p_index]
        selected_profile_data_bokeh.change.emit()
        
        var all_profile_lat_label = all_profile_lat_label_bokeh.data['profile_lat_label_list']
        var all_profile_lon_label = all_profile_lon_label_bokeh.data['profile_lon_label_list']
        selected_profile_lat_label_bokeh.data['lat_label'] = [all_profile_lat_label[p_index]]
        selected_profile_lon_label_bokeh.data['lon_label'] = [all_profile_lon_label[p_index]]
        selected_profile_lat_label_bokeh.change.emit()
        selected_profile_lon_label_bokeh.change.emit()
        
    """)
    profile_slider = Slider(start=0, end=nprofile-1, value=style_parameter['profile_default_index'], \
                           step=1, title=style_parameter['profile_slider_title'], \
                           width=style_parameter['profile_plot_width'], height=50)
    profile_slider_callback.args['profile_index'] = profile_slider
    profile_slider.js_on_change('value', profile_slider_callback)
    # ==============================
    simple_text_button_callback = CustomJS(args=dict(button_data_all_bokeh=button_data_all_bokeh,\
                                                    selected_dot_on_map_bokeh=selected_dot_on_map_bokeh), \
                                           code="""
        var index = selected_dot_on_map_bokeh.data['index']
        
        var button_data_vs = button_data_all_bokeh.data['button_data_all_vs'][index]
        var button_data_vp = button_data_all_bokeh.data['button_data_all_vp'][index]
        var button_data_rho = button_data_all_bokeh.data['button_data_all_rho'][index]
        var button_data_top = button_data_all_bokeh.data['button_data_all_top'][index]
        
        var csvContent = ""
        var i = 0
        var temp = csvContent
        temp += "# Layer Top (km)      Vs(km/s)    Vp(km/s)    Rho(g/cm^3) \\n"
        while(button_data_vp[i]) {
            temp+=button_data_top[i].toPrecision(6) + "    " + button_data_vs[i].toPrecision(4) + "   " + \
                    button_data_vp[i].toPrecision(4) + "   " + button_data_rho[i].toPrecision(4) + "\\n"
            i = i + 1
        }
        const blob = new Blob([temp], { type: 'text/csv;charset=utf-8;' })
        const link = document.createElement('a');
        link.href = URL.createObjectURL(blob);
        link.download = 'vel_model.txt';
        link.target = '_blank'
        link.style.visibility = 'hidden'
        link.dispatchEvent(new MouseEvent('click'))
        
    """)

    simple_text_button = Button(
        label=style_parameter['simple_text_button_label'],
        button_type='default',
        width=style_parameter['button_width'])
    simple_text_button.js_on_click(simple_text_button_callback)
    # ------------------------------
    model96_button_callback = CustomJS(args=dict(button_data_all_bokeh=button_data_all_bokeh,\
                                                    selected_dot_on_map_bokeh=selected_dot_on_map_bokeh), \
                                           code="""
        var index = selected_dot_on_map_bokeh.data['index']
        var lat = selected_dot_on_map_bokeh.data['lat']
        var lon = selected_dot_on_map_bokeh.data['lon']
        
        var button_data_vs = button_data_all_bokeh.data['button_data_all_vs'][index]
        var button_data_vp = button_data_all_bokeh.data['button_data_all_vp'][index]
        var button_data_rho = button_data_all_bokeh.data['button_data_all_rho'][index]
        var button_data_top = button_data_all_bokeh.data['button_data_all_top'][index]
        
        var csvContent = ""
        var i = 0
        var temp = csvContent
        temp +=  "MODEL." + index + " \\n"
        temp +=  "ShearVelocityModel Lat: "+ lat +"  Lon: " + lon + "\\n"
        temp +=  "ISOTROPIC \\n"
        temp +=  "KGS \\n"
        temp +=  "SPHERICAL EARTH \\n"
        temp +=  "1-D \\n"
        temp +=  "CONSTANT VELOCITY \\n"
        temp +=  "LINE08 \\n"
        temp +=  "LINE09 \\n"
        temp +=  "LINE10 \\n"
        temp +=  "LINE11 \\n"
        temp +=  "      H(KM)   VP(KM/S)   VS(KM/S) RHO(GM/CC)     QP         QS       ETAP       ETAS      FREFP      FREFS \\n"
        while(button_data_vp[i+1]) {
            var thickness = button_data_top[i+1] - button_data_top[i]
            temp+="      " +thickness.toPrecision(6) + "    " + button_data_vp[i].toPrecision(4) + "      " + button_data_vs[i].toPrecision(4) \
                 + "      " + button_data_rho[i].toPrecision(4) + "     0.00       0.00       0.00       0.00       1.00       1.00" + "\\n"
            i = i + 1
        }
        const blob = new Blob([temp], { type: 'text/csv;charset=utf-8;' })
        const link = document.createElement('a');
        link.href = URL.createObjectURL(blob);
        link.download = 'vel_model96.txt';
        link.target = '_blank'
        link.style.visibility = 'hidden'
        link.dispatchEvent(new MouseEvent('click'))
    """)
    model96_button = Button(label=style_parameter['model96_button_label'],
                            button_type='default',
                            width=style_parameter['button_width'])
    model96_button.js_on_click(model96_button_callback)
    # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text=style_parameter['annotating_html02'],\
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    output_file(filename,
                title=style_parameter['html_title'],
                mode=style_parameter['library_source'])
    left_column = Column(depth_slider,
                         map_view,
                         colorbar_fig,
                         annotating_fig01,
                         width=style_parameter['left_column_width'])
    button_pannel = Row(simple_text_button, model96_button)
    right_column = Column(profile_slider,
                          profile_fig,
                          button_pannel,
                          annotating_fig02,
                          width=style_parameter['right_column_width'])
    layout = Row(left_column, right_column)
    save(layout)
for cluster in clusters:
	plot_data.append(np.where(predictions == cluster)[0])
plot_data = spike_waveforms[plot_data]

# Set up data
N = 0
x = np.arange(len(plot_data[N])/10)
y = spike_waveforms[0, ::10]
source = ColumnDataSource(data=dict(xs=[x for i in range(50)], ys = [plot_data[N + i, ::10] for i in range(50)]))

# Set up plot
plot = Figure(plot_height=400, plot_width=400, title="Unit waveforms",
              tools="crosshair,pan,reset,save,wheel_zoom",
              x_range=[0, 45], y_range=[-200, 200])

plot.multi_line('xs', 'ys', source=source, line_width=1, line_alpha=1.0)

# Set up widgets
# text = TextInput(title="title", value='my sine wave')
offset = Slider(title="offset", value=0, start=0, end=50000, step= 100) # put the end of the slider at a large enough value so that almost all cluster sizes will fit in
electrode = TextInput(title = 'Electrode Number', value = '0')
clusters = TextInput(title = 'Number of clusters', value = '2')
cluster_num = TextInput(title = 'Cluster Number', value = '0')
#amplitude = Slider(title="amplitude", value=1.0, start=-5.0, end=5.0)
#phase = Slider(title="phase", value=0.0, start=0.0, end=2*np.pi)
#freq = Slider(title="frequency", value=1.0, start=0.1, end=5.1)

def update_data(attrname, old, new):
    
    os.chdir(dir_name)
    
          line_width=5)  #Column 1
plot.line(x='x', y='y', source=col2.pts, color='#0065BD',
          line_width=5)  #Column 2
plot.line(x='x', y='y', source=col3.pts, color='#0065BD',
          line_width=5)  #Column 3
plot.line(x='x', y='y', source=col4.pts, color='#0065BD',
          line_width=5)  #Column 4
#Create the floors for each column:
plot.line(x='x', y='y', source=col1.floor, color='black', line_width=6)
plot.line(x='x', y='y', source=col2.floor, color='black', line_width=6)
plot.line(x='x', y='y', source=col3.floor, color='black', line_width=6)
plot.line(x='x', y='y', source=col4.floor, color='black', line_width=6)
#Create walls for columns that require a wall:
plot.line(x='x', y='y', source=col2.wall, color='black', line_width=6)
plot.line(x='x', y='y', source=col3.wall, color='black', line_width=6)
plot.multi_line(xs='x', ys='y', source=col4.wall, color='black', line_width=6)
#Create circles for columns that have pins:
plot.circle(x='x', y='y', source=col2.cir1, color='#0065BD', size=10)
plot.circle(x='x', y='y', source=col2.cir2, color='#0065BD', size=10)
plot.circle(x='x', y='y', source=col3.cir2, color='#0065BD', size=10)
#Create the shapes of the ends of the columns:
plot.triangle(x='x',
              y='y',
              source=col2.tri1,
              color='black',
              angle=0.0,
              fill_alpha=0,
              size=20)
plot.triangle(x='x',
              y='y',
              source=col2.tri2,
Exemple #18
0
def hst_time(result_dict, plot=True, output_file ='hsttime.html', model = True):
    """Plot earliest and latest start times for hst observation
    
    Parameters
    ----------
    result_dict : dict 
        Dictionary from pandexo output.
    
    plot : bool 
        (Optional) True renders plot, False does not. Default=True
    model : bool 
        (Optional) Plot model under data. Default=True
    output_file : str
        (Optional) Default = 'hsttime.html'    
    
    Return
    ------
    obsphase1 : numpy array
        earliest start time
    obstr1 : numpy array
        white light curve
    obsphase2 : numpy array
        latest start time
    obstr2 : numpy array
        white light curve
    rms : numpy array
        1D rms noise

    See Also
    --------
    hst_spec
    """
    TOOLS = "pan,wheel_zoom,box_zoom,resize,reset,save"
    #earliest and latest start times 
    obsphase1 = result_dict['calc_start_window']['obsphase1']
    obstr1 = result_dict['calc_start_window']['obstr1']
    rms = result_dict['calc_start_window']['light_curve_rms']
    obsphase2 = result_dict['calc_start_window']['obsphase2']
    obstr2 = result_dict['calc_start_window']['obstr2']
    phase1 = result_dict['calc_start_window']['phase1']    
    phase2 = result_dict['calc_start_window']['phase2']
    trmodel1 = result_dict['calc_start_window']['trmodel1']
    trmodel2 = result_dict['calc_start_window']['trmodel2']    
    
    if isinstance(rms, float):
        rms = np.zeros(len(obsphase1))+rms
    y_err1 = []
    x_err1 = []
    for px, py, yerr in zip(obsphase1, obstr1, rms):
        np.array(x_err1.append((px, px)))
        np.array(y_err1.append((py - yerr, py + yerr)))

    y_err2 = []
    x_err2 = []
    for px, py, yerr in zip(obsphase2, obstr2, rms):
        np.array(x_err2.append((px, px)))
        np.array(y_err2.append((py - yerr, py + yerr)))

    early = Figure(plot_width=400, plot_height=300,
                               tools=TOOLS,#responsive=True,
                                 x_axis_label='Orbital Phase',
                                 y_axis_label='Flux', 
                               title="Earliest Start Time")
    
    if model: early.line(phase1, trmodel1, color='black',alpha=0.5, line_width = 4)
    early.circle(obsphase1, obstr1, line_width=3, line_alpha=0.6)
    early.multi_line(x_err1, y_err1)
     
    late = Figure(plot_width=400, plot_height=300, 
                                tools=TOOLS,#responsive=True,
                                 x_axis_label='Orbital Phase',
                                 y_axis_label='Flux', 
                               title="Latest Start Time")
    if model: late.line(phase2, trmodel2, color='black',alpha=0.5, line_width = 3)
    late.circle(obsphase2, obstr2, line_width=3, line_alpha=0.6)
    late.multi_line(x_err2, y_err2)

    start_time = row(early, late)    
    
    if plot: 
        outputfile(output_file)
        show(start_time)
    

    
    return obsphase1, obstr1, obsphase2, obstr2,rms
Exemple #19
0
def hst_spec(result_dict, plot=True, output_file ='hstspec.html', model = True):
    """Plot 1d spec with error bars for hst 
    
    Parameters
    ----------
    result_dict : dict 
        Dictionary from pandexo output.
    
    plot : bool 
        (Optional) True renders plot, False does not. Default=True
    model : bool 
        (Optional) Plot model under data. Default=True
    output_file : str
        (Optional) Default = 'hstspec.html'    
    
    Return
    ------
    x : numpy array
        micron
    y : numpy array
        1D spec fp/f* or rp^2/r*^2
    e : numpy array
        1D rms noise
    modelx : numpy array
        micron
    modely : numpy array
        1D spec fp/f* or rp^2/r*^2        
    See Also
    --------
    hst_time
    """
    TOOLS = "pan,wheel_zoom,box_zoom,resize,reset,save"
    #plot planet spectrum
    mwave = result_dict['planet_spec']['model_wave']
    mspec = result_dict['planet_spec']['model_spec']
    
    binwave = result_dict['planet_spec']['binwave']
    binspec = result_dict['planet_spec']['binspec']
    
    error = result_dict['planet_spec']['error']
    error = np.zeros(len(binspec))+ error
    xlims = [result_dict['planet_spec']['wmin'], result_dict['planet_spec']['wmax']]
    ylims = [np.min(binspec)-2.0*error[0], np.max(binspec)+2.0*error[0]]
    
    plot_spectrum = Figure(plot_width=800, plot_height=300, x_range=xlims,
                               y_range=ylims, tools=TOOLS,#responsive=True,
                                 x_axis_label='Wavelength [microns]',
                                 y_axis_label='Ratio', 
                               title="Original Model with Observation")
    
    y_err = []
    x_err = []
    for px, py, yerr in zip(binwave, binspec, error):
        np.array(x_err.append((px, px)))
        np.array(y_err.append((py - yerr, py + yerr)))
    if model:
        plot_spectrum.line(mwave,mspec, color= "black", alpha = 0.5, line_width = 4)
    plot_spectrum.circle(binwave,binspec, line_width=3, line_alpha=0.6)
    plot_spectrum.multi_line(x_err, y_err)
    
    if plot: 
        outputfile(output_file)
        show(plot_spectrum)

    return binwave, binspec, error, mwave, mspec
Exemple #20
0
def create_component_hst(result_dict):
    """Generate front end plots HST
    
    Function that is responsible for generating the front-end spectra plots for HST.
    
    Parameters
    ----------
    result_dict : dict 
        The dictionary returned from a PandExo (HST) run

    Returns
    -------
    tuple
        A tuple containing `(script, div)`, where the `script` is the
        front-end javascript required, and `div` is a dictionary of plot
        objects.
    """                                   
    TOOLS = "pan,wheel_zoom,box_zoom,resize,reset,save"

    #plot planet spectrum
    mwave = result_dict['planet_spec']['model_wave']
    mspec = result_dict['planet_spec']['model_spec']
    
    binwave = result_dict['planet_spec']['binwave']
    binspec = result_dict['planet_spec']['binspec']
    
    error = result_dict['planet_spec']['error']
    error = np.zeros(len(binspec))+ error
    xlims = [result_dict['planet_spec']['wmin'], result_dict['planet_spec']['wmax']]
    ylims = [np.min(binspec)-2.0*error[0], np.max(binspec)+2.0*error[0]]
    
    plot_spectrum = Figure(plot_width=800, plot_height=300, x_range=xlims,
                               y_range=ylims, tools=TOOLS,#responsive=True,
                                 x_axis_label='Wavelength [microns]',
                                 y_axis_label='(Rp/R*)^2', 
                               title="Original Model with Observation")
    
    y_err = []
    x_err = []
    for px, py, yerr in zip(binwave, binspec, error):
        np.array(x_err.append((px, px)))
        np.array(y_err.append((py - yerr, py + yerr)))

    plot_spectrum.line(mwave,mspec, color= "black", alpha = 0.5, line_width = 4)
    plot_spectrum.circle(binwave,binspec, line_width=3, line_alpha=0.6)
    plot_spectrum.multi_line(x_err, y_err)
    
    
    #earliest and latest start times 
    obsphase1 = result_dict['calc_start_window']['obsphase1']
    obstr1 = result_dict['calc_start_window']['obstr1']
    rms = result_dict['calc_start_window']['light_curve_rms']
    obsphase2 = result_dict['calc_start_window']['obsphase2']
    obstr2 = result_dict['calc_start_window']['obstr2']
    phase1 = result_dict['calc_start_window']['phase1']    
    phase2 = result_dict['calc_start_window']['phase2']
    trmodel1 = result_dict['calc_start_window']['trmodel1']
    trmodel2 = result_dict['calc_start_window']['trmodel2']    
    
    if isinstance(rms, float):
        rms = np.zeros(len(obsphase1))+rms
    y_err1 = []
    x_err1 = []
    for px, py, yerr in zip(obsphase1, obstr1, rms):
        np.array(x_err1.append((px, px)))
        np.array(y_err1.append((py - yerr, py + yerr)))

    y_err2 = []
    x_err2 = []
    for px, py, yerr in zip(obsphase2, obstr2, rms):
        np.array(x_err2.append((px, px)))
        np.array(y_err2.append((py - yerr, py + yerr)))

    early = Figure(plot_width=400, plot_height=300,
                               tools=TOOLS,#responsive=True,
                                 x_axis_label='Orbital Phase',
                                 y_axis_label='Flux', 
                               title="Earliest Start Time")
    
    early.line(phase1, trmodel1, color='black',alpha=0.5, line_width = 4)
    early.circle(obsphase1, obstr1, line_width=3, line_alpha=0.6)
    early.multi_line(x_err1, y_err1)
     
    late = Figure(plot_width=400, plot_height=300, 
                                tools=TOOLS,#responsive=True,
                                 x_axis_label='Orbital Phase',
                                 y_axis_label='Flux', 
                               title="Latest Start Time")
    late.line(phase2, trmodel2, color='black',alpha=0.5, line_width = 3)
    late.circle(obsphase2, obstr2, line_width=3, line_alpha=0.6)
    late.multi_line(x_err2, y_err2)
        
    start_time = row(early, late)
    
    result_comp = components({'plot_spectrum':plot_spectrum, 
                              'start_time':start_time})

    return result_comp
Exemple #21
0
def create_component_hst(result_dict):
    """Generate front end plots HST
    
    Function that is responsible for generating the front-end spectra plots for HST.
    
    Parameters
    ----------
    result_dict : dict 
        The dictionary returned from a PandExo (HST) run

    Returns
    -------
    tuple
        A tuple containing `(script, div)`, where the `script` is the
        front-end javascript required, and `div` is a dictionary of plot
        objects.
    """                                   
    TOOLS = "pan,wheel_zoom,box_zoom,resize,reset,save"

    #plot planet spectrum
    mwave = result_dict['planet_spec']['model_wave']
    mspec = result_dict['planet_spec']['model_spec']
    
    binwave = result_dict['planet_spec']['binwave']
    binspec = result_dict['planet_spec']['binspec']
    
    error = result_dict['planet_spec']['error']
    error = np.zeros(len(binspec))+ error
    xlims = [result_dict['planet_spec']['wmin'], result_dict['planet_spec']['wmax']]
    ylims = [np.min(binspec)-2.0*error[0], np.max(binspec)+2.0*error[0]]
    
    eventType = result_dict['calc_start_window']['eventType'] 

    if eventType=='tranist':
        y_axis = '(Rp/R*)^2'
    elif eventType =='eclipse':
        y_axis='Fp/F*'

    plot_spectrum = Figure(plot_width=800, plot_height=300, x_range=xlims,
                               y_range=ylims, tools=TOOLS,#responsive=True,
                                 x_axis_label='Wavelength [microns]',
                                 y_axis_label=y_axis, 
                               title="Original Model with Observation")
    
    y_err = []
    x_err = []
    for px, py, yerr in zip(binwave, binspec, error):
        np.array(x_err.append((px, px)))
        np.array(y_err.append((py - yerr, py + yerr)))

    plot_spectrum.line(mwave,mspec, color= "black", alpha = 0.5, line_width = 4)
    plot_spectrum.circle(binwave,binspec, line_width=3, line_alpha=0.6)
    plot_spectrum.multi_line(x_err, y_err)
    
    
    #earliest and latest start times 
    obsphase1 = result_dict['calc_start_window']['obsphase1']
    obstr1 = result_dict['calc_start_window']['obstr1']
    rms = result_dict['calc_start_window']['light_curve_rms']
    obsphase2 = result_dict['calc_start_window']['obsphase2']
    obstr2 = result_dict['calc_start_window']['obstr2']
    phase1 = result_dict['calc_start_window']['phase1']    
    phase2 = result_dict['calc_start_window']['phase2']
    trmodel1 = result_dict['calc_start_window']['trmodel1']
    trmodel2 = result_dict['calc_start_window']['trmodel2']    
    
    if isinstance(rms, float):
        rms = np.zeros(len(obsphase1))+rms
    y_err1 = []
    x_err1 = []
    for px, py, yerr in zip(obsphase1, obstr1, rms):
        np.array(x_err1.append((px, px)))
        np.array(y_err1.append((py - yerr, py + yerr)))

    y_err2 = []
    x_err2 = []
    for px, py, yerr in zip(obsphase2, obstr2, rms):
        np.array(x_err2.append((px, px)))
        np.array(y_err2.append((py - yerr, py + yerr)))

    early = Figure(plot_width=400, plot_height=300,
                               tools=TOOLS,#responsive=True,
                                 x_axis_label='Orbital Phase',
                                 y_axis_label='Flux', 
                               title="Earliest Start Time")
    
    early.line(phase1, trmodel1, color='black',alpha=0.5, line_width = 4)
    early.circle(obsphase1, obstr1, line_width=3, line_alpha=0.6)
    early.multi_line(x_err1, y_err1)
     
    late = Figure(plot_width=400, plot_height=300, 
                                tools=TOOLS,#responsive=True,
                                 x_axis_label='Orbital Phase',
                                 y_axis_label='Flux', 
                               title="Latest Start Time")
    late.line(phase2, trmodel2, color='black',alpha=0.5, line_width = 3)
    late.circle(obsphase2, obstr2, line_width=3, line_alpha=0.6)
    late.multi_line(x_err2, y_err2)
        
    start_time = row(early, late)
    
    result_comp = components({'plot_spectrum':plot_spectrum, 
                              'start_time':start_time})

    return result_comp
Exemple #22
0
def create_component_jwst(result_dict):
    """Generate front end plots JWST
    
    Function that is responsible for generating the front-end interactive plots for JWST.

    Parameters 
    ----------
    result_dict : dict 
        the dictionary returned from a PandExo run
    
    Returns
    -------
    tuple 
        A tuple containing `(script, div)`, where the `script` is the
        front-end javascript required, and `div` is a dictionary of plot
        objects.
    """  
    noccultations = result_dict['timing']['Number of Transits']
    
    # select the tools we want
    TOOLS = "pan,wheel_zoom,box_zoom,resize,reset,save"

    #Define units for x and y axis
    punit = result_dict['input']['Primary/Secondary']
    p=1.0
    if punit == 'fp/f*': p = -1.0
    else: punit = '('+punit+')^2'
    
    if result_dict['input']['Calculation Type'] =='phase_spec':
        x_axis_label='Time (secs)'
        frac = 1.0
    else:
        x_axis_label='Wavelength [microns]'
        frac = result_dict['timing']['Num Integrations Out of Transit']/result_dict['timing']['Num Integrations In Transit']

    electrons_out = result_dict['RawData']['electrons_out']
    electrons_in = result_dict['RawData']['electrons_in']
    
    var_in = result_dict['RawData']['var_in']
    var_out = result_dict['RawData']['var_out']
    
    
    x = result_dict['FinalSpectrum']['wave']
    y = result_dict['FinalSpectrum']['spectrum_w_rand']
    err = result_dict['FinalSpectrum']['error_w_floor']

    y_err = []
    x_err = []
    for px, py, yerr in zip(x, y, err):
        np.array(x_err.append((px, px)))
        np.array(y_err.append((py - yerr, py + yerr)))

    source = ColumnDataSource(data=dict(x=x, y=y, y_err=y_err, x_err=x_err, err=err, 
                                electrons_out=electrons_out, electrons_in=electrons_in, var_in=var_in, var_out=var_out, 
                                p=var_in*0+p,nocc=var_in*0+noccultations, frac = var_in*0+frac))
    original = ColumnDataSource(data=dict(x=x, y=y, y_err=y_err, x_err=x_err, err=err, electrons_out=electrons_out, electrons_in=electrons_in, var_in=var_in, var_out=var_out))

    ylims = [min(result_dict['OriginalInput']['model_spec'])- 0.1*min(result_dict['OriginalInput']['model_spec']),
                 0.1*max(result_dict['OriginalInput']['model_spec'])+max(result_dict['OriginalInput']['model_spec'])]
    xlims = [min(result_dict['FinalSpectrum']['wave']), max(result_dict['FinalSpectrum']['wave'])]

    plot_spectrum = Figure(plot_width=800, plot_height=300, x_range=xlims,
                               y_range=ylims, tools=TOOLS,#responsive=True,
                                 x_axis_label=x_axis_label,
                                 y_axis_label=punit, 
                               title="Original Model with Observation")
    
    plot_spectrum.line(result_dict['OriginalInput']['model_wave'],result_dict['OriginalInput']['model_spec'], color= "black", alpha = 0.5, line_width = 4)
        
    plot_spectrum.circle('x', 'y', source=source, line_width=3, line_alpha=0.6)
    plot_spectrum.multi_line('x_err', 'y_err', source=source)

    callback = CustomJS(args=dict(source=source, original=original), code="""
            // Grab some references to the data
            var sdata = source.get('data');
            var odata = original.get('data');

            // Create copies of the original data, store them as the source data
            sdata['x'] = odata['x'].slice(0);
            sdata['y'] = odata['y'].slice(0);

            sdata['y_err'] = odata['y_err'].slice(0);
            sdata['x_err'] = odata['x_err'].slice(0);
            sdata['err'] = odata['err'].slice(0);

            sdata['electrons_out'] = odata['electrons_out'].slice(0);
            sdata['electrons_in'] = odata['electrons_in'].slice(0);
            sdata['var_in'] = odata['var_in'].slice(0);
            sdata['var_out'] = odata['var_out'].slice(0);

            // Create some variables referencing the source data
            var x = sdata['x'];
            var y = sdata['y'];
            var y_err = sdata['y_err'];
            var x_err = sdata['x_err'];
            var err = sdata['err'];
            var p = sdata['p'];
            var frac = sdata['frac'];
            var og_ntran = sdata['nocc'];

            var electrons_out = sdata['electrons_out'];
            var electrons_in = sdata['electrons_in'];
            var var_in = sdata['var_in'];
            var var_out = sdata['var_out'];

            var f = wbin.get('value');
            var ntran = ntran.get('value');

            var wlength = Math.pow(10.0,f);

            var ind = [];
            ind.push(0);
            var start = 0;


            for (i = 0; i < x.length-1; i++) {
                if (x[i+1] - x[start] >= wlength) {
                    ind.push(i+1);
                    start = i;
                }
            }

            if (ind[ind.length-1] != x.length) {
                ind.push(x.length);
            }

            var xout = [];


            var foutout = [];
            var finout = [];
            var varinout = [];
            var varoutout = [];

            var xslice = []; 

            var foutslice = [];
            var finslice = [];
            var varoutslice = [];
            var varinslice = [];

            function add(a, b) {
                return a+b;
            }

            for (i = 0; i < ind.length-1; i++) {
                xslice = x.slice(ind[i],ind[i+1]);

                foutslice = electrons_out.slice(ind[i],ind[i+1]);
                finslice = electrons_in.slice(ind[i],ind[i+1]);
                
                varinslice = var_in.slice(ind[i],ind[i+1]);
                varoutslice = var_out.slice(ind[i],ind[i+1]);

                xout.push(xslice.reduce(add, 0)/xslice.length);
                foutout.push(foutslice.reduce(add, 0));
                finout.push(finslice.reduce(add, 0));
                
                varinout.push(varinslice.reduce(add, 0));
                varoutout.push(varoutslice.reduce(add, 0));

                xslice = [];
                foutslice = [];
                finslice = [];
                varinslice = [];
                varoutslice = [];
            }
            
            var new_err = 1.0;
            var rand = 1.0;

            for (i = 0; i < x.length; i++) {
                new_err = Math.pow((frac[i]/foutout[i]),2)*varinout[i] + Math.pow((finout[i]*frac[i]/Math.pow(foutout[i],2)),2)*varoutout[i];
                new_err = Math.sqrt(new_err)*Math.sqrt(og_ntran[i]/ntran);
                rand = new_err*(Math.random()-Math.random());
                y[i] = p[i]*((1.0 - frac[i]*finout[i]/foutout[i]) + rand); 
                x[i] = xout[i];
                x_err[i][0] = xout[i];
                x_err[i][1] = xout[i];
                y_err[i][0] = y[i] + new_err;
                y_err[i][1] = y[i] - new_err;            
            }

            source.trigger('change');
        """)

    #var_tot = (frac/electrons_out)**2.0 * var_in + (electrons_in*frac/electrons_out**2.0)**2.0 * var_out

    sliderWbin =  Slider(title="binning", value=np.log10(x[1]-x[0]), start=np.log10(x[1]-x[0]), end=np.log10(max(x)/2.0), step= .05, callback=callback)
    callback.args["wbin"] = sliderWbin
    sliderTrans =  Slider(title="Num Trans", value=noccultations, start=1, end=50, step= 1, callback=callback)
    callback.args["ntran"] = sliderTrans
    layout = column(row(sliderWbin,sliderTrans), plot_spectrum)


    #out of transit 2d output 
    out = result_dict['PandeiaOutTrans']
    
    # Flux 1d
    x, y = out['1d']['extracted_flux']
    x = x[~np.isnan(y)]
    y = y[~np.isnan(y)]

    plot_flux_1d1 = Figure(tools=TOOLS,
                         x_axis_label='Wavelength [microns]',
                         y_axis_label='Flux (e/s)', title="Out of Transit Flux Rate",
                         plot_width=800, plot_height=300)
    plot_flux_1d1.line(x, y, line_width = 4, alpha = .7)
    tab1 = Panel(child=plot_flux_1d1, title="Total Flux")

    # BG 1d
    x, y = out['1d']['extracted_bg_only']
    y = y[~np.isnan(y)]
    x = x[~np.isnan(y)]
    plot_bg_1d1 = Figure(tools=TOOLS,
                         x_axis_label='Wavelength [microns]',
                         y_axis_label='Flux (e/s)', title="Background",
                         plot_width=800, plot_height=300)
    plot_bg_1d1.line(x, y, line_width = 4, alpha = .7)
    tab2 = Panel(child=plot_bg_1d1, title="Background Flux")

    # SNR 1d accounting for number of occultations
    x= out['1d']['sn'][0]
    y = out['1d']['sn'][1]
    x = x[~np.isnan(y)]
    y = y[~np.isnan(y)]
    y = y*np.sqrt(noccultations)
    plot_snr_1d1 = Figure(tools=TOOLS,
                         x_axis_label=x_axis_label,
                         y_axis_label='SNR', title="Pandeia SNR",
                         plot_width=800, plot_height=300)
    plot_snr_1d1.line(x, y, line_width = 4, alpha = .7)
    tab3 = Panel(child=plot_snr_1d1, title="SNR")


    # Error bars (ppm) 

    x = result_dict['FinalSpectrum']['wave']
    y = result_dict['FinalSpectrum']['error_w_floor']*1e6
    x = x[~np.isnan(y)]
    y = y[~np.isnan(y)]    
    ymed = np.median(y)


    plot_noise_1d1 = Figure(tools=TOOLS,#responsive=True,
                         x_axis_label=x_axis_label,
                         y_axis_label='Error on Spectrum (PPM)', title="Error Curve",
                         plot_width=800, plot_height=300, y_range = [0,2.0*ymed])
    ymed = np.median(y)
    plot_noise_1d1.circle(x, y, line_width = 4, alpha = .7)
    tab4 = Panel(child=plot_noise_1d1, title="Error")

    #Not happy? Need help picking a different mode? 
    plot_spectrum2 = Figure(plot_width=800, plot_height=300, x_range=xlims,y_range=ylims, tools=TOOLS,
                             x_axis_label=x_axis_label,
                             y_axis_label=punit, title="Original Model",y_axis_type="log")

    plot_spectrum2.line(result_dict['OriginalInput']['model_wave'],result_dict['OriginalInput']['model_spec'],
                        line_width = 4,alpha = .7)
    tab5 = Panel(child=plot_spectrum2, title="Original Model")


    #create set of five tabs 
    tabs1d = Tabs(tabs=[ tab1, tab2,tab3, tab4, tab5])



    # Detector 2d
    data = out['2d']['detector']

    
    xr, yr = data.shape
    
    plot_detector_2d = Figure(tools="pan,wheel_zoom,box_zoom,resize,reset,hover,save",
                         x_range=[0, yr], y_range=[0, xr],
                         x_axis_label='Pixel', y_axis_label='Spatial',
                         title="2D Detector Image",
                        plot_width=800, plot_height=300)
    
    plot_detector_2d.image(image=[data], x=[0], y=[0], dh=[xr], dw=[yr],
                      palette="Spectral11")


    #2d tabs 

    #2d snr 
    data = out['2d']['snr']
    data[np.isinf(data)] = 0.0
    xr, yr = data.shape
    plot_snr_2d = Figure(tools=TOOLS,
                         x_range=[0, yr], y_range=[0, xr],
                         x_axis_label='Pixel', y_axis_label='Spatial',
                         title="Signal-to-Noise Ratio",
                        plot_width=800, plot_height=300)
    
    plot_snr_2d.image(image=[data], x=[0], y=[0], dh=[xr], dw=[yr],
                      palette="Spectral11")
    
    tab1b = Panel(child=plot_snr_2d, title="SNR")

    #saturation
    
    data = out['2d']['saturation']
    xr, yr = data.shape
    plot_sat_2d = Figure(tools=TOOLS,
                         x_range=[0, yr], y_range=[0, xr],
                         x_axis_label='Pixel', y_axis_label='Spatial',
                         title="Saturation",
                        plot_width=800, plot_height=300)
    
    plot_sat_2d.image(image=[data], x=[0], y=[0], dh=[xr], dw=[yr],
                      palette="Spectral11")
    
    tab2b = Panel(child=plot_sat_2d, title="Saturation")

    tabs2d = Tabs(tabs=[ tab1b, tab2b])
    
 
    result_comp = components({'plot_spectrum':layout, 
                              'tabs1d': tabs1d, 'det_2d': plot_detector_2d,
                              'tabs2d': tabs2d})

    return result_comp
Exemple #23
0
def plot_cross_section_bokeh(filename, map_data_all_slices, map_depth_all_slices, \
                             color_range_all_slices, cross_data, boundary_data, \
                             style_parameter):
    '''
    Plot shear velocity maps and cross-sections using bokeh

    Input:
        filename is the filename of the resulting html file
        map_data_all_slices contains the velocity model parameters saved for map view plots
        map_depth_all_slices is a list of depths
        color_range_all_slices is a list of color ranges
        profile_data_all is a list of velocity profiles
        cross_lat_data_all is a list of cross-sections along latitude
        lat_value_all is a list of corresponding latitudes for these cross-sections
        cross_lon_data_all is a list of cross-sections along longitude
        lon_value_all is a list of corresponding longitudes for these cross-sections
        boundary_data is a list of boundaries
        style_parameter contains parameters to customize the plots

    Output:
        None
    
    '''
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    #
    colorbar_data_all_left = []
    colorbar_data_all_right = []
    map_view_ndepth = style_parameter['map_view_ndepth']
    palette_r = palette[::-1]
    ncolor = len(palette_r)
    colorbar_top = [0.1 for i in range(ncolor)]
    colorbar_bottom = [0 for i in range(ncolor)]
    map_data_all_slices_depth = []
    for idepth in range(map_view_ndepth): 
        color_min = color_range_all_slices[idepth][0]
        color_max = color_range_all_slices[idepth][1]
        color_step = (color_max - color_min)*1./ncolor
        colorbar_left = np.linspace(color_min,color_max-color_step,ncolor)
        colorbar_right = np.linspace(color_min+color_step,color_max,ncolor)
        colorbar_data_all_left.append(colorbar_left)
        colorbar_data_all_right.append(colorbar_right)
        map_depth = map_depth_all_slices[idepth]
        map_data_all_slices_depth.append('Depth: {0:8.0f} km'.format(map_depth))
    # data for the colorbar
    colorbar_data_one_slice = {}
    colorbar_data_one_slice['colorbar_left'] = colorbar_data_all_left[style_parameter['map_view_default_index']]
    colorbar_data_one_slice['colorbar_right'] = colorbar_data_all_right[style_parameter['map_view_default_index']]
    colorbar_data_one_slice_bokeh = ColumnDataSource(data=dict(colorbar_top=colorbar_top,colorbar_bottom=colorbar_bottom,\
                                                               colorbar_left=colorbar_data_one_slice['colorbar_left'],\
                                                               colorbar_right=colorbar_data_one_slice['colorbar_right'],\
                                                               palette_r=palette_r))
    colorbar_data_all_slices_bokeh = ColumnDataSource(data=dict(colorbar_data_all_left=colorbar_data_all_left,\
                                                                colorbar_data_all_right=colorbar_data_all_right))
    #
    map_view_label_lon = style_parameter['map_view_depth_label_lon']
    map_view_label_lat = style_parameter['map_view_depth_label_lat']
    map_data_one_slice_depth = map_data_all_slices_depth[style_parameter['map_view_default_index']]
    map_data_one_slice_depth_bokeh = ColumnDataSource(data=dict(lat=[map_view_label_lat], lon=[map_view_label_lon],
                                                           map_depth=[map_data_one_slice_depth]))
    
    #
    map_view_default_index = style_parameter['map_view_default_index']
    #map_data_one_slice = map_data_all_slices[map_view_default_index]
    map_color_all_slices = []
    for i in range(len(map_data_all_slices)):
        vmin, vmax = color_range_all_slices[i]
        map_color = val_to_rgb(map_data_all_slices[i], palette_r, vmin, vmax)
        map_color_all_slices.append(map_color)
    map_color_one_slice = map_color_all_slices[map_view_default_index]
    #
    map_data_one_slice_bokeh = ColumnDataSource(data=dict(x=[style_parameter['map_view_image_lon_min']],\
                   y=[style_parameter['map_view_image_lat_min']],dw=[style_parameter['nlon']],\
                   dh=[style_parameter['nlat']],map_data_one_slice=[map_color_one_slice]))
    map_data_all_slices_bokeh = ColumnDataSource(data=dict(map_data_all_slices=map_color_all_slices,\
                                                           map_data_all_slices_depth=map_data_all_slices_depth))
    #
    
    plot_depth = np.shape(cross_data)[0] * style_parameter['cross_ddepth']
    plot_lon = great_arc_distance(style_parameter['cross_default_lat0'], style_parameter['cross_default_lon0'],\
                                  style_parameter['cross_default_lat1'], style_parameter['cross_default_lon1'])
    
    vs_min = style_parameter['cross_view_vs_min']
    vs_max = style_parameter['cross_view_vs_max']
    cross_color = val_to_rgb(cross_data, palette_r, vmin, vmax)
    cross_data_bokeh = ColumnDataSource(data=dict(x=[0],\
                   y=[plot_depth],dw=[plot_lon],\
                   dh=[plot_depth],cross_data=[cross_color]))
    
    map_line_bokeh = ColumnDataSource(data=dict(lat=[style_parameter['cross_default_lat0'], style_parameter['cross_default_lat1']],\
                                                    lon=[style_parameter['cross_default_lon0'], style_parameter['cross_default_lon1']]))
    #
    ncolor_cross = len(my_palette)
    colorbar_top_cross = [0.1 for i in range(ncolor_cross)]
    colorbar_bottom_cross = [0 for i in range(ncolor_cross)]
    color_min_cross = style_parameter['cross_view_vs_min']
    color_max_cross = style_parameter['cross_view_vs_max']
    color_step_cross = (color_max_cross - color_min_cross)*1./ncolor_cross
    colorbar_left_cross = np.linspace(color_min_cross, color_max_cross-color_step_cross, ncolor_cross)
    colorbar_right_cross = np.linspace(color_min_cross+color_step_cross, color_max_cross, ncolor_cross)
    # ==============================
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], plot_height=style_parameter['map_view_plot_height'], \
                      tools=style_parameter['map_view_tools'], title=style_parameter['map_view_title'], \
                      y_range=[style_parameter['map_view_figure_lat_min'], style_parameter['map_view_figure_lat_max']],\
                      x_range=[style_parameter['map_view_figure_lon_min'], style_parameter['map_view_figure_lon_max']])
    #
    map_view.image_rgba('map_data_one_slice',x='x',\
                   y='y',dw='dw',dh='dh',\
                   source=map_data_one_slice_bokeh, level='image')

    depth_slider_callback = CustomJS(args=dict(map_data_one_slice_bokeh=map_data_one_slice_bokeh,\
                                               map_data_all_slices_bokeh=map_data_all_slices_bokeh,\
                                               colorbar_data_all_slices_bokeh=colorbar_data_all_slices_bokeh,\
                                               colorbar_data_one_slice_bokeh=colorbar_data_one_slice_bokeh,\
                                               map_data_one_slice_depth_bokeh=map_data_one_slice_depth_bokeh), code="""

        var d_index = Math.round(cb_obj.value)
        
        var map_data_all_slices = map_data_all_slices_bokeh.data
        
        map_data_one_slice_bokeh.data['map_data_one_slice'] = [map_data_all_slices['map_data_all_slices'][d_index]]
        map_data_one_slice_bokeh.change.emit()
        
        var color_data_all_slices = colorbar_data_all_slices_bokeh.data
        colorbar_data_one_slice_bokeh.data['colorbar_left'] = color_data_all_slices['colorbar_data_all_left'][d_index]
        colorbar_data_one_slice_bokeh.data['colorbar_right'] = color_data_all_slices['colorbar_data_all_right'][d_index]
        colorbar_data_one_slice_bokeh.change.emit()
        
        map_data_one_slice_depth_bokeh.data['map_depth'] = [map_data_all_slices['map_data_all_slices_depth'][d_index]]
        map_data_one_slice_depth_bokeh.change.emit()
        
    """) 
    depth_slider = Slider(start=0, end=style_parameter['map_view_ndepth']-1, \
                          value=map_view_default_index, step=1, \
                          width=style_parameter['map_view_plot_width'],\
                          title=style_parameter['depth_slider_title'], height=50, \
                          callback=depth_slider_callback)
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='black',\
                        line_width=2, level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='black',\
                        level='underlay',nonselection_line_alpha=1.0,\
                        nonselection_line_color='black')
     # ------------------------------
    # add depth label
    map_view.rect(style_parameter['map_view_depth_box_lon'], style_parameter['map_view_depth_box_lat'], \
                  width=style_parameter['map_view_depth_box_width'], height=style_parameter['map_view_depth_box_height'], \
                  width_units='screen',height_units='screen', color='#FFFFFF', line_width=1., line_color='black', level='underlay')
    map_view.text('lon', 'lat', 'map_depth', source=map_data_one_slice_depth_bokeh,\
                  text_font_size=style_parameter['annotating_text_font_size'],text_align='left',level='underlay')
    # ------------------------------
    map_view.line('lon', 'lat', source=map_line_bokeh, line_dash=[8,2,8,2], line_color='#00ff00',\
                        nonselection_line_alpha=1.0, line_width=5.,\
                        nonselection_line_color='black')
    map_view.text([style_parameter['cross_default_lon0']],[style_parameter['cross_default_lat0']], ['A'], \
            text_font_size=style_parameter['title_font_size'],text_align='left')
    map_view.text([style_parameter['cross_default_lon1']],[style_parameter['cross_default_lat1']], ['B'], \
            text_font_size=style_parameter['title_font_size'],text_align='left')
    # ------------------------------
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # ==============================
    # plot colorbar
    
    colorbar_fig = Figure(tools=[], y_range=(0,0.1),plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig.toolbar_location=None
    colorbar_fig.quad(top='colorbar_top',bottom='colorbar_bottom',left='colorbar_left',right='colorbar_right',\
                  color='palette_r',source=colorbar_data_one_slice_bokeh)
    colorbar_fig.yaxis[0].ticker=FixedTicker(ticks=[])
    colorbar_fig.xgrid.grid_line_color = None
    colorbar_fig.ygrid.grid_line_color = None
    colorbar_fig.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig.title.text_font_size = xlabel_fontsize
    colorbar_fig.title.align = 'center'
    colorbar_fig.title.text_font_style = 'normal'
     # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text="""<p style="font-size:16px">""", \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    # plot cross-section along latitude
    cross_section_plot_width = int(style_parameter['cross_plot_height']*1.0/plot_depth*plot_lon/10.)
    cross_view = Figure(plot_width=cross_section_plot_width, plot_height=style_parameter['cross_plot_height'], \
                      tools=style_parameter['cross_view_tools'], title=style_parameter['cross_view_title'], \
                      y_range=[plot_depth, -30],\
                      x_range=[0, plot_lon])
    cross_view.image_rgba('cross_data',x='x',\
                   y='y',dw='dw',dh='dh',\
                   source=cross_data_bokeh, level='image')
    cross_view.text([plot_lon*0.1], [-10], ['A'], \
                text_font_size=style_parameter['title_font_size'],text_align='left',level='underlay')
    cross_view.text([plot_lon*0.9], [-10], ['B'], \
                text_font_size=style_parameter['title_font_size'],text_align='left',level='underlay')
    # ------------------------------
    # change style
    cross_view.title.text_font_size = style_parameter['title_font_size']
    cross_view.title.align = 'center'
    cross_view.title.text_font_style = 'normal'
    cross_view.xaxis.axis_label = style_parameter['cross_view_xlabel']
    cross_view.xaxis.axis_label_text_font_style = 'normal'
    cross_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    cross_view.xaxis.major_label_text_font_size = xlabel_fontsize
    cross_view.yaxis.axis_label = style_parameter['cross_view_ylabel']
    cross_view.yaxis.axis_label_text_font_style = 'normal'
    cross_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    cross_view.yaxis.major_label_text_font_size = xlabel_fontsize
    cross_view.xgrid.grid_line_color = None
    cross_view.ygrid.grid_line_color = None
    cross_view.toolbar.logo = None
    cross_view.toolbar_location = 'right'
    cross_view.toolbar_sticky = False
    # ==============================
    colorbar_fig_right = Figure(tools=[], y_range=(0,0.1),plot_width=cross_section_plot_width, \
                      plot_height=style_parameter['colorbar_plot_height'],title=style_parameter['colorbar_title'])
    colorbar_fig_right.toolbar_location=None
    
    colorbar_fig_right.quad(top=colorbar_top_cross,bottom=colorbar_bottom_cross,\
                            left=colorbar_left_cross,right=colorbar_right_cross,\
                            color=my_palette)
    colorbar_fig_right.yaxis[0].ticker=FixedTicker(ticks=[])
    colorbar_fig_right.xgrid.grid_line_color = None
    colorbar_fig_right.ygrid.grid_line_color = None
    colorbar_fig_right.xaxis.axis_label_text_font_size = xlabel_fontsize
    colorbar_fig_right.xaxis.major_label_text_font_size = xlabel_fontsize
    colorbar_fig_right.xaxis[0].formatter = PrintfTickFormatter(format="%5.2f")
    colorbar_fig_right.title.text_font_size = xlabel_fontsize
    colorbar_fig_right.title.align = 'center'
    colorbar_fig_right.title.text_font_style = 'normal'
    # ==============================
    output_file(filename,title=style_parameter['html_title'], mode=style_parameter['library_source'])
    left_column = Column(depth_slider, map_view, colorbar_fig, annotating_fig01, width=style_parameter['left_column_width'])
    
    right_column = Column(annotating_fig02, cross_view, colorbar_fig_right, width=cross_section_plot_width)
    layout = Row(left_column, right_column, height=800)
    save(layout)
Exemple #24
0
    return {'year': tf.year, 'area': tf.area, 'region': tf.region}


#
source = ColumnDataSource(data=get_data('감자'))

# figure 생성부
p1 = Figure(plot_height=300,
            plot_width=1000,
            x_axis_label='Year',
            y_axis_label='Area',
            toolbar_location='above',
            title='작물별 재배면적 시계열')
p1.multi_line('year',
              'area',
              alpha=1,
              color='color',
              legend='region',
              source=source)
p2 = Figure(plot_height=300,
            plot_width=1000,
            x_axis_label='Year',
            y_axis_label='Production',
            toolbar_location='above',
            title='작물별 생산량 시계열')
p2.multi_line('year',
              'product',
              alpha=1,
              color='color',
              legend='region',
              source=source)
p3 = Figure(plot_height=300,
Exemple #25
0
# Set up data
N = 0
x = np.arange(len(plot_data[N]) / 10)
y = spike_waveforms[0, ::10]
source = ColumnDataSource(data=dict(
    xs=[x for i in range(50)], ys=[plot_data[N + i, ::10] for i in range(50)]))

# Set up plot
plot = Figure(plot_height=400,
              plot_width=400,
              title="Unit waveforms",
              tools="crosshair,pan,reset,resize,save,wheel_zoom",
              x_range=[0, 45],
              y_range=[-200, 200])

plot.multi_line('xs', 'ys', source=source, line_width=1, line_alpha=1.0)

# Set up widgets
# text = TextInput(title="title", value='my sine wave')
offset = Slider(
    title="offset", value=0, start=0, end=50000, step=100
)  # put the end of the slider at a large enough value so that almost all cluster sizes will fit in
electrode = TextInput(title='Electrode Number', value='0')
clusters = TextInput(title='Number of clusters', value='2')
cluster_num = TextInput(title='Cluster Number', value='0')
#amplitude = Slider(title="amplitude", value=1.0, start=-5.0, end=5.0)
#phase = Slider(title="phase", value=0.0, start=0.0, end=2*np.pi)
#freq = Slider(title="frequency", value=1.0, start=0.1, end=5.1)


def update_data(attrname, old, new):
Exemple #26
0
def hst_time(result_dict, plot=True, output_file='hsttime.html', model=True):
    """Plot earliest and latest start times for hst observation

    Parameters
    ----------
    result_dict : dict
        Dictionary from pandexo output.

    plot : bool
        (Optional) True renders plot, False does not. Default=True
    model : bool
        (Optional) Plot model under data. Default=True
    output_file : str
        (Optional) Default = 'hsttime.html'

    Return
    ------
    obsphase1 : numpy array
        earliest start time
    obstr1 : numpy array
        white light curve
    obsphase2 : numpy array
        latest start time
    obstr2 : numpy array
        white light curve
    rms : numpy array
        1D rms noise

    See Also
    --------
    hst_spec
    """
    TOOLS = "pan,wheel_zoom,box_zoom,reset,save"
    #earliest and latest start times
    obsphase1 = result_dict['calc_start_window']['obsphase1']
    obstr1 = result_dict['calc_start_window']['obstr1']
    rms = result_dict['calc_start_window']['light_curve_rms']
    obsphase2 = result_dict['calc_start_window']['obsphase2']
    obstr2 = result_dict['calc_start_window']['obstr2']
    phase1 = result_dict['calc_start_window']['phase1']
    phase2 = result_dict['calc_start_window']['phase2']
    trmodel1 = result_dict['calc_start_window']['trmodel1']
    trmodel2 = result_dict['calc_start_window']['trmodel2']

    if isinstance(rms, float):
        rms = np.zeros(len(obsphase1)) + rms
    y_err1 = []
    x_err1 = []
    for px, py, yerr in zip(obsphase1, obstr1, rms):
        np.array(x_err1.append((px, px)))
        np.array(y_err1.append((py - yerr, py + yerr)))

    y_err2 = []
    x_err2 = []
    for px, py, yerr in zip(obsphase2, obstr2, rms):
        np.array(x_err2.append((px, px)))
        np.array(y_err2.append((py - yerr, py + yerr)))

    early = Figure(
        plot_width=400,
        plot_height=300,
        tools=TOOLS,  #responsive=True,
        x_axis_label='Orbital Phase',
        y_axis_label='Flux',
        title="Earliest Start Time")

    if model:
        early.line(phase1, trmodel1, color='black', alpha=0.5, line_width=4)
    early.circle(obsphase1, obstr1, line_width=3, line_alpha=0.6)
    early.multi_line(x_err1, y_err1)

    late = Figure(
        plot_width=400,
        plot_height=300,
        tools=TOOLS,  #responsive=True,
        x_axis_label='Orbital Phase',
        y_axis_label='Flux',
        title="Latest Start Time")
    if model:
        late.line(phase2, trmodel2, color='black', alpha=0.5, line_width=3)
    late.circle(obsphase2, obstr2, line_width=3, line_alpha=0.6)
    late.multi_line(x_err2, y_err2)

    start_time = row(early, late)

    if plot:
        outputfile(output_file)
        show(start_time)

    return obsphase1, obstr1, obsphase2, obstr2, rms
          line_width=2,
          line_dash='dashed',
          line_alpha=0.3)
plot.line(x='x2',
          y='y2',
          source=MBD.f2.wdline21,
          color=MBD.f2color,
          line_width=2,
          line_dash='dashed',
          line_alpha=0.3)
#EDIT End

#creation of the a and b scale reference things:
plot.multi_line(
    [[orig.x0, orig.xf], [orig.x0, orig.x0], [orig.xf, orig.xf]],
    [[0, 0], [0 - abshift, 0 + abshift], [0 - abshift, 0 + abshift]],
    color=["black", "black", "black"],
    line_width=1)
plot.multi_line(
    [[xb, xb], [xb - abshift, xb + abshift], [xb - abshift, xb + abshift]],
    [[orig.y0, orig.yf], [orig.y0, orig.y0], [orig.yf, orig.yf]],
    color=["black", "black", "black"],
    line_width=1)

#Frame bases
plot.triangle(x='x',
              y='y',
              size='size',
              source=default,
              color="grey",
              line_width=2)
Exemple #28
0
def jwst_1d_spec(result_dict,
                 model=True,
                 title='Model + Data + Error Bars',
                 output_file='data.html',
                 legend=False,
                 R=False,
                 num_tran=False,
                 plot_width=800,
                 plot_height=400,
                 x_range=[1, 10],
                 y_range=None,
                 plot=True,
                 output_notebook=False):
    """Plots 1d simulated spectrum and rebin or rescale for more transits

    Plots 1d data points with model in the background (if wanted). Designed to read in exact
    output of run_pandexo.

    Parameters
    ----------
    result_dict : dict or list of dict
        Dictionary from pandexo output. If parameter space was run in run_pandexo
        make sure to restructure the input as a list of dictionaries without they key words
        that run_pandexo assigns.
    model : bool
        (Optional) True is default. True plots model, False does not plot model
    title : str
        (Optional) Title of plot. Default is "Model + Data + Error Bars".
    output_file : str
        (Optional) name of html file for you bokeh plot. After bokeh plot is rendered you will
        have the option to save as png.
    legend : bool
        (Optional) Default is False. True, plots legend.
    R : float
        (Optional) Rebin data from native instrument resolution to specified resolution. Dafult is False,
        no binning. Here I adopt R as w[1]/(w[2] - w[0]) to maintain consistency with `pandeia.engine`
    num_tran : float
        (Optional) Scales data by number of transits to improve error by sqrt(`num_trans`)
    plot_width : int
        (Optional) Sets the width of the plot. Default = 800
    plot_height : int
        (Optional) Sets the height of the plot. Default = 400
    y_range : list of int
        (Optional) sets y range of plot. Defaut is +- 10% of max and min
    x_range : list of int
        (Optional) Sets x range of plot. Default = [1,10]
    plot : bool
        (Optional) Supresses the plot if not wanted (Default = True)
    out_notebook : bool 
        (Optional) Output notebook. Default is false, if true, outputs in the notebook

    Returns
    -------
    x,y,e : list of arrays
        Returns wave axis, spectrum and associated error in list format. x[0] will be correspond
        to the first dictionary input, x[1] to the second, etc.

    Examples
    --------

    >>> jwst_1d_spec(result_dict, num_tran = 3, R = 35) #for a single plot

    If you wanted to save each of the axis that were being plotted:

    >>> x,y,e = jwst_1d_data([result_dict1, result_dict2], model=False, num_tran = 5, R = 100) #for multiple

    See Also
    --------
    jwst_noise, jwst_1d_bkg, jwst_1d_flux, jwst_1d_snr, jwst_2d_det, jwst_2d_sat

    """
    outx = []
    outy = []
    oute = []
    TOOLS = "pan,wheel_zoom,box_zoom,reset,save"
    if output_notebook:
        outnotebook()
    else:
        outputfile(output_file)
    colors = [
        'black', 'blue', 'red', 'orange', 'yellow', 'purple', 'pink', 'cyan',
        'grey', 'brown'
    ]
    #make sure its iterable
    if type(result_dict) != list:
        result_dict = [result_dict]

    if type(legend) != bool:
        legend_keys = legend
        legend = True
        if type(legend_keys) != list:
            legend_keys = [legend_keys]

    i = 0
    for dictt in result_dict:
        ntran_old = dictt['timing']['Number of Transits']
        to = dictt['timing']["Num Integrations Out of Transit"]
        ti = dictt['timing']["Num Integrations In Transit"]
        #remove any nans
        y = dictt['FinalSpectrum']['spectrum_w_rand']
        x = dictt['FinalSpectrum']['wave'][~np.isnan(y)]
        err = dictt['FinalSpectrum']['error_w_floor'][~np.isnan(y)]
        y = y[~np.isnan(y)]

        if (R == False) & (num_tran == False):
            x = x
            y = y
        elif (R != False) & (num_tran != False):
            new_wave = bin_wave_to_R(x, R)
            out = uniform_tophat_sum(
                new_wave, x,
                dictt['RawData']['electrons_out'] * num_tran / ntran_old)
            inn = uniform_tophat_sum(
                new_wave, x,
                dictt['RawData']['electrons_in'] * num_tran / ntran_old)
            vout = uniform_tophat_sum(
                new_wave, x,
                dictt['RawData']['var_out'] * num_tran / ntran_old)
            vin = uniform_tophat_sum(
                new_wave, x, dictt['RawData']['var_in'] * num_tran / ntran_old)
            var_tot = (to / ti / out)**2.0 * vin + (inn * to / ti /
                                                    out**2.0)**2.0 * vout
            if dictt['input']['Primary/Secondary'] == 'fp/f*':
                fac = -1.0
            else:
                fac = 1.0
            rand_noise = np.sqrt((var_tot)) * (np.random.randn(len(new_wave)))
            raw_spec = (out / to - inn / ti) / (out / to)
            sim_spec = fac * (raw_spec + rand_noise)
            x = new_wave
            y = sim_spec
            err = np.sqrt(var_tot)
        elif (R == False) & (num_tran != False):
            out = dictt['RawData']['electrons_out'] * num_tran / ntran_old
            inn = dictt['RawData']['electrons_in'] * num_tran / ntran_old
            vout = dictt['RawData']['var_out'] * num_tran / ntran_old
            vin = dictt['RawData']['var_in'] * num_tran / ntran_old
            var_tot = (to / ti / out)**2.0 * vin + (inn * to / ti /
                                                    out**2.0)**2.0 * vout
            if dictt['input']['Primary/Secondary'] == 'fp/f*':
                fac = -1.0
            else:
                fac = 1.0
            rand_noise = np.sqrt((var_tot)) * (np.random.randn(len(x)))
            raw_spec = (out / to - inn / ti) / (out / to)
            sim_spec = fac * (raw_spec + rand_noise)
            x = x
            y = sim_spec
            err = np.sqrt(var_tot)
        elif (R != False) & (num_tran == False):
            new_wave = bin_wave_to_R(x, R)
            out = uniform_tophat_sum(new_wave, x,
                                     dictt['RawData']['electrons_out'])
            inn = uniform_tophat_sum(new_wave, x,
                                     dictt['RawData']['electrons_in'])
            vout = uniform_tophat_sum(new_wave, x, dictt['RawData']['var_out'])
            vin = uniform_tophat_sum(new_wave, x, dictt['RawData']['var_in'])
            var_tot = (to / ti / out)**2.0 * vin + (inn * to / ti /
                                                    out**2.0)**2.0 * vout
            if dictt['input']['Primary/Secondary'] == 'fp/f*':
                fac = -1.0
            else:
                fac = 1.0
            rand_noise = np.sqrt((var_tot)) * (np.random.randn(len(new_wave)))
            raw_spec = (out / to - inn / ti) / (out / to)
            sim_spec = fac * (raw_spec + rand_noise)
            x = new_wave
            y = sim_spec
            err = np.sqrt(var_tot)
        else:
            print(
                "Something went wrong. Cannot enter both resolution and ask to bin to new wave"
            )
            return

        #create error bars for Bokeh's multi_line and drop nans
        data = pd.DataFrame({'x': x, 'y': y, 'err': err}).dropna()

        y_err = []
        x_err = []
        for px, py, yerr in zip(data['x'], data['y'], data['err']):
            np.array(x_err.append((px, px)))
            np.array(y_err.append((py - yerr, py + yerr)))
        #initialize Figure
        if i == 0:
            #Define units for x and y axis
            y_axis_label = dictt['input']['Primary/Secondary']

            if y_axis_label == 'fp/f*': p = -1.0
            else: y_axis_label = y_axis_label

            if dictt['input']['Calculation Type'] == 'phase_spec':
                x_axis_label = 'Time (secs)'
                x_range = [min(x), max(x)]
            else:
                x_axis_label = 'Wavelength [microns]'

            if y_range != None:
                ylims = y_range
            else:
                ylims = [
                    min(dictt['OriginalInput']['model_spec']) -
                    0.1 * min(dictt['OriginalInput']['model_spec']),
                    0.1 * max(dictt['OriginalInput']['model_spec']) +
                    max(dictt['OriginalInput']['model_spec'])
                ]

            fig1d = Figure(x_range=x_range,
                           y_range=ylims,
                           plot_width=plot_width,
                           plot_height=plot_height,
                           title=title,
                           x_axis_label=x_axis_label,
                           y_axis_label=y_axis_label,
                           tools=TOOLS,
                           background_fill_color='white')

        #plot model, data, and errors
        if model:
            mxx = dictt['OriginalInput']['model_wave']
            myy = dictt['OriginalInput']['model_spec']
            my = uniform_tophat_mean(x, mxx, myy)
            model_line = pd.DataFrame({'x': x, 'my': my}).dropna()
            fig1d.line(model_line['x'],
                       model_line['my'],
                       color=colors[i],
                       alpha=0.2,
                       line_width=4)

        if legend:
            fig1d.circle(data['x'],
                         data['y'],
                         color=colors[i],
                         legend=legend_keys[i])
        else:
            fig1d.circle(data['x'], data['y'], color=colors[i])
        outx += [data['x'].values]
        outy += [data['y'].values]
        oute += [data['err'].values]
        fig1d.multi_line(x_err, y_err, color=colors[i])
        i += 1
    if plot:
        show(fig1d)
    return outx, outy, oute
def tab_learning():

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)

    select_data = Select(title="Data:", value="", options=[])
    select_model = Select(title="Model script:", value="", options=[])

    data_descriptor = Paragraph(text=""" Data descriptor """, width=250, height=250)
    model_descriptor = Paragraph(text=""" Model descriptor """, width=250, height=250)

    select_grid = gridplot([[select_data, select_model], [data_descriptor, model_descriptor]])

    problem_type = RadioGroup(labels=["Classification", "Regression"], active=0)

    def problem_handler(new):
        if(new == 0):
            select_data.options = glob.glob('./np/Classification/*.npy')
            select_model.options = list(filter(lambda x: 'model_' in x, dir(classification)))

        elif(new == 1):
            select_data.options = glob.glob('./np/Regression/*.npy')
            select_model.options = list(filter(lambda x: 'model_' in x, dir(regression)))

    problem_type.on_click(problem_handler)

    learning_rate = TextInput(value="0.01", title="Learning rate")
    epoch_size = Slider(start=2, end=200, value=5, step=1, title="Epoch")
    batch_size = Slider(start=16, end=256, value=64, step=1, title="Batch")
    model_insert = TextInput(value="model", title="Model name")
    opeimizer = Select(title="Optimizer:", value="", options=["SGD", "ADAM", "RMS"])

    hyper_param = gridplot([[learning_rate], [epoch_size], [batch_size], [opeimizer], [model_insert]])

    xs = [[1], [1]]
    ys = [[1], [1]]
    label = [['Train loss'], ['Validation loss']]
    color = [['blue'], ['green']]
    total_loss_src = ColumnDataSource(data=dict(xs=xs, ys=ys, label=label, color=color))
    plot2 = Figure(plot_width=500, plot_height=300)
    plot2.multi_line('xs', 'ys', color='color', source=total_loss_src, line_width=3, line_alpha=0.6)
    TOOLTIPS = [("loss type", "@label"), ("loss value", "$y")]
    plot2.add_tools(HoverTool(tooltips=TOOLTIPS))
    t = Title()
    t.text = 'Loss'
    plot2.title = t

    acc_src = ColumnDataSource(data=dict(x=[1], y=[1], label=['R^2 score']))
    plot_acc = Figure(plot_width=500, plot_height=300, title="Accuracy")
    plot_acc.line('x', 'y', source=acc_src, line_width=3, line_alpha=0.7, color='red')
    TOOLTIPS = [("Type ", "@label"), ("Accuracy value", "$y")]
    plot_acc.add_tools(HoverTool(tooltips=TOOLTIPS))
    acc_list = []

    notifier = Paragraph(text=""" Notification """, width=200, height=100)

    def learning_handler():
        print("Start learning")
        del acc_list[:]

        tf.reset_default_graph()
        K.clear_session()

        data = np.load(select_data.value)
        data = data.item()

        print("data load complete")

        time_window = data.get('x').shape[-2]
        model_name = model_insert.value
        model_name = '(' + str(time_window) + ')' + model_name

        if (problem_type.active == 0):
            sub_path = 'Classification/'
        elif (problem_type.active == 1):
            sub_path = 'Regression/'

        model_save_dir = './model/' + sub_path + model_name + '/'
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)

        x_shape = list(data.get('x').shape)
        print("Optimizer: " + str(opeimizer.value))

        print(select_model.value)

        if (problem_type.active == 0):
            target_model = getattr(classification, select_model.value)
            model = target_model(x_shape[-3], x_shape[-2], float(learning_rate.value), str(opeimizer.value),
                                 data.get('y').shape[-1])
        elif (problem_type.active == 1):
            target_model = getattr(regression, select_model.value)
            model = target_model(x_shape[-3], x_shape[-2], float(learning_rate.value), str(opeimizer.value),
                                 data.get('y').shape[-1])

        notifier.text = """ get model """

        training_epochs = int(epoch_size.value)
        batch = int(batch_size.value)
        loss_train = []
        loss_val = []

        train_ratio = 0.8
        train_x = data.get('x')
        train_y = data.get('y')
        length = train_x.shape[0]

        print(train_x.shape)

        data_descriptor.text = "Data shape: " + str(train_x.shape)
        # model_descriptor.text = "Model layer: " + str(model.model.summary())

        val_x = train_x[int(length * train_ratio):]
        if(val_x.shape[-1] == 1 and not 'cnn' in select_model.value):
            val_x = np.squeeze(val_x, -1)
        val_y = train_y[int(length * train_ratio):]

        train_x = train_x[:int(length * train_ratio)]
        if (train_x.shape[-1] == 1 and not 'cnn' in select_model.value):
            train_x = np.squeeze(train_x, -1)
        train_y = train_y[:int(length * train_ratio)]

        print(train_x.shape)

        if('model_dl' in select_model.value):
            for epoch in range(training_epochs):
                notifier.text = """ learning -- epoch: """ + str(epoch)

                hist = model.fit(train_x,
                                 train_y,
                                 epochs=1,
                                 batch_size=batch,
                                 validation_data=(val_x, val_y),
                                 verbose=1)

                print("%d epoch's cost:  %f" % (epoch, hist.history['loss'][0]))
                loss_train.append(hist.history['loss'][0])
                loss_val.append(hist.history['val_loss'][0])

                xs_temp = []
                xs_temp.append([i for i in range(epoch + 1)])
                xs_temp.append([i for i in range(epoch + 1)])

                ys_temp = []
                ys_temp.append(loss_train)
                ys_temp.append(loss_val)

                total_loss_src.data['xs'] = xs_temp
                total_loss_src.data['ys'] = ys_temp

                if (problem_type.active == 0):
                    r2 = hist.history['val_acc'][0]
                    label_str = 'Class accuracy'
                elif (problem_type.active == 1):
                    pred_y = model.predict(val_x)
                    r2 = r2_score(val_y, pred_y)
                    label_str = 'R^2 score'

                print("%d epoch's acc:  %f" % (epoch, r2))
                acc_list.append(np.max([r2, 0]))

                acc_src.data['x'] = [i for i in range(epoch+1)]
                acc_src.data['y'] = acc_list
                acc_src.data['label'] = [label_str for _ in range(epoch + 1)]

                print(acc_src.data)

            notifier.text = """ learning complete """
            model.save(model_save_dir + model_name + '.h5')
            notifier.text = """ model save complete """

            K.clear_session()

        elif('model_ml' in select_model.value):
            notifier.text = """ Machine learning model """

            if(train_x.shape[-2] != 1):
                notifier.text = """ Data include more then one time-frame. \n\n Data will automatically be flatten"""

            train_x = train_x.reshape([train_x.shape[0], -1])
            val_x = val_x.reshape([val_x.shape[0], -1])

            ##### shit
            if (problem_type.active == 0):
                train_y = np.argmax(train_y, axis=-1).astype(float)

            print(train_x.shape)
            print(train_y.shape)

            model.fit(train_x, train_y)
            notifier.text = """ Training done """
            pred_y = model.predict(val_x)

            print(pred_y)

            pickle.dump(model, open(model_save_dir + model_name + '.sav', 'wb'))
            notifier.text = """ Machine learning model saved """


    button_learning = Button(label="Run model")
    button_learning.on_click(learning_handler)

    learning_grid = gridplot(
        [[problem_type],
         [select_grid, hyper_param, button_learning, notifier],
         [plot2, plot_acc]])

    tab = Panel(child=learning_grid, title='Learning')

    return tab
Exemple #30
0
def create_component_jwst(result_dict):
    """Generate front end plots JWST
    
    Function that is responsible for generating the front-end interactive plots for JWST.

    Parameters 
    ----------
    result_dict : dict 
        the dictionary returned from a PandExo run
    
    Returns
    -------
    tuple 
        A tuple containing `(script, div)`, where the `script` is the
        front-end javascript required, and `div` is a dictionary of plot
        objects.
    """  
    timing = result_dict['timing']
    noccultations = result_dict['timing']['Number of Transits']
    out = result_dict['PandeiaOutTrans']
    
    # select the tools we want
    TOOLS = "pan,wheel_zoom,box_zoom,resize,reset,save"

    #Define units for x and y axis
    punit = result_dict['input']['Primary/Secondary']
    p=1.0
    if punit == 'fp/f*': p = -1.0
    else: punit = punit
    
    if result_dict['input']['Calculation Type'] =='phase_spec':
        x_axis_label='Time (secs)'
        frac = 1.0
    else:
        x_axis_label='Wavelength [microns]'
        frac = result_dict['timing']['Num Integrations Out of Transit']/result_dict['timing']['Num Integrations In Transit']

    electrons_out = result_dict['RawData']['electrons_out']
    electrons_in = result_dict['RawData']['electrons_in']
    
    var_in = result_dict['RawData']['var_in']
    var_out = result_dict['RawData']['var_out']
    
    
    x = result_dict['FinalSpectrum']['wave']
    y = result_dict['FinalSpectrum']['spectrum_w_rand']
    err = result_dict['FinalSpectrum']['error_w_floor']

    y_err = []
    x_err = []
    for px, py, yerr in zip(x, y, err):
        np.array(x_err.append((px, px)))
        np.array(y_err.append((py - yerr, py + yerr)))

    source = ColumnDataSource(data=dict(x=x, y=y, y_err=y_err, x_err=x_err, err=err, 
                                electrons_out=electrons_out, electrons_in=electrons_in, var_in=var_in, var_out=var_out, 
                                p=var_in*0+p,nocc=var_in*0+noccultations, frac = var_in*0+frac))
    original = ColumnDataSource(data=dict(x=x, y=y, y_err=y_err, x_err=x_err, err=err, electrons_out=electrons_out, electrons_in=electrons_in, var_in=var_in, var_out=var_out))

    ylims = [min(result_dict['OriginalInput']['model_spec'])- 0.1*min(result_dict['OriginalInput']['model_spec']),
                 0.1*max(result_dict['OriginalInput']['model_spec'])+max(result_dict['OriginalInput']['model_spec'])]
    xlims = [min(result_dict['FinalSpectrum']['wave']), max(result_dict['FinalSpectrum']['wave'])]

    plot_spectrum = Figure(plot_width=800, plot_height=300, x_range=xlims,
                               y_range=ylims, tools=TOOLS,#responsive=True,
                                 x_axis_label=x_axis_label,
                                 y_axis_label=punit, 
                               title="Original Model with Observation")
    
    plot_spectrum.line(result_dict['OriginalInput']['model_wave'],result_dict['OriginalInput']['model_spec'], color= "black", alpha = 0.5, line_width = 4)
        
    plot_spectrum.circle('x', 'y', source=source, line_width=3, line_alpha=0.6)
    plot_spectrum.multi_line('x_err', 'y_err', source=source)

    callback = CustomJS(args=dict(source=source, original=original), code="""
            // Grab some references to the data
            var sdata = source.get('data');
            var odata = original.get('data');

            // Create copies of the original data, store them as the source data
            sdata['x'] = odata['x'].slice(0);
            sdata['y'] = odata['y'].slice(0);

            sdata['y_err'] = odata['y_err'].slice(0);
            sdata['x_err'] = odata['x_err'].slice(0);
            sdata['err'] = odata['err'].slice(0);

            sdata['electrons_out'] = odata['electrons_out'].slice(0);
            sdata['electrons_in'] = odata['electrons_in'].slice(0);
            sdata['var_in'] = odata['var_in'].slice(0);
            sdata['var_out'] = odata['var_out'].slice(0);

            // Create some variables referencing the source data
            var x = sdata['x'];
            var y = sdata['y'];
            var y_err = sdata['y_err'];
            var x_err = sdata['x_err'];
            var err = sdata['err'];
            var p = sdata['p'];
            var frac = sdata['frac'];
            var og_ntran = sdata['nocc'];

            var electrons_out = sdata['electrons_out'];
            var electrons_in = sdata['electrons_in'];
            var var_in = sdata['var_in'];
            var var_out = sdata['var_out'];

            var f = wbin.get('value');
            var ntran = ntran.get('value');

            var wlength = Math.pow(10.0,f);

            var ind = [];
            ind.push(0);
            var start = 0;


            for (i = 0; i < x.length-1; i++) {
                if (x[i+1] - x[start] >= wlength) {
                    ind.push(i+1);
                    start = i;
                }
            }

            if (ind[ind.length-1] != x.length) {
                ind.push(x.length);
            }

            var xout = [];


            var foutout = [];
            var finout = [];
            var varinout = [];
            var varoutout = [];

            var xslice = []; 

            var foutslice = [];
            var finslice = [];
            var varoutslice = [];
            var varinslice = [];

            function add(a, b) {
                return a+b;
            }

            for (i = 0; i < ind.length-1; i++) {
                xslice = x.slice(ind[i],ind[i+1]);

                foutslice = electrons_out.slice(ind[i],ind[i+1]);
                finslice = electrons_in.slice(ind[i],ind[i+1]);
                
                varinslice = var_in.slice(ind[i],ind[i+1]);
                varoutslice = var_out.slice(ind[i],ind[i+1]);

                xout.push(xslice.reduce(add, 0)/xslice.length);
                foutout.push(foutslice.reduce(add, 0));
                finout.push(finslice.reduce(add, 0));
                
                varinout.push(varinslice.reduce(add, 0));
                varoutout.push(varoutslice.reduce(add, 0));

                xslice = [];
                foutslice = [];
                finslice = [];
                varinslice = [];
                varoutslice = [];
            }
            
            var new_err = 1.0;
            var rand = 1.0;

            for (i = 0; i < x.length; i++) {
                new_err = Math.pow((frac[i]/foutout[i]),2)*varinout[i] + Math.pow((finout[i]*frac[i]/Math.pow(foutout[i],2)),2)*varoutout[i];
                new_err = Math.sqrt(new_err)*Math.sqrt(og_ntran[i]/ntran);
                rand = new_err*(Math.random()-Math.random());
                y[i] = p[i]*((1.0 - frac[i]*finout[i]/foutout[i]) + rand); 
                x[i] = xout[i];
                x_err[i][0] = xout[i];
                x_err[i][1] = xout[i];
                y_err[i][0] = y[i] + new_err;
                y_err[i][1] = y[i] - new_err;            
            }

            source.trigger('change');
        """)

    #var_tot = (frac/electrons_out)**2.0 * var_in + (electrons_in*frac/electrons_out**2.0)**2.0 * var_out

    sliderWbin =  Slider(title="binning", value=np.log10(x[1]-x[0]), start=np.log10(x[1]-x[0]), end=np.log10(max(x)/2.0), step= .05, callback=callback)
    callback.args["wbin"] = sliderWbin
    sliderTrans =  Slider(title="Num Trans", value=noccultations, start=1, end=50, step= 1, callback=callback)
    callback.args["ntran"] = sliderTrans
    layout = column(row(sliderWbin,sliderTrans), plot_spectrum)


    #out of transit 2d output 
    raw = result_dict['RawData']
    
    # Flux 1d
    x, y = raw['wave'], raw['e_rate_out']*result_dict['timing']['Seconds per Frame']*(timing["APT: Num Groups per Integration"]-1)
    x = x[~np.isnan(y)]
    y = y[~np.isnan(y)]

    plot_flux_1d1 = Figure(tools=TOOLS,
                         x_axis_label='Wavelength [microns]',
                         y_axis_label='e-/integration', title="Flux Per Integration",
                         plot_width=800, plot_height=300)
    plot_flux_1d1.line(x, y, line_width = 4, alpha = .7)
    tab1 = Panel(child=plot_flux_1d1, title="Flux per Int")

    # BG 1d
    #x, y = out['1d']['extracted_bg_only']
    #y = y[~np.isnan(y)]
    #x = x[~np.isnan(y)]
    #plot_bg_1d1 = Figure(tools=TOOLS,
    #                     x_axis_label='Wavelength [microns]',
    #                     y_axis_label='Flux (e/s)', title="Background",
    #                     plot_width=800, plot_height=300)
    #plot_bg_1d1.line(x, y, line_width = 4, alpha = .7)
    #tab2 = Panel(child=plot_bg_1d1, title="Background Flux")

    # SNR 
    y = np.sqrt(y) #this is computing the SNR (sqrt of photons in a single integration)


    plot_snr_1d1 = Figure(tools=TOOLS,
                         x_axis_label=x_axis_label,
                         y_axis_label='sqrt(e-)/integration', title="SNR per integration",
                         plot_width=800, plot_height=300)
    plot_snr_1d1.line(x, y, line_width = 4, alpha = .7)
    tab3 = Panel(child=plot_snr_1d1, title="SNR per Int")


    # Error bars (ppm) 
    x = result_dict['FinalSpectrum']['wave']
    y = result_dict['FinalSpectrum']['error_w_floor']*1e6
    x = x[~np.isnan(y)]
    y = y[~np.isnan(y)]    
    ymed = np.median(y)


    plot_noise_1d1 = Figure(tools=TOOLS,#responsive=True,
                         x_axis_label=x_axis_label,
                         y_axis_label='Spectral Precision (ppm)', title="Spectral Precision",
                         plot_width=800, plot_height=300, y_range = [0,2.0*ymed])
    ymed = np.median(y)
    plot_noise_1d1.circle(x, y, line_width = 4, alpha = .7)
    tab4 = Panel(child=plot_noise_1d1, title="Precision")

    #Not happy? Need help picking a different mode? 
    plot_spectrum2 = Figure(plot_width=800, plot_height=300, x_range=xlims,y_range=ylims, tools=TOOLS,
                             x_axis_label=x_axis_label,
                             y_axis_label=punit, title="Original Model",y_axis_type="log")

    plot_spectrum2.line(result_dict['OriginalInput']['model_wave'],result_dict['OriginalInput']['model_spec'],
                        line_width = 4,alpha = .7)
    tab5 = Panel(child=plot_spectrum2, title="Original Model")


    #create set of five tabs 
    tabs1d = Tabs(tabs=[ tab1,tab3, tab4, tab5])



    # Detector 2d
    data = out['2d']['detector']

    
    xr, yr = data.shape
    
    plot_detector_2d = Figure(tools="pan,wheel_zoom,box_zoom,resize,reset,hover,save",
                         x_range=[0, yr], y_range=[0, xr],
                         x_axis_label='Pixel', y_axis_label='Spatial',
                         title="2D Detector Image",
                        plot_width=800, plot_height=300)
    
    plot_detector_2d.image(image=[data], x=[0], y=[0], dh=[xr], dw=[yr],
                      palette="Spectral11")


    #2d tabs 

    #2d snr 
    data = out['2d']['snr']
    data[np.isinf(data)] = 0.0
    xr, yr = data.shape
    plot_snr_2d = Figure(tools=TOOLS,
                         x_range=[0, yr], y_range=[0, xr],
                         x_axis_label='Pixel', y_axis_label='Spatial',
                         title="Signal-to-Noise Ratio",
                        plot_width=800, plot_height=300)
    
    plot_snr_2d.image(image=[data], x=[0], y=[0], dh=[xr], dw=[yr],
                      palette="Spectral11")
    
    tab1b = Panel(child=plot_snr_2d, title="SNR")

    #saturation
    
    data = out['2d']['saturation']
    xr, yr = data.shape
    plot_sat_2d = Figure(tools=TOOLS,
                         x_range=[0, yr], y_range=[0, xr],
                         x_axis_label='Pixel', y_axis_label='Spatial',
                         title="Saturation",
                        plot_width=800, plot_height=300)
    
    plot_sat_2d.image(image=[data], x=[0], y=[0], dh=[xr], dw=[yr],
                      palette="Spectral11")
    
    tab2b = Panel(child=plot_sat_2d, title="Saturation")

    tabs2d = Tabs(tabs=[ tab1b, tab2b])
    
 
    result_comp = components({'plot_spectrum':layout, 
                              'tabs1d': tabs1d, 'det_2d': plot_detector_2d,
                              'tabs2d': tabs2d})

    return result_comp
              x_range=[odesystem_settings.x_min, odesystem_settings.x_max],
              y_range=[odesystem_settings.y_min, odesystem_settings.y_max]
              )
# remove grid from plot
plot.grid[0].grid_line_alpha = 0.0
plot.grid[1].grid_line_alpha = 0.0

# Plot the direction field
quiver = my_bokeh_utils.Quiver(plot)
# Plot initial values
plot.scatter('x0', 'y0', source=source_initialvalue, color='black', legend='(x0,y0)')
# Plot streamline
plot.line('x', 'y', source=source_streamline, color='black', legend='streamline')
# Plot critical points and lines
plot.scatter('x', 'y', source=source_critical_pts, color='red', legend='critical pts')
plot.multi_line('x_ls', 'y_ls', source=source_critical_lines, color='red', legend='critical lines')

# initialize controls
# text input for input of the ode system [u,v] = [x',y']
u_input = TextInput(value=odesystem_settings.sample_system_functions[odesystem_settings.init_fun_key][0], title="u(x,y):")
v_input = TextInput(value=odesystem_settings.sample_system_functions[odesystem_settings.init_fun_key][1], title="v(x,y):")

# dropdown menu for selecting one of the sample functions
sample_fun_input = Dropdown(label="choose a sample function pair or enter one below",
                            menu=odesystem_settings.sample_system_names)

# Interactor for entering starting point of initial condition
interactor = my_bokeh_utils.Interactor(plot)

# initialize callback behaviour
sample_fun_input.on_click(sample_fun_change)
dellconv = dell["StatTime"]
dell_time = pandas.to_datetime(dellconv, unit='s')

dell_2 = result[(result.DeviceName == 'Dell_2') & (result.TypeId == 2758)]
dell_2_up = dell_2["UpBytes"]
dell_2_down = dell_2["DownBytes"] 
dell_2_con = dell_2["StatTime"]
dell_2_time = pandas.to_datetime(dell_2_con, unit='s')

juniper = result[(result.DeviceName == 'juniper') & (result.TypeId == 1234)]
juniper_up = juniper["UpBytes"]
juniper_down = juniper["DownBytes"]
juniperconv = juniper["StatTime"]
juniper_time = pandas.to_datetime(juniperconv, unit='s')

p.multi_line(xs = [juniper_time, juniper_time], ys = [juniper_up, juniper_down] color=['red', 'blue'] line_width=1, legend='juniper')
p.multi_line(xs = [cisco_time, cisco_time], ys = [cisco_up, cisco_down], color=['#EE0091','#2828B0'], line_width=1, legend='cisco')
p.multi_line(xs = [hp_time, hp_time], ys = [hp_up, hp_down], color=['yellow','green'], line_width=1, legend='cisco')
p.multi_line(xs = [cisco_time, cisco_time], ys = [cisco_up, cisco_down], color=['pink','black'], line_width=1, legend='cisco')
p.multi_line(xs = [cisco_time, cisco_time], ys = [cisco_up, cisco_down], color=['#498ABF','#2F00E1'], line_width=1, legend='cisco')

hover = HoverTool(tooltips = [('Time', '@x{int}'),
                            ('Value', '@y{1.11} GB'),
                            ('Device', '@DeviceName')])

hover.formatters = {"Date": "datetime"}
p.legend.label_text_font = "times"
p.legend.label_text_font_style = "italic"

p.grid.grid_line_alpha=0.5
p.xaxis.major_tick_line_width = 3
def plot_waveform_bokeh(filename,waveform_list,metadata_list,station_lat_list,\
                       station_lon_list, event_lat, event_lon, boundary_data, style_parameter):
    xlabel_fontsize = style_parameter['xlabel_fontsize']
    #
    map_station_location_bokeh = ColumnDataSource(data=dict(map_lat_list=station_lat_list,\
                                                            map_lon_list=station_lon_list))
    dot_default_index = 0
    selected_dot_on_map_bokeh = ColumnDataSource(data=dict(lat=[station_lat_list[dot_default_index]],\
                                                           lon=[station_lon_list[dot_default_index]],\
                                                           index=[dot_default_index]))
    map_view = Figure(plot_width=style_parameter['map_view_plot_width'], \
                      plot_height=style_parameter['map_view_plot_height'], \
                      y_range=[style_parameter['map_view_lat_min'],\
                    style_parameter['map_view_lat_max']], x_range=[style_parameter['map_view_lon_min'],\
                    style_parameter['map_view_lon_max']], tools=style_parameter['map_view_tools'],\
                    title=style_parameter['map_view_title'])
    # ------------------------------
    # add boundaries to map view
    # country boundaries
    map_view.multi_line(boundary_data['country']['longitude'],\
                        boundary_data['country']['latitude'],color='gray',\
                        line_width=2, level='underlay', nonselection_line_alpha=1.0,\
                        nonselection_line_color='gray')
    # marine boundaries
    map_view.multi_line(boundary_data['marine']['longitude'],\
                        boundary_data['marine']['latitude'],color='gray',\
                        level='underlay', nonselection_line_alpha=1.0,\
                        nonselection_line_color='gray')
    # shoreline boundaries
    map_view.multi_line(boundary_data['shoreline']['longitude'],\
                        boundary_data['shoreline']['latitude'],color='gray',\
                        line_width=2, nonselection_line_alpha=1.0, level='underlay',
                        nonselection_line_color='gray')
    # state boundaries
    map_view.multi_line(boundary_data['state']['longitude'],\
                        boundary_data['state']['latitude'],color='gray',\
                        level='underlay', nonselection_line_alpha=1.0,\
                        nonselection_line_color='gray')
    #
    map_view.triangle('map_lon_list', 'map_lat_list', source=map_station_location_bokeh, \
                      line_color='gray', size=style_parameter['marker_size'], fill_color='black',\
                      selection_color='black', selection_line_color='gray',\
                      selection_fill_alpha=1.0,\
                      nonselection_fill_alpha=1.0, nonselection_fill_color='black',\
                      nonselection_line_color='gray', nonselection_line_alpha=1.0)
    map_view.triangle('lon','lat', source=selected_dot_on_map_bokeh,\
                      size=style_parameter['selected_marker_size'], line_color='black',fill_color='red')
    map_view.asterisk([event_lon], [event_lat], size=style_parameter['event_marker_size'], line_width=3, line_color='red', \
                      fill_color='red')
    # change style
    map_view.title.text_font_size = style_parameter['title_font_size']
    map_view.title.align = 'center'
    map_view.title.text_font_style = 'normal'
    map_view.xaxis.axis_label = style_parameter['map_view_xlabel']
    map_view.xaxis.axis_label_text_font_style = 'normal'
    map_view.xaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.xaxis.major_label_text_font_size = xlabel_fontsize
    map_view.yaxis.axis_label = style_parameter['map_view_ylabel']
    map_view.yaxis.axis_label_text_font_style = 'normal'
    map_view.yaxis.axis_label_text_font_size = xlabel_fontsize
    map_view.yaxis.major_label_text_font_size = xlabel_fontsize
    map_view.xgrid.grid_line_color = None
    map_view.ygrid.grid_line_color = None
    map_view.toolbar.logo = None
    map_view.toolbar_location = 'above'
    map_view.toolbar_sticky = False
    # --------------------------------------------------------
    max_waveform_length = 0
    max_waveform_amp = 0
    ncurve = len(waveform_list)
    for a_sta in waveform_list:
        for a_trace in a_sta:
            if len(a_trace) > max_waveform_length:
                max_waveform_length = len(a_trace)
            if np.max(np.abs(a_trace)) > max_waveform_amp:
                max_waveform_amp = np.max(np.abs(a_trace))
    #
    plotting_list = []
    for a_sta in waveform_list:
        temp = []
        for a_trace in a_sta:
            if len(a_trace) < max_waveform_length:
                a_trace = np.append(a_trace,np.zeros([(max_waveform_length-len(a_trace)),1]))
            temp.append(list(a_trace))
        plotting_list.append(temp)
    #
    time_list = []
    for ista in range(len(plotting_list)):
        a_sta = plotting_list[ista]
        temp = []
        for itr in range(len(a_sta)):
            a_trace = a_sta[itr]
            delta = metadata_list[ista][itr]['delta']
            time = list(np.arange(len(a_trace))*delta)
            temp.append(time)
        #
        time_list.append(temp)
    #
    reftime_label_list = []
    channel_label_list = []
    for ista in range(len(metadata_list)):
        temp_ref = []
        temp_channel = []
        a_sta = metadata_list[ista]
        for a_trace in a_sta:
            temp_ref.append('Starting from '+a_trace['starttime'])
            temp_channel.append(a_trace['network']+'_'+a_trace['station']+'_'+a_trace['channel'])
        reftime_label_list.append(temp_ref)
        channel_label_list.append(temp_channel)
    # --------------------------------------------------------
    curve_fig01 = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(-max_waveform_amp*1.05,max_waveform_amp*1.05), \
                       x_range=(0,max_waveform_length),\
                    tools=['save','box_zoom','ywheel_zoom','xwheel_zoom','reset','crosshair','pan']) 
    #
    curve_index = 0
    select_curve_data = plotting_list[dot_default_index][curve_index]
    select_curve_time = time_list[dot_default_index][curve_index]
    
    selected_curve_data_bokeh01 = ColumnDataSource(data=dict(time=select_curve_time,amp=select_curve_data))
    select_reftime_label = reftime_label_list[dot_default_index][curve_index]
    selected_reftime_label_bokeh01 = ColumnDataSource(data=dict(x=[style_parameter['curve_reftime_label_x']],\
                                                                y=[style_parameter['curve_reftime_label_y']],\
                                                                label=[select_reftime_label]))
    select_channel_label = channel_label_list[dot_default_index][curve_index]
    selected_channel_label_bokeh01 = ColumnDataSource(data=dict(x=[style_parameter['curve_channel_label_x']],\
                                                                y=[style_parameter['curve_channel_label_y']],\
                                                                label=[select_channel_label]))
    all_curve_data_bokeh = ColumnDataSource(data=dict(t=time_list, amp=plotting_list))
    all_reftime_label_bokeh = ColumnDataSource(data=dict(label=reftime_label_list))
    all_channel_label_bokeh = ColumnDataSource(data=dict(label=channel_label_list))
    # plot waveform
    curve_fig01.line('time','amp', source=selected_curve_data_bokeh01,\
                   line_color='black')
    # add refference time as a label
    curve_fig01.text('x', 'y', 'label', source=selected_reftime_label_bokeh01)
    # add channel label
    curve_fig01.text('x', 'y', 'label', source=selected_channel_label_bokeh01)
    # change style
    curve_fig01.title.text_font_size = style_parameter['title_font_size']
    curve_fig01.title.align = 'center'
    curve_fig01.title.text_font_style = 'normal'
    curve_fig01.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig01.xaxis.axis_label_text_font_style = 'normal'
    curve_fig01.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig01.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig01.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig01.yaxis.axis_label_text_font_style = 'normal'
    curve_fig01.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig01.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig01.toolbar.logo = None
    curve_fig01.toolbar_location = 'above'
    curve_fig01.toolbar_sticky = False
    # --------------------------------------------------------
    curve_fig02 = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(-max_waveform_amp*1.05,max_waveform_amp*1.05), \
                       x_range=(0,max_waveform_length),\
                    tools=['save','box_zoom','ywheel_zoom','xwheel_zoom','reset','crosshair','pan']) 
    #
    curve_index = 1
    select_curve_data = plotting_list[dot_default_index][curve_index]
    select_curve_time = time_list[dot_default_index][curve_index]
    selected_curve_data_bokeh02 = ColumnDataSource(data=dict(time=select_curve_time,amp=select_curve_data))
    select_channel_label = channel_label_list[dot_default_index][curve_index]
    selected_channel_label_bokeh02 = ColumnDataSource(data=dict(x=[style_parameter['curve_channel_label_x']],\
                                                                y=[style_parameter['curve_channel_label_y']],\
                                                                label=[select_channel_label]))
    # plot waveform
    curve_fig02.line('time','amp', source=selected_curve_data_bokeh02,\
                   line_color='black')
    # add channel label
    curve_fig02.text('x', 'y', 'label', source=selected_channel_label_bokeh02)
    # change style
    curve_fig02.title.text_font_size = style_parameter['title_font_size']
    curve_fig02.title.align = 'center'
    curve_fig02.title.text_font_style = 'normal'
    curve_fig02.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig02.xaxis.axis_label_text_font_style = 'normal'
    curve_fig02.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig02.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig02.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig02.yaxis.axis_label_text_font_style = 'normal'
    curve_fig02.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig02.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig02.toolbar.logo = None
    curve_fig02.toolbar_location = 'above'
    curve_fig02.toolbar_sticky = False
    # --------------------------------------------------------
    curve_fig03 = Figure(plot_width=style_parameter['curve_plot_width'], plot_height=style_parameter['curve_plot_height'], \
                       y_range=(-max_waveform_amp*1.05,max_waveform_amp*1.05), \
                       x_range=(0,max_waveform_length),\
                    tools=['save','box_zoom','ywheel_zoom','xwheel_zoom','reset','crosshair','pan']) 
    #
    curve_index = 2
    select_curve_data = plotting_list[dot_default_index][curve_index]
    select_curve_time = time_list[dot_default_index][curve_index]
    selected_curve_data_bokeh03 = ColumnDataSource(data=dict(time=select_curve_time,amp=select_curve_data))
    select_channel_label = channel_label_list[dot_default_index][curve_index]
    selected_channel_label_bokeh03 = ColumnDataSource(data=dict(x=[style_parameter['curve_channel_label_x']],\
                                                                y=[style_parameter['curve_channel_label_y']],\
                                                                label=[select_channel_label]))
    # plot waveform
    curve_fig03.line('time','amp', source=selected_curve_data_bokeh03,\
                   line_color='black')
    # add channel label
    curve_fig03.text('x', 'y', 'label', source=selected_channel_label_bokeh03)
    # change style
    curve_fig03.title.text_font_size = style_parameter['title_font_size']
    curve_fig03.title.align = 'center'
    curve_fig03.title.text_font_style = 'normal'
    curve_fig03.xaxis.axis_label = style_parameter['curve_xlabel']
    curve_fig03.xaxis.axis_label_text_font_style = 'normal'
    curve_fig03.xaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig03.xaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig03.yaxis.axis_label = style_parameter['curve_ylabel']
    curve_fig03.yaxis.axis_label_text_font_style = 'normal'
    curve_fig03.yaxis.axis_label_text_font_size = xlabel_fontsize
    curve_fig03.yaxis.major_label_text_font_size = xlabel_fontsize
    curve_fig03.toolbar.logo = None
    curve_fig03.toolbar_location = 'above'
    curve_fig03.toolbar_sticky = False
    # --------------------------------------------------------
    map_station_location_js = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                            map_station_location_bokeh=map_station_location_bokeh,\
                                                            selected_curve_data_bokeh01=selected_curve_data_bokeh01,\
                                                            selected_curve_data_bokeh02=selected_curve_data_bokeh02,\
                                                            selected_curve_data_bokeh03=selected_curve_data_bokeh03,\
                                                            selected_channel_label_bokeh01=selected_channel_label_bokeh01,\
                                                            selected_channel_label_bokeh02=selected_channel_label_bokeh02,\
                                                            selected_channel_label_bokeh03=selected_channel_label_bokeh03,\
                                                            selected_reftime_label_bokeh01=selected_reftime_label_bokeh01,\
                                                            all_reftime_label_bokeh=all_reftime_label_bokeh,\
                                                            all_channel_label_bokeh=all_channel_label_bokeh,\
                                                            all_curve_data_bokeh=all_curve_data_bokeh), code="""
    var inds = cb_obj.indices
    
    selected_dot_on_map_bokeh.data['index'] = [inds]
    var new_loc = map_station_location_bokeh.data
    
    selected_dot_on_map_bokeh.data['lat'] = [new_loc['map_lat_list'][inds]]
    selected_dot_on_map_bokeh.data['lon'] = [new_loc['map_lon_list'][inds]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh01.data['t'] = all_curve_data_bokeh.data['t'][inds][0]
    selected_curve_data_bokeh01.data['amp'] = all_curve_data_bokeh.data['amp'][inds][0]

    selected_curve_data_bokeh01.change.emit()
    
    selected_curve_data_bokeh02.data['t'] = all_curve_data_bokeh.data['t'][inds][1]
    selected_curve_data_bokeh02.data['amp'] = all_curve_data_bokeh.data['amp'][inds][1]

    selected_curve_data_bokeh02.change.emit()
    
    selected_curve_data_bokeh03.data['t'] = all_curve_data_bokeh.data['t'][inds][2]
    selected_curve_data_bokeh03.data['amp'] = all_curve_data_bokeh.data['amp'][inds][2]

    selected_curve_data_bokeh03.change.emit()
    
    selected_reftime_label_bokeh01.data['label'] = [all_reftime_label_bokeh.data['label'][inds][0]]
    
    selected_reftime_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh01.data['label'] = [all_channel_label_bokeh.data['label'][inds][0]]
    
    selected_channel_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh02.data['label'] = [all_channel_label_bokeh.data['label'][inds][1]]
    
    selected_channel_label_bokeh02.change.emit()
    
    selected_channel_label_bokeh03.data['label'] = [all_channel_label_bokeh.data['label'][inds][2]]
    
    selected_channel_label_bokeh03.change.emit()
    """)
    #
    map_station_location_bokeh.selected.js_on_change('indices', map_station_location_js)
    #
    curve_slider_callback = CustomJS(args=dict(selected_dot_on_map_bokeh=selected_dot_on_map_bokeh,\
                                                map_station_location_bokeh=map_station_location_bokeh,\
                                                selected_curve_data_bokeh01=selected_curve_data_bokeh01,\
                                                selected_curve_data_bokeh02=selected_curve_data_bokeh02,\
                                                selected_curve_data_bokeh03=selected_curve_data_bokeh03,\
                                                selected_channel_label_bokeh01=selected_channel_label_bokeh01,\
                                                selected_channel_label_bokeh02=selected_channel_label_bokeh02,\
                                                selected_channel_label_bokeh03=selected_channel_label_bokeh03,\
                                                selected_reftime_label_bokeh01=selected_reftime_label_bokeh01,\
                                                all_reftime_label_bokeh=all_reftime_label_bokeh,\
                                                all_channel_label_bokeh=all_channel_label_bokeh,\
                                                all_curve_data_bokeh=all_curve_data_bokeh),code="""
    var inds = Math.round(cb_obj.value)
    
    selected_dot_on_map_bokeh.data['index'] = [inds]
    var new_loc = map_station_location_bokeh.data
    
    selected_dot_on_map_bokeh.data['lat'] = [new_loc['map_lat_list'][inds]]
    selected_dot_on_map_bokeh.data['lon'] = [new_loc['map_lon_list'][inds]]
    
    selected_dot_on_map_bokeh.change.emit()
    
    selected_curve_data_bokeh01.data['t'] = all_curve_data_bokeh.data['t'][inds][0]
    selected_curve_data_bokeh01.data['amp'] = all_curve_data_bokeh.data['amp'][inds][0]

    selected_curve_data_bokeh01.change.emit()
    
    selected_curve_data_bokeh02.data['t'] = all_curve_data_bokeh.data['t'][inds][1]
    selected_curve_data_bokeh02.data['amp'] = all_curve_data_bokeh.data['amp'][inds][1]

    selected_curve_data_bokeh02.change.emit()
    
    selected_curve_data_bokeh03.data['t'] = all_curve_data_bokeh.data['t'][inds][2]
    selected_curve_data_bokeh03.data['amp'] = all_curve_data_bokeh.data['amp'][inds][2]

    selected_curve_data_bokeh03.change.emit()
    
    selected_reftime_label_bokeh01.data['label'] = [all_reftime_label_bokeh.data['label'][inds][0]]
    
    selected_reftime_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh01.data['label'] = [all_channel_label_bokeh.data['label'][inds][0]]
    
    selected_channel_label_bokeh01.change.emit()
    
    selected_channel_label_bokeh02.data['label'] = [all_channel_label_bokeh.data['label'][inds][1]]
    
    selected_channel_label_bokeh02.change.emit()
    
    selected_channel_label_bokeh03.data['label'] = [all_channel_label_bokeh.data['label'][inds][2]]
    
    selected_channel_label_bokeh03.change.emit()
    """)
    curve_slider = Slider(start=0, end=ncurve-1, value=style_parameter['curve_default_index'], \
                          step=1, title=style_parameter['curve_slider_title'], width=style_parameter['map_view_plot_width'],\
                          height=50, callback=curve_slider_callback)
    
    # ==============================
    # annotating text
    annotating_fig01 = Div(text=style_parameter['annotating_html01'], \
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    annotating_fig02 = Div(text=style_parameter['annotating_html02'],\
        width=style_parameter['annotation_plot_width'], height=style_parameter['annotation_plot_height'])
    # ==============================
    output_file(filename,title=style_parameter['html_title'],mode=style_parameter['library_source'])
    #
    left_fig = Column(curve_slider, map_view, annotating_fig01, width=style_parameter['left_column_width'] )
    
    right_fig = Column(curve_fig01, curve_fig02, curve_fig03, annotating_fig02, width=style_parameter['right_column_width'])
    layout = Row(left_fig, right_fig)
    save(layout)
Exemple #34
0
for rg in ragrid:
    raxs.append([2.0**1.5 * cos(x) * sin(rg / 2.0) /
                 sqrt(1.0 + cos(x) * cos(rg / 2.0)) for x in decrange])
    rays.append([sqrt(2.0) * sin(x) / sqrt(1.0 + cos(x) * cos(rg / 2.0))
                 for x in decrange])

decxs = []
decys = []
for dg in decgrid:
    decxs.append([2.0**1.5 * cos(dg) * sin(x / 2.0) /
                  sqrt(1.0 + cos(dg) * cos(x / 2.0)) for x in rarange])
    decys.append([sqrt(2.0) * sin(dg) / sqrt(1.0 + cos(dg) * cos(x / 2.0))
                  for x in rarange])

p1.add_tools(hover)
p1.multi_line(raxs, rays, color='#bbbbbb')
p1.multi_line(decxs, decys, color='#bbbbbb')

claimedtypes = sorted(list(set(sntypes)))

for ci, ct in enumerate(claimedtypes):
    ind = [i for i, t in enumerate(sntypes) if t == ct]

    source = ColumnDataSource(
        data=dict(
            x=[snhxs[i] for i in ind],
            y=[snhys[i] for i in ind],
            ra=[snras[i] for i in ind],
            dec=[sndecs[i] for i in ind],
            event=[snnames[i] for i in ind],
            claimedtype=[sntypes[i] for i in ind]
Exemple #35
0
                   end=25,
                   value=10,
                   step=1)
station_IDs_str = TextInput(title="Station IDs", value='6720 6551')

# Create Figure and Plot
hover = HoverTool(tooltips=[("Station ID", "@station_id"), ("Year", "@year")])
p = Figure(plot_height=600,
           plot_width=800,
           title="",
           tools=[hover],
           x_axis_type="datetime")
p.multi_line(
    xs='x',
    ys='y',
    line_color='color',
    source=source,
    line_width=2,
)
p.x_range = DataRange1d(range_padding=0.0, bounds=None)
#p.xaxis[0].formatter = DatetimeTickFormatter(formats=dict(days=["%B %d"], months=["%B"], years=["%Y"]))
p.yaxis.axis_label = 'Cumulative GDD'
p.xaxis.axis_label = 'Month'


# Select Data based on input info
def select_data():
    #p.title = p.title + str(' (Wait..)')
    global station_IDs
    # Make Stations ID's as a list
    station_IDs = []
Exemple #36
0
        for x in decrange
    ])

decxs = []
decys = []
for dg in decgrid:
    decxs.append([
        2.0**1.5 * cos(dg) * sin(x / 2.0) / sqrt(1.0 + cos(dg) * cos(x / 2.0))
        for x in rarange
    ])
    decys.append([
        sqrt(2.0) * sin(dg) / sqrt(1.0 + cos(dg) * cos(x / 2.0))
        for x in rarange
    ])

p1.multi_line(raxs, rays, color='#bbbbbb')
p1.multi_line(decxs, decys, color='#bbbbbb')

claimedtypes = sorted(list(set(evtypes)))

glyphs = []
glsize = max(2.5, 7.0 - np.log10(len(evtypes)))
for ci, ct in enumerate(claimedtypes):
    ind = [i for i, t in enumerate(evtypes) if t == ct]

    source = ColumnDataSource(data=dict(x=[evhxs[i] for i in ind],
                                        y=[evhys[i] for i in ind],
                                        ra=[evras[i] for i in ind],
                                        dec=[evdecs[i] for i in ind],
                                        event=[evnames[i] for i in ind],
                                        claimedtype=[evtypes[i] for i in ind]))
Exemple #37
0
slider_degrees = Slider(start=1, end=10, step=1, value=5, title="Degrees")
slider_lines = Slider(start=1, end=50, step=1, value=10, title="Lines")
slider_points = Slider(start=1, end=100, step=1, value=20, title="Points")

# The datapoints
source_points = ColumnDataSource(data=dict(x=x, y=func(x)+err))
p.scatter(x='x', y='y', source=source_points, color="blue", line_width=3)

# The function where the datapoints come from
source_function = ColumnDataSource(data=dict(x=x, y=func(x)))
p.line(x='x', y='y', source=source_function, color="blue", line_width=1)

# The bootstrap lines
source_lines = ColumnDataSource(data=dict(xs=[ [], [] ], ys=[ [], [] ]))
p.multi_line(xs='xs', ys='ys', source=source_lines, color="pink", line_width=0)

# Their average
source_avg = ColumnDataSource(data=dict(x=[], y=[]))
p.line(x='x', y='y', source=source_avg, color="red", line_width=2)

# Basic instructions
div_instr = Div(text="<font color=black>\
<br> <font color=blue>The blue line </font>is the “true” curve we are trying to approximate.\
<br> <font color=blue><b>The blue dots </b></font>are points drawn from the blue curve with an added random error.\
<br> <font color=pink>Each pink line</font> is the polynomial regression fit over a randomly drawn (with replacement) subset of points.\
<br> <font color=red><b>The thick red line </b></font>is the average of the pink lines.</font>", width=800, height=100)

def update(attrname, old, new):
    D=slider_degrees.value  # number of degrees for the polynomial
    L=slider_lines.value    # number of bootstrap lines
Exemple #38
0
                    src = [photosource[i] for i in indne]
                )
            )
            p1.circle('x', 'y', source = source, color=bandcolorf(band), fill_color="white", legend=noerrorlegend, size=4)

            source = ColumnDataSource(
                data = dict(
                    x = [phototime[i] for i in indye],
                    y = [photoAB[i] for i in indye],
                    err = [photoerrs[i] for i in indye],
                    desc = [photoband[i] for i in indye],
                    instr = [photoinstru[i] for i in indye],
                    src = [photosource[i] for i in indye]
                )
            )
            p1.multi_line([err_xs[x] for x in indye], [err_ys[x] for x in indye], color=bandcolorf(band))
            p1.circle('x', 'y', source = source, color=bandcolorf(band), legend=bandname, size=4)

            upplimlegend = bandname if len(indye) == 0 and len(indne) == 0 else ''

            indt = [i for i, j in enumerate(phototype) if j]
            ind = set(indb).intersection(indt)
            p1.inverted_triangle([phototime[x] for x in ind], [photoAB[x] for x in ind],
                color=bandcolorf(band), legend=upplimlegend, size=7)

    if spectraavail and dohtml and args.writehtml:
        spectrumwave = []
        spectrumflux = []
        spectrumerrs = []
        for spectrum in catalog[entry]['spectra']:
            spectrumdata = deepcopy(spectrum['data'])
Exemple #39
0
def plotHistogram(fileName, initData, stations, dateRange, bokehPlaceholderId='bokehContent'):   
    MAX_AGE = 150
    #The initial data
    data = {
    'xs':[initData['bins']], 
    'ys':[initData['values']],
    'ss':["All"], 
    'es':["All"],  
    'ut':["All"], 
    'g':["All"], 
    'a':[MAX_AGE], 
    'color':["blue"],
    'line_width':[2]
    }

    source = ColumnDataSource(data=data)
    stations.insert(0, "All")
    selectSS = Select(title="Start Station:", value="All", options=stations)
    selectES = Select(title="End Station:", value="All", options=stations)
    selectUT = Select(title="User Type:", value="All", options=["All", "Subscriber", "Customer"])
    selectGender = Select(title="Gender:", value="All", options=["All", "Male", "Female", "Unknown"])
    sliderAge = Slider(start=0, end=MAX_AGE, value=0, step=1, title="Age LEQ")    
    startDP = DatePicker(title="Start Date:", min_date=dateRange[0] ,max_date=dateRange[1], value=dateRange[0])
    endDP = DatePicker(title="End Date:", min_date=dateRange[0] ,max_date=dateRange[1], value=dateRange[1])
    binSize = TextInput(value="15", title="Bin Size (Days):")
    AddButton = Toggle(label="Add", type="success")
    DeleteButton = Toggle(label="Delete Selected", type="success")
    
    columns = [
    TableColumn(field="ss", title="SS"), 
    TableColumn(field="es", title="ES"),
    TableColumn(field="ut", title="User Type"),
    TableColumn(field="a", title="Age LEQ"),    
    TableColumn(field="g", title="Sex")
    ]
    data_table = DataTable(source=source, columns=columns, width=400, row_headers=False, selectable='checkbox')
    model = dict(source=source, selectSS = selectSS, selectES = selectES, startDP = startDP, endDP = endDP, binSize = binSize,selectUT=selectUT,selectGender=selectGender,sliderAge=sliderAge, dt = data_table)
    
    addCallback = CustomJS(args=model, code="""
        //alert("callback");
        var startStation = selectSS.get('value');
        var endStation = selectES.get('value');
        var startDate = startDP.get('value');
        
        if ( typeof(startDate) !== "number")
            startDate = startDate.getTime();
            
        var endDate = endDP.get('value');
        
        if ( typeof(endDate) !== "number")
            endDate = endDate.getTime();            
        
        var binSize = binSize.get('value');
        var gender = selectGender.get('value');
        var userType = selectUT.get('value');
        var age = sliderAge.get('value');
        //alert(age);
        //alert(startStation + " " + endStation + " " + startDate + " " + endDate + " " + binSize);
        var xmlhttp;
        xmlhttp = new XMLHttpRequest();
        
        xmlhttp.onreadystatechange = function() {
            if (xmlhttp.readyState == XMLHttpRequest.DONE ) {
                if(xmlhttp.status == 200){
                    var data = source.get('data');
                    var result = JSON.parse(xmlhttp.responseText);
                    var temp=[];
                    
                    for(var date in result.x) {
                        temp.push(new Date(result.x[date]));
                    }
                    
                    data['xs'].push(temp);
                    data['ys'].push(result.y);
                    data['ss'].push(startStation);
                    data['es'].push(endStation);
                    data['ut'].push(userType);
                    data['g'].push(gender);
                    data['a'].push(age);
                    data['color'].push('blue');
                    data['line_width'].push(2);
                    source.trigger('change');
                    dt.trigger('change');
                }
                else if(xmlhttp.status == 400) {
                    alert(400);
                }
                else {
                    alert(xmlhttp.status);
                }
            }
        };
    var params = {ss:startStation, es:endStation, sd:startDate, ed:endDate, bs: binSize, g:gender, ut:userType, age:age};
    url = "/histogram?" + jQuery.param( params );
    xmlhttp.open("GET", url, true);
    xmlhttp.send();
    """)
    
    deleteCallBack = CustomJS(args=dict(source=source, dt= data_table), code="""
            var indices = source.get('selected')['1d'].indices;
            
            if(indices.length != 0){
                indices.sort();
                var data = source.get('data');
                var counter = 0;
                var i = 0;
                var key = 0;
                var index = 0;
                
                for(i in indices)
                {
                    index = indices[i];
                    index -= counter;
                    
                    for(key in data) {
                        data[key].splice(index, 1);
                    }
                    
                    counter += 1;
                }
                
                source.trigger('change');
                dt.trigger('change');
            }
            """)
    AddButton.callback = addCallback         
    DeleteButton.callback = deleteCallBack;
    
    plot = Figure(title="Number Of Trips Over Time", x_axis_label='Time', y_axis_label='Number of trips', plot_width=750, plot_height=400, x_axis_type="datetime")
    plot.multi_line('xs', 'ys', source=source, line_width='line_width', line_alpha=0.9, line_color='color')
    
    l2 = vform(plot, hplot(startDP, endDP), binSize)
    l3 = vform(selectSS, selectES,selectUT,selectGender,sliderAge, hplot(AddButton, DeleteButton), data_table)
    layout = hplot(l2, l3)
    script, div = components(layout)
    html = readHtmlFile(fileName)
    html = insertScriptIntoHeader(html, script)
    html = appendElementContent(html, div, "div", "bokehContent")

    return html