Beispiel #1
0
def update_station_list(attrname, old, new):
    year_value= select_years.value
    partition_value= select_partition.value
    #
    key= station_partition_key_format.format(year_value, partition_value)
    stations_list= station_partitions[key]
    #
    new_select_station= Select(title='Station', value=stations_list[0], options=stations_list)
    new_select_station.options=stations_list
    #
    new_children= [select_years, select_partition, new_select_station]
    inputs.children= new_children
def update_station_list(attrname, old, new):
    year_value = select_years.value
    partition_value = select_partition.value
    #
    key = station_partition_key_format.format(year_value, partition_value)
    stations_list = station_partitions[key]
    #
    new_select_station = Select(title='Station',
                                value=stations_list[0],
                                options=stations_list)
    new_select_station.options = stations_list
    #
    new_children = [select_years, select_partition, new_select_station]
    inputs.children = new_children
Beispiel #3
0
def create_module(user):
    # start controller
    controller = Controller(user)
    # hover
    hover1 = create_hover(1)
    # module
    fig1 = figure(plot_width=600,
                  plot_height=300,
                  css_classes=['monitoring_fig1'],
                  tools='pan,box_zoom,reset',
                  name='fig1',
                  title='Cantidad de Clientes x día de Lectura')
    fig1.toolbar.logo = None
    fig1.toolbar_location = 'above'
    fig1.vbar(x=[1, 2, 3],
              width=1000 * 3600 * 24,
              bottom=0,
              top=[0, 0, 0],
              color='red',
              name='vbar2',
              alpha=0.2)
    line1 = fig1.line(x=[0, 1],
                      y=[0, 1],
                      line_color="darkgray",
                      line_width=2,
                      alpha=0.6,
                      legend='Planificación',
                      name='line1')
    line2 = fig1.line(x=[0, 1],
                      y=[0, 1],
                      line_color="blue",
                      line_width=2,
                      alpha=0.8,
                      legend='Real',
                      name='line1.2')
    fig1.legend.click_policy = "hide"
    fig1.legend.location = "top_left"
    fig1.legend.background_fill_color = "white"
    fig1.legend.background_fill_alpha = 0.5
    fig1.legend.label_text_color = "#505050"
    fig1.legend.orientation = "vertical"
    fig1.xaxis.axis_label = 'Días del mes'
    hover1.renderers = [line1, line2]
    fig1.add_tools(hover1)
    fig1.xaxis.formatter = DatetimeTickFormatter(days=["%d"])
    fig1.yaxis[0].formatter = NumeralTickFormatter(format="0.0a")

    fig2 = figure(plot_width=600,
                  plot_height=300,
                  css_classes=['monitoring_fig2'],
                  tools='pan,box_zoom,reset',
                  name='fig2',
                  title='Cantidad de Itinerarios x día de Lectura')
    fig2.toolbar.logo = None
    fig2.toolbar_location = 'above'
    fig2.vbar(x=[1, 2, 3],
              width=1000 * 3600 * 24,
              bottom=0,
              top=[0, 0, 0],
              color='red',
              name='vbar3',
              alpha=0.2)
    line3 = fig2.line(x=[0, 1],
                      y=[0, 1],
                      line_color="darkgray",
                      line_width=2,
                      alpha=0.6,
                      legend='Planificación',
                      name='line2')
    line4 = fig2.line(x=[0, 1],
                      y=[0, 1],
                      line_color="blue",
                      line_width=2,
                      alpha=0.8,
                      legend='Real',
                      name='line2.2')
    fig2.legend.click_policy = "hide"
    fig2.legend.location = "top_left"
    fig2.legend.background_fill_color = "white"
    fig2.legend.background_fill_alpha = 0.5
    fig2.legend.label_text_color = "#505050"
    fig2.legend.orientation = "vertical"
    fig2.xaxis.axis_label = 'Días del mes'
    hover2 = create_hover(1)
    hover2.renderers = [line3, line4]
    fig2.add_tools(hover2)
    fig2.xaxis.formatter = DatetimeTickFormatter(days=["%d"])

    map_options = GMapOptions(lat=10.032663,
                              lng=-74.042470,
                              map_type="roadmap",
                              zoom=7)
    fig3 = GMapPlot(x_range=Range1d(),
                    y_range=Range1d(),
                    map_options=map_options,
                    plot_width=600,
                    plot_height=450,
                    css_classes=['monitoring_fig3'],
                    name='fig3')
    fig3.toolbar.logo = None
    fig3.toolbar_location = 'above'
    fig3.add_tools(PanTool(), WheelZoomTool())
    fig3.title.text = 'Dispersión Geográfica de Itinerarios'
    fig3.api_key = 'AIzaSyATl81v4Wnm4udDvlNTcgw4oWMzWJndkfQ'
    x = np.linspace(-2, 2, 10)
    source = ColumnDataSource(data=dict(
        lat=x,
        lon=x**2,
        sizes=np.linspace(10, 20, 10),
        colors=controller.day_colors[0:10],
    ))
    circle = Circle(x="lon",
                    y="lat",
                    size='sizes',
                    fill_color='colors',
                    fill_alpha=0.6,
                    line_color='black')
    fig3.add_glyph(source, circle, name='circles1')
    fig3.add_tools(create_hover(2))

    menu1 = Select(title="Periodo:",
                   value="opt1",
                   name='menu1',
                   options=["opt1", "opt2", "opt3", "opt4"],
                   width=150,
                   css_classes=['monitoring_menu1'])
    menu1.options = controller.periodos_str
    menu1.value = controller.periodos_str[controller.now.month - 1]

    menu2 = Select(title="Delegación:",
                   value='TODOS',
                   name='menu2',
                   options=['TODOS'],
                   width=200,
                   css_classes=['monitoring_menu2'])

    menu3 = Select(title="Unicom:",
                   value="TODOS",
                   name='menu3',
                   options=["TODOS"],
                   width=150,
                   css_classes=['monitoring_menu3'])

    menu4 = Select(title="Día:",
                   value="TODOS",
                   name='menu4',
                   options=["TODOS"],
                   width=150,
                   css_classes=['monitoring_menu4'])

    menu5 = Select(title="Municipio:",
                   value="TODOS",
                   name='menu5',
                   options=["TODOS"],
                   width=200,
                   css_classes=['monitoring_menu5'])

    menu6 = Select(title="Tipología:",
                   value="TODOS",
                   name='menu6',
                   options=["TODOS"],
                   width=150,
                   css_classes=['monitoring_menu6'])

    fig4 = figure(plot_width=600,
                  plot_height=300,
                  css_classes=['monitoring_fig4'],
                  tools='pan,box_zoom,reset',
                  name='fig4',
                  title='Promedio de Días Facturados')
    fig4.toolbar.logo = None
    fig4.toolbar_location = 'above'
    fig4.line(x=[0, 1],
              y=[0, 1],
              line_color="darkgray",
              line_width=2,
              alpha=0.6,
              legend='Planificación',
              name='line4')
    fig4.line(x=[0, 1],
              y=[0, 1],
              line_color="blue",
              line_width=2,
              alpha=0.8,
              legend='Real',
              name='line4.2')
    fig4.legend.click_policy = "hide"
    fig4.legend.location = "top_left"
    fig4.legend.background_fill_color = "white"
    fig4.legend.background_fill_alpha = 0.5
    fig4.legend.label_text_color = "#505050"
    fig4.legend.orientation = "vertical"
    fig4.xaxis.axis_label = 'Mes del Año'
    fig4.add_tools(create_hover(4))
    fig4.yaxis[0].formatter = NumeralTickFormatter(format="0.0a")
    # TODO: visualizar el promedio y el total dias facturados al año

    fig5 = figure(plot_width=600,
                  plot_height=300,
                  css_classes=['monitoring_fig5'],
                  tools='pan,box_zoom,reset',
                  name='fig5',
                  title='Histograma de Días Facturados')
    fig5.toolbar.logo = None
    fig5.toolbar_location = 'above'
    fig5.vbar(x=[1, 2, 3],
              width=0.5,
              bottom=0,
              top=[1.2, 2.5, 3.7],
              color="darkcyan",
              fill_alpha=0.6,
              line_color='black',
              name='vbar1',
              legend='Planificación')
    fig5.line(x=[0, 1],
              y=[0, 1],
              line_color="blue",
              line_width=2,
              alpha=0.8,
              legend='Real',
              name='line5.2')
    fig5.legend.click_policy = "hide"
    fig5.legend.location = "top_left"
    fig5.legend.background_fill_color = "white"
    fig5.legend.background_fill_alpha = 0.5
    fig5.legend.label_text_color = "#505050"
    fig5.legend.orientation = "vertical"
    fig5.xaxis.axis_label = 'Días Facturados'
    fig5.add_tools(create_hover(3))
    fig5.yaxis[0].formatter = NumeralTickFormatter(format="0.0a")
    # TODO: agregar curva de suma acumulativa

    fig6 = figure(plot_width=600,
                  plot_height=300,
                  css_classes=['monitoring_fig6'],
                  tools='pan,box_zoom,reset',
                  name='fig6',
                  title='Traslados')
    fig6.toolbar.logo = None
    fig6.toolbar_location = 'above'
    fig6.line(x=[0, 1],
              y=[0, 1],
              line_color="darkgray",
              line_width=2,
              alpha=0.6,
              legend='Planificación',
              name='line6')
    fig6.line(x=[0, 1],
              y=[0, 1],
              line_color="blue",
              line_width=2,
              alpha=0.8,
              legend='Real',
              name='line6.2')
    fig6.legend.click_policy = "hide"
    fig6.legend.location = "top_left"
    fig6.legend.background_fill_color = "white"
    fig6.legend.background_fill_alpha = 0.5
    fig6.legend.label_text_color = "#505050"
    fig6.legend.orientation = "vertical"
    fig6.xaxis.axis_label = 'Mes del Año'
    fig6.add_tools(create_hover(4))
    fig6.yaxis[0].formatter = NumeralTickFormatter(format="0.0a")

    button_group1 = RadioButtonGroup(
        labels=["Itinerarios", "Energía", "Importe"],
        active=0,
        name='button_group1',
        css_classes=['monitoring_button_group1'])
    button_group1.on_change('active', controller.on_change_menus)

    widget1 = layout([[fig1], [fig2]], sizing_mode='fixed')
    widget2 = layout([
        [fig3],
        [menu1, menu2, menu3],
        [menu4, menu5, menu6],
    ],
                     sizing_mode='fixed')
    widget3 = layout([[fig6], [button_group1]], sizing_mode='fixed')
    dashboard = layout([
        [widget1, widget2],
        [fig4, fig5],
        [widget3],
    ],
                       sizing_mode='fixed')
    dashboard.name = 'monitoring'
    # ini module data
    curdoc().add_root(dashboard)
    controller.get_user_data()
    controller.populate_menus(controller.info_itin, controller.fechas_itin)
    menu1.on_change('value', controller.on_change_menus)
    menu2.on_change('value', controller.on_change_menus)
    menu3.on_change('value', controller.on_change_menus)
    menu4.on_change('value', controller.on_change_menus)
    menu5.on_change('value', controller.on_change_menus)
    menu6.on_change('value', controller.on_change_menus)
    controller.on_change_menus(None, None, None)
def generate_gui(tsne, cut_extracellular_data, all_extra_spike_times, time_axis, cluster_info_file,
                 use_existing_cluster, autocor_bin_number, sampling_freq, prb_file=None, k4=False, verbose=False):

    if k4:
        tsne_figure_size = [1000, 800]
        tsne_min_border_left = 50
        spike_figure_size = [500, 500]
        hist_figure_size = [500, 500]
        heatmap_plot_size = [200, 800]
        clusters_table_size = [400, 300]
        layout_size = [1500, 1400]
        slider_size = [300, 100]
        user_info_size = [700, 80]
    else:
        tsne_figure_size = [850, 600]
        tsne_min_border_left = 10
        spike_figure_size = [450, 300]
        hist_figure_size = [450, 300]
        heatmap_plot_size = [200, 800]
        clusters_table_size = [400, 400]
        layout_size = [1200, 800]
        slider_size = [270, 80]
        user_info_size = [450, 80]
    # Plots ------------------------------
    # scatter plot
    global non_selected_points_alpha
    global selected_points_size
    global non_selected_points_size
    global update_old_selected_switch
    global previously_selected_spike_indices

    tsne_fig_tools = "pan,wheel_zoom,box_zoom,box_select,lasso_select,tap,resize,reset,save"
    tsne_figure = figure(tools=tsne_fig_tools, plot_width=tsne_figure_size[0], plot_height=tsne_figure_size[1],
                         title='T-sne', min_border=10, min_border_left=tsne_min_border_left, webgl=True)

    tsne_source = ColumnDataSource({'tsne-x': tsne[0], 'tsne-y': tsne[1]})

    tsne_selected_points_glyph = Circle(x='tsne-x', y='tsne-y', size=selected_points_size,
                                        line_alpha=0, fill_alpha=1, fill_color='red')
    tsne_nonselected_points_glyph = Circle(x='tsne-x', y='tsne-y', size=non_selected_points_size,
                                           line_alpha=0, fill_alpha=non_selected_points_alpha, fill_color='blue')
    tsne_invisible_points_glyph = Circle(x='tsne-x', y='tsne-y', size=selected_points_size, line_alpha=0, fill_alpha=0)

    tsne_nonselected_glyph_renderer = tsne_figure.add_glyph(tsne_source, tsne_nonselected_points_glyph,
                                                            selection_glyph=tsne_invisible_points_glyph,
                                                            nonselection_glyph=tsne_nonselected_points_glyph,
                                                            name='tsne_nonselected_glyph_renderer')
        # note: the invisible glyph is required to be able to change the size of the selected points, since the
        # use of selection_glyph is usefull only for colors and alphas
    tsne_invinsible_glyph_renderer = tsne_figure.add_glyph(tsne_source, tsne_invisible_points_glyph,
                                                           selection_glyph=tsne_selected_points_glyph,
                                                           nonselection_glyph=tsne_invisible_points_glyph,
                                                           name='tsne_invinsible_glyph_renderer')


    tsne_figure.select(BoxSelectTool).select_every_mousemove = False
    tsne_figure.select(LassoSelectTool).select_every_mousemove = False


    def on_tsne_data_update(attr, old, new):
        global previously_selected_spike_indices
        global currently_selected_spike_indices
        global non_selected_points_alpha
        global non_selected_points_size
        global selected_points_size
        global checkbox_find_clusters_of_selected_points

        previously_selected_spike_indices = np.array(old['1d']['indices'])
        currently_selected_spike_indices = np.array(new['1d']['indices'])
        num_of_selected_spikes = len(currently_selected_spike_indices)

        if num_of_selected_spikes > 0:
            if verbose:
                print('Num of selected spikes = ' + str(num_of_selected_spikes))

            # update t-sne plot
            tsne_invisible_points_glyph.size = selected_points_size
            tsne_nonselected_points_glyph.size = non_selected_points_size
            tsne_nonselected_points_glyph.fill_alpha = non_selected_points_alpha

            # update spike plot
            avg_x = np.mean(cut_extracellular_data[:, :, currently_selected_spike_indices], axis=2)
            spike_mline_plot.data_source.data['ys'] = avg_x.tolist()
            print('Finished avg spike plot')

            # update autocorelogram
            diffs, norm = crosscorrelate_spike_trains(all_extra_spike_times[currently_selected_spike_indices].astype(np.int64),
                                                      all_extra_spike_times[currently_selected_spike_indices].astype(np.int64), lag=1500)
            hist, edges = np.histogram(diffs, bins=autocor_bin_number)
            hist_plot.data_source.data["top"] = hist
            hist_plot.data_source.data["left"] = edges[:-1] / sampling_freq
            hist_plot.data_source.data["right"] = edges[1:] / sampling_freq
            print('finished autocorelogram')

            # update heatmap
            if prb_file is not None:
                print('Doing heatmap')
                data = cut_extracellular_data[:, :, currently_selected_spike_indices]
                final_image, (x_size, y_size) = spike_heatmap.create_heatmap(data, prb_file, rotate_90=True,
                                                                             flip_ud=True, flip_lr=False)
                new_image_data = dict(image=[final_image], x=[0], y=[0], dw=[x_size], dh=[y_size])
                heatmap_data_source.data.update(new_image_data)
                print('Finished heatmap')

    tsne_source.on_change('selected', on_tsne_data_update)

    # spike plot
    spike_fig_tools = 'pan,wheel_zoom,box_zoom,reset,save'
    spike_figure = figure(toolbar_location='below', plot_width=spike_figure_size[0], plot_height=spike_figure_size[1],
                          tools=spike_fig_tools, title='Spike average', min_border=10, webgl=True, toolbar_sticky=False)

    num_of_channels = cut_extracellular_data.shape[0]
    num_of_time_points = cut_extracellular_data.shape[1]
    xs = np.repeat(np.expand_dims(time_axis, axis=0), repeats=num_of_channels, axis=0).tolist()
    ys = np.ones((num_of_channels, num_of_time_points)).tolist()
    spike_mline_plot = spike_figure.multi_line(xs=xs, ys=ys)

    # autocorelogram plot
    hist, edges = np.histogram([], bins=autocor_bin_number)
    hist_fig_tools = 'pan,wheel_zoom,box_zoom,save,reset'

    hist_figure = figure(toolbar_location='below', plot_width=hist_figure_size[0], plot_height=hist_figure_size[1],
                         tools=hist_fig_tools, title='Autocorrelogram', min_border=10, webgl=True, toolbar_sticky=False)
    hist_plot = hist_figure.quad(bottom=0, left=edges[:-1], right=edges[1:], top=hist, color="#3A5785", alpha=0.5,
                                 line_color="#3A5785")
    # heatmap plot
    heatmap_plot = figure(toolbar_location='right', plot_width=1, plot_height=heatmap_plot_size[1],
                          x_range=(0, 1), y_range=(0, 1), title='Probe heatmap',
                          toolbar_sticky=False)
    if prb_file is not None:
        data = np.zeros(cut_extracellular_data.shape)
        final_image, (x_size, y_size) = spike_heatmap.create_heatmap(data, prb_file, rotate_90=True,
                                                                     flip_ud=True, flip_lr=False)
        final_image[:, :, ] = 4294967295  # The int32 for the int8 255 (white)
        plot_width = max(heatmap_plot_size[0], int(heatmap_plot_size[1] * y_size / x_size))
        heatmap_plot = figure(toolbar_location='right', plot_width=plot_width, plot_height=heatmap_plot_size[1],
                              x_range=(0, x_size), y_range=(0, y_size), title='Probe heatmap',
                              toolbar_sticky=False)

        heatmap_data_source = ColumnDataSource(data=dict(image=[final_image], x=[0], y=[0], dw=[x_size], dh=[y_size]))
        heatmap_renderer = heatmap_plot.image_rgba(source=heatmap_data_source, image='image', x='x', y='y',
                                                   dw='dw', dh='dh', dilate=False)
        heatmap_plot.axis.visible = None
        heatmap_plot.xgrid.grid_line_color = None
        heatmap_plot.ygrid.grid_line_color = None
    # ---------------------------------------
    # --------------- CONTROLS --------------
    # Texts and Tables
    # the clusters DataTable
    if use_existing_cluster:
        cluster_info = load_cluster_info(cluster_info_file)
    else:
        cluster_info = create_new_cluster_info_file(cluster_info_file, len(tsne))
    cluster_info_data_source = ColumnDataSource(cluster_info)
    clusters_columns = [TableColumn(field='Cluster', title='Clusters'),
                        TableColumn(field='Num_of_Spikes', title='Number of Spikes')]
    clusters_table = DataTable(source=cluster_info_data_source, columns=clusters_columns, selectable=True,
                               editable=False, width=clusters_table_size[0], height=clusters_table_size[1],
                               scroll_to_selection=True)

    def on_select_cluster_info_table(attr, old, new):
        global selected_cluster_names
        cluster_info = load_cluster_info(cluster_info_file)
        indices = list(chain.from_iterable(cluster_info.iloc[new['1d']['indices']].Spike_Indices.tolist()))
        selected_cluster_names = list(cluster_info.index[new['1d']['indices']])
        old = new = tsne_source.selected
        tsne_source.selected['1d']['indices'] = indices
        tsne_source.trigger('selected', old, new)
        user_info_edit.value = 'Selected clusters = ' + ', '.join(selected_cluster_names)

    cluster_info_data_source.on_change('selected', on_select_cluster_info_table)

    def update_data_table():
        cluster_info_data_source = ColumnDataSource(load_cluster_info(cluster_info_file))
        cluster_info_data_source.on_change('selected', on_select_cluster_info_table)
        clusters_table.source = cluster_info_data_source
        options = list(cluster_info_data_source.data['Cluster'])
        options.insert(0, 'No cluster selected')
        select_cluster_to_move_points_to.options = options

    # cluster TextBox that adds cluster to the DataTable
    new_cluster_name_edit = TextInput(value='give the new cluster a name',
                                      title='Put selected points into a new cluster')

    def on_text_edit_new_cluster_name(attr, old, new):
        global currently_selected_spike_indices
        global clusters_of_all_spikes

        new_cluster_name = new_cluster_name_edit.value

        spike_indices_to_delete_from_existing_clusters = {}
        for spike_index in currently_selected_spike_indices:
            if clusters_of_all_spikes[spike_index] != -1:
                cluster_index = clusters_of_all_spikes[spike_index]
                if cluster_index not in spike_indices_to_delete_from_existing_clusters:
                    spike_indices_to_delete_from_existing_clusters[cluster_index] = [spike_index]
                else:
                    spike_indices_to_delete_from_existing_clusters[cluster_index].append(spike_index)
        cluster_info = load_cluster_info(cluster_info_file)
        for cluster_index in spike_indices_to_delete_from_existing_clusters.keys():
            cluster_name = cluster_info.iloc[cluster_index].name
            remove_spikes_from_cluster(cluster_info_file, cluster_name,
                                       spike_indices_to_delete_from_existing_clusters[cluster_index], unassign=False)

        add_cluster_to_cluster_info(cluster_info_file, new_cluster_name, currently_selected_spike_indices)

        update_data_table()

    new_cluster_name_edit.on_change('value', on_text_edit_new_cluster_name)

    # user information Text
    user_info_edit = TextInput(value='', title='User information',
                               width=user_info_size[0], height=user_info_size[1])

    # Buttons ------------------------
    # show all clusters Button
    button_show_all_clusters = Toggle(label='Show all clusters', button_type='primary')

    def on_button_show_all_clusters(state, *args):
        global tsne_clusters_scatter_plot

        if state:
            cluster_info = load_cluster_info(cluster_info_file)
            num_of_clusters = cluster_info.shape[0]
            indices_list_of_lists = cluster_info['Spike_Indices'].tolist()
            indices = [item for sublist in indices_list_of_lists for item in sublist]
            cluster_indices = np.arange(num_of_clusters)

            if verbose:
                print('Showing all clusters in colors... wait for it...')

            colors = []
            for c in cluster_indices:
                r = np.random.random(size=1) * 255
                g = np.random.random(size=1) * 255
                for i in np.arange(len(indices_list_of_lists[c])):
                    colors.append("#%02x%02x%02x" % (int(r), int(g), 50))

            first_time = True
            for renderer in tsne_figure.renderers:
                if renderer.name == 'tsne_all_clusters_glyph_renderer':
                    renderer.data_source.data['fill_color'] = renderer.data_source.data['line_color'] = colors
                    renderer.glyph.fill_color = 'fill_color'
                    renderer.glyph.line_color = 'line_color'
                    first_time = False
                    break
            if first_time:
                tsne_clusters_scatter_plot = tsne_figure.scatter(tsne[0][indices], tsne[1][indices], size=1,
                                                                 color=colors, alpha=1,
                                                                 name='tsne_all_clusters_glyph_renderer')
            tsne_clusters_scatter_plot.visible = True
            button_show_all_clusters.label = 'Hide all clusters'
        else:
            if verbose:
                print('Hiding clusters')
            button_show_all_clusters.update()
            tsne_clusters_scatter_plot.visible = False
            button_show_all_clusters.label = 'Show all clusters'

    button_show_all_clusters.on_click(on_button_show_all_clusters)


    # select the clusters that the selected points belong to Button
    # (that will then drive the selection of these spikes on t-sne through the update of the clusters_table source)
    button_show_clusters_of_selected_points = Button(label='Show clusters of selected points')

    def on_button_show_clusters_change():
        print('Hello')
        global clusters_of_all_spikes
        currently_selected_spike_indices = tsne_source.selected['1d']['indices']
        cluster_info = load_cluster_info(cluster_info_file)
        clusters_selected = []
        new_indices_to_select = []
        update_data_table()
        for spike_index in currently_selected_spike_indices:
            if clusters_of_all_spikes[spike_index] not in clusters_selected:
                clusters_selected.append(clusters_of_all_spikes[spike_index])
                indices_in_cluster = cluster_info.iloc[clusters_of_all_spikes[spike_index]].Spike_Indices
                new_indices_to_select.append(indices_in_cluster)
        if len(new_indices_to_select) > 0:
            old = clusters_table.source.selected
            clusters_table.source.selected['1d']['indices'] = clusters_selected
            new = clusters_table.source.selected
            clusters_table.source.trigger('selected', old, new)
            for c in np.arange(len(clusters_selected)):
                clusters_selected[c] = cluster_info.index[clusters_selected[c]]


    button_show_clusters_of_selected_points.on_click(on_button_show_clusters_change)

    # merge clusters Button
    button_merge_clusters_of_selected_points = Button(label='Merge clusters of selected points')

    def on_button_merge_clusters_change():
        global clusters_of_all_spikes
        currently_selected_spike_indices = tsne_source.selected['1d']['indices']
        cluster_info = load_cluster_info(cluster_info_file)
        clusters_selected = []
        for spike_index in currently_selected_spike_indices:
            if clusters_of_all_spikes[spike_index] not in clusters_selected:
                clusters_selected.append(clusters_of_all_spikes[spike_index])
        if len(clusters_selected) > 0:
            clusters_selected = np.sort(clusters_selected)
            clusters_selected_names = []
            for cluster_index in clusters_selected:
                clusters_selected_names.append(cluster_info.iloc[cluster_index].name)
            cluster_name = clusters_selected_names[0]
            add_cluster_to_cluster_info(cluster_info_file, cluster_name, currently_selected_spike_indices)
            i = 0
            for c in np.arange(1, len(clusters_selected)):
                cluster_info = remove_cluster_from_cluster_info(cluster_info_file,
                                                                cluster_info.iloc[clusters_selected[c] - i].name,
                                                                unassign=False)
                i = i + 1 # Every time you remove a cluster the original index of the remaining clusters drops by one

            update_data_table()
            user_info_edit.value = 'Clusters '+ ', '.join(clusters_selected_names) + ' merged to cluster ' + cluster_name

    button_merge_clusters_of_selected_points.on_click(on_button_merge_clusters_change)

    # delete cluster Button
    button_delete_cluster = Button(label='Delete selected cluster(s)')

    def on_button_delete_cluster():
        global selected_cluster_names
        for cluster_name in selected_cluster_names:
            remove_cluster_from_cluster_info(cluster_info_file, cluster_name)
        user_info_edit.value = 'Deleted clusters: ' + ', '.join(selected_cluster_names)
        update_data_table()

    button_delete_cluster.on_click(on_button_delete_cluster)

    # select cluster to move selected points to Select
    select_cluster_to_move_points_to = Select(title="Assign selected points to cluster:", value="No cluster selected")

    options = list(cluster_info_data_source.data['Cluster'])
    options.insert(0, 'No cluster selected')
    select_cluster_to_move_points_to.options = options


    def move_selected_points_to_cluster(attr, old, new):
        global currently_selected_spike_indices
        if len(currently_selected_spike_indices) > 0 and new is not 'No cluster selected':
            remove_spikes_from_all_clusters(cluster_info_file, currently_selected_spike_indices)
            add_spikes_to_cluster(cluster_info_file, new, currently_selected_spike_indices)
            update_data_table()
            select_cluster_to_move_points_to.value = 'No cluster selected'
            user_info_edit.value = 'Selected clusters = ' + new

    select_cluster_to_move_points_to.on_change('value', move_selected_points_to_cluster)


    # undo selection button
    undo_selected_points_button = Button(label='Undo last selection')

    def on_button_undo_selection():
        global previously_selected_spike_indices
        tsne_source.selected['1d']['indices'] = previously_selected_spike_indices
        old = new = tsne_source.selected
        tsne_source.trigger('selected', old, new)

    undo_selected_points_button.on_click(on_button_undo_selection)

    # Sliders -------------------
    # use the fake data trick to call the callback only when the mouse is released (mouseup only works for CustomJS)

    # change visibility of non selected points Slider
    slider_non_selected_visibility = Slider(start=0, end=1, value=0.2, step=.02, callback_policy='mouseup',
                                            title='Alpha of not selected points',
                                            width=slider_size[0], height=slider_size[1])

    def on_slider_change_non_selected_visibility(attrname, old, new):
        global non_selected_points_alpha
        if len(source_fake_nsv.data['value']) > 0:
            non_selected_points_alpha = source_fake_nsv.data['value'][0]
            old = new = tsne_source.selected
            tsne_source.trigger('selected', old, new)

    source_fake_nsv = ColumnDataSource(data=dict(value=[]))
    source_fake_nsv.on_change('data', on_slider_change_non_selected_visibility)

    slider_non_selected_visibility.callback = CustomJS(args=dict(source=source_fake_nsv),
                                                       code="""
                                                            source.data = { value: [cb_obj.value] }
                                                            """)

    # change size of non selected points Slider
    slider_non_selected_size = Slider(start=0.5, end=10, value=2, step=0.5, callback_policy='mouseup',
                                      title='Size of not selected points',
                                      width=slider_size[0], height=slider_size[1])

    def on_slider_change_non_selected_size(attrname, old, new):
        global non_selected_points_size
        if len(source_fake_nss.data['value']) > 0:
            non_selected_points_size = source_fake_nss.data['value'][0]
            old = new = tsne_source.selected
            tsne_source.trigger('selected', old, new)

    source_fake_nss = ColumnDataSource(data=dict(value=[]))
    source_fake_nss.on_change('data', on_slider_change_non_selected_size)

    slider_non_selected_size.callback = CustomJS(args=dict(source=source_fake_nss),
                                                 code="""
                                                      source.data = { value: [cb_obj.value] }
                                                      """)

    # change size of selected points Slider
    slider_selected_size = Slider(start=0.5, end=10, value=2, step=0.5, callback_policy='mouseup',
                                  title='Size of selected points',
                                  width=slider_size[0], height=slider_size[1])

    def on_slider_change_selected_size(attrname, old, new):
        global selected_points_size
        if len(source_fake_ss.data['value']) > 0:
            selected_points_size = source_fake_ss.data['value'][0]
            old = new = tsne_source.selected
            tsne_source.trigger('selected', old, new)

    source_fake_ss = ColumnDataSource(data=dict(value=[]))
    source_fake_ss.on_change('data', on_slider_change_selected_size)

    slider_selected_size.callback = CustomJS(args=dict(source=source_fake_ss),
                                             code="""
                                                  source.data = { value: [cb_obj.value] }
                                                  """)

    # -------------------------------------------

    # Layout and session setup ------------------
    # align and make layout
    spike_figure.min_border_top = 50
    spike_figure.min_border_right = 10
    hist_figure.min_border_top = 50
    hist_figure.min_border_left = 10
    tsne_figure.min_border_right = 50

    if k4:
        lay = row(column(tsne_figure,
                         row(slider_non_selected_visibility, slider_non_selected_size, slider_selected_size),
                         row(spike_figure, hist_figure),
                         user_info_edit),
                 column(clusters_table,
                        button_show_clusters_of_selected_points,
                        button_merge_clusters_of_selected_points,
                        button_delete_cluster,
                        select_cluster_to_move_points_to,
                        new_cluster_name_edit,
                        button_show_all_clusters,
                        undo_selected_points_button,
                        heatmap_plot))
    else:
        lay = row(column(tsne_figure,
                         row(spike_figure, hist_figure)),
                  column(row(heatmap_plot, column(slider_non_selected_visibility,
                                                  slider_non_selected_size,
                                                  slider_selected_size)),
                         user_info_edit),
                  column(clusters_table,
                         button_show_clusters_of_selected_points,
                         button_merge_clusters_of_selected_points,
                         button_delete_cluster,
                         select_cluster_to_move_points_to,
                         new_cluster_name_edit,
                         button_show_all_clusters,
                         undo_selected_points_button))


    session = push_session(curdoc())
    session.show(lay)  # open the document in a browser
    session.loop_until_closed()  # run forever, requires stopping the interpreter in order to stop :)
def tab_analysis(csv):

    csv_original = csv

    g = csv_original.columns.to_series().groupby(csv_original.dtypes).groups
    g_list = list(g.keys())

    t = Figure()

    def convert(val, target):
        val_type = str(type(val))

        if ('float' in val_type):
            return float(target)
        elif ('int' in val_type):
            return int(target)
        elif ('str' in val_type):
            return str(target)

    box_figure = figure(tools="save",
                        background_fill_color="#EFE8E2",
                        title="Box",
                        plot_width=500,
                        plot_height=500,
                        toolbar_location="below",
                        x_range=[])
    box_figure.add_tools(WheelZoomTool())
    box_figure.add_tools(PanTool())

    corr_figure = figure(plot_width=500,
                         plot_height=500,
                         title="Correlation",
                         toolbar_location=None,
                         tools="",
                         x_axis_location="above",
                         x_range=[],
                         y_range=[])

    def make_box_plot(df, param_list):
        df_box = pd.DataFrame(columns=['group', 'value'])

        for col in param_list:
            temp = pd.DataFrame(columns=['group', 'value'])
            temp['value'] = df[col].values
            temp['group'] = col

            df_box = pd.concat([df_box, temp])

        cats = param_list

        groups = df_box.groupby('group')
        q1 = groups.quantile(q=0.25)
        q2 = groups.quantile(q=0.5)
        q3 = groups.quantile(q=0.75)
        iqr = q3 - q1
        upper = q3 + 1.5 * iqr
        lower = q1 - 1.5 * iqr

        # find the outliers for each category
        def outliers(group):
            cat = group.name
            return group[(group.value > upper.loc[cat]['value']) |
                         (group.value < lower.loc[cat]['value'])]['value']

        out = groups.apply(outliers).dropna()

        # prepare outlier data for plotting, we need coordinates for every outlier.
        if not out.empty:
            outx = []
            outy = []
            for cat in cats:
                # only add outliers if they exist
                if not out.loc[cat].empty:
                    for value in out[cat]:
                        outx.append(cat)
                        outy.append(value)

        box_figure.x_range.factors = cats

        # if no outliers, shrink lengths of stems to be no longer than the minimums or maximums
        qmin = groups.quantile(q=0.00)
        qmax = groups.quantile(q=1.00)
        upper.value = [
            min([x, y])
            for (x, y) in zip(list(qmax.loc[:, 'value']), upper.value)
        ]
        lower.value = [
            max([x, y])
            for (x, y) in zip(list(qmin.loc[:, 'value']), lower.value)
        ]

        # stems
        box_figure.segment(cats,
                           upper.value,
                           cats,
                           q3.value,
                           line_color="black")
        box_figure.segment(cats,
                           lower.value,
                           cats,
                           q1.value,
                           line_color="black")

        # boxes
        box_figure.vbar(cats,
                        0.7,
                        q2.value,
                        q3.value,
                        fill_color="#E08E79",
                        line_color="black")
        box_figure.vbar(cats,
                        0.7,
                        q1.value,
                        q2.value,
                        fill_color="#3B8686",
                        line_color="black")

        # whiskers (almost-0 height rects simpler than segments)
        box_figure.rect(cats, lower.value, 0.2, 0.01, line_color="black")
        box_figure.rect(cats, upper.value, 0.2, 0.01, line_color="black")

        # outliers
        if not out.empty:
            box_figure.circle(outx,
                              outy,
                              size=6,
                              color="#F38630",
                              fill_alpha=0.6)

        box_figure.xgrid.grid_line_color = None
        box_figure.ygrid.grid_line_color = "white"
        box_figure.grid.grid_line_width = 2
        box_figure.xaxis.major_label_text_font_size = "12pt"

    def make_correlation_plot(df, param_list):
        df_corr = df[param_list].corr().fillna(0)
        df_corr = df_corr.stack().rename("value").reset_index()

        print(df_corr)

        colors = RdBu[11]

        # Had a specific mapper to map color with value
        mapper = LinearColorMapper(palette=colors, low=-1, high=1)

        corr_figure.x_range.factors = list(df_corr.level_0.drop_duplicates())
        corr_figure.y_range.factors = list(df_corr.level_1.drop_duplicates())

        hover = HoverTool(tooltips=[
            ("Corr", "@value"),
        ])

        # Create rectangle for heatmap
        corr_figure.rect(x="level_0",
                         y="level_1",
                         width=1,
                         height=1,
                         source=ColumnDataSource(df_corr),
                         line_color=None,
                         fill_color=transform('value', mapper))
        corr_figure.add_tools(hover)

        # Add legend
        color_bar = ColorBar(color_mapper=mapper, location=(0, 0))

        corr_figure.add_layout(color_bar, 'left')

    box_cor_x = MultiSelect(title="Predictor")
    button_box_corr = Button(label="Analysis response")

    def box_corr_handler():

        param_list = box_cor_x.value

        make_box_plot(csv_original, param_list)
        make_correlation_plot(csv_original, param_list)

    button_box_corr.on_click(box_corr_handler)

    param_key = MultiSelect(title="Separator(Maximum 2)")
    param_key.options = list(csv_original.columns)

    param_x = MultiSelect(title="Predictor")
    param_x.options = list(csv_original.columns)

    param_y = Select(title="Response")
    param_y.options = list(csv_original.columns)

    button_set = Button(label="Set parameter")

    key1 = MultiSelect(title="Key 1")
    key2 = MultiSelect(title="Key 2")

    target_x = Select(title="Sensor")
    #show_option = RadioGroup(labels=["Raw", "Moving average"], active=0)
    show_option = CheckboxGroup(labels=["Raw", "Moving average"],
                                active=[0, 1])
    average_select = Slider(start=2,
                            end=30,
                            value=5,
                            step=1,
                            title='Average window')

    # 3rd row
    target_reduction = MultiSelect(title="Target for dimension reduction")
    reduction_method = Select(title="Dimension reduction",
                              options=["PCA", "Autoencoder"])
    button_reduction = Button(label="Show result")
    figure_reduction = figure(tools="save, lasso_select",
                              title="Dimension reduction result",
                              plot_width=500,
                              plot_height=500,
                              toolbar_location="below")

    src = ColumnDataSource(data=dict(x=[], y=[], time=[]))
    # color_mapper = LinearColorMapper(palette='Viridis256', low=min(csv_original['group_index'].values), high=max(csv_original['group_index'].values))
    color_mapper = LinearColorMapper(palette='Viridis256', low=0, high=1000)
    figure_reduction.circle('x',
                            'y',
                            source=src,
                            size=5,
                            color={
                                'field': 'time',
                                'transform': color_mapper
                            })
    TOOLTIPS = [("(x,y)", "($x, $y)"), ("Time", "@time")]
    figure_reduction.add_tools(HoverTool(tooltips=TOOLTIPS))

    src_reduction = ColumnDataSource(
        data=dict(center_x=[0], center_y=[0], radius=[0]))
    figure_reduction.circle("center_x",
                            "center_y",
                            radius="radius",
                            source=src_reduction,
                            alpha=0.3)

    def set_handler():

        if (len(param_key.value) >= 1):
            key1.options = list(
                map(lambda x: str(x),
                    csv_original[param_key.value[0]].unique()))

        if (len(param_key.value) >= 2):
            key2.options = list(
                map(lambda x: str(x),
                    csv_original[param_key.value[1]].unique()))

        if (len(param_key.value) != 0):
            csv_original['group_index'] = csv_original.groupby(
                param_key.value).cumcount(
                ) + 1  # Index per group -> consider it as time flow
        else:
            csv_original['group_index'] = [
                i + 1 for i in range(csv_original.shape[0])
            ]

        x_list = []
        for col in param_x.value:
            if (csv_original[col].std() <= 0.0):
                continue

            x_list.append(col)

        target_x.options = x_list
        target_reduction.options = x_list
        box_cor_x.options = x_list

    button_set.on_click(set_handler)

    figure_multi_line = figure(tools="save",
                               title="Sensor value per key",
                               plot_width=1000,
                               plot_height=500,
                               toolbar_location="below")
    src1 = ColumnDataSource()
    button_sensor = Button(label="Show values")

    def sensor_hander():

        xs = []
        ys = []
        label_key = []
        colors = []
        line_width = []
        rolling_mean = int(average_select.value)

        if (target_x.value == ""):
            target_x.value = target_x.options[0]

        if (len(param_key.value) == 0):

            if (0 in show_option.active):

                y = csv_original[target_x.value].values
                x = np.arange(y.shape[0])
                xs.append(x)
                ys.append(y)
                label_key.append(str(target_x.value))
                colors.append(0)
                line_width.append(1)

            if (1 in show_option.active):

                y = csv_original[target_x.value].rolling(
                    window=rolling_mean).mean().fillna(method='ffill').values
                x = np.arange(y.shape[0])
                xs.append(x)
                ys.append(y)
                label_key.append(str(target_x.value))
                colors.append(2)
                line_width.append(3)

        elif (len(param_key.value) == 1):
            cond1 = csv_original[param_key.value[0]].isin(key1.value)
            csv_slice = csv_original[cond1]

            for group in key1.value:
                if (0 in show_option.active):

                    print(csv_slice[param_key.value[0]].head())
                    print(csv_slice[param_key.value[0]].iloc[0])
                    print()

                    group_convert = convert(
                        csv_slice[param_key.value[0]].iloc[0], group)

                    y = csv_slice[csv_slice[param_key.value[0]] ==
                                  group_convert][target_x.value].values
                    x = np.arange(y.shape[0])
                    xs.append(x)
                    ys.append(y)
                    label_key.append(str(group_convert))
                    colors.append(group_convert)
                    line_width.append(1)

                if (1 in show_option.active):

                    group_convert = convert(
                        csv_slice[param_key.value[0]].iloc[0], group)

                    y = csv_slice[csv_slice[param_key.value[0]] ==
                                  group_convert][target_x.value].rolling(
                                      window=rolling_mean).mean().fillna(
                                          method='ffill').values
                    x = np.arange(y.shape[0])
                    xs.append(x)
                    ys.append(y)
                    label_key.append(str(group_convert))
                    colors.append(group_convert)
                    line_width.append(3)

        elif (len(param_key.value) == 2):
            cond1 = csv_original[param_key.value[0]].isin(key1.value)
            cond2 = csv_original[param_key.value[1]].isin(key2.value)
            csv_slice = csv_original[cond1 & cond2]

            # need type check

            for group1 in key1.value:
                for group2 in key2.value:
                    if (0 in show_option.active):

                        group_convert1 = convert(
                            csv_slice[param_key.value[0]].iloc[0], group1)
                        group_convert2 = convert(
                            csv_slice[param_key.value[1]].iloc[0], group2)

                        target_cond1 = csv_slice[
                            param_key.value[0]] == group_convert1
                        target_cond2 = csv_slice[
                            param_key.value[1]] == group_convert2

                        y = csv_slice[target_cond1
                                      & target_cond2][target_x.value].values
                        x = np.arange(y.shape[0])
                        xs.append(x)
                        ys.append(y)
                        label_key.append(
                            str(group_convert1) + " / " + str(group_convert2))
                        colors.append([group_convert1, group_convert2])
                        line_width.append(1)

                    if (1 in show_option.active):

                        group_convert1 = convert(
                            csv_slice[param_key.value[0]].iloc[0], group1)
                        group_convert2 = convert(
                            csv_slice[param_key.value[1]].iloc[0], group2)

                        target_cond1 = csv_slice[
                            param_key.value[0]] == group_convert1
                        target_cond2 = csv_slice[
                            param_key.value[1]] == group_convert2

                        y = csv_slice[target_cond1
                                      & target_cond2][target_x.value].rolling(
                                          window=rolling_mean).mean().fillna(
                                              method='ffill').values
                        x = np.arange(y.shape[0])
                        xs.append(x)
                        ys.append(y)
                        label_key.append(
                            str(group_convert1) + " / " + str(group_convert2))
                        colors.append([group_convert1, group_convert2])
                        line_width.append(3)

        color_all = [i for i in range(len(colors))]
        src1.data = ColumnDataSource(data=dict(xs=xs,
                                               ys=ys,
                                               label=label_key,
                                               color_all=color_all,
                                               line_width=line_width)).data

        figure_multi_line.multi_line('xs',
                                     'ys',
                                     legend='label',
                                     source=src1,
                                     color=linear_cmap('color_all',
                                                       "Viridis256", 0,
                                                       len(colors) - 1),
                                     line_width="line_width")
        ###
        TOOLTIPS = [
            ("Keys", "@label"),
        ]
        figure_multi_line.add_tools(HoverTool(tooltips=TOOLTIPS))
        ###

    button_sensor.on_click(sensor_hander)

    def reduction_handler():

        print(reduction_method.options)
        print(reduction_method.value)

        x = csv_original[target_reduction.value].values

        if (reduction_method.value == "" or reduction_method.value == "PCA"):

            pca = decomposition.PCA(n_components=2)
            scaler = preprocessing.MinMaxScaler()

            result = pca.fit_transform(x)

            r_x = scaler.fit_transform(X=np.expand_dims(result[:, 0], -1))
            r_y = scaler.fit_transform(X=np.expand_dims(result[:, 1], -1))

            src.data = ColumnDataSource(data=dict(
                x=r_x, y=r_y, time=csv_original['group_index'].values)).data

            csv_original['reduction_1'] = r_x
            csv_original['reduction_2'] = r_y

            result_x_mean = np.mean(r_x)
            result_y_mean = np.mean(r_y)

            msg = "X center: " + str(
                result_x_mean) + "\n" + "_Y center: " + str(result_y_mean)

        else:
            print("Autoencoder")
            scaler = preprocessing.MinMaxScaler()

            result = ae.auto_encoder(x, 2, epoch=50)
            r_x = scaler.fit_transform(X=np.expand_dims(result[:, 0], -1))
            r_y = scaler.fit_transform(X=np.expand_dims(result[:, 1], -1))

            src.data = ColumnDataSource(data=dict(
                x=r_x, y=r_y, time=csv_original['group_index'].values)).data

            csv_original['reduction_1'] = r_x
            csv_original['reduction_2'] = r_y

            result_x_mean = np.mean(r_x)
            result_y_mean = np.mean(r_y)

            msg = "X center: " + str(
                result_x_mean) + "\n" + " Y_center: " + str(result_y_mean)

        color_mapper.low = min(csv_original['group_index'].values)
        color_mapper.high = max(csv_original['group_index'].values)
        src_reduction.data = ColumnDataSource(
            data=dict(center_x=[0], center_y=[0], radius=[0])).data
        ##################################################################################################################################### 컬럼 바꾸기
        csv_original['Classification_cutPoint'] = np.NaN
        csv_original['Classification_manualSelect'] = np.NaN

    button_reduction.on_click(reduction_handler)

    set_y_select = Dropdown(label="Select Y",
                            menu=[
                                ("Linear", "item_1"), ("Weillibul", "item_2"),
                                ("Piecewise", "item_3"), None,
                                ("Dimension-reduction based select", "item_4")
                            ])
    new_y_setter_piecewise = TextInput(value="150",
                                       title="Piecewise cut point")
    new_y_setter_x = TextInput(value="", title="Regression center X")
    new_y_setter_y = TextInput(value="", title="Regression center Y")
    new_y_setter_r = TextInput(value="",
                               title="Radius(out-of-bound value equal 0)")
    button_new_y = Button(label="Set new regression label")

    def set_y_handler():
        def y_linear(length):
            y = [i for i in range(length)]
            y.reverse()

            return y

        def y_weillibul(length, k=1.5, lmd=0.00002):
            dist_length = 1000
            y_ = [
                math.pow(math.e, -lmd * math.pow(i, k))
                for i in range(dist_length + 1)
            ]
            y_end = y_[dist_length]

            return [(y_[int(
                (dist_length / length) * i)] - y_end) / (1.0 - y_end) * length
                    for i in range(length)]

        def y_piecewise(length, const=130):

            y = [i for i in range(length)]
            y.reverse()
            y = [x if x < const else const for x in y]

            return y

        if (set_y_select.value == "item_1"):

            ys = []

            if (len(param_key.value) == 1):
                key1 = csv_original[param_key.value[0]].unique()
                for k in key1:
                    ys += y_linear(csv_original[csv_original[
                        param_key.value[0]] == k].shape[0])

            elif (len(param_key.value) == 2):
                key1 = csv_original[param_key.value[0]].unique()
                for k in key1:
                    csv_temp = csv_original[csv_original[param_key.value[0]] ==
                                            k]
                    key2 = csv_temp[param_key.value[1]].unique()

                    for k2 in key2:
                        ys += y_linear(csv_temp[csv_temp[param_key.value[1]] ==
                                                k2].shape[0])

            csv_original['Regression_linear'] = ys

            msg = "Set Y value: Linear"
            message_y.text = msg

        if (set_y_select.value == "item_2"):

            ys = []

            if (len(param_key.value) == 1):
                key1 = csv_original[param_key.value[0]].unique()
                for k in key1:
                    ys += y_weillibul(csv_original[csv_original[
                        param_key.value[0]] == k].shape[0])

            elif (len(param_key.value) == 2):
                key1 = csv_original[param_key.value[0]].unique()
                for k in key1:
                    csv_temp = csv_original[csv_original[param_key.value[0]] ==
                                            k]
                    key2 = csv_temp[param_key.value[1]].unique()

                    for k2 in key2:
                        ys += y_weillibul(csv_temp[csv_temp[param_key.value[1]]
                                                   == k2].shape[0])

            csv_original['Regression_weillibul'] = ys

            msg = "Set Y value: Weillibul"
            message_y.text = msg

        if (set_y_select.value == "item_3"):

            ys = []

            if (len(param_key.value) == 1):
                key1 = csv_original[param_key.value[0]].unique()
                for k in key1:
                    ys += y_piecewise(
                        csv_original[csv_original[param_key.value[0]] ==
                                     k].shape[0],
                        int(new_y_setter_piecewise.value))

            elif (len(param_key.value) == 2):
                key1 = csv_original[param_key.value[0]].unique()
                for k in key1:
                    csv_temp = csv_original[csv_original[param_key.value[0]] ==
                                            k]
                    key2 = csv_temp[param_key.value[1]].unique()

                    for k2 in key2:
                        ys += y_piecewise(
                            csv_temp[csv_temp[param_key.value[1]] ==
                                     k2].shape[0],
                            int(new_y_setter_piecewise.value))

            csv_original['Regression_piecewise'] = ys

            msg = "Set Y value: Piecewise"
            message_y.text = msg

        if (set_y_select.value == "item_4"):
            center_x = float(new_y_setter_x.value)
            center_y = float(new_y_setter_y.value)
            radius = float(new_y_setter_r.value)
            """ add circle to the figure """
            # figure_reduction.circle([center_x], [center_y], radius=radius, alpha=0.3)

            src_reduction.data = ColumnDataSource(data=dict(
                center_x=[center_x], center_y=[center_y], radius=[radius
                                                                  ])).data

            result_val = csv_original[['reduction_1', 'reduction_2']].values
            result_y = list(
                map(
                    lambda xy: 1.0 - np.sqrt((center_x - xy[0])**2 +
                                             (center_y - xy[1])**2) / radius,
                    result_val))
            result_y = [x if x >= 0.0 else 0 for x in result_y]

            csv_original['Regression_manualSelect'] = result_y

            #csv_original.to_csv('temp.csv')

            msg = "Set Y value: From reduction"
            message_y.text = msg

    button_new_y.on_click(set_y_handler)

    new_class_cut = TextInput(value="", title="Enter class cut point")
    new_class_setter = TextInput(value="", title="Enter class name")

    set_y_clss_select = Dropdown(label="Select Y",
                                 menu=[("Manual cut", "item_1"), None,
                                       ("Dimension-reduction based select",
                                        "item_2")])

    button_add_label = Button(label="Add classification label")
    button_new_class = Button(label="Set classification label")

    labels = []

    def label_adder_handler():
        indices = src.selected['1d']['indices']
        print(indices[:10])

        if (len(indices) == 0):
            label_notifier.text = "Non selected"
            return

        csv_original['Classification_manualSelect'].iloc[
            indices] = new_class_setter.value
        labels.append(new_class_setter.value)
        print("add label")
        label_notifier.text = str(labels)

    button_add_label.on_click(label_adder_handler)

    def label_all_handler():

        if (set_y_clss_select.value == "item_1"):
            class_cut = new_class_cut.value
            class_cut = list(map(int, class_cut.split(',')))

            # csv_original['Class'] = pd.np.digitize(csv_original['group_index'], bins=class_cut).astype(str)
            csv_original['Classification_cutPoint'] = pd.np.digitize(
                csv_original.groupby(param_key.value)['group_index'].transform(
                    lambda x: x[::-1]),
                bins=class_cut).astype(str)

            label_notifier.text = "Labeling complete \n\nLabel: " + str(
                class_cut)
            del labels[:]

        elif (set_y_clss_select.value == "item_2"):
            if (csv_original['Classification_manualSelect'].isnull().any().any(
            )):
                print("There is NaN values")

                csv_original[
                    'Classification_manualSelect'] = csv_original.fillna(
                        method='ffill')['Classification_manualSelect'].values

            label_notifier.text = "Labeling complete \n\nLabel: " + str(labels)
            del labels[:]

        print(csv_original['Classification_cutPoint'])
        print(csv_original['Classification_manualSelect'])

    button_new_class.on_click(label_all_handler)

    head_regression = Div(text=""" <b>Set Regression Label</b> """)
    head_classification = Div(text=""" <b>Set Class Label</b> """)
    label_notifier = Paragraph(text=""" - """)

    button_export = Button(label="Export CSV")

    def handler_export():
        csv_original.to_csv('./Export/exported.csv', index=False)

    button_export.on_click(handler_export)

    message_y = Paragraph(text=""" - """, width=200, height=200)

    layout = Column(
        Row(param_key, param_x, param_y, button_set),
        Row(Column(box_cor_x, button_box_corr), box_figure, corr_figure),
        Row(Column(key1, key2), figure_multi_line,
            Column(target_x, show_option, average_select, button_sensor)),
        Row(
            Column(reduction_method, target_reduction, button_reduction),
            figure_reduction,
            Column(head_regression, set_y_select, new_y_setter_piecewise,
                   new_y_setter_x, new_y_setter_y, new_y_setter_r,
                   button_new_y, message_y),
            Column(head_classification, set_y_clss_select, new_class_cut,
                   new_class_setter, button_add_label, button_new_class,
                   label_notifier), button_export))

    tab = Panel(child=layout, title='Analysis')

    return tab