コード例 #1
0
class CheckboxButtonGroupComponent(ComponentMixin):
    def __init__(self, checkbox_kwargs):
        super().__init__()
        self.checkbox = CheckboxButtonGroup(**checkbox_kwargs)
        self.checkbox_change = None
        self.layout = self.checkbox

    def set_mediator(self, mediator):
        super().set_mediator(mediator)
        event_name = 'checkbox-change'
        self.checkbox_change = self.make_attr_old_new_callback(event_name)
        self.checkbox.on_change('active', self.checkbox_change)
コード例 #2
0
 def create_viewer(self):
     runs = []
     for f_name, run_info in self.data_sources.items():
         # Construct button and apply callback to them.
         run_button = Button(label=f_name,
                             name=f_name,
                             button_type='primary')
         run_button.on_click(self._show_tags(f_name))
         runs.append(run_button)
         # Generate check boxes for tags.
         tag_names = list(run_info["tags"].keys())
         tags = CheckboxButtonGroup(labels=tag_names)
         tags.visible = False
         # tags.on_click(self._plot_tag(f_name, tag_names))
         tags.on_change('active', self._plot_tag(f_name, tag_names))
         runs.append(tags)
         # Captrue both button and tag models for later modificaiton.
         self.data_sources[f_name]["run_button"] = run_button
         self.data_sources[f_name]["tag_boxes"] = tags
         self.data_sources[f_name]["loaded"] = False
     return column(runs)
コード例 #3
0
def setup_graph():
    global max_iter, device

    tsne_source = ColumnDataSource(pd.DataFrame(columns=["x", "y", "color"]))
    loss_source = ColumnDataSource(pd.DataFrame(columns=["iteration", "loss"]))

    TOOLTIPS = [("label", "@color")]

    title = "tSNE on 2500 MNIST digits using {} device".format(device)

    p = figure(
        width=700,
        height=700,
        match_aspect=True,
        tooltips=TOOLTIPS,
        name="tsne_graph",
        title=title,
        x_range=(-1.5, 1.5),
        y_range=(-1.5, 1.5),
    )
    p.title.align = "center"
    q = figure(
        width=350,
        height=200,
        name="loss_graph",
        x_range=(0, max_iter),
        y_range=(0, 15),
        title="Loss vs iterations",
        tools="save",
        title_location="below",
    )
    q.title.align = "center"

    r = p.circle(
        x="x",
        y="y",
        fill_color=linear_cmap("color", palette=Spectral10, low=0, high=9),
        source=tsne_source,
        size=5,
        name="tsne_glyphs",
    )

    s = q.line(
        x="iteration",
        y="loss",
        source=loss_source,
        line_width=3,
        line_color="red",
        name="loss_glyphs",
    )

    notice = Label(
        x=50,
        y=300,
        x_units="screen",
        y_units="screen",
        text="Thinking...",
        name="notice",
        text_color="red",
    )
    p.add_layout(notice)

    button = Button(
        label="Computing Distances in Feature Space",
        button_type="success",
        name="go_button",
        disabled=True,
        width=310,
    )
    button.on_click(no_op)

    slider = Slider(
        start=500,
        end=20000,
        value=1500,
        step=100,
        title="Number of Iterations",
        name="iter_slider",
        width=310,
    )

    slider.on_change("value", slider_callback)

    cb = CheckboxButtonGroup(labels=[str(x) for x in range(10)],
                             active=[],
                             name="hide_buttons",
                             orientation="horizontal",
                             disabled=True,
                             width=230)
    cb.on_change("active", cb_trigger_callback)

    note = Div(text="<em>Click the buttons to focus the view</em>")

    doc.add_root(
        column(p, row(Spacer(width=20), widgetbox(button, slider, cb, note),
                      q)))

    return
コード例 #4
0
class PlayerWidget:
    def __init__(self, parent=None):
        self.parent = parent
        self.fs = main_config['fs']
        self.n_channels = main_config['n_channels']
        self.t0 = 0
        self.last_ts = 0

        # Game log reader (separate thread)
        self.game_logs_path = game_config['game_logs_path']
        self.game_log_reader = None
        self._expected_action = (0, 'Rest')
        self.thread_log = QtCore.QThreadPool()

        # Game window (separate process)
        clean_log_directory(self.game_logs_path)
        self.game = None

        # Port event sender
        self.micro_path = main_config['micro_path']
        self.port_sender = None

        # Game player
        self.game_path = game_config['game_path']
        self.player_idx = game_config['player_idx']
        self.game_start_time = None

        # Chronogram
        self.chrono_source = ColumnDataSource(dict(ts=[],
                                                   y_true=[],
                                                   y_pred=[]))
        self.pred_decoding = main_config['pred_decoding']

        # LSL stream reader
        self.lsl_reader = None
        self.lsl_start_time = None
        self.thread_lsl = QtCore.QThreadPool()
        self.channel_source = ColumnDataSource(dict(ts=[], eeg=[]))
        self._lsl_data = (None, None)

        # LSL stream recorder
        if not os.path.isdir(main_config['record_path']):
            os.mkdir(main_config['record_path'])
        self.record_path = main_config['record_path']
        self.record_name = game_config['record_name']
        self.lsl_recorder = None

        # Predictor
        self.models_path = main_config['models_path']
        self.input_signal = np.zeros((self.n_channels, 4 * self.fs))
        self.predictor = None
        self.thread_pred = QtCore.QThreadPool()
        self._pred_action = (0, 'Rest')

    @property
    def lsl_data(self):
        return self._lsl_data

    @lsl_data.setter
    def lsl_data(self, data):
        self._lsl_data = data

        # Memorize the most recent timestamp
        ts, eeg = data
        self.last_ts = ts[-1]

        # Record signal
        if self.lsl_recorder is not None:
            self.lsl_recorder.save_data(copy.deepcopy(ts), copy.deepcopy(eeg))

        self.parent.add_next_tick_callback(self.update_signal)

    @property
    def pred_action(self):
        return self._pred_action

    @pred_action.setter
    def pred_action(self, val_tuple):
        self._pred_action = val_tuple
        if self.game_start_time is not None:
            self.parent.add_next_tick_callback(self.update_prediction)

    @property
    def expected_action(self):
        return self._expected_action

    @expected_action.setter
    def expected_action(self, action):
        logging.info(f'Receiving groundtruth from logs: {action}')
        self._expected_action = copy.deepcopy(action)

        # In autoplay, we directly update the model prediction (no delay)
        if self.modelfile == 'AUTOPLAY':
            self._pred_action = copy.deepcopy(action)

        self.parent.add_next_tick_callback(self.update_groundtruth)

    @property
    def available_logs(self):
        logs = list(self.game_logs_path.glob(game_config['game_logs_pattern']))
        return sorted(logs)

    @property
    def game_is_on(self):
        if self.game is not None:
            # Poll returns None when game process is running and 0 otherwise
            return self.game.poll() is None
        else:
            return False

    @property
    def should_record(self):
        return 'Record' in self.selected_settings

    @property
    def available_models(self):
        ml_models = [p.name for p in self.models_path.glob('*.pkl')]
        dl_models = [p.name for p in self.models_path.glob('*.h5')]
        return ['AUTOPLAY'] + ml_models + dl_models

    @property
    def model_name(self):
        return self.select_model.value

    @property
    def modelfile(self):
        if self.select_model.value == 'AUTOPLAY':
            return 'AUTOPLAY'
        else:
            return self.models_path / self.select_model.value

    @property
    def is_convnet(self):
        return self.select_model.value.split('.')[-1] == 'h5'

    @property
    def available_ports(self):
        if sys.platform == 'linux':
            ports = self.micro_path.glob('*')
            return [''] + [p.name for p in ports]
        elif sys.platform == 'win32':
            return [''] + [p.device for p in serial.tools.list_ports.comports()]

    @property
    def sending_events(self):
        return 'Send events' in self.selected_settings

    @property
    def channel_idx(self):
        return int(self.select_channel.value.split('-')[0])

    @property
    def selected_settings(self):
        active = self.checkbox_settings.active
        return [self.checkbox_settings.labels[i] for i in active]

    @property
    def accuracy(self):
        y_pred = self.chrono_source.data['y_pred']
        y_true = self.chrono_source.data['y_true']
        return accuracy_score(y_true, y_pred)

    def reset_lsl(self):
        if self.lsl_reader:
            self.lsl_reader.should_stream = False
            self.lsl_reader = None
            self.lsl_start_time = None
            self.thread_lsl.clear()

    def reset_predictor(self):
        if self.predictor:
            self.predictor.should_predict = False
            self.predictor = None
            self.thread_pred.clear()

    def reset_recorder(self):
        if self.lsl_recorder:
            self.lsl_recorder.close_h5()
            self.lsl_recorder = None

    def reset_plots(self):
        self.chrono_source.data = dict(ts=[], y_pred=[], y_true=[])
        self.channel_source.data = dict(ts=[], eeg=[])
        self.gd_info.text = ''
        self.pred_info.text = ''
        self.acc_info.text = ''

    def reset_game(self):
        self.game.kill()
        self.game = None

    def reset_log_reader(self):
        logging.info('Delete old log reader')
        self.thread_log.clear()
        self.game_log_reader = None

    def on_model_change(self, attr, old, new):
        logging.info(f'Select new pre-trained model {new}')
        self.select_model.options = self.available_models
        self.model_info.text = f'<b>Model:</b> {new}'
        self.parent.add_next_tick_callback(self.start_predictor_thread)

    def on_select_port(self, attr, old, new):
        logging.info(f'Select new port: {new}')

        if self.port_sender is not None:
            logging.info('Delete old log reader')
            self.port_sender = None

        logging.info(f'Instanciate port sender {new}')
        self.port_sender = CommandSenderPort(new)

    def on_channel_change(self, attr, old, new):
        logging.info(f'Select new channel {new}')
        self.channel_source.data['eeg'] = []
        self.plot_stream.yaxis.axis_label = f'Amplitude ({new})'

    def on_settings_change(self, attr, old, new):
        self.plot_stream.visible = 0 in new

    def start_game_process(self):
        logging.info('Lauching Cybathlon game')
        self.n_old_logs = len(self.available_logs)
        self.reset_plots()

        # Close any previous game process
        if self.game is not None:
            self.reset_game()

        self.game = subprocess.Popen(str(self.game_path),
                                     stdin=subprocess.PIPE,
                                     stdout=subprocess.PIPE,
                                     stderr=subprocess.PIPE,
                                     text=True)
        assert self.game is not None, 'Can\'t launch game !'

    def start_log_reader(self):
        # Check if log reader already instanciated
        if self.game_log_reader is not None:
            self.reset_log_reader()

        # Wait for new logfile to be created
        while not len(self.available_logs) - self.n_old_logs > 0:
            logging.info('Waiting for new race logs...')
            time.sleep(0.5)
        log_filename = str(self.available_logs[-1])

        # Log reader is started in a separate thread
        logging.info(f'Instanciate log reader {log_filename}')
        self.game_log_reader = GameLogReader(self, log_filename,
                                             self.player_idx)
        self.thread_log.start(self.game_log_reader)

    def on_launch_game_start(self):
        self.button_launch_game.label = 'Lauching...'
        self.button_launch_game.button_type = 'warning'
        self.parent.add_next_tick_callback(self.on_launch_game)

    def on_launch_game(self):
        self.start_game_process()
        self.start_log_reader()
        self.button_launch_game.label = 'Launched'
        self.button_launch_game.button_type = 'success'

    def update_groundtruth(self):
        action_idx, action_name = self.expected_action

        # Start autoplay predictor when game starts + reset chronogram (if multiple consecutive runs)
        if action_name == 'Game start':
            self.reset_plots()
            self.game_start_time = time.time()
            if self.modelfile == 'AUTOPLAY':
                self.parent.add_next_tick_callback(self.start_predictor_thread)
        elif action_name in ['Game end', 'Pause']:
            self.reset_predictor()
        elif action_name == 'Resume':
            self.parent.add_next_tick_callback(self.start_predictor_thread)
        elif action_name == 'Reset game':
            self.reset_plots()
            self.reset_predictor()
            self.reset_log_reader()
            self.parent.add_next_tick_callback(self.start_log_reader)

        # Send groundtruth to microcontroller
        if self.sending_events:
            if self.port_sender is not None:
                self.port_sender.sendCommand(action_idx)
                logging.info(f'Send event: {action_idx}')
            else:
                logging.info('Please select a port !')

    def update_prediction(self):
        if not self.game_is_on:
            logging.info('Game window was closed')
            self.button_launch_game.label = 'Launch Game'
            self.button_launch_game.button_type = 'primary'
            self.select_model.value = 'AUTOPLAY'
            self.reset_predictor()
            return

        groundtruth = self.expected_action[0]
        action_idx = self.pred_action[0]

        # Save groundtruth as event
        if self.lsl_recorder is not None:
            marker_id = int(f'{(groundtruth+1)*2}{(action_idx+1)*2}')
            self.lsl_recorder.save_event(self.last_ts, marker_id)

        # Update chronogram source
        ts = time.time() - self.game_start_time
        self.chrono_source.stream(dict(ts=[ts],
                                       y_true=[groundtruth],
                                       y_pred=[action_idx]))

        # Update information display
        self.gd_info.text = f'<b>Groundtruth:</b> {self.expected_action}'
        self.pred_info.text = f'<b>Prediction:</b> {self.pred_action}'
        self.acc_info.text = f'<b>Accuracy:</b> {self.accuracy:.2f}'

    def on_lsl_connect_toggle(self, active):
        if active:
            # Connect to LSL stream
            self.button_lsl.label = 'Seaching...'
            self.button_lsl.button_type = 'warning'
            self.parent.add_next_tick_callback(self.start_lsl_thread)
        else:
            self.reset_lsl()
            self.button_lsl.label = 'LSL Disconnected'
            self.button_lsl.button_type = 'danger'

    def start_lsl_thread(self):
        try:
            self.lsl_reader = LSLClient(self)
            if self.lsl_reader is not None:
                self.select_channel.options = [f'{i+1} - {ch}' for i, ch
                                               in enumerate(self.lsl_reader.ch_names)]
                self.thread_lsl.start(self.lsl_reader)
                self.button_lsl.label = 'Reading LSL stream'
                self.button_lsl.button_type = 'success'
        except Exception:
            logging.info(f'No LSL stream - {traceback.format_exc()}')
            self.button_lsl.label = 'Can\'t find stream'
            self.button_lsl.button_type = 'danger'
            self.reset_lsl()

    def start_predictor_thread(self):
        self.reset_predictor()

        try:
            self.predictor = ActionPredictor(self,
                                             self.modelfile,
                                             self.is_convnet)
            self.thread_pred.start(self.predictor)
        except Exception as e:
            logging.error(f'Failed loading model {self.modelfile} - {e}')
            self.select_model.value = 'AUTOPLAY'
            self.reset_predictor()

    def on_lsl_record_toggle(self, active):
        if active:
            try:
                self.lsl_recorder = LSLRecorder(self.record_path,
                                                self.record_name,
                                                self.lsl_reader.ch_names)
                self.lsl_recorder.open_h5()
                self.button_record.label = 'Stop recording'
                self.button_record.button_type = 'success'
            except Exception:
                self.reset_recorder()
                self.button_record.label = 'Recording failed'
                self.button_record.button_type = 'danger'
        else:
            self.reset_recorder()
            self.button_record.label = 'Start recording'
            self.button_record.button_type = 'primary'

    def update_signal(self):
        ts, eeg = self.lsl_data

        if ts.shape[0] != eeg.shape[-1]:
            logging.info('Skipping data points (bad format)')
            return

        # Convert timestamps in seconds
        if self.lsl_start_time is None:
            self.lsl_start_time = time.time()
            self.t0 = ts[0]

        # Update source display
        ch = self.channel_idx
        self.channel_source.stream(dict(ts=ts-self.t0, eeg=eeg[ch, :]),
                                   rollover=int(2 * self.fs))

        # Update signal
        chunk_size = eeg.shape[-1]
        self.input_signal = np.roll(self.input_signal, -chunk_size, axis=-1)
        self.input_signal[:, -chunk_size:] = eeg

    def create_widget(self):
        # Button - Launch Cybathlon game in new window
        self.button_launch_game = Button(label='Launch Game',
                                         button_type='primary')
        self.button_launch_game.on_click(self.on_launch_game_start)

        # Toggle - Connect to LSL stream
        self.button_lsl = Toggle(label='Connect to LSL')
        self.button_lsl.on_click(self.on_lsl_connect_toggle)

        # Toggle - Start/stop LSL stream recording
        self.button_record = Toggle(label='Start Recording',
                                    button_type='primary')
        self.button_record.on_click(self.on_lsl_record_toggle)

        # Select - Choose pre-trained model
        self.select_model = Select(title="Select pre-trained model",
                                   value='AUTOPLAY',
                                   options=self.available_models)
        self.select_model.on_change('value', self.on_model_change)

        # Select - Choose port to send events to
        self.select_port = Select(title='Select port')
        self.select_port.options = self.available_ports
        self.select_port.on_change('value', self.on_select_port)

        # Checkbox - Choose player settings
        self.div_settings = Div(text='<b>Settings</b>', align='center')
        self.checkbox_settings = CheckboxButtonGroup(labels=['Show signal',
                                                             'Send events'])
        self.checkbox_settings.on_change('active', self.on_settings_change)

        # Select - Channel to visualize
        self.select_channel = Select(title='Select channel', value='1 - Fp1')
        self.select_channel.on_change('value', self.on_channel_change)

        # Plot - LSL EEG Stream
        self.plot_stream = figure(title='Temporal EEG signal',
                                  x_axis_label='Time [s]',
                                  y_axis_label='Amplitude',
                                  plot_height=500,
                                  plot_width=800,
                                  visible=False)
        self.plot_stream.line(x='ts', y='eeg', source=self.channel_source)

        # Plot - Chronogram prediction vs results
        self.plot_chronogram = figure(title='Chronogram',
                                      x_axis_label='Time [s]',
                                      y_axis_label='Action',
                                      plot_height=300,
                                      plot_width=800)
        self.plot_chronogram.line(x='ts', y='y_true', color='blue',
                                  source=self.chrono_source,
                                  legend_label='Groundtruth')
        self.plot_chronogram.cross(x='ts', y='y_pred', color='red',
                                   source=self.chrono_source,
                                   legend_label='Prediction')
        self.plot_chronogram.legend.background_fill_alpha = 0.6
        self.plot_chronogram.yaxis.ticker = list(self.pred_decoding.keys())
        self.plot_chronogram.yaxis.major_label_overrides = self.pred_decoding

        # Div - Display useful information
        self.model_info = Div(text=f'<b>Model:</b> AUTOPLAY')
        self.pred_info = Div()
        self.gd_info = Div()
        self.acc_info = Div()

        # Create layout
        column1 = column(self.button_launch_game, self.button_lsl,
                         self.button_record, self.select_model,
                         self.select_port, self.select_channel,
                         self.div_settings, self.checkbox_settings)
        column2 = column(self.plot_stream, self.plot_chronogram)
        column3 = column(self.model_info, self.gd_info,
                         self.pred_info, self.acc_info)
        return row(column1, column2, column3)
コード例 #5
0
button = Button(label='►', width=30)
button.on_click(animate)
##############################################################################

############################### INPUT ###################################
select = Select(title="Measure",
                options=list(df.columns.values)[2:10],
                value="INFECTED_NOSYMPTOMS_NOTCONTAGIOUS")
select.on_change('value', update_plot)

slider = Slider(title='Period', start=0, end=max_time, step=1, value=0)
slider.on_change('value', update_plot)

checkbox_button_group = CheckboxButtonGroup(labels=list(AGEGROUPS), active=[0])
checkbox_button_group.on_change('active', update_plot)
##############################################################################

#hover = HoverTool(tooltips = [ ('COROP', '@{OBJECTID}'),('Infected', '@Infected_plus'), ('Age group', '@AGEGROUP')])

#Create color bar.
color_bar = ColorBar(
    color_mapper=color_mapper,
    label_standoff=8,
    width=450,
    height=20,
    border_line_color=None,
    location=(0, 0),
    orientation='horizontal')  #, major_label_overrides = tick_labels)

hover = HoverTool(tooltips=[(
コード例 #6
0
ファイル: main.py プロジェクト: lagoi/covidarg
def tabMapWithSelectAndUpdate(arg: pd.DataFrame):
    arg['Confirmados2'] = arg['Confirmados'] / 1000
    arg['Recuperados2'] = arg['Recuperados'] / 1000
    arg['Fallecidos2'] = arg['Fallecidos'] / 1000
    arg['Activos2'] = arg['Activos'] / 1000
    arg['Mayores_de_652'] = arg['Mayores_de_65'] / 1000000

    merged_json = json.loads(arg.to_json())
    json_data = json.dumps(merged_json)

    # Input geojson source that contains features for plotting:
    geosource = GeoJSONDataSource(geojson=json_data)

    # Make a selection object: select
    select1 = Select(title='Dato en Color:',
                     value='Casos Confirmados',
                     options=[
                         'Casos Confirmados', 'Recuperados', 'Fallecidos',
                         'Activos', 'Mayores de 65'
                     ])

    select2 = CheckboxButtonGroup(active=[1],
                                  labels=[
                                      'Casos Confirmados', 'Recuperados',
                                      'Fallecidos', 'Activos', 'Mayores de 65'
                                  ])

    def update_plot(attr, old, new):
        # The input cr is the criteria selected from the select box
        cr1 = select1.value
        cr2 = select2.active
        fields = {
            'Casos Confirmados': 'Confirmados',
            'Recuperados': 'Recuperados',
            'Fallecidos': 'Fallecidos',
            'Activos': 'Activos',
            'Mayores de 65': 'Mayores_de_65'
        }
        fields2 = [
            'Confirmados2', 'Recuperados2', 'Fallecidos2', 'Activos2',
            'Mayores_de_652'
        ]
        input_field1 = fields[cr1]
        input_fields2 = [fields2[index] for index in cr2]

        map_arg = make_map(input_field1, input_fields2)
        l.children.pop()
        l.children.append(map_arg)

    # Create a plotting function
    def make_map(field_name1, field_names2):
        # Set the format of the colorbar
        min_range = 0
        max_range = max(arg[field_name1])
        field_format = "0"

        if field_name1 == "Confirmados":
            # Define a sequential multi-hue color palette.
            palette = brewer['Blues'][8]
            # Reverse color order so that dark blue is highest.
            palette = palette[::-1]
        elif field_name1 == "Recuperados":
            palette = brewer['Greens'][8]
            palette = palette[::-1]
        elif field_name1 == "Fallecidos":
            palette = brewer['Reds'][8]
            palette = palette[::-1]
        else:
            palette = brewer['Blues'][8]
            palette = palette[::-1]

        # Instantiate LinearColorMapper that linearly maps numbers in a range, into a sequence of colors.
        color_mapper = LinearColorMapper(palette=palette,
                                         low=min_range,
                                         high=max_range)

        # Create color bar.
        format_tick = NumeralTickFormatter(format=field_format)
        color_bar = ColorBar(color_mapper=color_mapper,
                             label_standoff=18,
                             formatter=format_tick,
                             border_line_color=None,
                             location=(0, 0))

        # Create figure object.

        map_arg = figure(
            title=field_name1,
            plot_height=900,
            plot_width=700,
        )
        map_arg.xgrid.grid_line_color = None
        map_arg.ygrid.grid_line_color = None
        map_arg.axis.visible = False

        # Add patch renderer to figure.
        patches = map_arg.patches('xs',
                                  'ys',
                                  source=geosource,
                                  fill_color={
                                      'field': field_name1,
                                      'transform': color_mapper
                                  },
                                  line_color='black',
                                  line_width=0.25,
                                  fill_alpha=1)

        #Add circles renderer to figure.
        for field_name2 in field_names2:
            if field_name2 == "Confirmados2":
                color = '#0000ff'
            elif field_name2 == "Recuperados2":
                color = '#00ff00'
            elif field_name2 == "Fallecidos2":
                color = '#ff0000'
            elif field_name2 == "Activos2":
                color = '#ffff00'
            else:
                color = '#000000'
            map_arg.circle("lon",
                           "lat",
                           source=geosource,
                           fill_alpha=0.5,
                           fill_color=color,
                           line_color="#FFFFFF",
                           line_width=2,
                           line_alpha=0.5,
                           radius=field_name2)

        # Specify color bar layout.
        map_arg.add_layout(color_bar, 'right')

        # Add hover tool
        hover_map = HoverTool(tooltips="""
                                    <div class="plot-tooltip">
                                        <div>
                                            <h5>@nam</h5>
                                        </div>
                                        <div>
                                            <span style="font-weight: bold;">Casos Confirmados: </span>@Confirmados
                                        </div>
                                        <div>
                                            <span style="font-weight: bold;">Recuperados: </span>@Recuperados
                                        </div>
                                        <div>
                                            <span style="font-weight: bold;">Fallecidos: </span>@Fallecidos
                                        </div>
                                        <div>
                                            <span style="font-weight: bold;">Activos: </span>@Activos
                                        </div>
                                        <div>
                                            <span style="font-weight: bold;">Habitantes Mayores de 65: </span>@Mayores_de_65
                                        </div>
                                    </div>
                                    """,
                              renderers=[patches])
        # Add the hover tool to the graph
        map_arg.add_tools(hover_map)
        return map_arg

    # Attach function to select
    select1.on_change('value', update_plot)
    select2.on_change('active', update_plot)

    # Call the plotting function
    map_arg = make_map('Confirmados', ['Recuperados2'])

    paragraph1 = Paragraph(text="""Datos en circulos:""", width=600)

    l = layout([[select1], [paragraph1], [select2], [map_arg]])
    tab = Panel(child=l, title="Mapa Actual", name='mapa_actual')
    return tab
コード例 #7
0
class WarmUpWidget:
    def __init__(self, parent=None):
        self.parent = parent
        self.fs = main_config['fs']
        self.n_channels = main_config['n_channels']
        self.t0 = 0
        self.last_ts = 0
        self.game_is_on = False

        # Chronogram
        self.chrono_source = ColumnDataSource(dict(ts=[], y_pred=[]))
        self.pred_decoding = main_config['pred_decoding']

        # LSL stream reader
        self.lsl_reader = None
        self.lsl_start_time = None
        self._lsl_data = (None, None)
        self.thread_lsl = QtCore.QThreadPool()
        self.channel_source = ColumnDataSource(dict(ts=[], eeg=[]))
        self.buffer_size_s = 10

        # LSL stream recorder
        if not os.path.isdir(main_config['record_path']):
            os.mkdir(main_config['record_path'])
        self.record_path = main_config['record_path']
        self.record_name = warmup_config['record_name']
        self.lsl_recorder = None

        # Predictor
        self.models_path = main_config['models_path']
        self.input_signal = np.zeros((self.n_channels, 4 * self.fs))
        self.predictor = None
        self.thread_pred = QtCore.QThreadPool()
        self._pred_action = (0, 'Rest')

        # Feedback images
        self.static_folder = warmup_config['static_folder']
        self.action2image = warmup_config['action2image']

    @property
    def pred_action(self):
        return self._pred_action

    @pred_action.setter
    def pred_action(self, val_tuple):
        self._pred_action = val_tuple
        self.parent.add_next_tick_callback(self.update_prediction)

    @property
    def lsl_data(self):
        return self._lsl_data

    @lsl_data.setter
    def lsl_data(self, data):
        self._lsl_data = data
        self.parent.add_next_tick_callback(self.update_signal)

    @property
    def available_models(self):
        ml_models = [p.name for p in self.models_path.glob('*.pkl')]
        dl_models = [p.name for p in self.models_path.glob('*.h5')]
        return [''] + ml_models + dl_models

    @property
    def selected_settings(self):
        active = self.checkbox_settings.active
        return [self.checkbox_settings.labels[i] for i in active]

    @property
    def modelfile(self):
        return self.models_path / self.select_model.value

    @property
    def model_name(self):
        return self.select_model.value

    @property
    def is_convnet(self):
        return self.select_model.value.split('.')[-1] == 'h5'

    @property
    def channel_idx(self):
        return int(self.select_channel.value.split('-')[0])

    def reset_lsl(self):
        if self.lsl_reader:
            self.lsl_reader.should_stream = False
            self.lsl_reader = None
            self.lsl_start_time = None
            self.thread_lsl.clear()

    def reset_predictor(self):
        self.model_info.text = f'<b>Model:</b> None'
        self.pred_info.text = f'<b>Prediction:</b> None'
        self.image.text = ''
        if self.predictor:
            self.predictor.should_predict = False
            self.predictor = None
            self.thread_pred.clear()

    def reset_recorder(self):
        if self.lsl_recorder:
            self.lsl_recorder.close_h5()
            self.lsl_recorder = None

    def on_settings_change(self, attr, old, new):
        self.plot_stream.visible = 0 in new

    def on_model_change(self, attr, old, new):
        logging.info(f'Select new pre-trained model {new}')
        self.select_model.options = self.available_models

        # Delete existing predictor thread
        if self.predictor is not None:
            self.reset_predictor()
            if new == '':
                return

        try:
            self.predictor = ActionPredictor(self, self.modelfile,
                                             self.is_convnet)
            self.thread_pred.start(self.predictor)
            self.model_info.text = f'<b>Model:</b> {new}'
        except Exception as e:
            logging.error(f'Failed loading model {self.modelfile} - {e}')
            self.reset_predictor()

    def on_channel_change(self, attr, old, new):
        logging.info(f'Select new channel {new}')
        self.channel_source.data = dict(ts=[], eeg=[])
        self.plot_stream.yaxis.axis_label = f'Amplitude ({new})'

    def reset_plots(self):
        self.chrono_source.data = dict(ts=[], y_pred=[])
        self.channel_source.data = dict(ts=[], eeg=[])

    def update_prediction(self):
        # Update chronogram source
        action_idx = self.pred_action[0]
        if self.lsl_start_time is not None:
            ts = time.time() - self.lsl_start_time
            self.chrono_source.stream(dict(ts=[ts], y_pred=[action_idx]))

        # Update information display (might cause delay)
        self.pred_info.text = f'<b>Prediction:</b> {self.pred_action}'
        src = self.static_folder / \
            self.action2image[self.pred_decoding[action_idx]]
        self.image.text = f"<img src={src} width='200' height='200' text-align='center'>"

        # Save prediction as event
        if self.lsl_recorder is not None:
            self.lsl_recorder.save_event(copy.deepcopy(self.last_ts),
                                         copy.deepcopy(action_idx))

    def on_lsl_connect_toggle(self, active):
        if active:
            # Connect to LSL stream
            self.button_lsl.label = 'Seaching...'
            self.button_lsl.button_type = 'warning'
            self.reset_plots()
            self.parent.add_next_tick_callback(self.start_lsl_thread)
        else:
            self.reset_lsl()
            self.reset_predictor()
            self.button_lsl.label = 'LSL Disconnected'
            self.button_lsl.button_type = 'danger'

    def start_lsl_thread(self):
        try:
            self.lsl_reader = LSLClient(self)
            self.fs = self.lsl_reader.fs

            if self.lsl_reader is not None:
                self.select_channel.options = [
                    f'{i+1} - {ch}'
                    for i, ch in enumerate(self.lsl_reader.ch_names)
                ]
                self.thread_lsl.start(self.lsl_reader)
                self.button_lsl.label = 'Reading LSL stream'
                self.button_lsl.button_type = 'success'
        except Exception:
            logging.info(f'No LSL stream - {traceback.format_exc()}')
            self.button_lsl.label = 'Can\'t find stream'
            self.button_lsl.button_type = 'danger'
            self.reset_lsl()

    def on_lsl_record_toggle(self, active):
        if active:
            try:
                self.lsl_recorder = LSLRecorder(self.record_path,
                                                self.record_name,
                                                self.lsl_reader.ch_names)
                self.lsl_recorder.open_h5()
                self.button_record.label = 'Stop recording'
                self.button_record.button_type = 'success'
            except Exception as e:
                logging.info(f'Failed creating LSLRecorder - {e}')
                self.reset_recorder()
                self.button_record.label = 'Recording failed'
                self.button_record.button_type = 'danger'
        else:
            self.reset_recorder()
            self.button_record.label = 'Start recording'
            self.button_record.button_type = 'primary'

    def update_signal(self):
        ts, eeg = self.lsl_data
        self.last_ts = ts[-1]

        if ts.shape[0] != eeg.shape[-1]:
            logging.info('Skipping data points (bad format)')
            return

        # Local LSL start time
        if self.lsl_start_time is None:
            self.lsl_start_time = time.time()
            self.t0 = ts[0]

        # Update source display
        ch = self.channel_idx
        self.channel_source.stream(dict(ts=(ts - self.t0) / self.fs,
                                        eeg=eeg[ch, :]),
                                   rollover=int(self.buffer_size_s * self.fs))

        # Update signal
        chunk_size = eeg.shape[-1]
        self.input_signal = np.roll(self.input_signal, -chunk_size, axis=-1)
        self.input_signal[:, -chunk_size:] = eeg

        # Record signal
        if self.lsl_recorder is not None:
            self.lsl_recorder.save_data(copy.deepcopy(ts), copy.deepcopy(eeg))

    def create_widget(self):
        # Toggle - Connect to LSL stream
        self.button_lsl = Toggle(label='Connect to LSL')
        self.button_lsl.on_click(self.on_lsl_connect_toggle)

        # Toggle - Start/stop LSL stream recording
        self.button_record = Toggle(label='Start Recording',
                                    button_type='primary')
        self.button_record.on_click(self.on_lsl_record_toggle)

        # Select - Choose pre-trained model
        self.select_model = Select(title="Select pre-trained model")
        self.select_model.options = self.available_models
        self.select_model.on_change('value', self.on_model_change)

        # Checkbox - Choose settings
        self.div_settings = Div(text='<b>Settings</b>', align='center')
        self.checkbox_settings = CheckboxButtonGroup(labels=['Show signal'])
        self.checkbox_settings.on_change('active', self.on_settings_change)

        # Select - Channel to visualize
        self.select_channel = Select(title='Select channel', value='1 - Fp1')
        self.select_channel.on_change('value', self.on_channel_change)

        # Plot - LSL EEG Stream
        self.plot_stream = figure(title='Temporal EEG signal',
                                  x_axis_label='Time [s]',
                                  y_axis_label='Amplitude',
                                  plot_height=500,
                                  plot_width=800,
                                  visible=False)
        self.plot_stream.line(x='ts', y='eeg', source=self.channel_source)

        # Plot - Chronogram prediction vs results
        self.plot_chronogram = figure(title='Chronogram',
                                      x_axis_label='Time [s]',
                                      y_axis_label='Action',
                                      plot_height=300,
                                      plot_width=800)
        self.plot_chronogram.cross(x='ts',
                                   y='y_pred',
                                   color='red',
                                   source=self.chrono_source,
                                   legend_label='Prediction')
        self.plot_chronogram.legend.background_fill_alpha = 0.6
        self.plot_chronogram.yaxis.ticker = list(self.pred_decoding.keys())
        self.plot_chronogram.yaxis.major_label_overrides = self.pred_decoding

        # Div - Display useful information
        self.model_info = Div(text=f'<b>Model:</b> None')
        self.pred_info = Div(text=f'<b>Prediction:</b> None')
        self.image = Div()

        # Create layout
        column1 = column(self.button_lsl, self.button_record,
                         self.select_model, self.select_channel,
                         self.div_settings, self.checkbox_settings)
        column2 = column(self.plot_stream, self.plot_chronogram)
        column3 = column(self.model_info, self.pred_info, self.image)
        return row(column1, column2, column3)
コード例 #8
0
def stuff_2(in_active):
	if in_active:
		print 'banana'
	else:
		print 'apple'

# put the functions in a list
stuff_list = [stuff_0,stuff_1,stuff_2]

# on_change callback for the CheckboxButtonGroup, occurs whenever a button is clicked
def do_stuff(attr,old,new):
	print attr,old,new
	'''
	# this loops over all buttons and do stuff with all of them whenever one button is clicked
	for i in [0,1,2]:
		stuff = stuff_list[i]
		in_active = i in new
		stuff(in_active)
	'''

	# this gets the ID of the last clicked button from the change in the 'active' list of the CheckboxButtonGroup
	last_clicked_ID = list(set(old)^set(new))[0] # [0] since there will always be just one different element at a time
	print 'Last button clicked:', group.labels[last_clicked_ID]
	last_clicked_button_stuff = stuff_list[last_clicked_ID]
	in_active = last_clicked_ID in new
	last_clicked_button_stuff(in_active)

# assign the callback to the CheckboxButtonGroup
group.on_change('active',do_stuff)

curdoc().add_root(group)
コード例 #9
0
ファイル: LampBokeh.py プロジェクト: dxwayne/FlexSpec1
class BokehFlexLamp(object):
    """ A small class to blink the led, with varying rate
    """

    # FAKE up some enums.
    ON = 0  # BokehFlexLamp.ON
    OFF = 1  # BokehFlexLamp.OFF
    RUN = 2  # BokehFlexLamp.RUN
    SateText = ["Off", "On", "Illegal"]
    brre = re.compile(r'\n')  # used to convert newline to <br/>
    postmessage = {
        "name": "Unassigned",
        "near": None,
        "osram": None,
        "halpha": None,
        "oiii": None,
        "flat": None,
        "augflat": None
    }

    #__slots__ = [''] # add legal instance variables
    # (setq properties `("" ""))
    def __init__(self,
                 name: str = "Default",
                 display=fakedisplay,
                 width=250,
                 pin=4):  # BokehFlexLamp::__init__()
        """Initialize this class."""
        #super().__init__()
        # (wg-python-property-variables)
        self.wwidth = width
        self.display = display
        self.name = name
        self.display = display
        self.wheat_value = 1  # add a variable for each lamp
        self.osram_value = 0  # installed.
        self.halpha_value = 0
        self.oiii_value = 0
        self.flat_value = 0
        self.augflat_value = 0
        self.near_value = 0

        # // coordinate with lampcheckboxes_handler
        self.CBLabels = [
            "Wheat", "Osram", "H-alpha", "O[iii]", "Flat", "Blue Flat", "NeAr"
        ]

        self.LampCheckBoxes = CheckboxButtonGroup(
            labels=self.CBLabels,
            active=[0] * len(self.CBLabels))  # create/init them
        self.process = Button(align='end',
                              label=f"{self.name} On",
                              disabled=False,
                              button_type="success",
                              width=self.wwidth)
        self.offbutton = Button(align='end',
                                label=f"{self.name} Off",
                                disabled=False,
                                button_type="primary",
                                width=self.wwidth)

        self.LampCheckBoxes.on_change(
            'active',
            lambda attr, old, new: self.lampcheckboxes_handler(attr, old, new))
        self.process.on_click(lambda: self.update_process())
        self.offbutton.on_click(lambda: self.update_offbutton())

    ### BokehFlexLamp.__init__()

    def update_offbutton(self):  # BokehFlexLamp::update_offbutton()
        """Set internal variables to off."""
        msg = self.send_off()

    ### BokehFlexLamp.update_offbutton()

    def update_process(self):  # BokehFlexLamp::update_button_in()
        """update_process Button via an event lambda"""
        #os = io.StringIO()
        #self.debug(f"{self.name} Debug",skip=['varmap'], os=os)
        #os.seek(0)
        msg = self.send_state()

    ### BokehFlexLamp.update_process()

    def lampcheckboxes_handler(self, attr, old,
                               new):  # BokehFlexLamp::lampcheckboxes_handler()
        """Handle the checkboxes, new is a list of indices into
        self.CBLabels for their purpose"""
        msg = f"attr {attr}, old {old}, new {new}"
        self.wheat_value = 1 if 0 in new else 0
        self.osram_value = 1 if 1 in new else 0
        self.halpha_value = 1 if 2 in new else 0
        self.oiii_value = 1 if 3 in new else 0
        self.flat_value = 1 if 4 in new else 0
        self.augflat_value = 1 if 5 in new else 0
        self.near_value = 1 if 6 in new else 0
        #self.display(msg)

    ### BokehFlexLamp.lampcheckboxes_handler()

    def update_debugbtn(self):  # BokehFlexLamp::update_button_in()
        """update_debugbtn Button via an event lambda"""
        os = io.StringIO()
        self.debug(f"{self.name} Debug", os=os)
        os.seek(0)
        self.display.display(BokehFlexLamp.brre.sub("<br/>", os.read()))

    ### BokehFlexLamp.update_edebugbtn()

    def send_state(self):  # BokehFlexLamp::send_state()
        """Several ways to send things"""
        cmddict = dict([("wheat", self.wheat_value),
                        ("osram", self.osram_value),
                        ("halpha", self.halpha_value),
                        ("oiii", self.oiii_value), ("flat", self.flat_value),
                        ("augflat", self.augflat_value),
                        ("near", self.near_value)])
        d2 = dict([(f"{self.name}", dict([("Process", cmddict)]))])
        jdict = json.dumps(d2)
        self.display.display(
            f'{{ "{self.name}" : {jdict} , "returnreceipt" : 1 }}')

    ### BokehFlexLamp.send_state()

    def send_off(self):  # BokehFlexLamp::send_off()
        """Don't change the internal variables, fake a message to make
        the lamps off."""
        cmddict = dict([("wheat", 0), ("osram", 0), ("halpha", 0), ("oiii", 0),
                        ("flat", 0), ("augflat", 0), ("near", 0)])
        d2 = dict([("Kzin", dict([("Process", cmddict)]))])
        jdict = json.dumps(d2)
        self.display.display(
            f'{{ "{self.name}" : {jdict} , "returnreceipt" : 1 }}')
        return jdict

    ### BokehFlexLamp.send_off(()

    def layout(self):  # BokehFlexLamp::layout()
        """Get the layout in gear"""
        return (row(
            column(self.LampCheckBoxes, row(self.process, self.offbutton))))
        return self

    ### BokehFlexLamp.layout()

    def debug(self, msg="", skip=[], os=sys.stderr):  # BokehFlexLamp::debug()
        """Help with momentary debugging, file to fit.
           msg  -- special tag for this call
           skip -- the member variables to ignore
           os   -- output stream: may be IOStream etc.
        """
        import pprint
        print("BokehFlexLamp - %s " % msg, file=os)
        for key, value in self.__dict__.items():
            if (key in skip):
                continue
            print(f'{key:20s} =', file=os, end='')
            pprint.pprint(value, stream=os, indent=4)
        return self

    ### BokehFlexLamp.debug()

    __BokehFlexLamp_debug = debug  # really preserve our debug name if we're inherited
コード例 #10
0
                    })

checkbox_button = CheckboxButtonGroup(
    labels=['Show x-axis label', 'Show y-axis label'])


def checkbox_button_click(attr, old, new):
    active_checkbox = checkbox_button.active  ##Getting checkbox value in list

    ## Get first checkbox value and show x-axis label

    if len(active_checkbox) != 0 and (0 in active_checkbox):
        plot_figure.xaxis.axis_label = 'X-Axis'
    else:
        plot_figure.xaxis.axis_label = None

    ## Get second checkbox value and show y-axis label

    if len(active_checkbox) != 0 and (1 in active_checkbox):
        plot_figure.yaxis.axis_label = 'Y-Axis'
    else:
        plot_figure.yaxis.axis_label = None


checkbox_button.on_change('active', checkbox_button_click)

layout = row(checkbox_button, plot_figure)

curdoc().add_root(layout)
curdoc().title = "Checkbox Button Bokeh Server"
コード例 #11
0
ファイル: main.py プロジェクト: alex-marty/demo-eau
    }
    title = "Consommation {}, {}".format(sector_id, conso_date)
    print(title)
    conso_plot.update_title(title)
    conso_plot.set_active(mode)

def update(attrname, old, new):
    print("update {} {} {}".format(attrname, old, new))
    sector_id = int(sector_select.value)
    conso_date = datetime(2016, 3, int(date_slider.value),
                          int(time_slider.value))
    update_sources(sector_id, conso_date, type_button.active)
sector_select.on_change("value", update)
date_slider.on_change("value", update)
time_slider.on_change("value", update)
type_button.on_change("active", update)

def on_selection_change(attr, old, new):
    print("on_selection_change")
    indices = new["1d"]["indices"]
    if indices:
        sector_id = geo_sectors.loc[indices[0]]["sector_id"]
        sector_select.value = str(sector_id) # Calls update in chain
        #update_sources(sector_id=sector_id)
map_plot.get_data_source().on_change("selected", on_selection_change)

def close_session():
    session.close()
button.on_click(close_session)

コード例 #12
0
hist_r1r2, hist_r1r2_src = plot_histogram('R1R2_sum')
scat_r1r2, scat_r1r2_src = plot_scatter('R1R2_sum')
# second column, fitting residuals, seperated for arc cal and combined for grid
hm_res, hm_res_src = plot_heatmap('residuals')
hist_res, hist_res_src = plot_histogram('residuals')
scat_res, scat_res_src = plot_scatter('residuals')
# third column, gear ratio theta
hm_grt, hm_grt_src = plot_heatmap('GEAR_CALIB_T')
hist_grt, hist_grt_src = plot_histogram('GEAR_CALIB_T')
scat_grt, scat_grt_src = plot_scatter('GEAR_CALIB_T')
# fouth column, gear ratio phi
hm_grp, hm_grp_src = plot_heatmap('GEAR_CALIB_P')
hist_grp, hist_grp_src = plot_histogram('GEAR_CALIB_P')
scat_grp, scat_grp_src = plot_scatter('GEAR_CALIB_P')
# construct webpage layout
col_r1r2 = column([hm_r1r2, hist_r1r2, scat_r1r2])
col_res = column([hm_res, hist_res, scat_res])
col_grt = column([hm_grt, hist_grt, scat_grt])
col_grp = column([hm_grp, hist_grp, scat_grp])
layout = layout([[title],
                 [table],
                 [plot_bt, ptls_bt_group],
                 [col_r1r2, col_res, col_grt, col_grp]])
# add callbacks and layout
source.on_change('data', on_change_source_data)
source.selected.on_change('indices', change_selected_calibration)
plot_bt.on_click(update_plots)
ptls_bt_group.on_change('active', change_ptls)
curdoc().title = 'DESI Positioner Calibration Manager'
curdoc().add_root(layout)
コード例 #13
0
ファイル: streamer_v2.py プロジェクト: behinger/bandotuner
                             start=430,
                             end=450,
                             step=1,
                             title="Concert Pitch")
concertpitch_slider.on_change('value', bt.changeConcertPitch)

fundamental_xrange = Slider(value=bt.xrange,
                            start=5,
                            end=100,
                            step=1,
                            title="X-range")
fundamental_xrange.on_change('value', bt.change_xrange)

checkbox_button_group = CheckboxButtonGroup(
    labels=["Auto-Save to Table", "FreezeInset"], active=[])
checkbox_button_group.on_change('active', bt.startstop_autoupdate)

# This button allows to lock in the current data
detect_button = Button(label='Manual Detect and add to Table')
detect_button.on_change('clicks', bt.detect_base_freq_bokeh)

save_button = Button(label='Save Current Table')
#save_button.on_change('clicks',bt.save_table)
save_button.callback = CustomJS(args=dict(source=bt.sources['savednotes']),
                                code=open(
                                    join(dirname(__file__),
                                         "download.js")).read())

#%%
# make the grid & add the plots
widgets = widgetbox(length_slider,
コード例 #14
0
ファイル: config.py プロジェクト: me020523/btplotting
    def _create_plotgroup_config(self):
        self.plotgroup = []
        self.plotgroup_chk = defaultdict(list)
        self.plotgroup_objs = defaultdict(list)
        self.plotgroup_text = None

        def active_obj(obj, selected):
            if not len(selected) or obj.plotinfo.plotid in selected:
                return True
            return False

        title = Paragraph(
            text='Plot Group',
            css_classes=['config-title'])
        options = []

        # get client plot group selection
        if self._client.plotgroup != '':
            selected_plot_objs = self._client.plotgroup.split(',')
        else:
            selected_plot_objs = []

        # get all plot objects
        self.plotgroup_objs = get_plotobjs(
            self._figurepage.strategy,
            order_by_plotmaster=False)

        # create plotgroup checkbox buttons
        for d in self.plotgroup_objs:
            # generate master chk
            master_chk = None
            if not isinstance(d, bt.Strategy):
                active = []
                if active_obj(d, selected_plot_objs):
                    active.append(0)
                    self._add_to_plotgroup(d)
                master_chk = CheckboxButtonGroup(
                    labels=[obj2label(d)], active=active)

            # generate childs chk
            childs_chk = []
            objsd = self.plotgroup_objs[d]
            # sort child objs by type
            objsd.sort(key=lambda x: (FigureType.get_type(x).value))
            # split objs into chunks and store chk
            objsd = [objsd[i:i + 3] for i in range(0, len(objsd), 3)]
            for x in objsd:
                childs = []
                active = []
                for i, o in enumerate(x):
                    childs.append(obj2label(o))
                    if active_obj(o, selected_plot_objs):
                        active.append(i)
                        self._add_to_plotgroup(o)
                # create a chk for every chunk
                if len(childs):
                    chk = CheckboxButtonGroup(
                        labels=childs, active=active)
                    chk.on_change(
                        'active',
                        partial(
                            self._on_update_plotgroups,
                            chk=chk,
                            master=d,
                            childs=x))
                    # if master is not active, disable childs
                    if master_chk and not len(master_chk.active):
                        chk.disabled = True
                    childs_chk.append(chk)
                self.plotgroup_chk[d].append(x)

            # append title for master (this will also include strategy)
            if len(self.plotgroup_objs[d]):
                options.append(Paragraph(text=f'{obj2label(d)}:'))
            # append master_chk and childs_chk to layout
            if master_chk:
                master_chk.on_change(
                    'active',
                    partial(
                        self._on_update_plotgroups,
                        # provide all related chk to master
                        chk=[master_chk] + childs_chk,
                        master=d))
                options.append(master_chk)
            for c in childs_chk:
                options.append(c)

        # text input to display selection
        self.plotgroup_text = TextInput(
            value=','.join(self.plotgroup),
            disabled=True)
        options.append(Paragraph(text='Plot Group Selection:'))
        options.append(self.plotgroup_text)

        return column([title] + options)
コード例 #15
0
class TrainerWidget:
    def __init__(self):
        self.data_path = main_config['data_path']
        self.save_path = main_config['models_path']
        self.active_preproc_ordered = []

    @property
    def available_pilots(self):
        pilots = self.data_path.glob('*')
        return [''] + [p.parts[-1] for p in pilots]

    @property
    def selected_pilot(self):
        return self.select_pilot.value

    @property
    def available_sessions(self):
        pilot_path = self.data_path / self.selected_pilot
        sessions = pilot_path.glob('*')
        return [s.name for s in sessions]

    @property
    def selected_preproc(self):
        active = self.active_preproc_ordered
        return [self.checkbox_preproc.labels[i] for i in active]

    @property
    def train_ids(self):
        return self.select_session.value

    @property
    def preproc_config(self):
        config_cn = dict(sigma=6)
        config_bpf = dict(fs=self.fs,
                          f_order=train_config['f_order'],
                          f_type='butter',
                          f_low=train_config['f_low'],
                          f_high=train_config['f_high'])
        config_crop = dict(fs=self.fs,
                           n_crops=train_config['n_crops'],
                           crop_len=train_config['crop_len'])
        return {'CN': config_cn, 'BPF': config_bpf, 'Crop': config_crop}

    @property
    def should_crop(self):
        return 'Crop' in self.selected_preproc

    @property
    def selected_folders(self):
        active = self.checkbox_folder.active
        return [self.checkbox_folder.labels[i] for i in active]

    @property
    def selected_settings(self):
        active = self.checkbox_settings.active
        return [self.checkbox_settings.labels[i] for i in active]

    @property
    def model_name(self):
        return self.select_model.value

    @property
    def model_config(self):
        config = {'model_name': self.model_name, 'C': 10}
        return config

    @property
    def is_convnet(self):
        return self.model_name == 'ConvNet'

    @property
    def train_mode(self):
        return 'optimize' if 'Optimize' in self.selected_settings \
            else 'validate'

    @property
    def folder_ids(self):
        ids = []
        if 'New Calib' in self.selected_folders:
            ids.append('formatted_raw_500Hz')
        if 'Game' in self.selected_folders:
            ids.append('formatted_raw_500Hz_game')
        return ids

    @property
    def start(self):
        return self.slider_roi_start.value

    @property
    def end(self):
        return self.slider_roi_end.value

    @property
    def n_iters(self):
        return self.slider_n_iters.value

    def on_pilot_change(self, attr, old, new):
        logging.info(f'Select pilot {new}')
        self.select_session.value = ['']
        self.update_widget()

    def on_session_change(self, attr, old, new):
        logging.info(f"Select train sessions {new}")
        self.update_widget()

    def on_model_change(self, attr, old, new):
        logging.info(f'Select model {new}')
        self.update_widget()

    def on_preproc_change(self, attr, old, new):
        # Case 1: Add preproc
        if len(new) > len(old):
            to_add = list(set(new) - set(old))[0]
            self.active_preproc_ordered.append(to_add)
        # Case 2: Remove preproc
        else:
            to_remove = list(set(old) - set(new))[0]
            self.active_preproc_ordered.remove(to_remove)

        logging.info(f'Preprocessing selected: {self.selected_preproc}')
        self.update_widget()

    def update_widget(self):
        self.select_pilot.options = self.available_pilots
        self.select_session.options = self.available_sessions
        self.button_train.button_type = 'primary'
        self.button_train.label = 'Train'
        self.div_info.text = f'<b>Preprocessing selected:</b> {self.selected_preproc} <br>'

    def on_train_start(self):
        assert self.model_name != '', 'Please select a model !'
        assert len(self.train_ids) > 0, 'Please select at least one session !'

        self.button_train.button_type = 'warning'
        self.button_train.label = 'Loading data...'
        curdoc().add_next_tick_callback(self.on_load)

    def on_load(self):
        X, y = {}, {}
        for id in self.train_ids:
            for folder in self.folder_ids:
                logging.info(f'Loading {id} - {folder}')
                try:
                    session_path = self.data_path / self.selected_pilot /\
                        id / folder
                    filepath = session_path / 'train/train1.npz'
                    X_id, y_id, fs, ch_names = load_session(
                        filepath, self.start, self.end)
                    X[f'{id}_{folder}'] = X_id
                    y[f'{id}_{folder}'] = y_id
                    self.fs = fs
                    self.ch_names = ch_names

                except Exception as e:
                    logging.info(f'Loading data failed - {e}')
                    self.button_train.button_type = 'danger'
                    self.button_train.label = 'Training failed'
                    return

        # Concatenate all data
        self.X = np.vstack([X[id] for id in X.keys()])
        self.y = np.hstack([y[id] for id in y.keys()]).flatten()

        # Cropping FIXME: Integrate inside preproc to avoid data leakage
        if self.should_crop:
            self.X, self.y = cropping(self.X, self.y,
                                      **self.preproc_config['Crop'])

        if self.is_convnet:
            assert self.should_crop, 'ConvNet requires cropping !'
            self.X = self.X[:, :, :, np.newaxis]

        # Update session info
        self.div_info.text = f'<b>Sampling frequency:</b> {self.fs} Hz<br>' \
            f'<b>Classes:</b> {np.unique(self.y)} <br>' \
            f'<b>Nb trials:</b> {len(self.y)} <br>' \
            f'<b>Nb channels:</b> {self.X.shape[1]} <br>' \
            f'<b>Trial length:</b> {self.X.shape[-1] / self.fs}s <br>'

        self.button_train.label = 'Training...'
        curdoc().add_next_tick_callback(self.on_train)

    def on_train(self):
        pipeline, search_space = get_pipeline(self.selected_preproc,
                                              self.preproc_config,
                                              self.model_config)

        try:
            logging.info(f'Shape: X {self.X.shape} - y {self.y.shape}')
            trained_model, cv_mean, cv_std, train_time = train(
                self.X,
                self.y,
                pipeline,
                search_space,
                self.train_mode,
                self.n_iters,
                n_jobs=train_config['n_jobs'],
                is_convnet=self.is_convnet)
        except Exception:
            logging.info(f'Training failed - {traceback.format_exc()}')
            self.button_train.button_type = 'danger'
            self.button_train.label = 'Failed'
            return

        model_to_save = trained_model if self.train_mode == 'validate' \
            else trained_model.best_estimator_

        if 'Save' in self.selected_settings:
            dataset_name = '_'.join([id for id in self.train_ids])
            filename = f'{self.model_name}_{dataset_name}'
            save_pipeline(model_to_save, self.save_path, filename)

            model_info = {
                "Model name": self.model_name,
                "Model file": filename,
                "Train ids": self.train_ids,
                "fs": self.fs,
                "Shape": self.X.shape,
                "Preprocessing": self.selected_preproc,
                "Model pipeline": {k: str(v)
                                   for k, v in model_to_save.steps},
                "CV RMSE": f'{cv_mean:.3f}+-{cv_std:.3f}',
                "Train time": train_time
            }
            save_json(model_info, self.save_path, filename)

        logging.info(f'{model_to_save} \n'
                     f'Trained successfully in {train_time:.0f}s \n'
                     f'Accuracy: {cv_mean:.2f}+-{cv_std:.2f}')

        # Update info
        self.button_train.button_type = 'success'
        self.button_train.label = 'Trained'
        self.div_info.text += f'<b>Accuracy:</b> {cv_mean:.2f}+-{cv_std:.2f} <br>'

    def create_widget(self):
        # Select - Pilot
        self.select_pilot = Select(title='Pilot:',
                                   options=self.available_pilots)
        self.select_pilot.on_change('value', self.on_pilot_change)

        # Multichoice - Choose training folder
        self.checkbox_folder = CheckboxButtonGroup(
            labels=['New Calib', 'Game'])

        # Multichoice - Choose session to use for training
        self.select_session = MultiChoice(title='Select train ids',
                                          width=250,
                                          height=120)
        self.select_session.on_change('value', self.on_session_change)

        # Select - Choose model to train
        self.select_model = Select(title="Model")
        self.select_model.on_change('value', self.on_model_change)
        self.select_model.options = ['', 'CSP', 'FBCSP', 'Riemann', 'ConvNet']

        # Slider - Select ROI start (in s after start of epoch)
        self.slider_roi_start = Slider(start=0,
                                       end=6,
                                       value=2,
                                       step=0.25,
                                       title='ROI start (s)')

        # Slider - Select ROI end (in s after start of epoch)
        self.slider_roi_end = Slider(start=0,
                                     end=6,
                                     value=6,
                                     step=0.25,
                                     title='ROI end (s)')

        self.checkbox_settings = CheckboxButtonGroup(
            labels=['Save', 'Optimize'])

        # Slider - Number of iterations if optimization
        self.slider_n_iters = Slider(start=1,
                                     end=50,
                                     value=5,
                                     title='Iterations (optimization)')

        # Checkbox - Choose preprocessing steps
        self.div_preproc = Div(text='<b>Preprocessing</b>', align='center')
        self.checkbox_preproc = CheckboxButtonGroup(
            labels=['BPF', 'CN', 'CAR', 'Crop'])
        self.checkbox_preproc.on_change('active', self.on_preproc_change)

        self.button_train = Button(label='Train', button_type='primary')
        self.button_train.on_click(self.on_train_start)

        self.div_info = Div()

        column1 = column(self.select_pilot, self.checkbox_folder,
                         self.select_session, self.select_model)
        column2 = column(self.slider_roi_start, self.slider_roi_end,
                         self.checkbox_settings, self.slider_n_iters,
                         self.div_preproc, self.checkbox_preproc,
                         self.button_train, self.div_info)
        return row(column1, column2)
コード例 #16
0
cb_autoupdate = None


def autoupdate(attr, old, new):
    print(attr, old, new)
    if new:
        print('starting autoupdate')
        cb_autoupdate = curdoc().add_periodic_callback(detect_base_freq, 500)
    else:
        print('removing autoupdate')
        curdoc().remove_periodic_callback(cb_autoupdate)


checkbox_button_group = CheckboxButtonGroup(labels=["Autodetect"],
                                            active=[0, 1])
checkbox_button_group.on_change('active', autoupdate)

# This button allows to lock in the current data
detect_button = Button()
detect_button.on_change('clicks', detect_base_freq)

#%%
# make the grid & add the plots
curdoc().add_root(
    gridplot([
        p,
        widgetbox(length_slider,
                  concertpitch_slider,
                  detect_button,
                  checkbox_button_group,
                  width=400)
コード例 #17
0
    source.data = {
        'years': years,
        'Apples': [20, 10, 40, 30, 20, 10],
        'Pears': [3, 2, 4, 4, 5, 1],
        'Nectarines': [5, 4, 2, 1, 3, 1],
        'Plums': [2, 2, 2, 5, 3, 1]
    }


from bokeh.models import CheckboxButtonGroup

#chart qtd
options_yr = ['50']
button_group = CheckboxButtonGroup(labels=options_yr)

button_group.on_change("active", get_update)

l = column(
    row(stack_fig),
    row(button_group),
)

# tab_ytd_region_stacked = Panel(child=l, title="YTD Rev By Region")

# tabs = Tabs(tabs=[ tab_ytd_region_stacked ])

curdoc().add_root(l)

# curdoc().add_periodic_callback(get_update, 5000)
# curdoc().add_root(column)
コード例 #18
0
        d3.visible = False
    elif new == [1]:
        d2.visible = True
        d1.visible = False
        d3.visible = False
    elif new == [0, 1]:
        d1.visible = False
        d2.visible = False
        d3.visible = True
    elif new == []:
        d1.visible = False
        d2.visible = False
        d3.visible = False


c.on_change('active', checkbox_changed)


def drawmap(lijst):
    map1 = folium.Map(
        location=[50, 0],
        tiles='stamentoner',
        zoom_start=1,
    )

    print(type(lijst))
    print(lijst)
    teller = 0

    for naam in lijst:
        print(naam)
コード例 #19
0
ファイル: main.py プロジェクト: Lrakotoson/M_Python
def update(attr, old, new):
    """
    Evalue l'ensemble des filtres actifs
    """
    yearChoose = f_year.active
    etatChoose = f_etat.active
    quartierChoose = f_quartier.active
    df_data.query(yearChoose, etatChoose, quartierChoose)
    bd_source.data.update(df_data.bd.to_dict(orient='list'))
    qt_source.data.update(df_data.qt.to_dict(orient='list'))
    pj_source.data.update(df_data.proj.to_dict(orient='list'))
    et_source.data.update(df_data.etat.to_dict(orient='list'))


f_year.on_change('active', update)
f_etat.on_change('active', update)
f_quartier.on_change('active', update)
f_reset.on_event(ButtonClick, reset)

########################## GRAPHS #########################

g_map = mapPlot(bd_source, qt_source)
g_bar = barPlot(pj_source)
g_stack = stackPlot(et_source)

graphs = column(g_map,
                row(g_bar, g_stack, sizing_mode="stretch_width"),
                sizing_mode="stretch_width")

########################## LAYOUT #########################