def main(): logger.stream_level = 'debug' sv = Server(num_processes=4) sv.load_data('../kv_model_v1.0.csv') sv.tune_hyperparameters(20, 20, shuffle='train', split=[0.7, 0.2, 0.1], eval_set='test')
def main(db_name: str): # Set up logging logger.stream_level = 'info' logger.log_dir = db_name.replace('.csv', '') + '_logs' logger.file_level = 'debug' # Split database proportionally based on property value # Proportions are 70% learn, 20% validate, 10% test prop_range_from_split(db_name, [0.7, 0.2, 0.1]) # Find the optimal number of input variables # Train (learn + valid) set used for evaluation n_desc = len(find_optimal_num_inputs(db_name, 'train', _NUM_PROC)[1]) logger.log('info', 'Optimal number of input variables: {}'.format(n_desc)) # Create server object with base config sv = Server(model_config=db_name.replace('.csv', '.yml'), num_processes=_NUM_PROC) # Load data sv.load_data(db_name) # Limit input variables to `n_desc` using Train set # Outputs to relevant database name sv.limit_inputs( n_desc, eval_set='train', output_filename=db_name.replace('.csv', '.{}.csv'.format(n_desc)) ) # Tune hyperparameters (architecture and ADAM) # 20 employer bees, 10 search cycles # Evaluation of solutions based on validation set median absolute error sv.tune_hyperparameters(20, 10, eval_set='valid', eval_fn='med_abs_error') # Create an ECNet project (saved and recalled later) # 5 pools with 75 trials/pool, best ANNs selected from each pool sv.create_project(db_name.replace('.csv', ''), 5, 75) # Train project # Select best candidates based on validation set median absolute error sv.train(validate=True, selection_set='valid', selection_fn='med_abs_error') # Obtain learning, validation, testing set median absolute error, r-squared err_l = sv.errors('med_abs_error', 'r2', dset='learn') err_v = sv.errors('med_abs_error', 'r2', dset='valid') err_t = sv.errors('med_abs_error', 'r2', dset='test') logger.log('info', 'Learning set performance: {}'.format(err_l)) logger.log('info', 'Validation set performance: {}'.format(err_v)) logger.log('info', 'Testing set performance: {}'.format(err_t)) # Save the project, creating a .prj file and removing un-chosen candidates sv.save_project(del_candidates=True)
def tune(num_processes, shuffle=None, split=[0.7, 0.2, 0.1], validate=True, eval_set=None, eval_fn='rmse'): logger.stream_level = 'debug' sv = Server(num_processes=num_processes) sv.load_data('cn_model_v1.0.csv', random=True, split=split) sv.tune_hyperparameters(2, 2, shuffle=shuffle, split=split, validate=validate, eval_set=eval_set, eval_fn=eval_fn)