Exemplo n.º 1
0
def modify_doc(doc):
    source = ColumnDataSource(dict(x=[1, 2], y=[1, 1], val=["a", "b"]))
    plot = Plot(plot_height=400, plot_width=400, x_range=Range1d(0, 1), y_range=Range1d(0, 1), min_border=0)
    plot.add_glyph(source, Circle(x='x', y='y', size=20))
    plot.add_tools(CustomAction(callback=CustomJS(args=dict(s=source), code=RECORD("data", "s.data"))))
    input_box = MultiChoice(css_classes=["foo"])
    input_box.title = "title"
    input_box.options = ["100001", "12344556", "12344557", "3194567289", "209374209374"]
    input_box.value = ["12344556", "12344557"]
    def cb(attr, old, new):
        source.data['val'] = [old, new]
    input_box.on_change('value', cb)
    doc.add_root(row(input_box, plot))
Exemplo n.º 2
0
class FilterWidget(BaseWidget):
    def __init__(self,
                 doc,
                 callback,
                 refresh_rate=500,
                 collection=None,
                 **kwargs):
        super().__init__(doc,
                         callback=callback,
                         refresh_rate=refresh_rate,
                         collection=collection,
                         **kwargs)

        self._choices = []
        self._root = MultiChoice(options=self._choices, title="Filter tasks")
        self._root.on_change("value", self._on_change)

    def _on_change(self, attr, old, new):
        self._callback(new)

    def set_choices(self, choices):
        if choices != self._choices:
            self._choices = choices
            self._root.options = list(self._choices)
Exemplo n.º 3
0
level_select = Select(value=level,
                      title='Tranformations',
                      options=['UK Pounds', 'Year over Year % Change'])
level_select.on_change('value', update_plot)

#print(sorted(options))
#################################################################################

country_select = MultiChoice(value=[country],
                             title='Country',
                             options=sorted(country_options),
                             width=325)
# This is the key thing that creates teh selection object

country_select.on_change('value', update_plot)

#################################################################################

popt = sorted(product_options)
popt.remove("Total (ex. metals)")
popt.insert(0, "Total (ex. metals)")

product_select = Select(value=product,
                        title='Product',
                        options=popt,
                        width=350)
# This is the key thing that creates teh selection object

product_select.on_change('value', update_plot)
# Change the value upone selection via the update plot
Exemplo n.º 4
0
def Persistence_EM_Matrix():
    ### ------------ read files ------------
    
    final = pd.read_csv('./data/App_data/AQ_EM_tse_final.csv',encoding='utf-8')
    profit = pd.read_csv('./data/App_data/AQ_EM_tse_final_profibility_1.csv',encoding='utf-8')
    profit['company_name'] = profit.company.astype(str)+" "+profit.company_abbreviation
    #此檔案中放的為dict, key為用途及名稱,value為list
    with open('app_lists.json','r')as f: 
        app_lists = json.load(f)
    with open('smallco.json','r')as f: 
        smallco_index = json.load(f)
        
    ### ------------ 左側的選項 ------------

    year = Select(title='Year:',value='200812',
                        options=list(final.yyyymm.drop_duplicates().sort_values().astype(str)))

    industry = Select(title='Industry:',value='水泥工業',
                          options = list(profit.tse_industry_name.drop_duplicates())+['All Sectors'])

    index_factor = Select(title='Compared Market Index:',value='TWN50',options=["TWN50", "TM100","TF001"])

    company_list = list((profit.query("yyyymm==200812 & tse_industry_name=='水泥工業'").company_name.astype(str)))

    company_code = Select(title='Company Code: ',value='',options=['']+company_list)

    persistence = Select(title='Persistence:',value='ebit_slope_standard',
                         options= app_lists['options_persistence'])

    EM = Select(title='EM:',value='Jones_model_measure',options=app_lists['options_EM'])

    profit_measure = Select(title='Profit Measure:',value='ROA_ebi',
                            options=app_lists['options_profit_measure'])

    Persistence_percent = RangeSlider(start=0, end=100,value=(20,40), step=1, title="Persistence % :")
    EM_percent = RangeSlider(start=0, end=100,value=(20,40), step=1, title="EM %:")

    #根據選擇日期、產業更新公司列表選項
    ###############################################################################
    def update_company_list(attr,old,new):
        selected_year = year.value
        selected_industry = industry.value
        if selected_industry !='All Sectors':
            company_list = list((profit.query("yyyymm==@selected_year & tse_industry_name==@selected_industry").\
                                 company_name.sort_values().astype(str)))
            #前面加入空値,代表沒有選公司
            company_code.options = ['']+company_list
            #default 為空値
            company_code.value = ''
        else:
            company_list = list((profit.query("yyyymm==@selected_year").\
                            company_name.sort_values().astype(str)))
            #前面加入空値,代表沒有選公司
            company_code.options = ['']+company_list
            #default 為空値
            company_code.value = ''

    #選出畫圖的資料
    ###############################################################################
    def get_plot_data():
        selected_year=year.value
        selected_industry = industry.value
        selected_index_factor = index_factor.value
        selected_Persistence = persistence.value
        selected_EM = EM.value
        selected_profit_measure = profit_measure.value
        if selected_industry !='All Sectors':
            data = profit.query('yyyymm == @selected_year & tse_industry_name  == @selected_industry')
        else :
            data = profit.query('yyyymm == @selected_year')
        #依據日期、產業選擇
        data = data[(data[selected_Persistence].notna()) & (data[selected_EM].notna())]
        #因為有可能選擇的資料也是之後要保留的資料,因此先備份,以免在rename後找不到資料
        origin_data = [selected_Persistence,selected_EM,selected_profit_measure]
        origin = data[origin_data]
        #重新命名,在ColumnDataSource中較好使用
        data.rename(columns={selected_Persistence:'Persistence',selected_EM:'EM',selected_index_factor:'index_factor',
                             selected_profit_measure:'profit_measure'} , inplace=True)
        for i in origin_data:
            data[i] = origin[i]
        data['Persistence'] = data.Persistence.apply(lambda x:value_transfrom(x,selected_Persistence))
        data['EM'] = data.EM.apply(lambda x:value_transfrom(x,selected_EM))
        data['color'] = data['index_factor'].apply(lambda x:'green' if x=='Y' else 'blue')
        data['color'] = data.apply(lambda x:'red' if str(x.company) in smallco_index['2020'] else x.color, 1)

        profit_min = data['profit_measure'].min()
        profit_range = data['profit_measure'].max()-data['profit_measure'].min()
        data['profit_score'] = data['profit_measure'].apply(lambda x:((x-profit_min)/profit_range)*25+5\
                                                            if profit_range!=0 else 30 if x==1 else 5)
        table_data = data[app_lists['select_stock_picking_table_column']]
        data_for_source = data.fillna('--')
        if company_code.value!='':
            data_for_source['text'] = data_for_source['company'].apply(lambda x:'.Here' if x==int(company_code.value[:4])else '')
        else :
            data_for_source['text']=''
        data_for_source = data_for_source[~data_for_source.isin([np.nan, np.inf, -np.inf]).any(1)]
        if company_code.value!='':
            select_co = int(company_code.value[:4])
            data_for_source['select_p'] = data.query('company==@select_co')['Persistence'].to_list()[0]
            data_for_source['select_e'] = data.query('company==@select_co')['EM'].to_list()[0]
        else :
            data_for_source['select_p'] = np.nan
            data_for_source['select_e'] = np.nan

        plot_source = ColumnDataSource(data_for_source)
        return (plot_source,table_data)
    def get_stock_picking_table_data(table_data):
        df = table_data
        Persistence_top = df.Persistence.quantile(Persistence_percent.value[1]/100)
        Persistence_low = df.Persistence.quantile(Persistence_percent.value[0]/100)
        EM_top = df.EM.quantile(EM_percent.value[1]/100)
        EM_low = df.EM.quantile(EM_percent.value[0]/100)
        df = df.query('Persistence <= @Persistence_top & Persistence >= @Persistence_low & EM <= @EM_top & EM >= @EM_low')
        df = df.applymap(lambda x:round(x,2) if type(x)==float else x)
        stock_picking_table_co_choice.options = (df.company.astype(str)+' '+df.company_abbreviation).sort_values().to_list()
        stock_picking_table_co_num.text = f'Total: {df.shape[0]} company'
        return ColumnDataSource(df)

    def get_stock_return_table_2_data():
        selected_year=year.value
        selected_index_factor = index_factor.value
        df = profit.rename(columns={selected_index_factor:'index_factor'})
        df = df.query('yyyymm == @selected_year & index_factor=="Y"')
        return ColumnDataSource(df)

    def get_stock_return_table_3_data(stock_picking_table_source,stock_return_table_2_source):
        if stock_picking_table_source.data['yearly_return'].size ==0 :
            stock_average = [' ']
        else : stock_average = [round(np.nanmean(stock_picking_table_source.data['yearly_return']),4)]

        if stock_return_table_2_source.data['yearly_return'].size ==0 :
            etf_average = [' ']
        else : etf_average = [round(np.nanmean(stock_return_table_2_source.data['yearly_return']),4)]

        return ColumnDataSource(data={'Stock Picking Return (Equally Weighted)':stock_average,
                                      "ETF Return (Equally Weighted)" :etf_average 
                                      })
    def get_matrix_plot_data():
        selected_year=year.value
        df = profit.query('yyyymm == @selected_year')
        df = df[app_lists['options_persistence']+app_lists['options_EM']].corr()
        df = df.apply(lambda x:round(x,2))
        return ColumnDataSource(df)
    ###################################################
    # 製作圖、表  
    def make_scatter_plot(plot_source):
        hover = HoverTool( names=['circle'],
                            tooltips=[('Company Abbreviation :','@company_abbreviation'),
                                        ('Company Code :','@company'),
                                        ('Persistence','@Persistence'),
                                        ('EM :','@EM'),('ROA (EBI) :','@ROA_ebi'),
                                        ('EPS :','@eps'),('ROE_b :','@ROE_b'),
                                        ('Diluted EPS :','@eps_diluted'),('Yearly Return','@yearly_return')]
                         )
        plot = figure(plot_height=500, plot_width=800,
                          tools = ['box_zoom','reset',hover],
                          x_axis_label='Persistence (Log Transformed)',
                          y_axis_label='EM (Log Transformed)', 
                          toolbar_location="right"
                     )
        plot.circle(x="Persistence", y="EM", source=plot_source,color= 'color',size='profit_score', name='circle',
                    line_color=None,alpha=0.5)
#         plot.text('Persistence','EM','text',source=plot_source,color='red',text_font_style='bold',text_font_size='20pt')
        plot.asterisk('select_p','select_e',source=plot_source,color='red',size=20)
        plot.toolbar.active_drag = None
        return plot
    def make_stock_picking_table(stock_picking_table_source):
        columns = []
        for colnames in stock_picking_table_source.data.keys():
            if colnames !='index':
                columns.append(TableColumn(field=colnames, title=colnames, width=6*len(colnames)))
        stock_picking_table = DataTable(source=stock_picking_table_source, columns=columns, width=4000, height = 500)
        return (stock_picking_table)

    def make_stock_return_table_1(stock_picking_table_source):
        columns = []
        for colnames in ['tse_industry_name','company','company_abbreviation','index_factor','yearly_return']:
            columns.append(TableColumn(field=colnames, title=colnames, width=6*len(colnames)))
        stock_return_table_1 = DataTable(source=stock_picking_table_source, columns=columns, height = 500)
        return (stock_return_table_1)
    def make_stock_return_table_2(stock_return_table_2_source):
        columns = []
        for colnames in ['tse_industry_name','company','company_abbreviation','index_factor','yearly_return']:
            columns.append(TableColumn(field=colnames, title=colnames, width=6*len(colnames)))
        stock_return_table_2 = DataTable(source=stock_return_table_2_source, columns=columns, height = 500)
        return (stock_return_table_2)

    def make_stock_return_table_3(stock_return_table_3_source): 
        columns = []
        for colnames in stock_return_table_3_source.data.keys():
            if colnames !='index':
                columns.append(TableColumn(field=colnames, title=colnames, width=6*len(colnames)))
        stock_return_table_3 = DataTable(source=stock_return_table_3_source, columns=columns)
        return (stock_return_table_3)

    def make_matrix_plot(matrix_plot_source):
        columns = []
        for colnames in matrix_plot_source.data.keys():
            if colnames =='index':
                columns.append(TableColumn(field=colnames, title=' ', width=200))
            else:
                columns.append(TableColumn(field=colnames, title=colnames, width=6*len(colnames)))
        matrix_plot = DataTable(source=matrix_plot_source, columns=columns, index_position=None, width = 2500, height=300)
        return (matrix_plot)
    ###################################################
    # 更新 
    def update(attr,old,new):
        stock_picking_table_co_choice.value = []
        new_plot_source,new_table_data=get_plot_data()
        plot_source.data.update(new_plot_source.data)


        new_stock_picking_table_source = get_stock_picking_table_data(new_table_data)
        stock_picking_table_source.data.update(new_stock_picking_table_source.data)

        new_stock_return_table_2_source = get_stock_return_table_2_data()
        stock_return_table_2_source.data.update(new_stock_return_table_2_source.data)

        new_stock_return_table_3_source = get_stock_return_table_3_data(new_stock_picking_table_source,new_stock_return_table_2_source)
        stock_return_table_3_source.data.update(new_stock_return_table_3_source.data)

        new_matrix_plot_source = get_matrix_plot_data()
        matrix_plot_source.data.update(new_matrix_plot_source.data)
    def update_stock_picking(attr,old,new):
        
        pick_list = list(map(lambda x:x[:4],stock_picking_table_co_choice.value))
        new_plot_source,new_table_data=get_plot_data()
        new_stock_picking_table_source = get_stock_picking_table_data(new_table_data)
        df = pd.DataFrame(new_stock_picking_table_source.data).iloc[:,1:]
        if len(pick_list)==0:
            df = df
        else:
            df = df.query('company in @pick_list')
        stock_picking_table_source.data.update(ColumnDataSource(df).data)
        stock_picking_table_co_num.text = f'Total: {df.shape[0]} company'
            

    ###################################################
    # initial 

    plot_source,table_data = get_plot_data()
    plot = make_scatter_plot(plot_source)
    plot_explain = Div(text =
                       '''
                       <span style="padding-left:20px">顏色(綠色): 該公司在該年,有被列在所選的Compared Market Index中 <br/>
                       <span style="padding-left:20px">顏色(紅色): 該公司在該年,有被列在中小型成分股中 <br/>
                       <span style="padding-left:20px">大小: 圈圈越大,代表該公司Profit Measure越大
                       ''')
    tab1 = Panel(child=column(Spacer(height=35), plot, Spacer(height=20), plot_explain), title='Persistence EM Matrix')

    
    stock_picking_table_co_choice = MultiChoice(title = 'select_company:', value=[], options=[], placeholder = '選擇想看的公司')
    stock_picking_table_co_choice.js_on_change("value", CustomJS(code="""
        console.log('multi_choice: value=' + this.value, this.toString())
    """))
    stock_picking_table_co_num = Div(text ='Total:   company')
    stock_picking_table_source = get_stock_picking_table_data(table_data)
    stock_picking_table = make_stock_picking_table(stock_picking_table_source)
    tab2 = Panel(child=column(stock_picking_table_co_num, stock_picking_table, stock_picking_table_co_choice), title='Stock Picking Table')

    div1 = Div(text ='Table 1: The next year return of stocks from the matrix')
    stock_return_table_1 = make_stock_return_table_1(stock_picking_table_source)
    div2 = Div(text ='Table 2: The next year return of stocks in ETF')
    stock_return_table_2_source = get_stock_return_table_2_data()
    stock_return_table_2 = make_stock_return_table_2(stock_return_table_2_source)
    div3 = Div(text ='Table 3: The next year return of equally weighted portfolios')
    stock_return_table_3_source = get_stock_return_table_3_data(stock_picking_table_source,stock_return_table_2_source)
    stock_return_table_3 = make_stock_return_table_3(stock_return_table_3_source)
    tab3 = Panel(child=row([column(div1,stock_return_table_1),
                            column(div2,stock_return_table_2),
                            column(div3,stock_return_table_3)]),
                 title='Stock Return Table')

    matrix_plot_source = get_matrix_plot_data()
    matrix_plot = make_matrix_plot(matrix_plot_source)
    matrix_plot_explain = Div(text = 
        '''
        Persistence: <br/>
        <span style="padding-left:50px">ebit_slope_standard <br/>
        <span style="padding-left:50px">operating_slope_standard <br/>
        <span style="padding-left:50px">yoy_ebit_standard <br/>
        <span style="padding-left:50px">yoy_operating_standard <br/>
        <br/><br/>
        EM: <br/>
        <span style="padding-left:50px">Jones_model_measure <br/>
        <span style="padding-left:50px">Modified_Jones_model_measure <br/>
        <span style="padding-left:50px">Performance_matching_measure <br/>
        <span style="padding-left:50px">opacity_Jones_model_measure <br/>
        <span style="padding-left:50px">opacity_modified_Jones_model_measure <br/>
        <span style="padding-left:50px">opacity_performance_matching <br/>
        ''')
    tab4 = Panel(child=column(matrix_plot,row(Spacer(width=20), matrix_plot_explain)), title='Correlation Matrix of Persistence & EM')

    tabs = Tabs(tabs=[tab1,tab2,tab3,tab4])

    ###################################################
    # input change
    year.on_change('value', update, update_company_list)
    industry.on_change('value', update, update_company_list)
    index_factor.on_change('value', update)
    company_code.on_change('value', update)
    persistence.on_change('value', update)
    EM.on_change('value', update)
    profit_measure.on_change('value', update)
    Persistence_percent.on_change('value', update)
    EM_percent.on_change('value', update)
    stock_picking_table_co_choice.on_change('value', update_stock_picking)

    ###################################################
    # layout
    div_title = Div(text ='Persistence & EM Matrix',style={'font-size': '200%', 'color': 'blue'})
    inputs = column(div_title,year, industry, index_factor, company_code,persistence,EM,profit_measure,
                             Persistence_percent,EM_percent, background='gainsboro')
    final_layout = row(inputs, tabs, width=1200)
    return Panel(child = column(Spacer(height = 35), final_layout), title = 'Persistence & EM 概況')
Exemplo n.º 5
0
def update():
    selected_countries = countries_selector.value
    countries = [name for name, _ in grouped if name in selected_countries]
    years = [list(df["Year"]) for name, df in grouped if name in selected_countries]
    percents = [
        list(df["Percent"]) for name, df in grouped if name in selected_countries
    ]
    span = [
        "%s - %s" % (df["Year"].min(), df["Year"].max())
        for name, df in grouped
        if name in selected_countries
    ]
    mean = [df["Percent"].mean() for name, df in grouped if name in selected_countries]
    color = [colors[name] for name, df in grouped if name in selected_countries]
    source.data = dict(
        countries=countries,
        years=years,
        percents=percents,
        span=span,
        mean=mean,
        color=color,
    )


countries_selector.on_change("value", lambda attr, old, new: update())

update()

curdoc().add_root(column(header, countries_selector, plot, table))
curdoc().title = "Top 5% Income Share"
Exemplo n.º 6
0
class TrainerWidget:
    def __init__(self):
        self.data_path = main_config['data_path']
        self.save_path = main_config['models_path']
        self.active_preproc_ordered = []

    @property
    def available_pilots(self):
        pilots = self.data_path.glob('*')
        return [''] + [p.parts[-1] for p in pilots]

    @property
    def selected_pilot(self):
        return self.select_pilot.value

    @property
    def available_sessions(self):
        pilot_path = self.data_path / self.selected_pilot
        sessions = pilot_path.glob('*')
        return [s.name for s in sessions]

    @property
    def selected_preproc(self):
        active = self.active_preproc_ordered
        return [self.checkbox_preproc.labels[i] for i in active]

    @property
    def train_ids(self):
        return self.select_session.value

    @property
    def preproc_config(self):
        config_cn = dict(sigma=6)
        config_bpf = dict(fs=self.fs,
                          f_order=train_config['f_order'],
                          f_type='butter',
                          f_low=train_config['f_low'],
                          f_high=train_config['f_high'])
        config_crop = dict(fs=self.fs,
                           n_crops=train_config['n_crops'],
                           crop_len=train_config['crop_len'])
        return {'CN': config_cn, 'BPF': config_bpf, 'Crop': config_crop}

    @property
    def should_crop(self):
        return 'Crop' in self.selected_preproc

    @property
    def selected_folders(self):
        active = self.checkbox_folder.active
        return [self.checkbox_folder.labels[i] for i in active]

    @property
    def selected_settings(self):
        active = self.checkbox_settings.active
        return [self.checkbox_settings.labels[i] for i in active]

    @property
    def model_name(self):
        return self.select_model.value

    @property
    def model_config(self):
        config = {'model_name': self.model_name, 'C': 10}
        return config

    @property
    def is_convnet(self):
        return self.model_name == 'ConvNet'

    @property
    def train_mode(self):
        return 'optimize' if 'Optimize' in self.selected_settings \
            else 'validate'

    @property
    def folder_ids(self):
        ids = []
        if 'New Calib' in self.selected_folders:
            ids.append('formatted_raw_500Hz')
        if 'Game' in self.selected_folders:
            ids.append('formatted_raw_500Hz_game')
        return ids

    @property
    def start(self):
        return self.slider_roi_start.value

    @property
    def end(self):
        return self.slider_roi_end.value

    @property
    def n_iters(self):
        return self.slider_n_iters.value

    def on_pilot_change(self, attr, old, new):
        logging.info(f'Select pilot {new}')
        self.select_session.value = ['']
        self.update_widget()

    def on_session_change(self, attr, old, new):
        logging.info(f"Select train sessions {new}")
        self.update_widget()

    def on_model_change(self, attr, old, new):
        logging.info(f'Select model {new}')
        self.update_widget()

    def on_preproc_change(self, attr, old, new):
        # Case 1: Add preproc
        if len(new) > len(old):
            to_add = list(set(new) - set(old))[0]
            self.active_preproc_ordered.append(to_add)
        # Case 2: Remove preproc
        else:
            to_remove = list(set(old) - set(new))[0]
            self.active_preproc_ordered.remove(to_remove)

        logging.info(f'Preprocessing selected: {self.selected_preproc}')
        self.update_widget()

    def update_widget(self):
        self.select_pilot.options = self.available_pilots
        self.select_session.options = self.available_sessions
        self.button_train.button_type = 'primary'
        self.button_train.label = 'Train'
        self.div_info.text = f'<b>Preprocessing selected:</b> {self.selected_preproc} <br>'

    def on_train_start(self):
        assert self.model_name != '', 'Please select a model !'
        assert len(self.train_ids) > 0, 'Please select at least one session !'

        self.button_train.button_type = 'warning'
        self.button_train.label = 'Loading data...'
        curdoc().add_next_tick_callback(self.on_load)

    def on_load(self):
        X, y = {}, {}
        for id in self.train_ids:
            for folder in self.folder_ids:
                logging.info(f'Loading {id} - {folder}')
                try:
                    session_path = self.data_path / self.selected_pilot /\
                        id / folder
                    filepath = session_path / 'train/train1.npz'
                    X_id, y_id, fs, ch_names = load_session(
                        filepath, self.start, self.end)
                    X[f'{id}_{folder}'] = X_id
                    y[f'{id}_{folder}'] = y_id
                    self.fs = fs
                    self.ch_names = ch_names

                except Exception as e:
                    logging.info(f'Loading data failed - {e}')
                    self.button_train.button_type = 'danger'
                    self.button_train.label = 'Training failed'
                    return

        # Concatenate all data
        self.X = np.vstack([X[id] for id in X.keys()])
        self.y = np.hstack([y[id] for id in y.keys()]).flatten()

        # Cropping FIXME: Integrate inside preproc to avoid data leakage
        if self.should_crop:
            self.X, self.y = cropping(self.X, self.y,
                                      **self.preproc_config['Crop'])

        if self.is_convnet:
            assert self.should_crop, 'ConvNet requires cropping !'
            self.X = self.X[:, :, :, np.newaxis]

        # Update session info
        self.div_info.text = f'<b>Sampling frequency:</b> {self.fs} Hz<br>' \
            f'<b>Classes:</b> {np.unique(self.y)} <br>' \
            f'<b>Nb trials:</b> {len(self.y)} <br>' \
            f'<b>Nb channels:</b> {self.X.shape[1]} <br>' \
            f'<b>Trial length:</b> {self.X.shape[-1] / self.fs}s <br>'

        self.button_train.label = 'Training...'
        curdoc().add_next_tick_callback(self.on_train)

    def on_train(self):
        pipeline, search_space = get_pipeline(self.selected_preproc,
                                              self.preproc_config,
                                              self.model_config)

        try:
            logging.info(f'Shape: X {self.X.shape} - y {self.y.shape}')
            trained_model, cv_mean, cv_std, train_time = train(
                self.X,
                self.y,
                pipeline,
                search_space,
                self.train_mode,
                self.n_iters,
                n_jobs=train_config['n_jobs'],
                is_convnet=self.is_convnet)
        except Exception:
            logging.info(f'Training failed - {traceback.format_exc()}')
            self.button_train.button_type = 'danger'
            self.button_train.label = 'Failed'
            return

        model_to_save = trained_model if self.train_mode == 'validate' \
            else trained_model.best_estimator_

        if 'Save' in self.selected_settings:
            dataset_name = '_'.join([id for id in self.train_ids])
            filename = f'{self.model_name}_{dataset_name}'
            save_pipeline(model_to_save, self.save_path, filename)

            model_info = {
                "Model name": self.model_name,
                "Model file": filename,
                "Train ids": self.train_ids,
                "fs": self.fs,
                "Shape": self.X.shape,
                "Preprocessing": self.selected_preproc,
                "Model pipeline": {k: str(v)
                                   for k, v in model_to_save.steps},
                "CV RMSE": f'{cv_mean:.3f}+-{cv_std:.3f}',
                "Train time": train_time
            }
            save_json(model_info, self.save_path, filename)

        logging.info(f'{model_to_save} \n'
                     f'Trained successfully in {train_time:.0f}s \n'
                     f'Accuracy: {cv_mean:.2f}+-{cv_std:.2f}')

        # Update info
        self.button_train.button_type = 'success'
        self.button_train.label = 'Trained'
        self.div_info.text += f'<b>Accuracy:</b> {cv_mean:.2f}+-{cv_std:.2f} <br>'

    def create_widget(self):
        # Select - Pilot
        self.select_pilot = Select(title='Pilot:',
                                   options=self.available_pilots)
        self.select_pilot.on_change('value', self.on_pilot_change)

        # Multichoice - Choose training folder
        self.checkbox_folder = CheckboxButtonGroup(
            labels=['New Calib', 'Game'])

        # Multichoice - Choose session to use for training
        self.select_session = MultiChoice(title='Select train ids',
                                          width=250,
                                          height=120)
        self.select_session.on_change('value', self.on_session_change)

        # Select - Choose model to train
        self.select_model = Select(title="Model")
        self.select_model.on_change('value', self.on_model_change)
        self.select_model.options = ['', 'CSP', 'FBCSP', 'Riemann', 'ConvNet']

        # Slider - Select ROI start (in s after start of epoch)
        self.slider_roi_start = Slider(start=0,
                                       end=6,
                                       value=2,
                                       step=0.25,
                                       title='ROI start (s)')

        # Slider - Select ROI end (in s after start of epoch)
        self.slider_roi_end = Slider(start=0,
                                     end=6,
                                     value=6,
                                     step=0.25,
                                     title='ROI end (s)')

        self.checkbox_settings = CheckboxButtonGroup(
            labels=['Save', 'Optimize'])

        # Slider - Number of iterations if optimization
        self.slider_n_iters = Slider(start=1,
                                     end=50,
                                     value=5,
                                     title='Iterations (optimization)')

        # Checkbox - Choose preprocessing steps
        self.div_preproc = Div(text='<b>Preprocessing</b>', align='center')
        self.checkbox_preproc = CheckboxButtonGroup(
            labels=['BPF', 'CN', 'CAR', 'Crop'])
        self.checkbox_preproc.on_change('active', self.on_preproc_change)

        self.button_train = Button(label='Train', button_type='primary')
        self.button_train.on_click(self.on_train_start)

        self.div_info = Div()

        column1 = column(self.select_pilot, self.checkbox_folder,
                         self.select_session, self.select_model)
        column2 = column(self.slider_roi_start, self.slider_roi_end,
                         self.checkbox_settings, self.slider_n_iters,
                         self.div_preproc, self.checkbox_preproc,
                         self.button_train, self.div_info)
        return row(column1, column2)