def test_generate_params():
    '''
    Sets up uniform generation of data points and csv output.
    '''

    n_samples = 1000
    seed = 1

    np.random.seed(seed)

    domain = Domain()
    sampling_strategy = UniformSamplingStrategy()
    df = domain.gen_data_frame(sampling_strategy, n_samples)
    df.to_csv('params/1000params.csv', index=False)
def generate_params():
    n_batches = 500
    n_samples_per_batch = 1000
    seed = 2

    print('Seed is %d' % seed)
    print('Each batch will contain %d samples' % n_samples_per_batch)

    np.random.seed(seed)

    domain = Domain()
    sampling_strategy = UniformSamplingStrategy()

    for batch_idx in range(n_batches):
        df = domain.gen_data_frame(sampling_strategy, n_samples_per_batch)
        df.to_csv('output/run1/batch%d_in.csv' % batch_idx, index=False)

        if batch_idx % 10 == 0:
            print('Generated batch %d of %d' % (batch_idx + 1, n_batches))
def test_samplerun():
    '''
    Sets up uniform generation of data points and csv output.
    '''

    n_samples = 5

    run1 = Samplerun()
    run1.perform_sample(out_file='100uniform.csv',
                        n_samples=n_samples,
                        domain=Domain(),
                        sampling_strategy=UniformSamplingStrategy())
Exemple #4
0
    def __init__(self, parent=None):
        super(Window, self).__init__(parent)
        self.domain = Domain()
        self.surrogate_model = None

        self.x_param = None
        self.x_param_name = None
        self.x_granularity = None

        self.y_param = None
        self.y_param_name = None
        self.y_granularity = None

        self.tbr_params = None
        self.tbr_true = None
        self.tbr_err = None

        self.tbr_worker = None
        self.tbr_worker_thread = None

        self.init_layout()
def test_runfromfile():
    '''
    Sets up uniform generation of data points and csv output.
    '''

    n_samples = 100

    run1 = Samplerun()
    run1.perform_sample(
        out_file='100fix0000000out.csv',
        n_samples=n_samples,
        param_values=pd.read_csv("params/100params0000000.csv"),
        domain=Domain(),
        sampling_strategy=UniformSamplingStrategy())
Exemple #6
0
def main():
    '''
    Perform quality-adaptive sampling algorithm
    '''
    
    # Parse inputs and store in relevant variables.
    args = input_parse()
    
    init_samples = args.init_samples
    step_samples = args.step_samples
    step_candidates = args.step_candidates
    d_params = disctrans(args.disc_fix)
    
    # Collect surrogate model type and theory under study.
    thismodel = get_model_factory()[args.model](cli_args=sys.argv[7:])
    thistheory = globals()["theory_" + args.theory]
    
    
    domain = Domain()

    if args.saved_init:
        # load data as initial evaluated samples
        df = load_batches(args.saved_init, (0, 1 + int(init_samples/1000)))
        X_init, d, y_multiple = c_d_y_split(df.iloc[0:init_samples])
        d_params = d.values[0]
        print(d.values[0][0])
        y_init = y_multiple['tbr']
        
    domain.fix_param(domain.params[1], d_params[0])
    domain.fix_param(domain.params[2], d_params[1])
    domain.fix_param(domain.params[3], d_params[2])
    domain.fix_param(domain.params[5], d_params[3])
    domain.fix_param(domain.params[6], d_params[4])
    domain.fix_param(domain.params[7], d_params[5])
    domain.fix_param(domain.params[8], d_params[6])
    
    if not args.saved_init:
        # generate initial parameters
        sampling_strategy = UniformSamplingStrategy()
        c = domain.gen_data_frame(sampling_strategy, init_samples)
        print(c.columns)
        # evaluate initial parameters in given theory
        print("Evaluating initial " + str(init_samples) + " samples in " + args.theory + " theory.")
        output = thistheory(params = c, domain = domain, n_samples = init_samples)
        X_init, d, y_multiple = c_d_y_split(output)
        y_init = y_multiple['tbr']
        current_samples, current_tbr = X_init, y_init
    
    
    # MAIN QASS LOOP
    
    complete_condition = False
    iter_count = 0
    
    err_target = 0.0001
    max_iter_count = 70
    
    all_metrics = pd.DataFrame()
    
    current_samples = current_samples.sort_index(axis=1)
    
    print(f'Features in order are: {list(current_samples.columns)}')
    
    X_train, X_test, y_train, y_test = train_test_split(current_samples, current_tbr, 
                                           test_size=0.5, random_state=1)
                                           
    thismodel.enable_renormalization(100)
     
        
    while not complete_condition:
        iter_count += 1
        samp_size = X_train.shape[0] * 2
        print("Iteration " + str(iter_count) + " -- Total samples: " + str(samp_size))
        
        # Train surrogate for theory, and plot results
                
        if iter_count == 1:                           
            new_samples, new_tbr = X_train, y_train
        train(thismodel, new_samples, new_tbr)
        test(thismodel, X_test, y_test)
        
        plot("qassplot", thismodel, X_test, y_test)
        this_metrics = get_metrics(thismodel, X_test, y_test)
        this_metrics['numdata'] = samp_size
        print(this_metrics)
        
        
        # Calculate error data for this training iteration
        
        y_train_pred = thismodel.predict(X_train.to_numpy())
        y_test_pred = thismodel.predict(X_test.to_numpy())
        
        train_err = abs(y_train - y_train_pred)
        test_err = abs(y_test - y_test_pred)
       
        
        
        # Train neural network surrogate for error function (Failed)
        
        X_test = X_test.sort_index(axis=1)
        
        X_test1, X_test2, test_err1, test_err2 = train_test_split(X_test, test_err, 
                                               test_size=0.5, random_state=1)
            
            #errmodel = get_model_factory()["nn"](cli_args=["--epochs=50", "--batch-size=200"
                                                             # ,"--arch-type=4F_512"])
            #errmodel = get_model_factory()["rbf"](cli_args=["--d0=20"])
                                               
            #scaled_X_test1, scaled_test_err1 = errmodel.scale_training_set(X_test1, test_err1)
            #scaled_X_test2, scaled_test_err2 = errmodel.scale_testing_set(X_test2, test_err2)
            #dtest1 = pd.DataFrame(scaled_X_test1, columns = X_test1.columns,
                                                #  index = X_test1.index)
            #dtest2 = pd.DataFrame(scaled_X_test2, columns = X_test2.columns,
                                                #  index = X_test2.index)
            #derr1 = pd.Series(scaled_test_err1, index = test_err1.index)
            #derr2 = pd.Series(scaled_test_err2, index = test_err2.index)
            
            #print(type(test_err1))
            #print(type(scaled_test_err1))
            #train(errmodel, dtest1, derr1)
            #test(errmodel, dtest2, derr2)
            #print(X_test1)
            #print(scaled_X_test1)
            #print(dtest1)
            
            #plot("qassplot3", errmodel, dtest2, derr2) 
            
            
                                               
                #tri = Delaunay(X_test1.values, qhull_options="Qc QbB Qx Qz")                 
                #f = interpolate.LinearNDInterpolator(tri, test_err1.values)
                
                 
        # Test surrogate (nearest neighbor interpolator) on split error data        
                                 
        errordist_test = interpolate.NearestNDInterpolator(X_test1.values, test_err1.values)
        pred_err1 = errordist_test(X_test1.values)    
        pred_err2 = errordist_test(X_test2.values)
        
        # Train surrogate (nearest neighbor interpolator) for error function
        
        errordist = interpolate.NearestNDInterpolator(X_test.values, test_err.values)
        pred_err = errordist(X_test.values)
        
        max_err = max(test_err.values)
        print('Max error: ' + str(max_err))
        this_metrics['maxerr'] = max_err
        
        plot_results("qassplot2", pred_err1, test_err1)
        plt.figure()
        plot_results("qassplot3", pred_err2, test_err2) 
        
        plt.figure()
        plt.hist(test_err.values, bins=100)
        plt.savefig("qassplot4.pdf", format="pdf")   
        
        
        
        # Perform MCMC on error function
        
        saveinterval = 1
        nburn = 1000
        nrun = 10000
        
        initial_sample = X_train.iloc[0]
        #print(initial_sample.values)
        #print(errordist(initial_sample.values))
        burnin_sample, burnin_dist, burnin_acceptance = burnin_MH(errordist, initial_sample.values, nburn)
        saved_samples, saved_dists, run_acceptance = run_MH(errordist, burnin_sample, nrun, saveinterval)
        
        plt.figure()
        plt.hist(saved_dists, bins=100)
        plt.savefig("qassplot5.pdf", format="pdf") 
        
        print('MCMC run finished.')
        print('Burn-In Acceptance: ' + str(burnin_acceptance))
        print('Run Acceptance: ' + str(run_acceptance))
        this_metrics['burn_acc'] = burnin_acceptance
        this_metrics['run_acc'] = run_acceptance
        
                
        # Extract candidate samples from MCMC output and calculate mutual crowding distance
        
        cand_cdms = []
        print(saved_samples.shape)
        samplestep = int(saved_samples.shape[0] / step_candidates)
        print(samplestep)
        candidates = saved_samples[::samplestep]

        for candidate in candidates:
            cand_cdms.append( cdm(candidate,candidates) )

        # Rank candidate samples by error value, and filter out crowded samples
        
        new_samples = pd.DataFrame(candidates, columns = current_samples.columns)
        new_samples['error'] = saved_dists[::samplestep]
        new_samples['cdm'] = cand_cdms 
        
        #print(new_samples)
        #print(new_samples.shape)
            
        new_samples = new_samples.sort_values(by='error', ascending=False)

        indexNames = new_samples[ new_samples['cdm'] <= new_samples['cdm'].median() ].index
        new_samples.drop(indexNames , inplace=True)
        
        new_samples.drop(columns=['error', 'cdm'])
        new_samples = new_samples.head(step_samples).reset_index()
        
        
        # Add new samples and corresponding TBR evaluations to current sample set
        
        new_output = thistheory(params = new_samples.join(pd.concat([d.head(1)]*step_samples, ignore_index=True)), domain = domain, n_samples = step_samples)
        new_samples, new_d, new_y_multiple = c_d_y_split(new_output)
        new_tbr = new_y_multiple['tbr']
        
        #print(new_samples) 
        
        new_samples = new_samples.sort_index(axis=1)
        
        #new_X_train, new_X_test, new_y_train, new_y_test = train_test_split(new_samples, new_tbr,test_size=0.5, random_state=1)

        X_train = pd.concat([X_train, new_samples], ignore_index=True)
        #X_test = pd.concat([X_test, new_X_test], ignore_index=True)
        y_train = pd.concat([y_train, new_tbr], ignore_index=True)
        #y_test = pd.concat([y_test, new_y_test], ignore_index=True)
    
        # Check completion conditions and close loop
    
        if max_err < err_target or iter_count > max_iter_count:
            complete_condition = True
        
        all_metrics = pd.concat([all_metrics,this_metrics], ignore_index=True)
        print(all_metrics)
        all_metrics.to_csv('qassmetrics.csv')


    print('QASS finished.')
def test_gp_dfixed():
    '''
    Sets up uniform generation with all discrete parameters fixed to set values.
    '''

    n_samples = 100
    seed = 2

    np.random.seed(seed)

    domain = Domain()
    sampling_strategy = UniformSamplingStrategy()

    domain.fix_param(domain.params[1], 'tungsten')
    domain.fix_param(domain.params[2], 'SiC')
    domain.fix_param(domain.params[3], 'H2O')
    domain.fix_param(domain.params[5], 'SiC')
    domain.fix_param(domain.params[6], 'Li4SiO4')
    domain.fix_param(domain.params[7], 'Be')
    domain.fix_param(domain.params[8], 'H20')

    df = domain.gen_data_frame(sampling_strategy, n_samples)
    df.to_csv('params/100params0000000.csv', index=False)
def generate_params_fix_disc():
    n_batches = 100
    n_samples_per_batch = 1000
    seed = 2

    print('Seed is %d' % seed)
    print('Each batch will contain %d samples' % n_samples_per_batch)

    np.random.seed(seed)

    domain = Domain()
    sampling_strategy = UniformSamplingStrategy()

    discrete_params = [
        {
            'firstwall_armour_material': 'tungsten',
            'firstwall_structural_material': 'eurofer',
            'firstwall_coolant_material': 'H2O',
            'blanket_coolant_material': 'H2O',
            'blanket_multiplier_material': 'Be12Ti',
            'blanket_breeder_material': 'Li4SiO4',
            'blanket_structural_material': 'eurofer',
        },
        {
            'firstwall_armour_material': 'tungsten',
            'firstwall_structural_material': 'eurofer',
            'firstwall_coolant_material': 'H2O',
            'blanket_coolant_material': 'He',
            'blanket_multiplier_material': 'Be12Ti',
            'blanket_breeder_material': 'Li4SiO4',
            'blanket_structural_material': 'eurofer',
        },
        {
            'firstwall_armour_material': 'tungsten',
            'firstwall_structural_material': 'eurofer',
            'firstwall_coolant_material': 'He',
            'blanket_coolant_material': 'H2O',
            'blanket_multiplier_material': 'Be12Ti',
            'blanket_breeder_material': 'Li4SiO4',
            'blanket_structural_material': 'eurofer',
        },
        {
            'firstwall_armour_material': 'tungsten',
            'firstwall_structural_material': 'eurofer',
            'firstwall_coolant_material': 'He',
            'blanket_coolant_material': 'He',
            'blanket_multiplier_material': 'Be12Ti',
            'blanket_breeder_material': 'Li4SiO4',
            'blanket_structural_material': 'eurofer',
        }
    ]

    df = domain.gen_data_frame(
        sampling_strategy, n_batches * n_samples_per_batch)

    save_idx = 0

    for param_set in discrete_params:
        for param, param_value in param_set.items():
            df[param] = param_value

        for batch_idx in range(n_batches):
            offset = batch_idx * n_samples_per_batch
            subdf = df.iloc[offset:(offset+n_samples_per_batch)]
            subdf.to_csv('output/run2/batch%d_in.csv' % save_idx, index=False)

            if save_idx % 10 == 0:
                print('Generated batch %d of %d' %
                      (save_idx + 1, len(discrete_params) * n_batches))

            save_idx += 1
def main():
    '''
    Perform FAKE quality-adaptive sampling algorithm
    '''

    # Parse inputs and store in relevant variables.
    args = input_parse()

    init_samples = args.init_samples
    step_samples = args.step_samples
    step_candidates = args.step_candidates
    eval_samples = args.eval_samples
    retrain = args.retrain
    d_params = disctrans(args.disc_fix)

    # Collect surrogate model type and theory under study.
    thismodel = get_model_factory()[args.model](cli_args=sys.argv[9:])
    thistheory = globals()["theory_" + args.theory]

    domain = Domain()

    if args.saved_init:
        # load data as initial evaluated samples
        df = load_batches(args.saved_init, (0, 1 + int(init_samples / 1000)))
        X_init, d, y_multiple = c_d_y_split(df.iloc[0:init_samples])
        d_params = d.values[0]
        print(d.values[0][0])
        y_init = y_multiple['tbr']

    domain.fix_param(domain.params[1], d_params[0])
    domain.fix_param(domain.params[2], d_params[1])
    domain.fix_param(domain.params[3], d_params[2])
    domain.fix_param(domain.params[5], d_params[3])
    domain.fix_param(domain.params[6], d_params[4])
    domain.fix_param(domain.params[7], d_params[5])
    domain.fix_param(domain.params[8], d_params[6])

    if not args.saved_init:
        # generate initial parameters
        sampling_strategy = UniformSamplingStrategy()
        c = domain.gen_data_frame(sampling_strategy, init_samples)
        print(c.columns)
        # evaluate initial parameters in given theory
        print("Evaluating initial " + str(init_samples) + " samples in " +
              args.theory + " theory.")
        output = thistheory(params=c, domain=domain, n_samples=init_samples)
        X_init, d, y_multiple = c_d_y_split(output)
        y_init = y_multiple['tbr']
        current_samples, current_tbr = X_init, y_init

    # MAIN QASS LOOP

    complete_condition = False
    iter_count = 0
    trigger_retrain = False
    similarity = 0

    err_target = 0.0001
    max_iter_count = 10000

    all_metrics = pd.DataFrame()

    while not complete_condition:
        iter_count += 1
        samp_size = current_samples.shape[0]
        print("Iteration " + str(iter_count) + " -- Total samples: " +
              str(samp_size))

        # Train surrogate for theory, and plot results

        X_train, X_test, y_train, y_test = train_test_split(
            current_samples, current_tbr, test_size=0.5,
            random_state=1)  #play with this

        # Goldilocks retraining scheme

        if iter_count > 1:
            alt_scaler = thismodel.create_scaler()
            Xy_in = thismodel.join_sets(X_train, y_train)
            alt_scaler.fit(Xy_in)
            similarity = thismodel.scaler_similarity(thismodel.scaler,
                                                     alt_scaler)
            if iter_count % 10000 == 0:  #restart with new random weights
                #thismodel = get_model_factory()[args.model](cli_args=sys.argv[8:])
                thismodel.scaler = alt_scaler

        train(thismodel, X_train, y_train)
        test(thismodel, X_test, y_test)

        plot("qassplot", thismodel, X_test, y_test)
        this_metrics = get_metrics(thismodel, X_test, y_test)
        this_metrics['numdata'] = samp_size
        this_metrics['similarity'] = similarity
        print(this_metrics)

        # True evaluation test on uniform random data

        evaltest_samples = domain.gen_data_frame(sampling_strategy,
                                                 eval_samples)

        eval_output = thistheory(params=evaltest_samples,
                                 domain=domain,
                                 n_samples=eval_samples)
        evaltest_samples, evaltest_d, evaltest_y_multiple = c_d_y_split(
            eval_output)
        evaltest_tbr = evaltest_y_multiple['tbr']

        test(thismodel, evaltest_samples, evaltest_tbr)
        plot("qassplot2", thismodel, evaltest_samples, evaltest_tbr)
        eval_metrics = get_metrics(thismodel, evaltest_samples, evaltest_tbr)
        print(eval_samples)

        this_metrics['E_MAE'] = eval_metrics['MAE']
        this_metrics['E_S'] = eval_metrics['S']
        this_metrics['E_R2'] = eval_metrics['R2']
        this_metrics['E_R2(adj)'] = eval_metrics['R2(adj)']

        # Generate uniform random new samples

        new_samples = domain.gen_data_frame(sampling_strategy, step_samples)

        new_output = thistheory(params=new_samples,
                                domain=domain,
                                n_samples=step_samples)
        new_samples, new_d, new_y_multiple = c_d_y_split(new_output)
        new_tbr = new_y_multiple['tbr']

        current_samples = pd.concat([current_samples, new_samples],
                                    ignore_index=True)
        current_tbr = pd.concat([current_tbr, new_tbr], ignore_index=True)

        # Check completion conditions and close loop

        if iter_count > max_iter_count:
            complete_condition = True

        all_metrics = pd.concat([all_metrics, this_metrics], ignore_index=True)
        print(all_metrics)
        all_metrics.to_csv('qassfakemetrics.csv')

    print('FAKE QASS finished.')
Exemple #10
0
class Window(QDialog):
    def __init__(self, parent=None):
        super(Window, self).__init__(parent)
        self.domain = Domain()
        self.surrogate_model = None

        self.x_param = None
        self.x_param_name = None
        self.x_granularity = None

        self.y_param = None
        self.y_param_name = None
        self.y_granularity = None

        self.tbr_params = None
        self.tbr_true = None
        self.tbr_err = None

        self.tbr_worker = None
        self.tbr_worker_thread = None

        self.init_layout()

    def init_layout(self):
        layout = QGridLayout()

        self.model_fig, self.model_canv, self.model_tool = self.init_fig()
        layout.addWidget(self.model_tool, 0, 0)
        layout.addWidget(self.model_canv, 1, 0)

        self.err_fig, self.err_canv, self.err_tool = self.init_fig()
        layout.addWidget(self.err_tool, 2, 0)
        layout.addWidget(self.err_canv, 3, 0, 4, 1)

        self.true_fig, self.true_canv, self.true_tool = self.init_fig()
        layout.addWidget(self.true_tool, 0, 1, 1, 3)
        layout.addWidget(self.true_canv, 1, 1, 1, 3)

        self.param_table = ParamWidget(self.domain)
        layout.addWidget(self.param_table, 2, 1, 2, 3)

        self.x_granularity_box = QLineEdit('5')
        layout.addWidget(self.x_granularity_box, 4, 1)

        self.y_granularity_box = QLineEdit('5')
        layout.addWidget(self.y_granularity_box, 4, 2)

        self.generate_button = QPushButton('Generate lattice')
        self.generate_button.clicked.connect(self.generate_lattice)
        layout.addWidget(self.generate_button, 4, 3)

        self.randomize_button = QPushButton('Randomize params')
        self.randomize_button.clicked.connect(self.randomize_params)
        layout.addWidget(self.randomize_button, 5, 1)

        self.load_model_button = QPushButton('Model: None')
        self.load_model_button.clicked.connect(self.load_model)
        layout.addWidget(self.load_model_button, 5, 2)

        self.query_tbr_button = QPushButton('Query true TBR')
        self.query_tbr_button.clicked.connect(self.query_tbr)
        layout.addWidget(self.query_tbr_button, 6, 1)

        self.query_tbr_progress = QProgressBar()
        layout.addWidget(self.query_tbr_progress, 6, 2, 1, 2)

        self.setLayout(layout)

    def init_fig(self):
        figure = plt.figure()
        canvas = FigureCanvas(figure)
        toolbar = NavigationToolbar(canvas, self)
        return figure, canvas, toolbar

    def load_model(self):
        filename, _ = QFileDialog.getOpenFileName(self, 'Load model', None,
                                                  None)
        loaded_model_name, loaded_model = load_model_from_file(filename)

        if loaded_model is None:
            return

        self.surrogate_model = loaded_model
        self.load_model_button.setText('Model: %s' % loaded_model_name)

        self.evaluate_model()
        self.plot_model()

    def randomize_params(self):
        sampling_strategy = UniformSamplingStrategy()
        df = self.domain.gen_data_frame(sampling_strategy, 1)
        for column in df.columns:
            self.param_table.set_param(column, df.at[0, column])

        self.generate_lattice()

    def generate_lattice(self):
        df = pd.DataFrame(data={
            key: [value]
            for key, value in self.param_table.get_params().items()
        })

        self.x_granularity = int(self.x_granularity_box.text())
        self.x_param_name = self.param_table.find_param_with_selection('x')
        print(f'X is {self.x_param_name}')

        self.y_granularity = int(self.y_granularity_box.text())
        self.y_param_name = self.param_table.find_param_with_selection('y')
        print(f'Y is {self.y_param_name}')

        if self.x_param_name is None or self.y_param_name is None:
            print('ERROR: missing X or Y axis!')
            return

        self.x_param = first([
            param for param in self.domain.params
            if param.name == self.x_param_name
            and isinstance(param, ContinuousParameter)
        ])
        self.y_param = first([
            param for param in self.domain.params
            if param.name == self.y_param_name
            and isinstance(param, ContinuousParameter)
        ])

        if self.x_param is None or self.y_param is None:
            print('ERROR: X or Y axis is not continuous!')
            return

        n_points = self.x_granularity * self.y_granularity
        df = pd.concat([df] * n_points)

        x_linspace = np.linspace(start=self.x_param.val[0],
                                 stop=self.x_param.val[1],
                                 num=self.x_granularity,
                                 endpoint=True)
        y_linspace = np.linspace(start=self.y_param.val[0],
                                 stop=self.y_param.val[1],
                                 num=self.y_granularity,
                                 endpoint=True)

        df[self.x_param_name] = np.c_[[x_linspace] *
                                      self.y_granularity].ravel('C')
        df[self.y_param_name] = np.c_[[y_linspace] *
                                      self.x_granularity].ravel('F')
        self.tbr_params = df.reset_index(drop=True)

        self.tbr_true = None
        self.tbr_err = None

        self.tbr_surrogate = None
        self.evaluate_model()

        self.plot_model()
        self.plot_true()
        self.plot_err()

    def evaluate_model(self):
        if self.tbr_params is None or self.surrogate_model is None:
            self.tbr_surrogate = None
            return

        # preprocess data
        X = encode_data_frame(self.tbr_params, self.domain)
        X = X.sort_index(axis=1)

        # make predictions
        self.tbr_surrogate = self.tbr_params.copy()
        self.tbr_surrogate.insert(0, 'tbr_pred', -1.)
        self.tbr_surrogate['tbr_pred'] = self.surrogate_model.predict(X)

    def query_tbr(self):
        if self.tbr_params is None:
            print('ERROR: TBR queried with no sampled parameters')
            return

        self.query_tbr_button.setEnabled(False)
        self.query_tbr_progress.setValue(0)
        self.query_tbr_progress.setMinimum(0)
        self.query_tbr_progress.setMaximum(self.tbr_params.shape[0])

        class Worker(QObject):
            finished = pyqtSignal()
            progress = pyqtSignal(int, int)
            samples_available = pyqtSignal(pd.DataFrame)

            def __init__(self,
                         tbr_params,
                         x_param_name,
                         y_param_name,
                         parent=None):
                super(Worker, self).__init__(parent)
                self.tbr_params = tbr_params
                self.x_param_name = x_param_name
                self.y_param_name = y_param_name

            def progress_handler(self, i, n_samples):
                self.progress.emit(i, n_samples)

            @pyqtSlot()
            def query_tbr(self):
                run = Samplerun(no_docker=True)
                sampled = run.perform_sample(
                    out_file=None,
                    param_values=self.tbr_params,
                    progress_handler=self.progress_handler)

                param_names = [self.x_param_name, self.y_param_name]
                sampled = pd.merge(self.tbr_params,
                                   sampled,
                                   how='left',
                                   left_on=param_names,
                                   right_on=param_names,
                                   suffixes=('', '_dup_'))
                sampled.drop([
                    column
                    for column in sampled.columns if column.endswith('_dup_')
                ],
                             axis=1,
                             inplace=True)
                self.samples_available.emit(sampled)

                self.finished.emit()

        self.tbr_worker = Worker(self.tbr_params, self.x_param_name,
                                 self.y_param_name)
        self.tbr_worker_thread = QThread()

        self.tbr_worker.moveToThread(self.tbr_worker_thread)
        self.tbr_worker.finished.connect(self.tbr_worker_thread.quit)
        self.tbr_worker.samples_available.connect(
            self.tbr_worker_samples_available)
        self.tbr_worker.progress.connect(self.tbr_worker_progress)
        self.tbr_worker_thread.started.connect(self.tbr_worker.query_tbr)
        self.tbr_worker_thread.finished.connect(self.tbr_worker_finished)

        self.tbr_worker_thread.start()

    @pyqtSlot(int, int)
    def tbr_worker_progress(self, i, n_samples):
        self.query_tbr_progress.setValue(i + 1)

    @pyqtSlot(pd.DataFrame)
    def tbr_worker_samples_available(self, sampled):
        self.tbr_true = sampled
        self.plot_true()

        err = sampled.copy()
        err.insert(0, 'tbr_surrogate_err', 0.01)

        param_names = [self.x_param_name, self.y_param_name]
        err = pd.merge(err,
                       self.tbr_surrogate,
                       how='left',
                       left_on=param_names,
                       right_on=param_names,
                       suffixes=('', '_dup_'))
        err.drop(
            [column for column in err.columns if column.endswith('_dup_')],
            axis=1,
            inplace=True)
        err['tbr_surrogate_err'] = err.apply(
            lambda row: np.abs(row.tbr - row.tbr_pred), axis=1)

        self.tbr_err = err
        self.plot_err()

    @pyqtSlot()
    def tbr_worker_finished(self):
        del self.tbr_worker
        del self.tbr_worker_thread
        self.query_tbr_button.setEnabled(True)

    def plot_domain(self, fig, canv, z_data, z_label, symmetrical=True):
        fig.clear()
        if self.tbr_params is None:
            return

        fig.set_tight_layout(True)
        ax1 = fig.add_subplot(111)

        cmap = 'RdBu' if symmetrical else 'viridis'

        x_data = self.tbr_params[self.x_param_name].to_numpy().reshape(
            self.y_granularity, self.x_granularity)
        y_data = self.tbr_params[self.y_param_name].to_numpy().reshape(
            self.y_granularity, self.x_granularity)

        pl1 = None
        vmin, vmax = None, None
        if z_data is not None:
            vmin, vmax = np.min(z_data), np.max(z_data)

            if symmetrical:
                if np.abs(vmin - 1) > np.abs(vmax - 1):
                    vmax = 1 + np.abs(vmin - 1)
                else:
                    vmin = 1 - np.abs(vmax - 1)

            z_data = z_data.reshape(self.y_granularity, self.x_granularity)
            pl1 = ax1.contourf(x_data,
                               y_data,
                               z_data,
                               cmap=cmap,
                               vmin=vmin,
                               vmax=vmax)

        ax1.scatter(x_data,
                    y_data,
                    marker='o',
                    c='k',
                    s=12,
                    linewidths=0.8,
                    edgecolors='w')
        ax1.set_xlabel(
            first([
                param.human_readable_name for param in self.domain.params
                if param.name == self.x_param_name
            ]))
        ax1.set_ylabel(
            first([
                param.human_readable_name for param in self.domain.params
                if param.name == self.y_param_name
            ]))

        if pl1 is not None:
            n_steps = 12
            cb1 = fig.colorbar(pl1,
                               orientation='vertical',
                               label=z_label,
                               ax=ax1,
                               boundaries=np.linspace(vmin, vmax, n_steps))
            cb1.ax.locator_params(nbins=n_steps)

        canv.draw()

    def plot_model(self):
        surrogate_data = self.tbr_surrogate['tbr_pred'].to_numpy() \
            if self.tbr_surrogate is not None else None
        self.plot_domain(self.model_fig, self.model_canv, surrogate_data,
                         'Surrogate TBR')

    def plot_true(self):
        true_data = self.tbr_true['tbr'].to_numpy() \
            if self.tbr_true is not None else None
        self.plot_domain(self.true_fig, self.true_canv, true_data, 'True TBR')

    def plot_err(self):
        err_data = self.tbr_err['tbr_surrogate_err'].to_numpy() \
            if self.tbr_err is not None else None
        self.plot_domain(self.err_fig,
                         self.err_canv,
                         err_data,
                         'Approximation error',
                         symmetrical=False)