def load_model(
        dir: str,
        model_prefix: str,
        model_ext: str,
        model_cls: WCEstModel,
        logger: Logger = get_basic_logger(),
):
    """Loads a model

    If there are several models in the specified directory, then the latest
    one is loaded (based on its timestamp)
    """
    file_pattern = os.path.join(dir, "{}*.{}".format(model_prefix, model_ext))
    model_files = glob.glob(file_pattern)

    if len(model_files) == 0:
        logger.log(
            NOPRINTCRITICAL,
            "No valid model was found with file pattern {}".format(
                file_pattern),
        )
        raise Exception("No valid model was found with file pattern {}".format(
            file_pattern))

    # Grab the newest model
    model_files.sort()
    model_file = model_files[-1]

    model = model_cls.from_saved_model(model_file)

    logger.debug("Loaded model {}".format(model_file))
    return model
Example #2
0
    def initialise_scheduler(cls,
                             user: str,
                             account: str = None,
                             logger: Logger = get_basic_logger()):
        if cls.__scheduler is not None:
            raise RuntimeError("Scheduler already initialised")

        scheduler = platform_config[PLATFORM_CONFIG.SCHEDULER.name]
        if account is None:
            account = platform_config[PLATFORM_CONFIG.DEFAULT_ACCOUNT.name]
        if scheduler == "slurm":
            cls.__scheduler = Slurm(user=user,
                                    account=account,
                                    current_machine=host,
                                    logger=logger)
        elif scheduler == "pbs":
            cls.__scheduler = Pbs(user=user,
                                  account=account,
                                  current_machine=host,
                                  logger=logger)
        else:
            cls.__scheduler = Bash(user=user,
                                   account=account,
                                   current_machine=host,
                                   logger=logger)
        cls.__scheduler.logger.debug("Scheduler initialised")
Example #3
0
def set_wct(est_run_time, ncores, auto=False, logger=get_basic_logger()):
    import estimation.estimate_wct as est

    if auto:
        level = DEBUG
    else:
        level = INFO
    logger.log(
        level,
        "Estimated time: {} with {} number of cores".format(
            est.convert_to_wct(est_run_time), ncores),
    )
    if not auto:
        print("Use the estimated wall clock time? (Minimum of 5 mins, "
              "otherwise adds a 50% overestimation to ensure "
              "the job completes)")
        use_estimation = show_yes_no_question()
    else:
        use_estimation = True

    if use_estimation:
        logger.debug("Using generated estimation.")
        wct = est.get_wct(est_run_time)
    else:
        logger.debug("Using user determined wct value.")
        wct = str(get_input_wc())
    logger.log(level, "WCT set to: {}".format(wct))
    return wct
Example #4
0
def main(args, logger: Logger = get_basic_logger()):
    params = utils.load_sim_params(os.path.join(args.rel_dir, "sim_params.yaml"))
    sim_dir = params.sim_dir
    mgmt_db_loc = params.mgmt_db_location
    submit_yes = True if args.auto else confirm("Also submit the job for you?")

    # get the srf(rup) name without extensions
    srf_name = os.path.splitext(os.path.basename(params.srf_file))[0]
    # if srf(variation) is provided as args, only create the slurm
    # with same name provided
    if args.srf is not None and srf_name != args.srf:
        return

    write_directory = args.write_directory if args.write_directory else sim_dir

    # get lf_sim_dir
    lf_sim_dir = os.path.join(sim_dir, "LF")

    header_dict = {
        "platform_specific_args": get_platform_node_requirements(
            platform_config[const.PLATFORM_CONFIG.MERGE_TS_DEFAULT_NCORES.name]
        ),
        "wallclock_limit": default_run_time_merge_ts,
        "job_name": "merge_ts.{}".format(srf_name),
        "job_description": "post emod3d: merge_ts",
        "additional_lines": "###SBATCH -C avx",
    }

    command_template_parameters = {
        "run_command": platform_config[const.PLATFORM_CONFIG.RUN_COMMAND.name],
        "merge_ts_path": binary_version.get_unversioned_bin(
            "merge_tsP3_par", get_machine_config(args.machine)["tools_dir"]
        ),
    }

    body_template_params = (
        "{}.sl.template".format(merge_ts_name_prefix),
        {"lf_sim_dir": lf_sim_dir},
    )

    script_prefix = "{}_{}".format(merge_ts_name_prefix, srf_name)
    script_file_path = write_sl_script(
        write_directory,
        sim_dir,
        const.ProcessType.merge_ts,
        script_prefix,
        header_dict,
        body_template_params,
        command_template_parameters,
    )
    if submit_yes:
        submit_script_to_scheduler(
            script_file_path,
            const.ProcessType.merge_ts.value,
            sim_struct.get_mgmt_db_queue(mgmt_db_loc),
            sim_dir,
            srf_name,
            target_machine=args.machine,
            logger=logger,
        )
def load_scaler(dir: str,
                scaler_prefix: str,
                logger: Logger = get_basic_logger()):
    """Loads the latest scaler
    """
    file_pattern = os.path.join(dir, "{}{}".format(scaler_prefix, "*.pickle"))
    scaler = glob.glob(file_pattern)

    if len(scaler) == 0:
        logger.log(
            NOPRINTCRITICAL,
            "No valid model was found with file pattern {}".format(
                file_pattern),
        )
        raise Exception("No valid model was found with file pattern {}".format(
            file_pattern))

    scaler.sort()
    scaler_file = scaler[-1]

    if not os.path.isfile(scaler_file):
        logger.log(
            NOPRINTCRITICAL,
            "No matching scaler was found for model {}".format(scaler_file),
        )
        raise Exception(
            "No matching scaler was found for model {}".format(scaler_file))

    with open(scaler_file, "rb") as f:
        scaler = pickle.load(f)

    logger.debug("Loaded scaler {}".format(scaler_file))
    return scaler
def submit_script_to_scheduler(
        script: str,
        proc_type: int,
        queue_folder: str,
        sim_dir: str,
        run_name: str,
        target_machine: str = None,
        logger: Logger = get_basic_logger(),
):
    """
    Submits the slurm script and updates the management db.
    Calling the scheduler submitter may result in an error being raised.
    This is not caught in order to get immediate attention of broken runs.
    :param sim_dir:
    :param script: The location of the script to be run
    :param proc_type: The process type of the job being run
    :param queue_folder: Where the folder for database updates is
    :param run_name: The name of the realisation
    :param target_machine: The
    :param logger:
    :return:
    """
    job_id = Scheduler.get_scheduler().submit_job(sim_dir, script,
                                                  target_machine)

    add_to_queue(
        queue_folder,
        run_name,
        proc_type,
        const.Status.queued.value,
        job_id=job_id,
        logger=logger,
    )
def calculate_unperturbated_empiricals(
    default_vs30,
    extended_period,
    fsf,
    im_config,
    n_processes,
    sim_root,
    empirical_im_logger: Logger = get_basic_logger(),
):
    events = load_fault_selection_file(fsf)
    empirical_im_logger.debug(
        f"Loaded {len(events)} events from the fault selection file"
    )
    events = [
        name if count == 1 else get_realisation_name(name, 1)
        for name, count in events.items()
    ]
    tasks = create_event_tasks(
        events, sim_root, im_config, default_vs30, extended_period, empirical_im_logger
    )

    pool = Pool(min(n_processes, len(tasks)))
    empirical_im_logger.debug(f"Running empirical im calculations")
    pool.starmap(calculate_empirical, tasks)
    empirical_im_logger.debug(f"Empirical ims calculated")
Example #8
0
 def _update_entry(
     self, cur: sql.Cursor, entry: SchedulerTask, logger: Logger = get_basic_logger()
 ):
     """Updates all fields that have a value for the specific entry"""
     if entry.status == const.Status.queued.value:
         logger.debug(
             "Got entry {} with status queued. Setting status and job id in the db".format(
                 entry
             )
         )
         cur.execute(
             "UPDATE state SET {} = ?, {} = ?, last_modified = strftime('%s','now') "
             "WHERE run_name = ? AND proc_type = ? and status < ?".format(
                 self.col_job_id, self.col_status
             ),
             (
                 entry.job_id,
                 entry.status,
                 entry.run_name,
                 entry.proc_type,
                 entry.status,
             ),
         )
     elif entry.job_id is not None:
         cur.execute(
             "UPDATE state SET status = ?, last_modified = strftime('%s','now') "
             "WHERE run_name = ? AND proc_type = ? and status < ? and job_id = ?",
             (
                 entry.status,
                 entry.run_name,
                 entry.proc_type,
                 entry.status,
                 entry.job_id,
             ),
         )
     else:
         logger.warning(
             "Recieved entry {}, status is more than created but the job_id is not set.".format(
                 entry
             )
         )
         cur.execute(
             "UPDATE state SET status = ?, last_modified = strftime('%s','now') "
             "WHERE run_name = ? AND proc_type = ? and status < ?",
             (entry.status, entry.run_name, entry.proc_type, entry.status),
         )
     if cur.rowcount > 1:
         logger.warning(
             "Last database update caused {} entries to be updated".format(
                 cur.rowcount
             )
         )
     if entry.error is not None:
         cur.execute(
             """INSERT INTO error (task_id, error)
               VALUES (
               (SELECT id from state WHERE proc_type = ? AND run_name = ?), ?)""",
             (entry.proc_type, entry.run_name, entry.error),
         )
Example #9
0
    def predict(
        self,
        X_nn: np.ndarray,
        X_svr: np.ndarray,
        n_cores: np.ndarray,
        default_n_cores: int,
        logger: Logger = get_basic_logger(),
    ):
        """Attempt to use the NN model for estimation, however if input data
        is out of bounds, use the SVR model

        Parameters
        ----------
        X_nn: array of floats, shape [number of entries, number of features]
            Input data for NN, last column has to be the number of cores
        X_svr: array of floats, shape [number of entries, number of features]
            Input data for SVR
        n_cores: array of integers
            The non-normalised number of cores (i.e. actual number of
            physical cores to estimate for)
        default_n_cores: int
            The default number of cores for the process type that is being estimated.
        logger: Logger
            Logger for messages to be logged against
        """
        assert X_nn.shape[0] == X_svr.shape[0]

        out_bound_mask = np.any(self.nn_model.get_out_of_bounds_mask(X_nn), axis=1)
        if np.all(~out_bound_mask):
            return self.nn_model.predict(X_nn, warning=False, logger=logger)
        else:
            if np.any(~out_bound_mask):
                # Identify all entries that are out of bounds

                # Estimate using NN
                results = np.ones(X_nn.shape[0], dtype=np.float) * np.nan
                results[~out_bound_mask] = self.nn_model.predict(
                    X_nn[~out_bound_mask, :], warning=False, logger=logger
                )

                # Estimate out of bounds using SVR
                logger.debug(
                    "Some entries are out of bounds, these will be "
                    "estimated using the SVR model."
                )
                results[out_bound_mask] = (
                    self.svr_model.predict(X_svr[out_bound_mask, :], logger=logger)
                    * default_n_cores
                ) / n_cores[out_bound_mask]

                return results
            else:
                logger.debug(
                    "The entry is out of bounds. The SVR models will be "
                    "used for estimation."
                )
                return self.svr_model.predict(X_svr, logger=logger)
Example #10
0
def estimate_HF_chours(
        data: np.ndarray,
        model: Union[str, EstModel],
        scale_ncores: bool,
        node_time_th_factor: float = 1.0,
        model_type: const.EstModelType = DEFAULT_MODEL_TYPE,
        logger: Logger = get_basic_logger(),
):
    """Make bulk HF estimations, requires data to be in the correct
    order (see above).

    Params
    ------
    data: np.ndarray of int, float
        Input data for the model in order fd_count, nsub_stoch, nt, n_cores
        Has to have shape [-1, 4]
    scale_ncores: bool
        If True then the number of cores is adjusted until
        n_nodes * node_time_th >= run_time
    node_time_th: float
        Node time threshold factor in hours, does nothing if scale_ncores is not set

    Returns
    -------
    core_hours: np.ndarray of floats
        Estimated number of core hours
    run_time: np.ndarray of floats
        Estimated run time (hours)
    n_cores: np.ndarray of ints
        The number of physical cores to use, returns the argument n_cores
        if scale_ncores is not set. Otherwise returns the updated ncores.
    """
    if data.shape[1] != 4:
        raise Exception("Invalid input data, has to 4 columns. "
                        "One for each feature.")

    hyperthreading_factor = 2.0 if const.ProcessType.HF.is_hyperth else 1.0

    # Adjust the number of cores to estimate physical core hours
    data[:, -1] = data[:, -1] / hyperthreading_factor
    core_hours = estimate(
        data,
        model,
        model_type,
        platform_config[const.PLATFORM_CONFIG.HF_DEFAULT_NCORES.name] /
        hyperthreading_factor,
        logger=logger,
    )

    wct = core_hours / data[:, -1]
    if scale_ncores and np.any(wct > (node_time_th_factor * data[:, -1] /
                                      PHYSICAL_NCORES_PER_NODE)):
        core_hours, wct, data[:,
                              -1] = scale_core_hours(core_hours, data,
                                                     node_time_th_factor)

    return core_hours, wct, data[:, -1] * hyperthreading_factor
Example #11
0
    def predict(self, X: np.ndarray, logger: Logger = get_basic_logger()):
        """Performs the actual prediction using the current model

        For full doc see WCEstModel.predict
        """
        if not self.is_trained:
            logger.log(
                NOPRINTCRITICAL, "There was an attempt to use an untrained model"
            )
            raise Exception("This model has not been trained!")

        return self._model.predict(X).reshape(-1)
Example #12
0
def test_update_live_db(mgmt_db):
    mgmt_db.update_entries_live(
        [
            SchedulerTask(TEST_RUN_NAME, TEST_PROC[0], TEST_STATUS[0], None,
                          None)
        ],
        get_basic_logger(),
    )
    value = get_rows(mgmt_db.db_file,
                     "state",
                     "proc_type",
                     TEST_PROC[0],
                     selected_col="status")[0][0]
    assert value == TEST_STATUS[0]

    mgmt_db.close_conn()
Example #13
0
def est_HF_chours_single(
        fd_count: int,
        nsub_stoch: float,
        nt: int,
        n_logical_cores: int,
        model: Union[str, EstModel],
        scale_ncores: bool,
        node_time_th_factor: float = 1.0,
        model_type: const.EstModelType = DEFAULT_MODEL_TYPE,
        logger: Logger = get_basic_logger(),
):
    """Convenience function to make a single estimation

    If the input parameters (or even just the order) of the model
    is ever changed, then this function has to be adjusted accordingly.

    Params
    ------
    fd_count, nsub_stoch, nt, n_cores: int, float
        Input features for the model

    Returns
    -------
    core_hours: float
        Estimated number of core hours
    run_time: float
        Estimated run time (hours)
    """
    # Make a numpy array of the input data in the right shape
    # The order of the features has to the same as for training!!
    data = np.array([
        float(fd_count),
        float(nsub_stoch),
        float(nt),
        float(n_logical_cores)
    ]).reshape(1, 4)

    core_hours, run_time, n_cpus = estimate_HF_chours(
        data,
        model,
        scale_ncores,
        node_time_th_factor=node_time_th_factor,
        model_type=model_type,
        logger=logger,
    )

    return core_hours[0], run_time[0], int(n_cpus[0])
def add_to_queue(
        queue_folder: str,
        run_name: str,
        proc_type: int,
        status: int,
        job_id: int = None,
        error: str = None,
        logger: Logger = get_basic_logger(),
):
    """Adds an update entry to the queue"""
    logger.debug(
        "Adding task to the queue. Realisation: {}, process type: {}, status: {}, job_id: {}, error: {}"
        .format(run_name, proc_type, status, job_id, error))
    filename = os.path.join(
        queue_folder,
        "{}.{}.{}".format(datetime.now().strftime(const.QUEUE_DATE_FORMAT),
                          run_name, proc_type),
    )

    if os.path.exists(filename):
        logger.log(
            NOPRINTCRITICAL,
            "An update with the name {} already exists. This should never happen. Quitting!"
            .format(os.path.basename(filename)),
        )
        raise Exception(
            "An update with the name {} already exists. This should never happen. Quitting!"
            .format(os.path.basename(filename)))

    logger.debug("Writing update file to {}".format(filename))

    with open(filename, "w") as f:
        json.dump(
            {
                MgmtDB.col_run_name: run_name,
                MgmtDB.col_proc_type: proc_type,
                MgmtDB.col_status: status,
                MgmtDB.col_job_id: job_id,
                "error": error,
            },
            f,
        )

    if not os.path.isfile(filename):
        logger.critical("File {} did not successfully write".format(filename))
    else:
        logger.debug("Successfully wrote task update file")
def aggregate_simulation_empirical_im_permutations(
    fsf, n_processes, sim_root, version, logger: Logger = get_basic_logger()):
    events = load_fault_selection_file(fsf)
    logger.debug(f"Loaded {len(events)} events from the fault selection file")
    events = [
        name if count == 1 else get_realisation_name(name, 1)
        for name, count in events.items()
    ]
    worker_pool = Pool(n_processes)
    worker_pool.starmap(
        agg_emp_perms,
        [(
            pathlib.Path(get_empirical_dir(sim_root, event)),
            event,
            version,
            get_realisation_logger(logger, event).name,
        ) for event in events],
    )
def check_mgmt_queue(queue_entries: List[str],
                     run_name: str,
                     proc_type: int,
                     logger=get_basic_logger()):
    """Returns True if there are any queued entries for this run_name and process type,
    otherwise returns False.
    """
    logger.debug(
        "Checking to see if the realisation {} has a process of type {} in updates folder"
        .format(run_name, proc_type))
    for entry in queue_entries:
        logger.debug("Checking against {}".format(entry))
        _, entry_run_name, entry_proc_type = entry.split(".")
        if entry_run_name == run_name and entry_proc_type == str(proc_type):
            logger.debug("It's a match, returning True")
            return True
    logger.debug("No match found")
    return False
Example #17
0
def agg_emp_perms(
        empirical_dir: pathlib.Path,
        realisation: str,
        version: str,
        aggregation_logger: Union[Logger, str] = get_basic_logger(),
):
    """
    Generates aggregated empirical permutations for a given realisation within a simulation run
    :param empirical_dir: The directory of the event or fault realisation to generate aggregation files for
    :param realisation: The name of the realisation being aggregrated
    :param version: The version of the simulation. e.g. the perturbation version
    :param aggregation_logger: The logger object or name of required logger object to be used for logging
    """
    if isinstance(aggregation_logger, str):
        aggregation_logger = get_logger(aggregation_logger)

    event = get_fault_from_realisation(realisation)
    empirical_files = empirical_dir.glob(f"{event}_*.csv")

    ims = {}
    for f in empirical_files:
        im = f.stem.split("_")[-1]
        if im not in ims.keys():
            ims[im] = []
        ims[im].append(f)

    if not ims:
        aggregation_logger.error("No empirical IM files found, exiting")
        return

    aggregation_logger.debug(
        f"Found {sum([len(fs) for fs in ims.values()])} empirical IM files")

    groups = calculate_aggregation_groups(ims)

    aggregation_logger.debug(f"Created {len(groups)} aggregated IM groups")
    for group in groups:
        identifier = get_agg_identifier(event, group)
        aggregation_logger.debug(
            f"The identifier {identifier} is being used for the IM group {group}"
        )
        aggregate_data(group, empirical_dir, identifier, event, version)
        aggregation_logger.debug(
            f"Saved empirical IMs to {empirical_dir / f'{identifier}.csv'}")
def install_bb(
        stat_file,
        root_dict,
        v1d_dir,
        v1d_full_path=None,
        site_v1d_dir=None,
        hf_stat_vs_ref=None,
        logger: Logger = get_basic_logger(),
):
    shared.show_horizontal_line(c="*")
    logger.info(" " * 37 + "EMOD3D HF/BB Preparation Ver.slurm")
    shared.show_horizontal_line(c="*")
    if v1d_full_path is not None:
        v_mod_1d_selected = v1d_full_path
        # temporary removed because master version of bb_sim does not take this as a argument
        # TODO: most of these logic are not required and should be removed
        # these logic are now depending on gmsim_version_template
        # root_dict["bb"]["site_specific"] = False
        root_dict["hf"][HF_VEL_MOD_1D] = v_mod_1d_selected

    # TODO:add in logic for site specific as well, if the user provided as args
    elif site_v1d_dir is not None and hf_stat_vs_ref is not None:
        hf_vel_mod_1d, hf_stat_vs_ref = shared.get_site_specific_path(
            os.path.dirname(stat_file),
            hf_stat_vs_ref=hf_stat_vs_ref,
            v1d_mod_dir=site_v1d_dir,
            logger=logger,
        )
        # root_dict["bb"]["site_specific"] = True
        root_dict["hf"][HF_VEL_MOD_1D] = hf_vel_mod_1d
        root_dict["hf_stat_vs_ref"] = hf_stat_vs_ref
    else:
        is_site_specific_id = q_site_specific()
        if is_site_specific_id:
            hf_vel_mod_1d, hf_stat_vs_ref = shared.get_site_specific_path(
                os.path.dirname(stat_file), logger=logger)
            # root_dict["bb"]["site_specific"] = True
            root_dict["hf"][HF_VEL_MOD_1D] = hf_vel_mod_1d
            root_dict["hf_stat_vs_ref"] = hf_stat_vs_ref
        else:
            hf_vel_mod_1d, v_mod_1d_selected = q_1d_velocity_model(v1d_dir)
            # root_dict["bb"]["site_specific"] = False
            root_dict["hf"][HF_VEL_MOD_1D] = v_mod_1d_selected
Example #19
0
def get_queue_entry(
    entry_file: str, queue_logger: Logger = qclogging.get_basic_logger()
):
    try:
        with open(entry_file, "r") as f:
            data_dict = json.load(f)
    except json.JSONDecodeError as ex:
        queue_logger.error(
            "Failed to decode the file {} as json. Check that this is "
            "valid json. Ignored!".format(entry_file)
        )
        return None

    return SchedulerTask(
        run_name=os.path.basename(entry_file).split(".")[1],
        proc_type=data_dict[MgmtDB.col_proc_type],
        status=data_dict[MgmtDB.col_status],
        job_id=data_dict[MgmtDB.col_job_id],
        error=data_dict.get("error"),
    )
Example #20
0
    def predict(
        self, X: np.ndarray, warning: bool = True, logger: Logger = get_basic_logger()
    ):
        """Performs the actual prediction using the current model

        For full doc see WCEstModel.predict
        """
        if not self.is_trained:
            logger.log(
                NOPRINTCRITICAL, "There was an attempt to use an untrained model"
            )
            raise Exception("This model has not been trained!")

        if np.any(self.get_out_of_bounds_mask(X)) and warning:
            print(
                "WARNING: Some of the data specified for estimation exceeds the "
                "limits of the data the model was trained. This will result in "
                "incorrect estimation!"
            )

        return self._model.predict(X).reshape(-1)
Example #21
0
def get_site_specific_path(
        stat_file_path,
        hf_stat_vs_ref=None,
        v1d_mod_dir=None,
        logger: Logger = get_basic_logger(),
):
    show_horizontal_line()
    logger.info("Auto-detecting site-specific info")
    show_horizontal_line()
    logger.info("- Station file path: %s" % stat_file_path)

    if v1d_mod_dir is not None:
        v_mod_1d_path = v1d_mod_dir
    else:
        v_mod_1d_path = os.path.join(os.path.dirname(stat_file_path), "1D")
    if os.path.exists(v_mod_1d_path):
        logger.info("- 1D profiles found at {}".format(v_mod_1d_path))
    else:
        logger.critical("Error: No such path exists: {}".format(v_mod_1d_path))
        sys.exit()
    if hf_stat_vs_ref is None:
        hf_stat_vs_ref_options = glob.glob(
            os.path.join(stat_file_path, "*.hfvs30ref"))
        if len(hf_stat_vs_ref_options) == 0:
            logger.critical("Error: No HF Vsref file was found at {}".format(
                stat_file_path))
            sys.exit()
        hf_stat_vs_ref_options.sort()

        show_horizontal_line()
        logger.info("Select one of HF Vsref files")
        show_horizontal_line()
        hf_stat_vs_ref_selected = show_multiple_choice(hf_stat_vs_ref_options)
        logger.info(
            " - HF Vsref tp be used: {}".format(hf_stat_vs_ref_selected))
    else:
        hf_stat_vs_ref_selected = hf_stat_vs_ref
    return v_mod_1d_path, hf_stat_vs_ref_selected
Example #22
0
def load_full_model(
        dir: str,
        model_type: const.EstModelType = DEFAULT_MODEL_TYPE,
        logger: Logger = get_basic_logger(),
):
    """Loads the full model, i.e. the estimation model(s) and their associated scaler.

    Returns an EstModel object.
    """
    # Load just NN
    if model_type is const.EstModelType.NN:
        return EstModel(
            load_model(dir, const.EST_MODEL_NN_PREFIX, "h5", NNWcEstModel,
                       logger),
            load_scaler(dir, SCALER_PREFIX.format("NN"), logger),
            None,
            None,
        )
    # Load just SVR
    elif model_type is const.EstModelType.SVR:
        return EstModel(
            None,
            None,
            load_model(dir, const.EST_MODEL_SVR_PREFIX, "pickle", SVRModel,
                       logger),
            load_scaler(dir, SCALER_PREFIX.format("SVR"), logger),
        )
    # Load both
    elif model_type is const.EstModelType.NN_SVR:
        return EstModel(
            load_model(dir, const.EST_MODEL_NN_PREFIX, "h5", NNWcEstModel,
                       logger),
            load_scaler(dir, SCALER_PREFIX.format("NN"), logger),
            load_model(dir, const.EST_MODEL_SVR_PREFIX, "pickle", SVRModel,
                       logger),
            load_scaler(dir, SCALER_PREFIX.format("SVR"), logger),
        )
Example #23
0
def parse_config_file(config_file_location: str,
                      logger: Logger = qclogging.get_basic_logger()):
    """Takes in the location of a wrapper config file and creates the tasks to be run.
    Requires that the file contains the keys 'run_all_tasks' and 'run_some', even if they are empty
    If the dependencies for a run_some task overlap with those in the tasks_to_run_for_all, as a race condition is
    possible if multiple auto_submit scripts have the same tasks. If multiple run_some instances have the same
    dependencies then this is not an issue as they run sequentially, rather than simultaneously
    :param config_file_location: The location of the config file
    :return: A tuple containing the tasks to be run on all processes and a list of pattern, tasks tuples which state
    which tasks can be run with which patterns
    """
    config = load_yaml(config_file_location)

    tasks_to_run_for_all = []
    tasks_with_pattern_match = {}

    for proc_name, pattern in config.items():
        proc = const.ProcessType.from_str(proc_name)
        if pattern == ALL:
            tasks_to_run_for_all.append(proc)
        elif pattern == NONE:
            pass
        else:
            if isinstance(pattern, str):
                pattern = [pattern]
            for subpattern in pattern:
                if subpattern == ONCE:
                    subpattern = ONCE_PATTERN
                if subpattern not in tasks_with_pattern_match.keys():
                    tasks_with_pattern_match.update({subpattern: []})
                tasks_with_pattern_match[subpattern].append(proc)
    logger.info("Master script will run {}".format(tasks_to_run_for_all))
    for pattern, tasks in tasks_with_pattern_match.items():
        logger.info("Pattern {} will run tasks {}".format(pattern, tasks))

    return tasks_to_run_for_all, tasks_with_pattern_match.items()
Example #24
0
    def _check_dependancy_met(self, task, logger=get_basic_logger()):
        """Checks if all dependencies for the specified are met"""
        process, run_name = task
        process = Process(process)

        with connect_db_ctx(self._db_file) as cur:
            completed_tasks = cur.execute(
                """SELECT proc_type 
                          FROM status_enum, state 
                          WHERE state.status = status_enum.id
                           AND run_name = (?)
                           AND status_enum.state = 'completed'""",
                (run_name,),
            ).fetchall()
        logger.debug(
            "Considering task {} for realisation {}. Completed tasks as follows: {}".format(
                process, run_name, completed_tasks
            )
        )
        remaining_deps = process.get_remaining_dependencies(
            [const.ProcessType(x[0]) for x in completed_tasks]
        )
        logger.debug("{} has remaining deps: {}".format(task, remaining_deps))
        return len(remaining_deps) == 0
def parse_config_file(config_file_location: str,
                      logger: Logger = get_basic_logger()):
    """Takes in the location of a wrapper config file and creates the tasks to be run.
    Each task that is desired to be run should have its name as given in qcore.constants followed by the relevant
    keyword or sqlite formatted query string, which uses % as the wildcard character.
    The keywords NONE, ONCE and ALL correspond to the patterns nothing, "%_REL01", "%" respectively.
    :param config_file_location: The location of the config file
    :param logger: The logger object used to record messages
    :return: A list containing the tasks to be run on all processes and a dictionary of pattern, task list pairs which
    state which query patterns should run which tasks
    """
    config = load_yaml(config_file_location)

    tasks_to_run_for_all = []
    tasks_with_pattern_match = {}

    for proc_name, patterns in config.items():
        proc = const.ProcessType.from_str(proc_name)
        if not isinstance(patterns, list):
            patterns = [patterns]
        for pattern in patterns:
            if pattern == ALL:
                tasks_to_run_for_all.append(proc)
            elif pattern == NONE:
                pass
            else:
                if pattern == ONCE:
                    pattern = ONCE_PATTERN
                if pattern not in tasks_with_pattern_match.keys():
                    tasks_with_pattern_match.update({pattern: []})
                tasks_with_pattern_match[pattern].append(proc)
    logger.info("Master script will run {}".format(tasks_to_run_for_all))
    for pattern, tasks in tasks_with_pattern_match.items():
        logger.info("Pattern {} will run tasks {}".format(pattern, tasks))

    return tasks_to_run_for_all, tasks_with_pattern_match.items()
Example #26
0
    def update_entries_live(
        self,
        entries: List[SchedulerTask],
        retry_max: int,
        logger: Logger = get_basic_logger(),
    ):
        """Updates the specified entries in the db. Leaves the connection open,
        so this should only be used when continuously updating entries.
        """
        try:
            if self._conn is None:
                logger.info("Aquiring db connection.")
                self._conn = sql.connect(self._db_file)
            logger.debug("Getting db cursor")

            cur = self._conn.cursor()
            cur.execute("BEGIN")
            for entry in entries:
                process = entry.proc_type
                realisation_name = entry.run_name

                logger.debug(
                    "The status of process {} for realisation {} is being set to {}. It has slurm id {}".format(
                        entry.proc_type, entry.run_name, entry.status, entry.job_id
                    )
                )

                if entry.status == const.Status.created.value:
                    # Something has attempted to set a task to created
                    # Make a new task with created status and move to the next task
                    logger.debug("Adding new task to the db")
                    # Check that there isn't already a task with the same realisation name
                    if self._does_task_exists(cur, realisation_name, process):
                        logger.debug(
                            "task is already in progress - does not need to be readded"
                        )
                        continue
                    self._insert_task(cur, realisation_name, process)
                    logger.debug("New task added to the db, continuing to next process")
                    continue
                logger.debug("Updating task in the db")
                self._update_entry(cur, entry, logger=logger)
                logger.debug("Task successfully updated")

                if (
                    entry.status == const.Status.failed.value
                    and self.get_retries(process, realisation_name) < retry_max
                ):
                    # The task was failed. If there have been few enough other attempts at the task make another one
                    logger.debug(
                        "Task failed but is able to be retried. Adding new task to the db"
                    )
                    self._insert_task(cur, realisation_name, process)
                    logger.debug("New task added to the db")

                # fails dependant task if parent task fails
                if entry.status == const.Status.failed.value:
                    tasks = MgmtDB.find_dependant_task(cur, entry)
                    i = 0
                    while i < len(tasks):
                        task = tasks[i]
                        self._update_entry(cur, task, logger=logger)
                        logger.debug(
                            f"Cascading failure for {entry.run_name} - {task.proc_type}"
                        )
                        tasks.extend(MgmtDB.find_dependant_task(cur, task))
                        i += 1

        except sql.Error as ex:
            self._conn.rollback()
            logger.critical(
                "Failed to update entry {}, due to the exception: \n{}".format(
                    entry, ex
                )
            )
            return False
        else:
            logger.debug("Committing changes to db")
            self._conn.commit()
        finally:
            logger.debug("Closing db cursor")
            cur.close()

        return True
Example #27
0
    def get_runnable_tasks(
        self,
        allowed_rels,
        task_limit,
        update_files,
        allowed_tasks=None,
        logger=get_basic_logger(),
    ):
        """Gets all runnable tasks based on their status and their associated
        dependencies (i.e. other tasks have to be finished first)

        Returns a list of tuples (proc_type, run_name, state_str)
        """
        if allowed_tasks is None:
            allowed_tasks = list(const.ProcessType)
        allowed_tasks = [str(task.value) for task in allowed_tasks]

        if len(allowed_tasks) == 0:
            return []

        runnable_tasks = []
        offset = 0

        # "{}__{}" is intended to be the template for a unique string for every realisation and process type pair
        # Used to compare with database entries to prevent running a task that has already been submitted, but not
        # recorded
        tasks_waiting_for_updates = [
            "{}__{}".format(*(entry.split(".")[1:3])) for entry in update_files
        ]

        with connect_db_ctx(self._db_file) as cur:
            entries = cur.execute(
                """SELECT COUNT(*) 
                          FROM status_enum, state 
                          WHERE state.status = status_enum.id
                           AND proc_type IN (?{})
                           AND run_name LIKE (?)
                           AND status_enum.state = 'created'""".format(
                    ",?" * (len(allowed_tasks) - 1)
                ),
                (*allowed_tasks, allowed_rels),
            ).fetchone()[0]

            while len(runnable_tasks) < task_limit and offset < entries:
                db_tasks = cur.execute(
                    """SELECT proc_type, run_name 
                              FROM status_enum, state 
                              WHERE state.status = status_enum.id
                               AND proc_type IN (?{})
                               AND run_name LIKE (?)
                                   AND status_enum.state = 'created'
                                   LIMIT 100 OFFSET ?""".format(
                        ",?" * (len(allowed_tasks) - 1)
                    ),
                    (*allowed_tasks, allowed_rels, offset),
                ).fetchall()
                runnable_tasks.extend(
                    [
                        (*task, self.get_retries(*task))
                        for task in db_tasks
                        if self._check_dependancy_met(task, logger)
                        and "{}__{}".format(*task) not in tasks_waiting_for_updates
                    ]
                )
                offset += 100

        return runnable_tasks
def install_fault(
    fault_name,
    n_rel,
    root_folder,
    version,
    stat_file_path,
    seed=HF_DEFAULT_SEED,
    extended_period=False,
    vm_perturbations=False,
    ignore_vm_perturbations=False,
    vm_qpqs_files=False,
    ignore_vm_qpqs_files=False,
    keep_dup_station=True,
    components=None,
    logger: Logger = get_basic_logger(),
):

    config_dict = utils.load_yaml(
        os.path.join(
            platform_config[PLATFORM_CONFIG.TEMPLATES_DIR.name],
            "gmsim",
            version,
            ROOT_DEFAULTS_FILE_NAME,
        )
    )
    # Load variables from cybershake config

    v1d_full_path = os.path.join(
        platform_config[PLATFORM_CONFIG.VELOCITY_MODEL_DIR.name],
        "Mod-1D",
        config_dict.get("v_1d_mod"),
    )
    site_v1d_dir = config_dict.get("site_v1d_dir")
    hf_stat_vs_ref = config_dict.get("hf_stat_vs_ref")

    vs30_file_path = stat_file_path.replace(".ll", ".vs30")
    vs30ref_file_path = stat_file_path.replace(".ll", ".vs30ref")

    # this variable has to be empty
    # TODO: fix this legacy issue, very low priority
    event_name = ""

    # get all srf from source
    srf_dir = simulation_structure.get_srf_dir(root_folder, fault_name)

    list_srf = glob.glob(os.path.join(srf_dir, "*_REL*.srf"))
    if len(list_srf) == 0:
        list_srf = glob.glob(os.path.join(srf_dir, "*.srf"))

    list_srf.sort()
    if n_rel is not None and len(list_srf) != n_rel:
        message = (
            "Error: fault {} failed. Number of realisations do "
            "not match number of SRF files".format(fault_name)
        )
        logger.log(NOPRINTCRITICAL, message)
        raise RuntimeError(message)

    # Get & validate velocity model directory
    vel_mod_dir = simulation_structure.get_fault_VM_dir(root_folder, fault_name)
    valid_vm, message = validate_vm.validate_vm(vel_mod_dir, srf=list_srf[0])
    if not valid_vm:
        message = "Error: VM {} failed {}".format(fault_name, message)
        logger.log(NOPRINTCRITICAL, message)
        raise RuntimeError(message)
    # Load the variables from vm_params.yaml
    vm_params_path = os.path.join(vel_mod_dir, VM_PARAMS_FILE_NAME)
    vm_params_dict = utils.load_yaml(vm_params_path)
    yes_model_params = (
        False  # statgrid should normally be already generated with Velocity Model
    )

    sim_root_dir = simulation_structure.get_runs_dir(root_folder)
    fault_yaml_path = simulation_structure.get_fault_yaml_path(sim_root_dir, fault_name)
    root_yaml_path = simulation_structure.get_root_yaml_path(sim_root_dir)
    for srf in list_srf:
        logger.info("Installing {}".format(srf))
        # try to match find the stoch with same basename
        realisation_name = os.path.splitext(os.path.basename(srf))[0]
        stoch_file_path = simulation_structure.get_stoch_path(
            root_folder, realisation_name
        )
        sim_params_file = simulation_structure.get_source_params_path(
            root_folder, realisation_name
        )

        if not os.path.isfile(stoch_file_path):
            message = "Error: Corresponding Stoch file is not found: {}".format(
                stoch_file_path
            )
            logger.log(NOPRINTCRITICAL, message)
            raise RuntimeError(message)

        # install pairs one by one to fit the new structure
        sim_dir = simulation_structure.get_sim_dir(root_folder, realisation_name)

        (root_params_dict, fault_params_dict, sim_params_dict) = install_simulation(
            version=version,
            sim_dir=sim_dir,
            rel_name=realisation_name,
            run_dir=sim_root_dir,
            vel_mod_dir=vel_mod_dir,
            srf_file=srf,
            stoch_file=stoch_file_path,
            stat_file_path=stat_file_path,
            vs30_file_path=vs30_file_path,
            vs30ref_file_path=vs30ref_file_path,
            yes_statcords=False,
            fault_yaml_path=fault_yaml_path,
            root_yaml_path=root_yaml_path,
            cybershake_root=root_folder,
            site_v1d_dir=site_v1d_dir,
            hf_stat_vs_ref=hf_stat_vs_ref,
            v1d_full_path=v1d_full_path,
            sim_params_file=sim_params_file,
            seed=seed,
            logger=logger,
            extended_period=extended_period,
            vm_perturbations=vm_perturbations,
            ignore_vm_perturbations=ignore_vm_perturbations,
            vm_qpqs_files=vm_qpqs_files,
            ignore_vm_qpqs_files=ignore_vm_qpqs_files,
            components=components,
        )

        if (
            root_params_dict is None
            or fault_params_dict is None
            or sim_params_dict is None
        ):
            # Something has gone wrong, returning without saving anything
            logger.critical(f"Critical Error some params dictionary are None")
            return

        if root_params_dict is not None and not isclose(
            vm_params_dict["flo"], root_params_dict["flo"]
        ):
            logger.critical(
                "The parameter 'flo' does not match in the VM params and root params files. "
                "Please ensure you are installing the correct gmsim version"
            )
            return

        create_mgmt_db.create_mgmt_db(
            [], simulation_structure.get_mgmt_db(root_folder), srf_files=srf
        )
        utils.setup_dir(os.path.join(root_folder, "mgmt_db_queue"))
        root_params_dict["mgmt_db_location"] = root_folder

        # Generate the fd files, create these at the fault level
        fd_statcords, fd_statlist = generate_fd_files(
            simulation_structure.get_fault_dir(root_folder, fault_name),
            vm_params_dict,
            stat_file=stat_file_path,
            logger=logger,
            keep_dup_station=keep_dup_station,
        )

        fault_params_dict[FaultParams.stat_coords.value] = fd_statcords
        fault_params_dict[FaultParams.FD_STATLIST.value] = fd_statlist

        #     root_params_dict['hf_stat_vs_ref'] = cybershake_cfg['hf_stat_vs_ref']
        dump_all_yamls(sim_dir, root_params_dict, fault_params_dict, sim_params_dict)

        # test if the params are accepted by steps HF and BB
        sim_params = utils.load_sim_params(os.path.join(sim_dir, "sim_params.yaml"))
        # check hf

        # temporary change the script name to hf_sim, due to how error message are shown
        main_script_name = sys.argv[0]
        sys.argv[0] = "hf_sim.py"

        command_template, add_args = hf_gen_command_template(
            sim_params, list(HPC)[0].name, seed
        )
        run_command = gen_args_cmd(
            ProcessType.HF.command_template, command_template, add_args
        )
        hf_args_parser(cmd=run_command)

        # check bb
        sys.argv[0] = "bb_sim.py"

        command_template, add_args = bb_gen_command_template(sim_params)
        run_command = gen_args_cmd(
            ProcessType.BB.command_template, command_template, add_args
        )
        bb_args_parser(cmd=run_command)
        # change back, to prevent unexpected error
        sys.argv[0] = main_script_name
Example #29
0
def main(
        args: argparse.Namespace,
        est_model: est.EstModel = None,
        logger: Logger = get_basic_logger(),
):
    params = utils.load_sim_params(
        os.path.join(args.rel_dir, "sim_params.yaml"))

    submit_yes = True if args.auto else confirm("Also submit the job for you?")

    logger.debug("params.srf_file {}".format(params.srf_file))
    # Get the srf(rup) name without extensions
    srf_name = os.path.splitext(os.path.basename(params.srf_file))[0]
    if args.srf is None or srf_name == args.srf:
        logger.debug("not set_params_only")
        # get lf_sim_dir
        sim_dir = os.path.abspath(params.sim_dir)
        lf_sim_dir = sim_struct.get_lf_dir(sim_dir)

        # default_core will be changed is user passes ncore
        nt = int(float(params.sim_duration) / float(params.dt))

        target_qconfig = get_machine_config(args.machine)

        retries = args.retries if hasattr(args, "retries") else None

        est_cores, est_run_time, wct = get_lf_cores_and_wct(
            est_model,
            logger,
            nt,
            params,
            sim_dir,
            srf_name,
            target_qconfig,
            args.ncore,
            retries,
        )

        binary_path = binary_version.get_lf_bin(params.emod3d.emod3d_version,
                                                target_qconfig["tools_dir"])
        # use the original estimated run time for determining the checkpoint, or uses a minimum of 3 checkpoints
        steps_per_checkpoint = int(
            min(nt / (60.0 * est_run_time) * const.CHECKPOINT_DURATION,
                nt // 3))
        write_directory = (args.write_directory
                           if args.write_directory else params.sim_dir)

        set_runparams.create_run_params(
            sim_dir, steps_per_checkpoint=steps_per_checkpoint, logger=logger)

        header_dict = {
            "wallclock_limit": wct,
            "job_name": "emod3d.{}".format(srf_name),
            "job_description": "emod3d slurm script",
            "additional_lines": "#SBATCH --hint=nomultithread",
            "platform_specific_args":
            get_platform_node_requirements(est_cores),
        }

        command_template_parameters = {
            "run_command":
            platform_config[const.PLATFORM_CONFIG.RUN_COMMAND.name],
            "emod3d_bin": binary_path,
            "lf_sim_dir": lf_sim_dir,
        }

        body_template_params = ("run_emod3d.sl.template", {})

        script_prefix = "run_emod3d_{}".format(srf_name)
        script_file_path = write_sl_script(
            write_directory,
            params.sim_dir,
            const.ProcessType.EMOD3D,
            script_prefix,
            header_dict,
            body_template_params,
            command_template_parameters,
        )
        if submit_yes:
            submit_script_to_scheduler(
                script_file_path,
                const.ProcessType.EMOD3D.value,
                sim_struct.get_mgmt_db_queue(params.mgmt_db_location),
                params.sim_dir,
                srf_name,
                target_machine=args.machine,
                logger=logger,
            )
def create_run_params(
        sim_dir,
        srf_name=None,
        steps_per_checkpoint=None,
        logger: Logger = get_basic_logger(),
):
    params = utils.load_sim_params(os.path.join(sim_dir, "sim_params.yaml"))

    emod3d_version = params["emod3d"]["emod3d_version"]
    emod3d_filepath = binary_version.get_lf_bin(emod3d_version)

    e3d_yaml = os.path.join(
        platform_config[constants.PLATFORM_CONFIG.TEMPLATES_DIR.name],
        "gmsim",
        params.version,
        "emod3d_defaults.yaml",
    )
    e3d_dict = utils.load_yaml(e3d_yaml)
    # skip all logic if a specific srf_name is provided
    if srf_name is None or srf_name == os.path.splitext(
            basename(params.srf_file))[0]:

        # EMOD3D adds a timeshift to the event rupture time
        # this must be accounted for as EMOD3D does not extend the sim duration by the amount of time shift
        # As flo is in Hz, the sim_duration_extension is in s
        # Version 3.0.4 was the last version of EMOD3D to have a shift of 1/flo,
        # while versions after it have a shift of 3/flo
        sim_duration_extension = 1 / float(params.flo)
        if compare_versions(emod3d_version,
                            MAXIMUM_EMOD3D_TIMESHIFT_1_VERSION) > 0:
            sim_duration_extension *= 3

        extended_sim_duration = float(
            params.sim_duration) + sim_duration_extension

        srf_file_basename = os.path.splitext(os.path.basename(
            params.srf_file))[0]
        e3d_dict["version"] = emod3d_version + "-mpi"

        e3d_dict["name"] = params.run_name
        e3d_dict["n_proc"] = 512

        e3d_dict["nx"] = params.nx
        e3d_dict["ny"] = params.ny
        e3d_dict["nz"] = params.nz
        e3d_dict["h"] = params.hh
        e3d_dict["dt"] = params.dt

        e3d_dict["nt"] = str(
            int(round(extended_sim_duration / float(params.dt))))
        e3d_dict["flo"] = float(params.flo)

        e3d_dict["faultfile"] = params.srf_file

        e3d_dict["vmoddir"] = params.vel_mod_dir

        e3d_dict["modellon"] = params.MODEL_LON
        e3d_dict["modellat"] = params.MODEL_LAT
        e3d_dict["modelrot"] = params.MODEL_ROT

        e3d_dict["main_dump_dir"] = os.path.join(params.sim_dir, "LF",
                                                 "OutBin")
        e3d_dict["seiscords"] = params.stat_coords
        e3d_dict["user_scratch"] = os.path.join(params.user_root, "scratch")
        e3d_dict["seisdir"] = os.path.join(e3d_dict["user_scratch"],
                                           params.run_name, srf_file_basename,
                                           "SeismoBin")

        e3d_dict["ts_total"] = str(
            int(extended_sim_duration /
                (float(e3d_dict["dt"]) * float(e3d_dict["dtts"]))))
        e3d_dict["ts_file"] = os.path.join(e3d_dict["main_dump_dir"],
                                           params.run_name + "_xyts.e3d")
        e3d_dict["ts_out_dir"] = os.path.join(params.sim_dir, "LF", "TSlice",
                                              "TSFiles")

        e3d_dict["restartdir"] = os.path.join(params.sim_dir, "LF", "Restart")
        if steps_per_checkpoint:
            e3d_dict["dump_itinc"] = e3d_dict["restart_itinc"] = int(
                steps_per_checkpoint)

        e3d_dict["restartname"] = params.run_name
        e3d_dict["logdir"] = os.path.join(params.sim_dir, "LF", "Rlog")
        e3d_dict["slipout"] = os.path.join(params.sim_dir, "LF", "SlipOut",
                                           "slipout-k2")

        # other locations
        e3d_dict["wcc_prog_dir"] = emod3d_filepath
        e3d_dict["vel_mod_params_dir"] = params.vel_mod_dir
        e3d_dict["sim_dir"] = params.sim_dir
        e3d_dict["stat_file"] = params.stat_file
        e3d_dict["grid_file"] = params.GRIDFILE
        e3d_dict["model_params"] = params.MODEL_PARAMS

        if params.emod3d:
            for key, value in params.emod3d.items():
                if key in e3d_dict:
                    e3d_dict[key] = value
                else:
                    logger.debug(
                        "{} not found as a key in e3d file. Ignoring variable. Value is {}."
                        .format(key, value))

        shared.write_to_py(os.path.join(params.sim_dir, "LF", "e3d.par"),
                           e3d_dict)