def main(db_name: str): # Set up logging logger.stream_level = 'info' logger.log_dir = db_name.replace('.csv', '') + '_logs' logger.file_level = 'debug' # Split database proportionally based on property value # Proportions are 70% learn, 20% validate, 10% test prop_range_from_split(db_name, [0.7, 0.2, 0.1]) # Find the optimal number of input variables # Train (learn + valid) set used for evaluation n_desc = len(find_optimal_num_inputs(db_name, 'train', _NUM_PROC)[1]) logger.log('info', 'Optimal number of input variables: {}'.format(n_desc)) # Create server object with base config sv = Server(model_config=db_name.replace('.csv', '.yml'), num_processes=_NUM_PROC) # Load data sv.load_data(db_name) # Limit input variables to `n_desc` using Train set # Outputs to relevant database name sv.limit_inputs( n_desc, eval_set='train', output_filename=db_name.replace('.csv', '.{}.csv'.format(n_desc)) ) # Tune hyperparameters (architecture and ADAM) # 20 employer bees, 10 search cycles # Evaluation of solutions based on validation set median absolute error sv.tune_hyperparameters(20, 10, eval_set='valid', eval_fn='med_abs_error') # Create an ECNet project (saved and recalled later) # 5 pools with 75 trials/pool, best ANNs selected from each pool sv.create_project(db_name.replace('.csv', ''), 5, 75) # Train project # Select best candidates based on validation set median absolute error sv.train(validate=True, selection_set='valid', selection_fn='med_abs_error') # Obtain learning, validation, testing set median absolute error, r-squared err_l = sv.errors('med_abs_error', 'r2', dset='learn') err_v = sv.errors('med_abs_error', 'r2', dset='valid') err_t = sv.errors('med_abs_error', 'r2', dset='test') logger.log('info', 'Learning set performance: {}'.format(err_l)) logger.log('info', 'Validation set performance: {}'.format(err_v)) logger.log('info', 'Testing set performance: {}'.format(err_t)) # Save the project, creating a .prj file and removing un-chosen candidates sv.save_project(del_candidates=True)
def test_server_limit(self): print('\nUNIT TEST: limit_rforest (Server)') sv = Server() sv.load_data(DB_LOC) sv.limit_inputs(2, output_filename='cn_limited.csv') self.assertEqual(len(sv._df._input_names), 2) self.assertEqual(len(sv._sets.learn_x[0]), 2) sv.load_data('cn_limited.csv') self.assertEqual(len(sv._df._input_names), 2) self.assertEqual(len(sv._sets.learn_x[0]), 2) remove('cn_limited.csv') remove('config.yml')
def limit(num_processes, output_filename=None): logger.stream_level = 'info' sv = Server(num_processes=num_processes) sv.load_data('cn_model_v1.0.csv') sv.limit_inputs(3, output_filename=output_filename)
def main(): logger.stream_level = 'debug' sv = Server(num_processes=4) sv.load_data('../kv_model_v1.0_full.csv') sv.limit_inputs(15, output_filename='../kv_model_v1.0.csv')