Ejemplo n.º 1
0
def test_hdf5_input_output():
    """Checks the data inside a .hdf5 input file."""

    config_file = './study/profit_hdf5.yaml'
    config = Config.from_file(config_file)
    try:
        run(f"profit run {config_file}", shell=True, timeout=TIMEOUT)
        data_in = load(config['files'].get('input'))
        assert data_in.shape == (2, 1)
        assert data_in.dtype.names == ('u', 'v', 'w')
    finally:
        clean(config)
Ejemplo n.º 2
0
def multi_test_1d(study, config_file, output_file):
    """ test 1D with different config files """
    config_file = path.join(study, config_file)
    output_file = path.join(study, output_file)
    config = Config.from_file(config_file)
    try:
        run(f"profit run {config_file}", shell=True, timeout=TIMEOUT)
        output = load(output_file)
        assert output.shape == (7, 1)
        assert all(output['f'] - array([0.7836, -0.5511, 1.0966, 0.4403, 1.6244, -0.4455, 0.0941]).reshape((7, 1))
                   < 1e-4)
    finally:
        clean(config)
Ejemplo n.º 3
0
    def load_model(cls, path):
        """Loads a saved model from a .hdf5 file and updates its attributes. In case of a multi-output model, the .pkl
        file is loaded, since .hdf5 is not supported yet.

        Parameters:
            path (str): Path including the file name, from where the model should be loaded.

        Returns:
            GPy.models: Instantiated surrogate model.
        """

        from profit.util import load
        from .encoders import Encoder
        from GPy import models

        self = cls()
        try:
            sur_dict = load(path, as_type='dict')
            self.model = models.GPRegression.from_dict(sur_dict['model'])
            self.Xtrain = sur_dict['Xtrain']
            self.ytrain = sur_dict['ytrain']
            self.encoder = [
                Encoder(func, cols, out)
                for func, cols, out in eval(sur_dict['encoder'])
            ]
        except (OSError, FileNotFoundError):
            from pickle import load as pload
            from os.path import splitext
            # Load multi-output model from pickle file
            print(
                "File {} not found. Trying to find a .pkl file with multi-output instead."
                .format(path))
            self.model, self.Xtrain, self.ytrain, encoder_str = pload(
                open(splitext(path)[0] + '.pkl', 'rb'))
            self.encoder = [
                Encoder(func, cols, out)
                for func, cols, out in eval(encoder_str)
            ]
            self.output_ndim = int(max(self.model.X[:, -1])) + 1
            self.multi_output = True

        # Initialize the encoder by encoding and decoding the training data once.
        self.encode_training_data()
        self.decode_training_data()

        self.kernel = self.model.kern
        self._set_hyperparameters_from_model()
        self.ndim = self.Xtrain.shape[-1]
        self.trained = True
        self.print_hyperparameters("Loaded")
        return self
Ejemplo n.º 4
0
    def load_model(cls, filename):
        """ Load a saved GPySurrogate object with hdf5 format. """
        from profit.util import load
        sur_dict = eval(load(filename, as_type='dict').get('data'))
        self = cls()

        for attr, value in sur_dict.items():
            setattr(
                self, attr, value if attr != 'm' else
                cls.GPy.models.GPRegression.from_dict(value))
        if self.m:
            self.xtrain = self.m.X
            self.ytrain = self.m.Y
            self.kern = self.m.kern

        return self
Ejemplo n.º 5
0
    def load_model(cls, path):
        """Load a saved SklGPSurrogate model from a pickle file and update its attributes.

        Parameters:
            path (str): Path including the file name, from where the model should be loaded.

        Returns:
            profit.sur.gaussian_process.SklearnGPSurrogate: Instantiated surrogate model.
        """

        from pickle import load

        self = cls()
        self.model = load(open(path, 'rb'))
        self.Xtrain = self.model.X_train_
        self.ytrain = self.model.y_train_
        self.kernel = self.model.kernel_
        self.ndim = self.Xtrain.shape[-1]
        self.fixed_sigma_n = self.model.alpha != 1e-5
        self.trained = True
        self._set_hyperparameters_from_model()
        self.print_hyperparameters("Loaded")
        return self
Ejemplo n.º 6
0
    def load_model(cls, path):
        """Loads a saved model from a .hdf5 file and updates its attributes. In case of a multi-output model, the .pkl
        file is loaded, since .hdf5 is not supported yet.

        Parameters:
            path (str): Path including the file name, from where the model should be loaded.

        Returns:
            GPy.models: Instantiated surrogate model.
        """

        from profit.util import load
        from GPy import models

        self = cls()
        try:
            model_dict = load(path, as_type='dict')
            self.model = models.GPRegression.from_dict(model_dict)
            self.Xtrain = self.model.X
            self.ytrain = self.model.Y
        except (OSError, FileNotFoundError):
            from pickle import load as pload
            from os.path import splitext
            # Load multi-output model from pickle file
            self.model = pload(open(splitext(path)[0] + '.pkl', 'rb'))
            self.output_ndim = int(max(self.model.X[:, -1])) + 1
            self.Xtrain = self.model.X[:len(self.model.X) // self.output_ndim, :-1]
            self.ytrain = self.model.Y.reshape(-1, self.output_ndim, order='F')
            self.multi_output = True

        self.kernel = self.model.kern
        self._set_hyperparameters_from_model()
        self.ndim = self.Xtrain.shape[-1]
        self.trained = True
        self.print_hyperparameters("Loaded")
        return self
Ejemplo n.º 7
0
def init_app(config):
    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

    app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
    server = app.server
    app.config.suppress_callback_exceptions = False

    indata = load(config['files']['input']).flatten()
    outdata = load(config['files']['output']).flatten()

    invars = indata.dtype.names
    outvars = outdata.dtype.names
    dd_opts_in = [{'label': invar, 'value': invar} for invar in invars]
    dd_opts_out = [{'label': outvar, 'value': outvar} for outvar in outvars]

    col_width = 400
    txt_width = 100
    dd_width = 250
    log_width = 50
    graph_height = 620
    txt_check_width = 50
    check_txt_width = txt_width-txt_check_width
    ax_opt_tit_sty = {'width': col_width}
    ax_opt_txt_sty = {'width': txt_width}
    ax_opt_log_sty = {'width': log_width}
    dd_sty = {'width': dd_width}
    axis_options_div_style = {'display': 'flex', 'align-items': 'center', 'height':36, 'padding': 1}
    fit_opt_txt_sty = {'width': txt_width}
    headline_sty = {'text-align': 'center', 'display': 'block', 'width': col_width-25}
    input_div_sty = {'height': 40}
    input_sty = {'width': 125}
    col_sty = {'padding-left':5, 'padding-right':5}
    col_sty_th = {**col_sty, 'text-align':'center'}
    button_sty = {'padding-left':15, 'padding-right':15}

    # try to load model with 'save' and 'fit' config option
    path = config['fit']['save'] or config['fit']['load']
    try:
        sur = Surrogate.load_model(path)
    except (TypeError, FileNotFoundError):
            print('Model could not be loaded')

    app.layout = html.Div(children=[
        html.Table(children=[html.Tr(children=[
            html.Td(id='axis-options', style={'width': '20%'}, children=[
                html.Div(dcc.RadioItems(
                    id='graph-type',
                    options=[{'label': i, 'value': i} for i in ['1D', '2D', '2D contour', '3D']],
                    value='1D',
                    labelStyle={'display': 'inline-block'})),
                html.Div(id='header-opt', children=[html.B("Axis options:", style=headline_sty)], style=ax_opt_tit_sty),
                html.Div(id='invar-1-div', style=axis_options_div_style, children=[
                    html.B('x: ', style=ax_opt_txt_sty),
                    dcc.Dropdown(id='invar', options=dd_opts_in, value=invars[0], style=dd_sty),
                    dcc.Checklist(id='invar-1-log', options=[{'label': 'log', 'value': 'log'}], style=ax_opt_log_sty),
                ]),
                html.Div(id='invar-2-div', style=axis_options_div_style, children=[
                    html.B('y: ', style=ax_opt_txt_sty),
                    dcc.Dropdown(
                        id='invar_2',
                        options=dd_opts_in,
                        value=invars[1] if len(invars) > 1 else invars[0],
                        style=dd_sty, ),
                    dcc.Checklist(
                        id='invar-2-log',
                        options=[{'label': 'log', 'value': 'log'}],
                        style=ax_opt_log_sty, ),
                ]),
                html.Div(id='invar-3-div', style=axis_options_div_style, children=[
                    html.B('z: ', style=ax_opt_txt_sty),
                    dcc.Dropdown(
                        id='invar_3',
                        options=dd_opts_in,
                        value=invars[2] if len(invars) > 2 else invars[0],
                        style=dd_sty,
                    ),
                    dcc.Checklist(
                        id='invar-3-log',
                        options=[{'label': 'log', 'value': 'log'}],
                        style=ax_opt_log_sty,
                    ),
                ]),
                html.Div(id='outvar-div', style=axis_options_div_style, children=[
                    html.B('output: ', style=ax_opt_txt_sty),
                    dcc.Dropdown(
                        id='outvar',
                        options=dd_opts_out,
                        value=outvars[0],
                        style=dd_sty, ),
                    dcc.Checklist(
                        id='outvar-log',
                        options=[{'label': 'log', 'value': 'log'}],
                        style=ax_opt_log_sty, ),
                ]),
                html.Div(id='color-div', style=axis_options_div_style, children=[
                    html.B("color: ", style={'width': txt_check_width}),
                    dcc.Checklist(
                        id='color-use',
                        options=[{'label': '', 'value': 'true'}],
                        style={'width': check_txt_width},
                        value=['true'], ),
                    dcc.Dropdown(
                        id='color-dropdown',
                        options=[{'label': 'OUTPUT', 'value': 'OUTPUT'}] + dd_opts_in + dd_opts_out,
                        value='OUTPUT',
                        style=dd_sty, ),
                ]),
                html.Div(id='error-div', style=axis_options_div_style, children=[
                    html.B("error: ", style={'width': txt_check_width}),
                    dcc.Checklist(
                        id='error-use',
                        options=[{'label': '', 'value': 'true'}],
                        style={'width': check_txt_width}, ),
                    dcc.Dropdown(
                        id='error-dropdown',
                        options=dd_opts_out,
                        value=outvars[-1],
                        style=dd_sty, ),
                ]),
                html.Div(id='fit-opt', children=html.B("Fit options:", style=headline_sty), style=ax_opt_tit_sty),
                html.Div(id='fit-use-div', style=axis_options_div_style, children=[
                    html.B("display fit:", style=fit_opt_txt_sty),
                    dcc.Checklist(
                        id='fit-use',
                        options=[{'label': '', 'value': 'show'}],
                        labelStyle={'display': 'inline-block'}, ),
                ]),
                html.Div(id='fit-multiinput-div', style=axis_options_div_style, children=[
                    html.B("multi-fit:", style=fit_opt_txt_sty),
                    dcc.Dropdown(
                        id='fit-multiinput-dropdown',
                        options=dd_opts_in,
                        value=invars[-1],
                        style=dd_sty, ),
                ]),
                html.Div(id='fit-number-div', style=axis_options_div_style, children=[
                    html.B("#fits:", style=fit_opt_txt_sty),
                    dcc.Input(id='fit-number', type='number', value=1, min=1),
                ]),
                html.Div(id='fit-conf-div', style=axis_options_div_style, children=[
                    html.B("\u03c3-confidence:", style=fit_opt_txt_sty),
                    dcc.Input(id='fit-conf', type='number', value=2, min=0),
                ]),
                html.Div(id='fit-noise-div', style=axis_options_div_style, children=[
                    dcc.Checklist(
                        id='fit-var',
                        options=[{'label': 'add noise covariance', 'value': 'add'}],
                        style={'margin-left': txt_width},
                    )
                ]),
                html.Div(id='fit-color-div', style=axis_options_div_style, children=[
                    html.B("fit-color:", style=fit_opt_txt_sty),
                    dcc.RadioItems(
                        id='fit-color',
                        options=[{'label': 'output', 'value': 'output'},
                                 {'label': 'multi-fit', 'value': 'multi-fit'},
                                 {'label': 'marker-color', 'value': 'marker-color'}],
                        value='output',
                        labelStyle={'display': 'inline-block'}, ),
                ]),
                html.Div(id='fit-opacity-div', style=axis_options_div_style, children=[
                    html.B("fit-opacity:", style=fit_opt_txt_sty),
                    html.Div(style={'width': col_width-txt_width}, children=[
                        dcc.Slider(
                            id='fit-opacity',
                            min=0,
                            max=1,
                            step=0.1,
                            value=0.5,
                            marks={i: {'label': f'{100 * i:.0f}%'} for i in [0, 0.2, 0.4, 0.6, 0.8, 1]},
                        ),
                    ]),
                ]),
                html.Div(id='fit-sampling-div', style=axis_options_div_style, children=[
                    html.B("#points:", style=fit_opt_txt_sty),
                    dcc.Input(id='fit-sampling', type='number', value=50, min=1, debounce=True, style={'appearance': 'textfield'}),
                ]),
            ]),
            html.Td(id='graph', style={'width': '80%'}, children=[html.Div(dcc.Graph(id='graph1'))]),
        ])]),
        html.Div(html.Table(id='filters', children=[html.Tr([
            html.Td(html.Div([
                dcc.Dropdown(
                    id='filter-dropdown',
                    options=dd_opts_in,
                    value=invars[0],
                    style={'width': 200, 'margin-right': 10},
                ),
                html.Button("Add Filter", id='add-filter', n_clicks=0, style=button_sty),
            ], style={'display': 'flex'}), style={**col_sty, 'border-bottom-width':0}),
            html.Td(html.Button("Clear all Filter", id='clear-all-filter', n_clicks=0, style=button_sty), style={**col_sty, 'border-bottom-width':0}),
            html.Td(dcc.Slider(id='scale-slider',
                               min=-0.5, max=0.5,
                               value=0, step=0.01,
                               marks={i: f'{100*i:.0f}%' for i in [-0.5, -0.25, 0, 0.25, 0.5]},
                               ),
                    style={**col_sty, 'width': 500, 'border-bottom-width':0}
            ),
            html.Td(html.Button("Scale Filter span", id='scale', n_clicks=0, style=button_sty), style={**col_sty, 'border-bottom-width':0}),
        ])])),
        html.Div(html.Table(id='param-table', children=[
            html.Thead(id='param-table-head', children=[
                html.Tr(children=[
                    html.Th("Parameter", style={**col_sty, **input_sty}),
                    html.Th("log", style=col_sty_th),
                    html.Th("Slider", style={**col_sty_th, 'width': 300}),
                    html.Th("Range (min/max)", style=col_sty_th),
                    html.Th("center/span", style=col_sty_th),
                    html.Th("filter active", style=col_sty_th),
                    html.Th("#digits", style=col_sty_th),
                    html.Th("reset", style=col_sty_th),
                    html.Th("", style=col_sty_th),
                ]),
            ]),
            html.Tbody(id='param-table-body', children=[
                html.Tr(children=[
                    html.Td(html.Div(id='param-text-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-log-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-slider-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-range-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-center-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-active-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-digits-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-reset-div', children=[]), style=col_sty),
                    html.Td(html.Div(id='param-clear-div', children=[]), style=col_sty),
                ]),
            ]),
        ])),
    ])


    @app.callback(
        [Output('param-text-div', 'children'),
         Output('param-log-div', 'children'),
         Output('param-slider-div', 'children'),
         Output('param-range-div', 'children'),
         Output('param-center-div', 'children'),
         Output('param-active-div', 'children'),
         Output('param-digits-div', 'children'),
         Output('param-reset-div', 'children'),
         Output('param-clear-div', 'children'), ],
        [Input('add-filter', 'n_clicks'),
         Input('clear-all-filter', 'n_clicks'),
         Input({'type': 'param-clear', 'index': ALL}, 'n_clicks')],
        [State('filter-dropdown', 'value'),
         State('param-text-div', 'children'),
         State('param-log-div', 'children'),
         State('param-slider-div', 'children'),
         State('param-range-div', 'children'),
         State('param-center-div', 'children'),
         State('param-active-div', 'children'),
         State('param-digits-div', 'children'),
         State('param-reset-div', 'children'),
         State('param-clear-div', 'children')],
    )
    def add_filterrow(n_clicks, clear_all, clear_clicks, filter_dd, text, log, slider, range_div, center_div, active_div, dig_div, reset_div, clear_div):
        ctx = dash.callback_context
        trigger_id = ctx.triggered[0]["prop_id"].split(".")[0]
        if trigger_id == 'clear-all-filter':
            return [], [], [], [], [], [], [], [], []
        elif trigger_id == 'add-filter':
            for i in range(len(text)):
                if text[i]['props']['children'][0] == filter_dd:
                    #return text, log, slider, range_div, center_div, active_div, dig_div, reset_div, clear_div
                    raise PreventUpdate
            ind = invars.index(filter_dd) #index of the variable
            new_text = html.Div(id={'type': 'dyn-text', 'index': ind}, children=[filter_dd], style={**input_div_sty, **input_sty})
            new_log = html.Div(id={'type': 'dyn-log', 'index': ind}, style={**input_div_sty, 'text-align':'center'}, children=[
                dcc.Checklist(id={'type': 'param-log', 'index': ind}, options=[{'label': '', 'value': 'log'}])], )
            new_slider = html.Div(id={'type': 'dyn-slider', 'index': ind}, style=input_div_sty, children=[
                create_slider(filter_dd)], )
            new_range = html.Div(id={'type': 'dyn-range', 'index': ind}, style=input_div_sty, children=[
                dcc.Input(id={'type': 'param-range-min', 'index': ind}, type='number', debounce=True, style={**input_sty, 'appearance': 'textfield'}),
                dcc.Input(id={'type': 'param-range-max', 'index': ind}, type='number', debounce=True, style={**input_sty, 'appearance': 'textfield'}),
            ], )
            new_center = html.Div(id={'type': 'dyn-center', 'index': ind}, style=input_div_sty, children=[
                dcc.Input(id={'type': 'param-center', 'index': ind}, type='number', debounce=True, style={**input_sty, 'appearance': 'textfield'}),
                dcc.Input(id={'type': 'param-span', 'index': ind}, type='number', debounce=True, style={**input_sty, 'appearance': 'textfield'}),
            ], )
            new_active = html.Div(id={'type': 'dyn-active', 'index': ind},
                                  style={**input_div_sty, 'text-align': 'center'},
                                  children=[
                                      dcc.Checklist(id={'type': 'param-active', 'index': ind},
                                                    options=[{'label': '', 'value': 'act'}],
                                                    value=['act'], )
                                  ])
            new_dig = html.Div(id={'type': 'dyn-dig', 'index': ind}, children=[
                dcc.Input(id={'type': 'param-dig', 'index': ind}, type='number', value=5, min=0, style={'width':100})
            ], style={'height':40})
            new_reset = html.Div(id={'type': 'dyn-reset', 'index': ind}, children=[
                html.Button("reset", id={'type': 'param-reset', 'index': ind}, n_clicks=0,
                            style={'padding-left':15, 'padding-right':15})
            ])
            new_clear = html.Div(id={'type': 'dyn-clear', 'index': ind}, children=[
                html.Button("x", id={'type': 'param-clear', 'index': ind}, n_clicks=0,
                            style={'border':'none', 'padding-left':5, 'padding-right':5})
            ])
            text.append(new_text)
            log.append(new_log)
            slider.append(new_slider)
            range_div.append(new_range)
            center_div.append(new_center)
            active_div.append(new_active)
            dig_div.append(new_dig)
            reset_div.append(new_reset)
            clear_div.append(new_clear)
        elif len(trigger_id) >=1 and trigger_id[0] == "{":
            for i in range(len(text)):
                # search table row to delete
                if int(text[i]['props']['id']['index']) == int(trigger_id.split(',')[0].split(':')[1]):
                    text.pop(i)
                    log.pop(i)
                    slider.pop(i)
                    range_div.pop(i)
                    center_div.pop(i)
                    active_div.pop(i)
                    dig_div.pop(i)
                    reset_div.pop(i)
                    clear_div.pop(i)
                    break
        return text, log, slider, range_div, center_div, active_div, dig_div, reset_div, clear_div


    @app.callback(
        [Output({'type': 'param-range-min', 'index': MATCH}, 'step'),
         Output({'type': 'param-range-max', 'index': MATCH}, 'step'),
         Output({'type': 'param-center', 'index': MATCH}, 'step'),
         Output({'type': 'param-span', 'index': MATCH}, 'step'),
         Output({'type': 'param-slider', 'index': MATCH}, 'step'), ],
        Input({'type': 'param-dig', 'index': MATCH}, 'value')
    )
    def update_step(dig):
        """ Function to update an synchronise step-sizes throughout the filter-table.

        Args:
            dig (int): Number of digits to be used. Selected by the user via a 'dcc.Input'-layout-element.

        Returns:
            step: Step-size for the 4 'dcc.Input'-Elements and the slider step.

        """
        step = 10**(-dig)
        return step, step, step, step, step


    @app.callback(
        [Output({'type': 'param-range-min', 'index': MATCH}, 'value'),
         Output({'type': 'param-range-max', 'index': MATCH}, 'value'),
         Output({'type': 'param-slider', 'index': MATCH}, 'value'),
         Output({'type': 'param-slider', 'index': MATCH}, 'min'),
         Output({'type': 'param-slider', 'index': MATCH}, 'max'),
         Output({'type': 'param-center', 'index': MATCH}, 'value'),
         Output({'type': 'param-span', 'index': MATCH}, 'value'),
         Output({'type': 'param-slider', 'index': MATCH}, 'marks'), ],
        [Input('param-text-div', 'children'),
         Input({'type': 'param-log', 'index': MATCH}, 'value'),
         Input({'type': 'param-range-min', 'index': MATCH}, 'value'),
         Input({'type': 'param-range-max', 'index': MATCH}, 'value'),
         Input({'type': 'param-slider', 'index': MATCH}, 'value'),
         Input({'type': 'param-center', 'index': MATCH}, 'value'),
         Input({'type': 'param-span', 'index': MATCH}, 'value'),
         Input('scale', 'n_clicks'),
         Input({'type': 'param-dig', 'index': MATCH}, 'value'),
         Input({'type': 'param-reset', 'index': MATCH}, 'n_clicks'), ],
        [State({'type': 'param-slider', 'index': MATCH}, 'id'),
         State('scale-slider', 'value'),
         State({'type': 'param-slider', 'index': MATCH}, 'marks'), ]
    )
    def update_dyn_slider_range(text_div, log_act, dyn_min, dyn_max, slider_val, center, span, scale, dig, reset, id, scale_slider, marks):
        ctx = dash.callback_context
        try:
            trigger_id = ctx.triggered[0]["prop_id"].split('}')[0].split(',')[1].split(':')[1]
        except IndexError:
            trigger_id = ctx.triggered[0]["prop_id"]

        mark_lim = [float(i) for i in list(marks.keys())]

        data_in = indata[invars[id['index']]]
        data_min_0 = min(data_in[data_in > 0])

        if trigger_id == '"param-log"' and log_act != ['log']:
            dyn_min = 10**dyn_min
            dyn_max = 10**dyn_max
            slider_val = [10**val for val in slider_val]
            mark_lim = [10 ** lim for lim in mark_lim]
            if min(data_in) < 0 and slider_val[0] > 0:
                mark_lim[0] = min(data_in)
            span = (slider_val[1] - slider_val[0])/2
            center = (slider_val[0] + slider_val[1])/2

        if trigger_id != '"param-log"' and log_act == ['log']:
            dyn_min = 10 ** dyn_min
            dyn_max = 10 ** dyn_max
            slider_val = [10 ** val for val in slider_val]
            mark_lim = [10 ** lim for lim in mark_lim]
            if min(data_in) < 0 and slider_val[0] > 0:
                mark_lim[0] = min(data_in)

        if trigger_id == '"param-reset"':
            slider_val = [min(data_in), max(data_in)]

        if ctx.triggered[0]["prop_id"] == "scale.n_clicks":
            # print('scale')
            span = span * (1 + scale_slider)
            if log_act == ['log']:
                trigger_id = '"param-span"'
            else:
                dyn_min = center - span
                dyn_max = center + span
                slider_val = [dyn_min, dyn_max]

        if trigger_id == '"param-center"' or trigger_id == '"param-span"' and (center and span):
            # print('center')
            if log_act == ['log']:
                dyn_min = 10**(center-span)
                dyn_max = 10**(center+span)
            else:
                dyn_min = center - span
                dyn_max = center + span
            slider_val = [dyn_min, dyn_max]
        elif (trigger_id == '"param-range-min"' or trigger_id == '"param-range-max"') and (
                dyn_min is not None and dyn_max is not None):
            # print('range')
            slider_val = [dyn_min, dyn_max]
            span = (slider_val[1] - slider_val[0]) / 2
            center = (slider_val[0] + slider_val[1]) / 2
        elif slider_val:
            # print('slider')
            dyn_min = slider_val[0]
            dyn_max = slider_val[1]
            span = (slider_val[1] - slider_val[0]) / 2
            center = (slider_val[0] + slider_val[1]) / 2

        if log_act == ['log']:
            # log values
            try:
                log_dyn_min = log10(dyn_min)
            except ValueError:
                log_dyn_min = log10(data_min_0)
            log_dyn_max = log10(dyn_max)
            log_slider_val = [log_dyn_min, log_dyn_max]
            try:
                log_mark_lim = [log10(mark) for mark in mark_lim]
            except ValueError:
                log_mark_lim = [log10(data_min_0), log10(mark_lim[1])]
            log_span = (log_slider_val[1] - log_slider_val[0])/2
            log_center = log_slider_val[0] + log_span
            log_marks = {log_mark_lim[0]: str(round(log_mark_lim[0], dig)),
                         log_mark_lim[1]: str(round(log_mark_lim[1], dig))}
            return round(log_dyn_min, dig), round(log_dyn_max, dig), log_slider_val, log_mark_lim[0], log_mark_lim[1],\
                   round(log_center, dig), round(log_span, dig), log_marks
        else:
            marks = {mark_lim[0]: str(round(mark_lim[0], dig)), mark_lim[1]: str(round(mark_lim[1], dig))}
            return round(dyn_min, dig), round(dyn_max, dig), slider_val, mark_lim[0], mark_lim[1], \
                   round(center, dig), round(span, dig), marks


    @app.callback(
        [Output('invar-2-div', 'style'),
         Output('invar-3-div', 'style'),
         Output('color-div', 'style'),
         Output('error-div', 'style'),
         Output('fit-use-div', 'style'),
         Output('fit-multiinput-div', 'style'),
         Output('fit-number-div', 'style'),
         Output('fit-number', 'value'),
         Output('fit-conf-div', 'style'),
         Output('fit-noise-div', 'style'),
         Output('fit-color-div', 'style'),
         Output('fit-opacity-div', 'style'), ],
        [Input('graph-type', 'value'), ],
        [State('fit-number', 'value'), ]
    )
    def div_visibility(graph_type, fits):
        hide = axis_options_div_style.copy()
        hide['visibility'] = 'hidden'
        show = axis_options_div_style.copy()
        show['visibility'] = 'visible'
        if graph_type == '1D':
            return hide, hide, show, show, show, show, show, fits, show, show, hide, show
        if graph_type == '2D':
            if len(invars) <= 2:
                return show, hide, show, show, show, hide, hide, 1, show, show, show, show
            else:
                return show, hide, show, show, show, show, show, fits, show, show, show, show
        if graph_type == '2D contour':
            return show, hide, show, hide, hide, hide, hide, fits, hide, hide, hide, hide
        if graph_type == '3D':
            return show, show, hide, hide, show, hide, show, fits, hide, hide, hide, show
        else:
            return show, show, show, show, show, show, show, fits, show, show, show, show


    @app.callback(
        Output('graph1', 'figure'),
        [Input('invar', 'value'),
         Input('invar_2', 'value'),
         Input('invar_3', 'value'),
         Input('outvar', 'value'),
         Input('invar-1-log', 'value'),
         Input('invar-2-log', 'value'),
         Input('invar-3-log', 'value'),
         Input('outvar-log', 'value'),
         Input({'type': 'param-slider', 'index': ALL}, 'value'),
         Input('graph-type', 'value'),
         Input('color-use', 'value'),
         Input('color-dropdown', 'value'),
         Input('error-use', 'value'),
         Input('error-dropdown', 'value'),
         Input({'type': 'param-active', 'index': ALL}, 'value'),
         Input('fit-use', 'value'),
         Input('fit-multiinput-dropdown', 'value'),
         Input('fit-number', 'value'),
         Input('fit-conf', 'value'),
         Input('fit-var', 'value'),
         Input('fit-color', 'value'),
         Input('fit-opacity', 'value'),
         Input('fit-sampling', 'value'), ],
        [State({'type': 'param-slider', 'index': ALL}, 'id'),
         State({'type': 'param-center', 'index': ALL}, 'value'),
         State({'type': 'param-log', 'index': ALL}, 'value')],
    )
    def update_figure(invar, invar_2, invar_3, outvar, invar1_log, invar2_log, invar3_log, outvar_log, param_slider,
                      graph_type, color_use, color_dd, error_use, error_dd, filter_active, fit_use, fit_dd, fit_num, fit_conf, add_noise_var, fit_color,
                      fit_opacity, fit_sampling, id_type, param_center, param_log):
        for i in range(len(param_slider)):
            if param_log[i] == ['log']:
                param_slider[i] = [10**val for val in param_slider[i]]
                param_center[i] = 10**param_center[i]
        if invar is None:
            return go.Figure()
        sel_y = np.full((len(outdata),), True)
        dds_value = []
        for iteration, values in enumerate(param_slider):
            dds_value.append(invars[id_type[iteration]['index']])
            # filter for minimum
            sel_y_min = np.array(indata[dds_value[iteration]] >= param_slider[iteration][0])
            # filter for maximum
            sel_y_max = np.array(indata[dds_value[iteration]] <= param_slider[iteration][1])
            # print('iter ', iteration, 'filer', filter_active[iteration][0])
            if filter_active != [[]]:
                if filter_active[iteration] == ['act']:
                    sel_y = sel_y_min & sel_y_max & sel_y
        if graph_type == '1D':
            fig = go.Figure(
                data=[go.Scatter(
                    x=indata[invar][sel_y],
                    y=outdata[outvar][sel_y],
                    mode='markers',
                    name='data',
                    error_y=dict(type='data', array=outdata[error_dd][sel_y], visible= error_use == ['true']),
                    # text=[(invar, outvar) for i in range(len(indata[invar][sel_y]))],
                    # hovertemplate=" %{text} <br> %{x} <br> %{y}",
                )],
                layout=go.Layout(xaxis=dict(title=invar, rangeslider=dict(visible=True)), yaxis=dict(title=outvar))
            )
            if fit_use == ['show']:
                mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                          param_center, [invar], [invar1_log],
                                                                          outvar, fit_sampling, add_noise_var)
                for i in range(len(fit_dd_values)):
                    fig.add_trace(go.Scatter(
                        x=mesh_in[i][invars.index(invar)],
                        y=mesh_out[i],
                        mode='lines',
                        name=f'fit: {fit_dd}={fit_dd_values[i]:.1e}',
                        line_color=colormap(indata[fit_dd].min(), indata[fit_dd].max(), fit_dd_values[i]),
                        marker_line=dict(coloraxis="coloraxis2"),
                    ))
                    fig.add_trace(go.Scatter(
                        x=np.hstack((mesh_in[i][invars.index(invar)], mesh_in[i][invars.index(invar)][::-1])),
                        y=np.hstack((mesh_out[i] + fit_conf * mesh_out_std[i], mesh_out[i][::-1] - fit_conf * mesh_out_std[i][::-1])),
                        showlegend=False,
                        fill='toself',
                        line_color=colormap(indata[fit_dd].min(), indata[fit_dd].max(), fit_dd_values[i]),
                        marker_line=dict(coloraxis="coloraxis2"),
                        opacity=fit_opacity,
                    ))
        elif graph_type == '2D':
            fig = go.Figure(
                data=[go.Scatter3d(
                    x=indata[invar][sel_y],
                    y=indata[invar_2][sel_y],
                    z=outdata[outvar][sel_y],
                    mode='markers',
                    name='Data',
                    error_z=dict(type='data', array=outdata[error_dd][sel_y], visible=error_use == ['true'], width= 10)
                )],
                layout=go.Layout(scene=dict(xaxis_title=invar, yaxis_title=invar_2, zaxis_title=outvar))
            )
            if fit_use == ['show'] and invar != invar_2:
                mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                          param_center, [invar, invar_2],
                                                                          [invar1_log, invar2_log], outvar,
                                                                          fit_sampling, add_noise_var)
                for i in range(len(fit_dd_values)):
                    fig.add_trace(go.Surface(
                        x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)),
                        y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)),
                        z=mesh_out[i].reshape((fit_sampling, fit_sampling)),
                        name=f'fit: {fit_dd}={fit_dd_values[i]:.2f}',
                        surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling])
                            if fit_color == 'multi-fit' else
                            (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling))
                             if (fit_color == 'marker-color' and color_dd in invars) else
                             mesh_out[i].reshape((fit_sampling, fit_sampling))),
                        opacity=fit_opacity,
                        coloraxis="coloraxis2" if (fit_color == 'multi-fit' or
                            (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT'))) else "coloraxis",
                        showlegend=True if len(invars) > 2 else False,
                    ))
                    if fit_conf > 0:
                        fig.add_trace(go.Surface(
                            x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)),
                            y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)),
                            z=mesh_out[i].reshape((fit_sampling, fit_sampling)) + fit_conf * mesh_out_std[i].reshape((fit_sampling, fit_sampling)),
                            showlegend=False,
                            name=f'fit+v: {fit_dd}={fit_dd_values[i]:.2f}',
                            surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling])
                                if fit_color == 'multi-fit' else
                                (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling))
                                 if (fit_color == 'marker-color' and color_dd in invars) else
                                 mesh_out[i].reshape((fit_sampling, fit_sampling))),
                            opacity=fit_opacity,
                            coloraxis="coloraxis2" if (fit_color == 'multi-fit' or
                                (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT')))
                                else "coloraxis",
                        ))
                        fig.add_trace(go.Surface(
                            x=mesh_in[i][invars.index(invar)].reshape((fit_sampling, fit_sampling)),
                            y=mesh_in[i][invars.index(invar_2)].reshape((fit_sampling, fit_sampling)),
                            z=mesh_out[i].reshape((fit_sampling, fit_sampling)) - fit_conf * mesh_out_std[i].reshape((fit_sampling, fit_sampling)),
                            showlegend=False,
                            name=f'fit-v: {fit_dd}={fit_dd_values[i]:.2f}',
                            surfacecolor=fit_dd_values[i] * np.ones([fit_sampling, fit_sampling])
                                if fit_color == 'multi-fit' else
                                (mesh_in[i][invars.index(color_dd)].reshape((fit_sampling, fit_sampling))
                                 if (fit_color == 'marker-color' and color_dd in invars) else
                                 mesh_out[i].reshape((fit_sampling, fit_sampling))),
                            opacity=fit_opacity,
                            coloraxis="coloraxis2" if (fit_color == 'multi-fit' or
                                (fit_color == 'output' and (color_dd != outvar and color_dd != 'OUTPUT')))
                                else "coloraxis",
                        ))
                fig.update_layout(coloraxis2=dict(
                    colorbar=dict(title=outvar if fit_color == 'output' else fit_dd),
                    cmin=min(fit_dd_values) if fit_color == 'multi-fit' else None,
                    cmax=max(fit_dd_values) if fit_color == 'multi-fit' else None,
                ))
        elif graph_type == '2D contour':
            mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                      param_center, [invar, invar_2],
                                                                      [invar1_log, invar2_log], outvar,
                                                                      fit_sampling, add_noise_var)
            data_x = mesh_in[0][invars.index(invar)]
            data_y = mesh_in[0][invars.index(invar_2)]
            fig = go.Figure()
            if min(data_x) != max(data_x):
                if min(data_y) != max(data_y):
                    fig.add_trace(go.Scatter(
                        x=indata[invar][sel_y],
                        y=indata[invar_2][sel_y],
                        mode='markers',
                        name='Data',
                    ))
                    fig.add_trace(go.Contour(
                        x=mesh_in[0][invars.index(invar)],
                        y=mesh_in[0][invars.index(invar_2)],
                        z=mesh_out[0],
                        contours_coloring='heatmap',
                        contours_showlabels=True,
                        coloraxis='coloraxis2',
                        name='fit',
                    ))
                    fig.update_xaxes(
                        range=[log10(min(fig.data[1]['x'])), log10(max(fig.data[1]['x']))] if invar1_log == ['log']
                        else [min(fig.data[1]['x']), max(fig.data[1]['x'])])
                    fig.update_yaxes(
                        range=[log10(min(fig.data[1]['y'])), log10(max(fig.data[1]['y']))] if invar2_log == ['log']
                        else [min(fig.data[1]['y']), max(fig.data[1]['y'])])
                    fig.update_layout(xaxis_title=invar,
                                      yaxis_title=invar_2,
                                      coloraxis2=dict(colorbar=dict(title=outvar),
                                                      colorscale='solar',
                                                      cmin=min(fig.data[1]['z']),
                                                      cmax=max(fig.data[1]['z'])))
                else:
                    fig.update_layout(title="y-data is constant, no contour-plot possible")
            else:
                fig.update_layout(title="x-data is constant, no contour-plot possible")
        elif graph_type == '3D':
            fig = go.Figure(
                data=go.Scatter3d(
                    x=indata[invar][sel_y],
                    y=indata[invar_2][sel_y],
                    z=indata[invar_3][sel_y],
                    mode='markers',
                    marker=dict(
                            color=outdata[outvar][sel_y],
                            coloraxis="coloraxis2",
                        ),
                    name='Data',
                ),
                layout=go.Layout(scene=dict(xaxis_title=invar, yaxis_title=invar_2, zaxis_title=invar_3)),
            )
            fig.update_layout(coloraxis2=dict(
                colorbar=dict(title=outvar),
            ))
            if fit_use == ['show'] and len({invar, invar_2, invar_3}) == 3:
                mesh_in, mesh_out, mesh_out_std, fit_dd_values = mesh_fit(param_slider, id_type, fit_dd, fit_num,
                                                                          param_center, [invar, invar_2, invar_3],
                                                                          [invar1_log, invar2_log, invar3_log], outvar,
                                                                          fit_sampling, add_noise_var)
                for i in range(len(fit_dd_values)):
                    fig.add_trace(
                        go.Isosurface(
                            x=mesh_in[i][invars.index(invar)],
                            y=mesh_in[i][invars.index(invar_2)],
                            z=mesh_in[i][invars.index(invar_3)],
                            value=mesh_out[i],
                            surface_count=fit_num,
                            coloraxis="coloraxis2",
                            isomin=mesh_out[i].min() * 1.1,
                            isomax=mesh_out[i].max() * 0.9,
                            caps=dict(x_show=False, y_show=False, z_show=False),
                            opacity=fit_opacity,
                        ),
                    )
        else:
            fig = go.Figure()
        fig.update_layout(legend=dict(xanchor="left", x=0.01))
        # log scale
        log_dict = {'1D': (invar1_log, outvar_log),
                    '2D': (invar1_log, invar2_log, outvar_log),
                    '2D contour': (invar1_log, invar2_log),
                    '3D': (invar1_log, invar2_log, invar3_log),}
        log_list = ['linear' if log is None or len(log) == 0 else log[0] for log in log_dict[graph_type]]
        log_key = ['xaxis', 'yaxis', 'zaxis']
        comb_dict = dict(zip(log_key, [{'type': log} for log in log_list]))
        if len(log_list) < 3 :
            fig.update_layout(**comb_dict)
        else:
            fig.update_scenes(**comb_dict)
        # color
        if color_use == ['true']:
            if fit_use == ['show'] and (graph_type=='2D' and (fit_color=='multi-fit' and color_dd==fit_dd)):
                fig.update_traces(
                    marker=dict(
                        coloraxis="coloraxis2",
                        color=indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y],
                    ),
                    selector=dict(mode='markers'),
                )
            elif graph_type == '3D':
                fig.update_traces(
                    marker=dict(
                        coloraxis="coloraxis2",
                        color=outdata[outvar][sel_y],
                    ),
                    selector=dict(mode='markers'),
                )
            elif graph_type=='1D':
                fig.update_traces(
                    marker=dict(
                        coloraxis="coloraxis2",
                        color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else
                        (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]),
                    ),
                    selector=dict(mode='markers'),
                )
                if color_dd==fit_dd:
                    fig.update_layout(coloraxis2=dict(colorscale='cividis', colorbar=dict(title=fit_dd)))
                elif color_dd == 'OUTPUT':
                    fig.update_layout(coloraxis2=dict(colorscale='plasma', colorbar=dict(title=outvar)))
                else:
                    fig.update_layout(coloraxis2=dict(colorscale='plasma', colorbar=dict(title=color_dd)))
            elif graph_type =='2D contour':
                fig.update_traces(
                    marker=dict(
                        coloraxis="coloraxis",
                        color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else
                        (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]),
                    ),
                    selector=dict(mode='markers'),
                )
                if color_dd == outvar or color_dd == 'OUTPUT':
                    fig.update_traces(marker_coloraxis="coloraxis2", selector=dict(mode='markers'))
                else:
                    fig.update_layout(coloraxis=dict(colorbar=dict(title=color_dd, x=1.1),
                                                     colorscale='ice'))
            else:
                fig.update_traces(
                    marker=dict(
                        coloraxis="coloraxis",
                        color=outdata[outvar][sel_y] if color_dd == 'OUTPUT' else
                        (indata[color_dd][sel_y] if color_dd in indata.dtype.names else outdata[color_dd][sel_y]),
                    ),
                    selector=dict(mode='markers'),
                )
                fig.update_layout(coloraxis=dict(
                    colorbar=dict(title=outvar if color_dd == 'OUTPUT' else color_dd, x=1.1),
                    colorscale='viridis',
                ))
        fig.update_layout(height=graph_height)
        return fig


    def mesh_fit(param_slider, id_type, fit_dd, fit_num, param_center, invar_list, invar_log_list, outvar, num_samples, add_noise_var):
        try:  # collecting min/max of slider for variable of multifit
            fit_dd_min, fit_dd_max = param_slider[[i['index'] for i in id_type].index(invars.index(fit_dd))]
        except ValueError:
            fit_dd_min = min(indata[fit_dd])
            fit_dd_max = max(indata[fit_dd])

        if fit_num == 1: # generate list of value of variable of multifit
            fit_dd_values = np.array([(fit_dd_max + fit_dd_min) / 2])
        else:
            fit_dd_values = np.linspace(fit_dd_min, fit_dd_max, fit_num)

        for iteration, fit_dd_value in enumerate(fit_dd_values): # iteration for each fit
            # set fit parameter for all invars as center of range
            fit_params = [(max(indata[var_invar]) + min(indata[var_invar])) / 2 for var_invar in invars]
            # for all invars with filter change fit_param to center defined by filter
            flt_ind_list = [] # list of filter indices
            for i, center_values in enumerate(param_center):
                flt_ind_list.append(id_type[i]['index'])
                fit_params[flt_ind_list[i]] = center_values
            # change param of fit-variable
            fit_params[invars.index(fit_dd)] = fit_dd_value
            # change param for axis invars
            for i, ax_in in enumerate(invar_list):
                if invars.index(ax_in) in flt_ind_list:
                    ax_min, ax_max = param_slider[flt_ind_list.index(invars.index(ax_in))]
                else:
                    ax_min = min(indata[ax_in])
                    ax_max = max(indata[ax_in])
                if invar_log_list[i] == ['log']:
                    fit_params[invars.index(ax_in)] = np.logspace(log10(ax_min), log10(ax_max), num_samples)
                else:
                    fit_params[invars.index(ax_in)] = np.linspace(ax_min, ax_max, num_samples)
            grid = np.meshgrid(*fit_params) # generate grid
            x_pred = np.vstack([g.flatten() for g in grid]).T  # extract vector for predict
            fit_data, fit_var = sur.predict(x_pred, add_noise_var == ['add']) # generate fit data and variance
            # generated data
            new_mesh_in = np.array([[grid[invars.index(invar)].flatten() for invar in invars]])
            new_mesh_out = np.array([fit_data[:, outvars.index(outvar)]])
            new_mesh_out_std = np.array([np.sqrt(fit_var[:, 0])])
            # stack data together
            if iteration == 0:
                mesh_in = new_mesh_in
                mesh_out = new_mesh_out
                mesh_out_std = new_mesh_out_std
            else:
                mesh_in = np.vstack((mesh_in, new_mesh_in))
                mesh_out = np.vstack((mesh_out, new_mesh_out))
                mesh_out_std = np.vstack((mesh_out_std, new_mesh_out_std))
        return mesh_in, mesh_out, mesh_out_std, fit_dd_values


    def create_slider(dd_value):
        ind = invars.index(dd_value)
        slider_min = indata[dd_value].min()
        slider_max = indata[dd_value].max()
        step_exponent = -3
        new_slider = dcc.RangeSlider(
            id={'type': 'param-slider', 'index': ind},
            step=10 ** step_exponent,
            min=slider_min,
            max=slider_max,
            value=[slider_min, slider_max],
            marks={slider_min: str(round(slider_min, -step_exponent)),
                   slider_max: str(round(slider_max, -step_exponent))},
        )
        return new_slider


    def colormap(cmin, cmax, c):
        if cmin == cmax:
            c_scal = 0.5
        else:
            c_scal = (c-cmin)/(cmax-cmin)
        return color2hex(colormaps.cividis(c_scal))

    return app
Ejemplo n.º 8
0
def main():
    """
    Main command line interface
    sys.argv is an array whose values are the entered series of command
    (e.g.: sys.argv=['profit','run', '--active-learning', '/home/user/example'])
    """
    """ Get parameters from argv """
    parser = ArgumentParser(
        usage='profit <mode> (base-dir)',
        description=
        "Probabilistic Response Model Fitting with Interactive Tools",
        formatter_class=RawTextHelpFormatter)
    parser.add_argument(
        'mode',  # ToDo: subparsers?
        metavar='mode',
        choices=['run', 'fit', 'ui', 'clean'],
        help='run ... start simulation runs \n'
        'fit ... fit data with Gaussian Process \n'
        'ui ... visualise results \n'
        'clean ... remove run directories and input/output files')
    parser.add_argument(
        'base_dir',
        metavar='base-dir',
        help='path to config file (default: current working directory)',
        default=getcwd(),
        nargs='?')
    args = parser.parse_args()

    print(args)
    """ Instantiate Config class from the given file """
    config_file = safe_path_to_file(args.base_dir, default='profit.yaml')
    config = Config.from_file(config_file)

    sys.path.append(config['base_dir'])

    if args.mode == 'run':
        from tqdm import tqdm
        from profit.pre import get_eval_points, write_input
        from profit.util import save

        runner = Runner.from_config(config['run'], config)

        eval_points = get_eval_points(config)
        write_input(config['files']['input'], eval_points)

        if 'activelearning' in (safe_str(v['kind'])
                                for v in config['input'].values()):
            from profit.fit import ActiveLearning
            from profit.sur.sur import Surrogate
            runner.fill(eval_points)
            if 'active_learning' not in config:
                config['active_learning'] = {}
            ActiveLearning.handle_config(config['active_learning'], config)
            al = ActiveLearning.from_config(runner, config['active_learning'],
                                            config)
            al.run_first()
            al.learn()
            if config['active_learning'].get('save'):
                al.save(config['active_learning']['save'])
        else:
            params_array = [row[0] for row in eval_points]
            runner.spawn_array(tqdm(params_array), blocking=True)

        if config['run']['clean']:
            runner.clean()

        if config['files']['output'].endswith('.txt'):
            data = runner.structured_output_data
            save(config['files']['output'], data.reshape(data.size, 1))
        else:
            save(config['files']['output'], runner.output_data)

    elif args.mode == 'fit':
        from numpy import arange, hstack, meshgrid
        from profit.util import load
        from profit.sur.sur import Surrogate

        sur = Surrogate.from_config(config['fit'], config)

        if not sur.trained:
            x = load(config['files']['input'])
            y = load(config['files']['output'])
            x = hstack([x[key] for key in x.dtype.names])
            y = hstack([y[key] for key in y.dtype.names])

            sur.train(x, y)

        if config['fit'].get('save'):
            sur.save_model(config['fit']['save'])
        if config['fit'].get('plot'):
            try:
                xpred = [
                    arange(minv, maxv, step)
                    for minv, maxv, step in config['fit']['plot'].get('xpred')
                ]
                xpred = hstack(
                    [xi.flatten().reshape(-1, 1) for xi in meshgrid(*xpred)])
            except AttributeError:
                xpred = None
            sur.plot(xpred, independent=config['independent'], show=True)

    elif args.mode == 'ui':
        from profit.ui import init_app
        app = init_app(config)
        app.run_server(debug=True)

    elif args.mode == 'clean':
        from shutil import rmtree
        from os import path, remove
        run_dir = config['run_dir']

        question = "Are you sure you want to remove the run directories in {} " \
                   "and input/output files? (y/N) ".format(config['run_dir'])
        if yes:
            print(question + 'y')
        else:
            answer = input(question)
            if not answer.lower().startswith('y'):
                print('exit...')
                sys.exit()

        for krun in range(config['ntrain']):
            single_run_dir = path.join(run_dir, f'run_{krun:03d}')
            if path.exists(single_run_dir):
                rmtree(single_run_dir)
        if path.exists(config['files']['input']):
            remove(config['files']['input'])
        if path.exists(config['files']['output']):
            remove(config['files']['output'])

        runner = Runner.from_config(config['run'], config)
        runner.clean()
        try:
            rmtree(config['run']['log_path'])
        except FileNotFoundError:
            pass