def load_model( dir: str, model_prefix: str, model_ext: str, model_cls: WCEstModel, logger: Logger = get_basic_logger(), ): """Loads a model If there are several models in the specified directory, then the latest one is loaded (based on its timestamp) """ file_pattern = os.path.join(dir, "{}*.{}".format(model_prefix, model_ext)) model_files = glob.glob(file_pattern) if len(model_files) == 0: logger.log( NOPRINTCRITICAL, "No valid model was found with file pattern {}".format( file_pattern), ) raise Exception("No valid model was found with file pattern {}".format( file_pattern)) # Grab the newest model model_files.sort() model_file = model_files[-1] model = model_cls.from_saved_model(model_file) logger.debug("Loaded model {}".format(model_file)) return model
def initialise_scheduler(cls, user: str, account: str = None, logger: Logger = get_basic_logger()): if cls.__scheduler is not None: raise RuntimeError("Scheduler already initialised") scheduler = platform_config[PLATFORM_CONFIG.SCHEDULER.name] if account is None: account = platform_config[PLATFORM_CONFIG.DEFAULT_ACCOUNT.name] if scheduler == "slurm": cls.__scheduler = Slurm(user=user, account=account, current_machine=host, logger=logger) elif scheduler == "pbs": cls.__scheduler = Pbs(user=user, account=account, current_machine=host, logger=logger) else: cls.__scheduler = Bash(user=user, account=account, current_machine=host, logger=logger) cls.__scheduler.logger.debug("Scheduler initialised")
def set_wct(est_run_time, ncores, auto=False, logger=get_basic_logger()): import estimation.estimate_wct as est if auto: level = DEBUG else: level = INFO logger.log( level, "Estimated time: {} with {} number of cores".format( est.convert_to_wct(est_run_time), ncores), ) if not auto: print("Use the estimated wall clock time? (Minimum of 5 mins, " "otherwise adds a 50% overestimation to ensure " "the job completes)") use_estimation = show_yes_no_question() else: use_estimation = True if use_estimation: logger.debug("Using generated estimation.") wct = est.get_wct(est_run_time) else: logger.debug("Using user determined wct value.") wct = str(get_input_wc()) logger.log(level, "WCT set to: {}".format(wct)) return wct
def main(args, logger: Logger = get_basic_logger()): params = utils.load_sim_params(os.path.join(args.rel_dir, "sim_params.yaml")) sim_dir = params.sim_dir mgmt_db_loc = params.mgmt_db_location submit_yes = True if args.auto else confirm("Also submit the job for you?") # get the srf(rup) name without extensions srf_name = os.path.splitext(os.path.basename(params.srf_file))[0] # if srf(variation) is provided as args, only create the slurm # with same name provided if args.srf is not None and srf_name != args.srf: return write_directory = args.write_directory if args.write_directory else sim_dir # get lf_sim_dir lf_sim_dir = os.path.join(sim_dir, "LF") header_dict = { "platform_specific_args": get_platform_node_requirements( platform_config[const.PLATFORM_CONFIG.MERGE_TS_DEFAULT_NCORES.name] ), "wallclock_limit": default_run_time_merge_ts, "job_name": "merge_ts.{}".format(srf_name), "job_description": "post emod3d: merge_ts", "additional_lines": "###SBATCH -C avx", } command_template_parameters = { "run_command": platform_config[const.PLATFORM_CONFIG.RUN_COMMAND.name], "merge_ts_path": binary_version.get_unversioned_bin( "merge_tsP3_par", get_machine_config(args.machine)["tools_dir"] ), } body_template_params = ( "{}.sl.template".format(merge_ts_name_prefix), {"lf_sim_dir": lf_sim_dir}, ) script_prefix = "{}_{}".format(merge_ts_name_prefix, srf_name) script_file_path = write_sl_script( write_directory, sim_dir, const.ProcessType.merge_ts, script_prefix, header_dict, body_template_params, command_template_parameters, ) if submit_yes: submit_script_to_scheduler( script_file_path, const.ProcessType.merge_ts.value, sim_struct.get_mgmt_db_queue(mgmt_db_loc), sim_dir, srf_name, target_machine=args.machine, logger=logger, )
def load_scaler(dir: str, scaler_prefix: str, logger: Logger = get_basic_logger()): """Loads the latest scaler """ file_pattern = os.path.join(dir, "{}{}".format(scaler_prefix, "*.pickle")) scaler = glob.glob(file_pattern) if len(scaler) == 0: logger.log( NOPRINTCRITICAL, "No valid model was found with file pattern {}".format( file_pattern), ) raise Exception("No valid model was found with file pattern {}".format( file_pattern)) scaler.sort() scaler_file = scaler[-1] if not os.path.isfile(scaler_file): logger.log( NOPRINTCRITICAL, "No matching scaler was found for model {}".format(scaler_file), ) raise Exception( "No matching scaler was found for model {}".format(scaler_file)) with open(scaler_file, "rb") as f: scaler = pickle.load(f) logger.debug("Loaded scaler {}".format(scaler_file)) return scaler
def submit_script_to_scheduler( script: str, proc_type: int, queue_folder: str, sim_dir: str, run_name: str, target_machine: str = None, logger: Logger = get_basic_logger(), ): """ Submits the slurm script and updates the management db. Calling the scheduler submitter may result in an error being raised. This is not caught in order to get immediate attention of broken runs. :param sim_dir: :param script: The location of the script to be run :param proc_type: The process type of the job being run :param queue_folder: Where the folder for database updates is :param run_name: The name of the realisation :param target_machine: The :param logger: :return: """ job_id = Scheduler.get_scheduler().submit_job(sim_dir, script, target_machine) add_to_queue( queue_folder, run_name, proc_type, const.Status.queued.value, job_id=job_id, logger=logger, )
def calculate_unperturbated_empiricals( default_vs30, extended_period, fsf, im_config, n_processes, sim_root, empirical_im_logger: Logger = get_basic_logger(), ): events = load_fault_selection_file(fsf) empirical_im_logger.debug( f"Loaded {len(events)} events from the fault selection file" ) events = [ name if count == 1 else get_realisation_name(name, 1) for name, count in events.items() ] tasks = create_event_tasks( events, sim_root, im_config, default_vs30, extended_period, empirical_im_logger ) pool = Pool(min(n_processes, len(tasks))) empirical_im_logger.debug(f"Running empirical im calculations") pool.starmap(calculate_empirical, tasks) empirical_im_logger.debug(f"Empirical ims calculated")
def _update_entry( self, cur: sql.Cursor, entry: SchedulerTask, logger: Logger = get_basic_logger() ): """Updates all fields that have a value for the specific entry""" if entry.status == const.Status.queued.value: logger.debug( "Got entry {} with status queued. Setting status and job id in the db".format( entry ) ) cur.execute( "UPDATE state SET {} = ?, {} = ?, last_modified = strftime('%s','now') " "WHERE run_name = ? AND proc_type = ? and status < ?".format( self.col_job_id, self.col_status ), ( entry.job_id, entry.status, entry.run_name, entry.proc_type, entry.status, ), ) elif entry.job_id is not None: cur.execute( "UPDATE state SET status = ?, last_modified = strftime('%s','now') " "WHERE run_name = ? AND proc_type = ? and status < ? and job_id = ?", ( entry.status, entry.run_name, entry.proc_type, entry.status, entry.job_id, ), ) else: logger.warning( "Recieved entry {}, status is more than created but the job_id is not set.".format( entry ) ) cur.execute( "UPDATE state SET status = ?, last_modified = strftime('%s','now') " "WHERE run_name = ? AND proc_type = ? and status < ?", (entry.status, entry.run_name, entry.proc_type, entry.status), ) if cur.rowcount > 1: logger.warning( "Last database update caused {} entries to be updated".format( cur.rowcount ) ) if entry.error is not None: cur.execute( """INSERT INTO error (task_id, error) VALUES ( (SELECT id from state WHERE proc_type = ? AND run_name = ?), ?)""", (entry.proc_type, entry.run_name, entry.error), )
def predict( self, X_nn: np.ndarray, X_svr: np.ndarray, n_cores: np.ndarray, default_n_cores: int, logger: Logger = get_basic_logger(), ): """Attempt to use the NN model for estimation, however if input data is out of bounds, use the SVR model Parameters ---------- X_nn: array of floats, shape [number of entries, number of features] Input data for NN, last column has to be the number of cores X_svr: array of floats, shape [number of entries, number of features] Input data for SVR n_cores: array of integers The non-normalised number of cores (i.e. actual number of physical cores to estimate for) default_n_cores: int The default number of cores for the process type that is being estimated. logger: Logger Logger for messages to be logged against """ assert X_nn.shape[0] == X_svr.shape[0] out_bound_mask = np.any(self.nn_model.get_out_of_bounds_mask(X_nn), axis=1) if np.all(~out_bound_mask): return self.nn_model.predict(X_nn, warning=False, logger=logger) else: if np.any(~out_bound_mask): # Identify all entries that are out of bounds # Estimate using NN results = np.ones(X_nn.shape[0], dtype=np.float) * np.nan results[~out_bound_mask] = self.nn_model.predict( X_nn[~out_bound_mask, :], warning=False, logger=logger ) # Estimate out of bounds using SVR logger.debug( "Some entries are out of bounds, these will be " "estimated using the SVR model." ) results[out_bound_mask] = ( self.svr_model.predict(X_svr[out_bound_mask, :], logger=logger) * default_n_cores ) / n_cores[out_bound_mask] return results else: logger.debug( "The entry is out of bounds. The SVR models will be " "used for estimation." ) return self.svr_model.predict(X_svr, logger=logger)
def estimate_HF_chours( data: np.ndarray, model: Union[str, EstModel], scale_ncores: bool, node_time_th_factor: float = 1.0, model_type: const.EstModelType = DEFAULT_MODEL_TYPE, logger: Logger = get_basic_logger(), ): """Make bulk HF estimations, requires data to be in the correct order (see above). Params ------ data: np.ndarray of int, float Input data for the model in order fd_count, nsub_stoch, nt, n_cores Has to have shape [-1, 4] scale_ncores: bool If True then the number of cores is adjusted until n_nodes * node_time_th >= run_time node_time_th: float Node time threshold factor in hours, does nothing if scale_ncores is not set Returns ------- core_hours: np.ndarray of floats Estimated number of core hours run_time: np.ndarray of floats Estimated run time (hours) n_cores: np.ndarray of ints The number of physical cores to use, returns the argument n_cores if scale_ncores is not set. Otherwise returns the updated ncores. """ if data.shape[1] != 4: raise Exception("Invalid input data, has to 4 columns. " "One for each feature.") hyperthreading_factor = 2.0 if const.ProcessType.HF.is_hyperth else 1.0 # Adjust the number of cores to estimate physical core hours data[:, -1] = data[:, -1] / hyperthreading_factor core_hours = estimate( data, model, model_type, platform_config[const.PLATFORM_CONFIG.HF_DEFAULT_NCORES.name] / hyperthreading_factor, logger=logger, ) wct = core_hours / data[:, -1] if scale_ncores and np.any(wct > (node_time_th_factor * data[:, -1] / PHYSICAL_NCORES_PER_NODE)): core_hours, wct, data[:, -1] = scale_core_hours(core_hours, data, node_time_th_factor) return core_hours, wct, data[:, -1] * hyperthreading_factor
def predict(self, X: np.ndarray, logger: Logger = get_basic_logger()): """Performs the actual prediction using the current model For full doc see WCEstModel.predict """ if not self.is_trained: logger.log( NOPRINTCRITICAL, "There was an attempt to use an untrained model" ) raise Exception("This model has not been trained!") return self._model.predict(X).reshape(-1)
def test_update_live_db(mgmt_db): mgmt_db.update_entries_live( [ SchedulerTask(TEST_RUN_NAME, TEST_PROC[0], TEST_STATUS[0], None, None) ], get_basic_logger(), ) value = get_rows(mgmt_db.db_file, "state", "proc_type", TEST_PROC[0], selected_col="status")[0][0] assert value == TEST_STATUS[0] mgmt_db.close_conn()
def est_HF_chours_single( fd_count: int, nsub_stoch: float, nt: int, n_logical_cores: int, model: Union[str, EstModel], scale_ncores: bool, node_time_th_factor: float = 1.0, model_type: const.EstModelType = DEFAULT_MODEL_TYPE, logger: Logger = get_basic_logger(), ): """Convenience function to make a single estimation If the input parameters (or even just the order) of the model is ever changed, then this function has to be adjusted accordingly. Params ------ fd_count, nsub_stoch, nt, n_cores: int, float Input features for the model Returns ------- core_hours: float Estimated number of core hours run_time: float Estimated run time (hours) """ # Make a numpy array of the input data in the right shape # The order of the features has to the same as for training!! data = np.array([ float(fd_count), float(nsub_stoch), float(nt), float(n_logical_cores) ]).reshape(1, 4) core_hours, run_time, n_cpus = estimate_HF_chours( data, model, scale_ncores, node_time_th_factor=node_time_th_factor, model_type=model_type, logger=logger, ) return core_hours[0], run_time[0], int(n_cpus[0])
def add_to_queue( queue_folder: str, run_name: str, proc_type: int, status: int, job_id: int = None, error: str = None, logger: Logger = get_basic_logger(), ): """Adds an update entry to the queue""" logger.debug( "Adding task to the queue. Realisation: {}, process type: {}, status: {}, job_id: {}, error: {}" .format(run_name, proc_type, status, job_id, error)) filename = os.path.join( queue_folder, "{}.{}.{}".format(datetime.now().strftime(const.QUEUE_DATE_FORMAT), run_name, proc_type), ) if os.path.exists(filename): logger.log( NOPRINTCRITICAL, "An update with the name {} already exists. This should never happen. Quitting!" .format(os.path.basename(filename)), ) raise Exception( "An update with the name {} already exists. This should never happen. Quitting!" .format(os.path.basename(filename))) logger.debug("Writing update file to {}".format(filename)) with open(filename, "w") as f: json.dump( { MgmtDB.col_run_name: run_name, MgmtDB.col_proc_type: proc_type, MgmtDB.col_status: status, MgmtDB.col_job_id: job_id, "error": error, }, f, ) if not os.path.isfile(filename): logger.critical("File {} did not successfully write".format(filename)) else: logger.debug("Successfully wrote task update file")
def aggregate_simulation_empirical_im_permutations( fsf, n_processes, sim_root, version, logger: Logger = get_basic_logger()): events = load_fault_selection_file(fsf) logger.debug(f"Loaded {len(events)} events from the fault selection file") events = [ name if count == 1 else get_realisation_name(name, 1) for name, count in events.items() ] worker_pool = Pool(n_processes) worker_pool.starmap( agg_emp_perms, [( pathlib.Path(get_empirical_dir(sim_root, event)), event, version, get_realisation_logger(logger, event).name, ) for event in events], )
def check_mgmt_queue(queue_entries: List[str], run_name: str, proc_type: int, logger=get_basic_logger()): """Returns True if there are any queued entries for this run_name and process type, otherwise returns False. """ logger.debug( "Checking to see if the realisation {} has a process of type {} in updates folder" .format(run_name, proc_type)) for entry in queue_entries: logger.debug("Checking against {}".format(entry)) _, entry_run_name, entry_proc_type = entry.split(".") if entry_run_name == run_name and entry_proc_type == str(proc_type): logger.debug("It's a match, returning True") return True logger.debug("No match found") return False
def agg_emp_perms( empirical_dir: pathlib.Path, realisation: str, version: str, aggregation_logger: Union[Logger, str] = get_basic_logger(), ): """ Generates aggregated empirical permutations for a given realisation within a simulation run :param empirical_dir: The directory of the event or fault realisation to generate aggregation files for :param realisation: The name of the realisation being aggregrated :param version: The version of the simulation. e.g. the perturbation version :param aggregation_logger: The logger object or name of required logger object to be used for logging """ if isinstance(aggregation_logger, str): aggregation_logger = get_logger(aggregation_logger) event = get_fault_from_realisation(realisation) empirical_files = empirical_dir.glob(f"{event}_*.csv") ims = {} for f in empirical_files: im = f.stem.split("_")[-1] if im not in ims.keys(): ims[im] = [] ims[im].append(f) if not ims: aggregation_logger.error("No empirical IM files found, exiting") return aggregation_logger.debug( f"Found {sum([len(fs) for fs in ims.values()])} empirical IM files") groups = calculate_aggregation_groups(ims) aggregation_logger.debug(f"Created {len(groups)} aggregated IM groups") for group in groups: identifier = get_agg_identifier(event, group) aggregation_logger.debug( f"The identifier {identifier} is being used for the IM group {group}" ) aggregate_data(group, empirical_dir, identifier, event, version) aggregation_logger.debug( f"Saved empirical IMs to {empirical_dir / f'{identifier}.csv'}")
def install_bb( stat_file, root_dict, v1d_dir, v1d_full_path=None, site_v1d_dir=None, hf_stat_vs_ref=None, logger: Logger = get_basic_logger(), ): shared.show_horizontal_line(c="*") logger.info(" " * 37 + "EMOD3D HF/BB Preparation Ver.slurm") shared.show_horizontal_line(c="*") if v1d_full_path is not None: v_mod_1d_selected = v1d_full_path # temporary removed because master version of bb_sim does not take this as a argument # TODO: most of these logic are not required and should be removed # these logic are now depending on gmsim_version_template # root_dict["bb"]["site_specific"] = False root_dict["hf"][HF_VEL_MOD_1D] = v_mod_1d_selected # TODO:add in logic for site specific as well, if the user provided as args elif site_v1d_dir is not None and hf_stat_vs_ref is not None: hf_vel_mod_1d, hf_stat_vs_ref = shared.get_site_specific_path( os.path.dirname(stat_file), hf_stat_vs_ref=hf_stat_vs_ref, v1d_mod_dir=site_v1d_dir, logger=logger, ) # root_dict["bb"]["site_specific"] = True root_dict["hf"][HF_VEL_MOD_1D] = hf_vel_mod_1d root_dict["hf_stat_vs_ref"] = hf_stat_vs_ref else: is_site_specific_id = q_site_specific() if is_site_specific_id: hf_vel_mod_1d, hf_stat_vs_ref = shared.get_site_specific_path( os.path.dirname(stat_file), logger=logger) # root_dict["bb"]["site_specific"] = True root_dict["hf"][HF_VEL_MOD_1D] = hf_vel_mod_1d root_dict["hf_stat_vs_ref"] = hf_stat_vs_ref else: hf_vel_mod_1d, v_mod_1d_selected = q_1d_velocity_model(v1d_dir) # root_dict["bb"]["site_specific"] = False root_dict["hf"][HF_VEL_MOD_1D] = v_mod_1d_selected
def get_queue_entry( entry_file: str, queue_logger: Logger = qclogging.get_basic_logger() ): try: with open(entry_file, "r") as f: data_dict = json.load(f) except json.JSONDecodeError as ex: queue_logger.error( "Failed to decode the file {} as json. Check that this is " "valid json. Ignored!".format(entry_file) ) return None return SchedulerTask( run_name=os.path.basename(entry_file).split(".")[1], proc_type=data_dict[MgmtDB.col_proc_type], status=data_dict[MgmtDB.col_status], job_id=data_dict[MgmtDB.col_job_id], error=data_dict.get("error"), )
def predict( self, X: np.ndarray, warning: bool = True, logger: Logger = get_basic_logger() ): """Performs the actual prediction using the current model For full doc see WCEstModel.predict """ if not self.is_trained: logger.log( NOPRINTCRITICAL, "There was an attempt to use an untrained model" ) raise Exception("This model has not been trained!") if np.any(self.get_out_of_bounds_mask(X)) and warning: print( "WARNING: Some of the data specified for estimation exceeds the " "limits of the data the model was trained. This will result in " "incorrect estimation!" ) return self._model.predict(X).reshape(-1)
def get_site_specific_path( stat_file_path, hf_stat_vs_ref=None, v1d_mod_dir=None, logger: Logger = get_basic_logger(), ): show_horizontal_line() logger.info("Auto-detecting site-specific info") show_horizontal_line() logger.info("- Station file path: %s" % stat_file_path) if v1d_mod_dir is not None: v_mod_1d_path = v1d_mod_dir else: v_mod_1d_path = os.path.join(os.path.dirname(stat_file_path), "1D") if os.path.exists(v_mod_1d_path): logger.info("- 1D profiles found at {}".format(v_mod_1d_path)) else: logger.critical("Error: No such path exists: {}".format(v_mod_1d_path)) sys.exit() if hf_stat_vs_ref is None: hf_stat_vs_ref_options = glob.glob( os.path.join(stat_file_path, "*.hfvs30ref")) if len(hf_stat_vs_ref_options) == 0: logger.critical("Error: No HF Vsref file was found at {}".format( stat_file_path)) sys.exit() hf_stat_vs_ref_options.sort() show_horizontal_line() logger.info("Select one of HF Vsref files") show_horizontal_line() hf_stat_vs_ref_selected = show_multiple_choice(hf_stat_vs_ref_options) logger.info( " - HF Vsref tp be used: {}".format(hf_stat_vs_ref_selected)) else: hf_stat_vs_ref_selected = hf_stat_vs_ref return v_mod_1d_path, hf_stat_vs_ref_selected
def load_full_model( dir: str, model_type: const.EstModelType = DEFAULT_MODEL_TYPE, logger: Logger = get_basic_logger(), ): """Loads the full model, i.e. the estimation model(s) and their associated scaler. Returns an EstModel object. """ # Load just NN if model_type is const.EstModelType.NN: return EstModel( load_model(dir, const.EST_MODEL_NN_PREFIX, "h5", NNWcEstModel, logger), load_scaler(dir, SCALER_PREFIX.format("NN"), logger), None, None, ) # Load just SVR elif model_type is const.EstModelType.SVR: return EstModel( None, None, load_model(dir, const.EST_MODEL_SVR_PREFIX, "pickle", SVRModel, logger), load_scaler(dir, SCALER_PREFIX.format("SVR"), logger), ) # Load both elif model_type is const.EstModelType.NN_SVR: return EstModel( load_model(dir, const.EST_MODEL_NN_PREFIX, "h5", NNWcEstModel, logger), load_scaler(dir, SCALER_PREFIX.format("NN"), logger), load_model(dir, const.EST_MODEL_SVR_PREFIX, "pickle", SVRModel, logger), load_scaler(dir, SCALER_PREFIX.format("SVR"), logger), )
def parse_config_file(config_file_location: str, logger: Logger = qclogging.get_basic_logger()): """Takes in the location of a wrapper config file and creates the tasks to be run. Requires that the file contains the keys 'run_all_tasks' and 'run_some', even if they are empty If the dependencies for a run_some task overlap with those in the tasks_to_run_for_all, as a race condition is possible if multiple auto_submit scripts have the same tasks. If multiple run_some instances have the same dependencies then this is not an issue as they run sequentially, rather than simultaneously :param config_file_location: The location of the config file :return: A tuple containing the tasks to be run on all processes and a list of pattern, tasks tuples which state which tasks can be run with which patterns """ config = load_yaml(config_file_location) tasks_to_run_for_all = [] tasks_with_pattern_match = {} for proc_name, pattern in config.items(): proc = const.ProcessType.from_str(proc_name) if pattern == ALL: tasks_to_run_for_all.append(proc) elif pattern == NONE: pass else: if isinstance(pattern, str): pattern = [pattern] for subpattern in pattern: if subpattern == ONCE: subpattern = ONCE_PATTERN if subpattern not in tasks_with_pattern_match.keys(): tasks_with_pattern_match.update({subpattern: []}) tasks_with_pattern_match[subpattern].append(proc) logger.info("Master script will run {}".format(tasks_to_run_for_all)) for pattern, tasks in tasks_with_pattern_match.items(): logger.info("Pattern {} will run tasks {}".format(pattern, tasks)) return tasks_to_run_for_all, tasks_with_pattern_match.items()
def _check_dependancy_met(self, task, logger=get_basic_logger()): """Checks if all dependencies for the specified are met""" process, run_name = task process = Process(process) with connect_db_ctx(self._db_file) as cur: completed_tasks = cur.execute( """SELECT proc_type FROM status_enum, state WHERE state.status = status_enum.id AND run_name = (?) AND status_enum.state = 'completed'""", (run_name,), ).fetchall() logger.debug( "Considering task {} for realisation {}. Completed tasks as follows: {}".format( process, run_name, completed_tasks ) ) remaining_deps = process.get_remaining_dependencies( [const.ProcessType(x[0]) for x in completed_tasks] ) logger.debug("{} has remaining deps: {}".format(task, remaining_deps)) return len(remaining_deps) == 0
def parse_config_file(config_file_location: str, logger: Logger = get_basic_logger()): """Takes in the location of a wrapper config file and creates the tasks to be run. Each task that is desired to be run should have its name as given in qcore.constants followed by the relevant keyword or sqlite formatted query string, which uses % as the wildcard character. The keywords NONE, ONCE and ALL correspond to the patterns nothing, "%_REL01", "%" respectively. :param config_file_location: The location of the config file :param logger: The logger object used to record messages :return: A list containing the tasks to be run on all processes and a dictionary of pattern, task list pairs which state which query patterns should run which tasks """ config = load_yaml(config_file_location) tasks_to_run_for_all = [] tasks_with_pattern_match = {} for proc_name, patterns in config.items(): proc = const.ProcessType.from_str(proc_name) if not isinstance(patterns, list): patterns = [patterns] for pattern in patterns: if pattern == ALL: tasks_to_run_for_all.append(proc) elif pattern == NONE: pass else: if pattern == ONCE: pattern = ONCE_PATTERN if pattern not in tasks_with_pattern_match.keys(): tasks_with_pattern_match.update({pattern: []}) tasks_with_pattern_match[pattern].append(proc) logger.info("Master script will run {}".format(tasks_to_run_for_all)) for pattern, tasks in tasks_with_pattern_match.items(): logger.info("Pattern {} will run tasks {}".format(pattern, tasks)) return tasks_to_run_for_all, tasks_with_pattern_match.items()
def update_entries_live( self, entries: List[SchedulerTask], retry_max: int, logger: Logger = get_basic_logger(), ): """Updates the specified entries in the db. Leaves the connection open, so this should only be used when continuously updating entries. """ try: if self._conn is None: logger.info("Aquiring db connection.") self._conn = sql.connect(self._db_file) logger.debug("Getting db cursor") cur = self._conn.cursor() cur.execute("BEGIN") for entry in entries: process = entry.proc_type realisation_name = entry.run_name logger.debug( "The status of process {} for realisation {} is being set to {}. It has slurm id {}".format( entry.proc_type, entry.run_name, entry.status, entry.job_id ) ) if entry.status == const.Status.created.value: # Something has attempted to set a task to created # Make a new task with created status and move to the next task logger.debug("Adding new task to the db") # Check that there isn't already a task with the same realisation name if self._does_task_exists(cur, realisation_name, process): logger.debug( "task is already in progress - does not need to be readded" ) continue self._insert_task(cur, realisation_name, process) logger.debug("New task added to the db, continuing to next process") continue logger.debug("Updating task in the db") self._update_entry(cur, entry, logger=logger) logger.debug("Task successfully updated") if ( entry.status == const.Status.failed.value and self.get_retries(process, realisation_name) < retry_max ): # The task was failed. If there have been few enough other attempts at the task make another one logger.debug( "Task failed but is able to be retried. Adding new task to the db" ) self._insert_task(cur, realisation_name, process) logger.debug("New task added to the db") # fails dependant task if parent task fails if entry.status == const.Status.failed.value: tasks = MgmtDB.find_dependant_task(cur, entry) i = 0 while i < len(tasks): task = tasks[i] self._update_entry(cur, task, logger=logger) logger.debug( f"Cascading failure for {entry.run_name} - {task.proc_type}" ) tasks.extend(MgmtDB.find_dependant_task(cur, task)) i += 1 except sql.Error as ex: self._conn.rollback() logger.critical( "Failed to update entry {}, due to the exception: \n{}".format( entry, ex ) ) return False else: logger.debug("Committing changes to db") self._conn.commit() finally: logger.debug("Closing db cursor") cur.close() return True
def get_runnable_tasks( self, allowed_rels, task_limit, update_files, allowed_tasks=None, logger=get_basic_logger(), ): """Gets all runnable tasks based on their status and their associated dependencies (i.e. other tasks have to be finished first) Returns a list of tuples (proc_type, run_name, state_str) """ if allowed_tasks is None: allowed_tasks = list(const.ProcessType) allowed_tasks = [str(task.value) for task in allowed_tasks] if len(allowed_tasks) == 0: return [] runnable_tasks = [] offset = 0 # "{}__{}" is intended to be the template for a unique string for every realisation and process type pair # Used to compare with database entries to prevent running a task that has already been submitted, but not # recorded tasks_waiting_for_updates = [ "{}__{}".format(*(entry.split(".")[1:3])) for entry in update_files ] with connect_db_ctx(self._db_file) as cur: entries = cur.execute( """SELECT COUNT(*) FROM status_enum, state WHERE state.status = status_enum.id AND proc_type IN (?{}) AND run_name LIKE (?) AND status_enum.state = 'created'""".format( ",?" * (len(allowed_tasks) - 1) ), (*allowed_tasks, allowed_rels), ).fetchone()[0] while len(runnable_tasks) < task_limit and offset < entries: db_tasks = cur.execute( """SELECT proc_type, run_name FROM status_enum, state WHERE state.status = status_enum.id AND proc_type IN (?{}) AND run_name LIKE (?) AND status_enum.state = 'created' LIMIT 100 OFFSET ?""".format( ",?" * (len(allowed_tasks) - 1) ), (*allowed_tasks, allowed_rels, offset), ).fetchall() runnable_tasks.extend( [ (*task, self.get_retries(*task)) for task in db_tasks if self._check_dependancy_met(task, logger) and "{}__{}".format(*task) not in tasks_waiting_for_updates ] ) offset += 100 return runnable_tasks
def install_fault( fault_name, n_rel, root_folder, version, stat_file_path, seed=HF_DEFAULT_SEED, extended_period=False, vm_perturbations=False, ignore_vm_perturbations=False, vm_qpqs_files=False, ignore_vm_qpqs_files=False, keep_dup_station=True, components=None, logger: Logger = get_basic_logger(), ): config_dict = utils.load_yaml( os.path.join( platform_config[PLATFORM_CONFIG.TEMPLATES_DIR.name], "gmsim", version, ROOT_DEFAULTS_FILE_NAME, ) ) # Load variables from cybershake config v1d_full_path = os.path.join( platform_config[PLATFORM_CONFIG.VELOCITY_MODEL_DIR.name], "Mod-1D", config_dict.get("v_1d_mod"), ) site_v1d_dir = config_dict.get("site_v1d_dir") hf_stat_vs_ref = config_dict.get("hf_stat_vs_ref") vs30_file_path = stat_file_path.replace(".ll", ".vs30") vs30ref_file_path = stat_file_path.replace(".ll", ".vs30ref") # this variable has to be empty # TODO: fix this legacy issue, very low priority event_name = "" # get all srf from source srf_dir = simulation_structure.get_srf_dir(root_folder, fault_name) list_srf = glob.glob(os.path.join(srf_dir, "*_REL*.srf")) if len(list_srf) == 0: list_srf = glob.glob(os.path.join(srf_dir, "*.srf")) list_srf.sort() if n_rel is not None and len(list_srf) != n_rel: message = ( "Error: fault {} failed. Number of realisations do " "not match number of SRF files".format(fault_name) ) logger.log(NOPRINTCRITICAL, message) raise RuntimeError(message) # Get & validate velocity model directory vel_mod_dir = simulation_structure.get_fault_VM_dir(root_folder, fault_name) valid_vm, message = validate_vm.validate_vm(vel_mod_dir, srf=list_srf[0]) if not valid_vm: message = "Error: VM {} failed {}".format(fault_name, message) logger.log(NOPRINTCRITICAL, message) raise RuntimeError(message) # Load the variables from vm_params.yaml vm_params_path = os.path.join(vel_mod_dir, VM_PARAMS_FILE_NAME) vm_params_dict = utils.load_yaml(vm_params_path) yes_model_params = ( False # statgrid should normally be already generated with Velocity Model ) sim_root_dir = simulation_structure.get_runs_dir(root_folder) fault_yaml_path = simulation_structure.get_fault_yaml_path(sim_root_dir, fault_name) root_yaml_path = simulation_structure.get_root_yaml_path(sim_root_dir) for srf in list_srf: logger.info("Installing {}".format(srf)) # try to match find the stoch with same basename realisation_name = os.path.splitext(os.path.basename(srf))[0] stoch_file_path = simulation_structure.get_stoch_path( root_folder, realisation_name ) sim_params_file = simulation_structure.get_source_params_path( root_folder, realisation_name ) if not os.path.isfile(stoch_file_path): message = "Error: Corresponding Stoch file is not found: {}".format( stoch_file_path ) logger.log(NOPRINTCRITICAL, message) raise RuntimeError(message) # install pairs one by one to fit the new structure sim_dir = simulation_structure.get_sim_dir(root_folder, realisation_name) (root_params_dict, fault_params_dict, sim_params_dict) = install_simulation( version=version, sim_dir=sim_dir, rel_name=realisation_name, run_dir=sim_root_dir, vel_mod_dir=vel_mod_dir, srf_file=srf, stoch_file=stoch_file_path, stat_file_path=stat_file_path, vs30_file_path=vs30_file_path, vs30ref_file_path=vs30ref_file_path, yes_statcords=False, fault_yaml_path=fault_yaml_path, root_yaml_path=root_yaml_path, cybershake_root=root_folder, site_v1d_dir=site_v1d_dir, hf_stat_vs_ref=hf_stat_vs_ref, v1d_full_path=v1d_full_path, sim_params_file=sim_params_file, seed=seed, logger=logger, extended_period=extended_period, vm_perturbations=vm_perturbations, ignore_vm_perturbations=ignore_vm_perturbations, vm_qpqs_files=vm_qpqs_files, ignore_vm_qpqs_files=ignore_vm_qpqs_files, components=components, ) if ( root_params_dict is None or fault_params_dict is None or sim_params_dict is None ): # Something has gone wrong, returning without saving anything logger.critical(f"Critical Error some params dictionary are None") return if root_params_dict is not None and not isclose( vm_params_dict["flo"], root_params_dict["flo"] ): logger.critical( "The parameter 'flo' does not match in the VM params and root params files. " "Please ensure you are installing the correct gmsim version" ) return create_mgmt_db.create_mgmt_db( [], simulation_structure.get_mgmt_db(root_folder), srf_files=srf ) utils.setup_dir(os.path.join(root_folder, "mgmt_db_queue")) root_params_dict["mgmt_db_location"] = root_folder # Generate the fd files, create these at the fault level fd_statcords, fd_statlist = generate_fd_files( simulation_structure.get_fault_dir(root_folder, fault_name), vm_params_dict, stat_file=stat_file_path, logger=logger, keep_dup_station=keep_dup_station, ) fault_params_dict[FaultParams.stat_coords.value] = fd_statcords fault_params_dict[FaultParams.FD_STATLIST.value] = fd_statlist # root_params_dict['hf_stat_vs_ref'] = cybershake_cfg['hf_stat_vs_ref'] dump_all_yamls(sim_dir, root_params_dict, fault_params_dict, sim_params_dict) # test if the params are accepted by steps HF and BB sim_params = utils.load_sim_params(os.path.join(sim_dir, "sim_params.yaml")) # check hf # temporary change the script name to hf_sim, due to how error message are shown main_script_name = sys.argv[0] sys.argv[0] = "hf_sim.py" command_template, add_args = hf_gen_command_template( sim_params, list(HPC)[0].name, seed ) run_command = gen_args_cmd( ProcessType.HF.command_template, command_template, add_args ) hf_args_parser(cmd=run_command) # check bb sys.argv[0] = "bb_sim.py" command_template, add_args = bb_gen_command_template(sim_params) run_command = gen_args_cmd( ProcessType.BB.command_template, command_template, add_args ) bb_args_parser(cmd=run_command) # change back, to prevent unexpected error sys.argv[0] = main_script_name
def main( args: argparse.Namespace, est_model: est.EstModel = None, logger: Logger = get_basic_logger(), ): params = utils.load_sim_params( os.path.join(args.rel_dir, "sim_params.yaml")) submit_yes = True if args.auto else confirm("Also submit the job for you?") logger.debug("params.srf_file {}".format(params.srf_file)) # Get the srf(rup) name without extensions srf_name = os.path.splitext(os.path.basename(params.srf_file))[0] if args.srf is None or srf_name == args.srf: logger.debug("not set_params_only") # get lf_sim_dir sim_dir = os.path.abspath(params.sim_dir) lf_sim_dir = sim_struct.get_lf_dir(sim_dir) # default_core will be changed is user passes ncore nt = int(float(params.sim_duration) / float(params.dt)) target_qconfig = get_machine_config(args.machine) retries = args.retries if hasattr(args, "retries") else None est_cores, est_run_time, wct = get_lf_cores_and_wct( est_model, logger, nt, params, sim_dir, srf_name, target_qconfig, args.ncore, retries, ) binary_path = binary_version.get_lf_bin(params.emod3d.emod3d_version, target_qconfig["tools_dir"]) # use the original estimated run time for determining the checkpoint, or uses a minimum of 3 checkpoints steps_per_checkpoint = int( min(nt / (60.0 * est_run_time) * const.CHECKPOINT_DURATION, nt // 3)) write_directory = (args.write_directory if args.write_directory else params.sim_dir) set_runparams.create_run_params( sim_dir, steps_per_checkpoint=steps_per_checkpoint, logger=logger) header_dict = { "wallclock_limit": wct, "job_name": "emod3d.{}".format(srf_name), "job_description": "emod3d slurm script", "additional_lines": "#SBATCH --hint=nomultithread", "platform_specific_args": get_platform_node_requirements(est_cores), } command_template_parameters = { "run_command": platform_config[const.PLATFORM_CONFIG.RUN_COMMAND.name], "emod3d_bin": binary_path, "lf_sim_dir": lf_sim_dir, } body_template_params = ("run_emod3d.sl.template", {}) script_prefix = "run_emod3d_{}".format(srf_name) script_file_path = write_sl_script( write_directory, params.sim_dir, const.ProcessType.EMOD3D, script_prefix, header_dict, body_template_params, command_template_parameters, ) if submit_yes: submit_script_to_scheduler( script_file_path, const.ProcessType.EMOD3D.value, sim_struct.get_mgmt_db_queue(params.mgmt_db_location), params.sim_dir, srf_name, target_machine=args.machine, logger=logger, )
def create_run_params( sim_dir, srf_name=None, steps_per_checkpoint=None, logger: Logger = get_basic_logger(), ): params = utils.load_sim_params(os.path.join(sim_dir, "sim_params.yaml")) emod3d_version = params["emod3d"]["emod3d_version"] emod3d_filepath = binary_version.get_lf_bin(emod3d_version) e3d_yaml = os.path.join( platform_config[constants.PLATFORM_CONFIG.TEMPLATES_DIR.name], "gmsim", params.version, "emod3d_defaults.yaml", ) e3d_dict = utils.load_yaml(e3d_yaml) # skip all logic if a specific srf_name is provided if srf_name is None or srf_name == os.path.splitext( basename(params.srf_file))[0]: # EMOD3D adds a timeshift to the event rupture time # this must be accounted for as EMOD3D does not extend the sim duration by the amount of time shift # As flo is in Hz, the sim_duration_extension is in s # Version 3.0.4 was the last version of EMOD3D to have a shift of 1/flo, # while versions after it have a shift of 3/flo sim_duration_extension = 1 / float(params.flo) if compare_versions(emod3d_version, MAXIMUM_EMOD3D_TIMESHIFT_1_VERSION) > 0: sim_duration_extension *= 3 extended_sim_duration = float( params.sim_duration) + sim_duration_extension srf_file_basename = os.path.splitext(os.path.basename( params.srf_file))[0] e3d_dict["version"] = emod3d_version + "-mpi" e3d_dict["name"] = params.run_name e3d_dict["n_proc"] = 512 e3d_dict["nx"] = params.nx e3d_dict["ny"] = params.ny e3d_dict["nz"] = params.nz e3d_dict["h"] = params.hh e3d_dict["dt"] = params.dt e3d_dict["nt"] = str( int(round(extended_sim_duration / float(params.dt)))) e3d_dict["flo"] = float(params.flo) e3d_dict["faultfile"] = params.srf_file e3d_dict["vmoddir"] = params.vel_mod_dir e3d_dict["modellon"] = params.MODEL_LON e3d_dict["modellat"] = params.MODEL_LAT e3d_dict["modelrot"] = params.MODEL_ROT e3d_dict["main_dump_dir"] = os.path.join(params.sim_dir, "LF", "OutBin") e3d_dict["seiscords"] = params.stat_coords e3d_dict["user_scratch"] = os.path.join(params.user_root, "scratch") e3d_dict["seisdir"] = os.path.join(e3d_dict["user_scratch"], params.run_name, srf_file_basename, "SeismoBin") e3d_dict["ts_total"] = str( int(extended_sim_duration / (float(e3d_dict["dt"]) * float(e3d_dict["dtts"])))) e3d_dict["ts_file"] = os.path.join(e3d_dict["main_dump_dir"], params.run_name + "_xyts.e3d") e3d_dict["ts_out_dir"] = os.path.join(params.sim_dir, "LF", "TSlice", "TSFiles") e3d_dict["restartdir"] = os.path.join(params.sim_dir, "LF", "Restart") if steps_per_checkpoint: e3d_dict["dump_itinc"] = e3d_dict["restart_itinc"] = int( steps_per_checkpoint) e3d_dict["restartname"] = params.run_name e3d_dict["logdir"] = os.path.join(params.sim_dir, "LF", "Rlog") e3d_dict["slipout"] = os.path.join(params.sim_dir, "LF", "SlipOut", "slipout-k2") # other locations e3d_dict["wcc_prog_dir"] = emod3d_filepath e3d_dict["vel_mod_params_dir"] = params.vel_mod_dir e3d_dict["sim_dir"] = params.sim_dir e3d_dict["stat_file"] = params.stat_file e3d_dict["grid_file"] = params.GRIDFILE e3d_dict["model_params"] = params.MODEL_PARAMS if params.emod3d: for key, value in params.emod3d.items(): if key in e3d_dict: e3d_dict[key] = value else: logger.debug( "{} not found as a key in e3d file. Ignoring variable. Value is {}." .format(key, value)) shared.write_to_py(os.path.join(params.sim_dir, "LF", "e3d.par"), e3d_dict)