def test_mcmc_sync(): info: InputDict = yaml_load(yaml) logger.info('Test end synchronization') if mpi.rank() == 1: max_samples = 200 else: max_samples = 600 # simulate asynchronous ending sampling loop info['sampler']['mcmc'] = {'max_samples': max_samples} updated_info, sampler = run(info) assert len(sampler.products()["sample"]) == max_samples logger.info('Test error synchronization') if mpi.rank() == 0: info['sampler']['mcmc'] = {'max_samples': 'none'} # 'none' not valid with NoLogging(logging.ERROR), pytest.raises(TypeError): run(info) else: with pytest.raises(mpi.OtherProcessError): run(info) logger.info('Test one-process hang abort') aborted = False def test_abort(): nonlocal aborted aborted = True # test error converted into MPI_ABORT after timeout # noinspection PyTypeChecker with pytest.raises( (LoggedError, mpi.OtherProcessError)), NoLogging(logging.ERROR): with mpi.ProcessState('test', time_out_seconds=0.5, timeout_abort_proc=test_abort): if mpi.rank() != 1: time.sleep(0.6) # fake hang else: raise LoggedError(logger, 'Expected test error') if mpi.rank() == 1: assert aborted
def run(self): """ Runs the sampler. """ self.mpi_info("Sampling!" + ( " (NB: no accepted step will be saved until %d burn-in samples " % self.burn_in.value + "have been obtained)" if self.burn_in.value else "")) self.n_steps_raw = 0 last_output: float = 0 last_n = self.n() state_check_every = 1 with mpi.ProcessState(self) as state: while last_n < self.max_samples and not self.converged: self.get_new_sample() self.n_steps_raw += 1 if self.output_every.unit: # if output_every in sec, print some info # and dump at fixed time intervals now = datetime.datetime.now() now_sec = now.timestamp() if now_sec >= last_output + self.output_every.value: self.do_output(now) last_output = now_sec state.check_error() if self.current_point.weight == 1: # have added new point # Callback function n = self.n() if n != last_n: # and actually added last_n = n if (self.callback_function and not (max(n, 1) % self.callback_every.value) and self.current_point.weight == 1): self.callback_function_callable(self) self.last_point_callback = len(self.collection) if more_than_one_process(): # Checking convergence and (optionally) learning # the covmat of the proposal if self.check_ready() and state.set( mpi.State.READY): self.log.info(self._msg_ready + " (waiting for the rest...)") if state.all_ready(): self.mpi_info("All chains are r%s", self._msg_ready[1:]) self.check_convergence_and_learn_proposal() self.i_learn += 1 else: if self.check_ready(): self.log.debug(self._msg_ready) self.check_convergence_and_learn_proposal() self.i_learn += 1 elif self.current_point.weight % state_check_every == 0: state.check_error() # more frequent checks near beginning state_check_every = min(10, state_check_every + 1) if last_n == self.max_samples: self.log.info( "Reached maximum number of accepted steps allowed (%s). " "Stopping.", self.max_samples) # Write the last batch of samples ( < output_every (not in sec)) self.collection.out_update() ns = mpi.gather(self.n()) self.mpi_info("Sampling complete after %d accepted steps.", sum(ns))
def run( info_or_yaml_or_file: Union[InputDict, str, os.PathLike], packages_path: Optional[str] = None, output: Union[str, LiteralFalse, None] = None, debug: Union[bool, int, None] = None, stop_at_error: Optional[bool] = None, resume: bool = False, force: bool = False, no_mpi: bool = False, test: bool = False, override: Optional[InputDict] = None, ) -> Union[InfoSamplerTuple, PostTuple]: """ Run from an input dictionary, file name or yaml string, with optional arguments to override settings in the input as needed. :param info_or_yaml_or_file: input options dictionary, yaml file, or yaml text :param packages_path: path where external packages were installed :param output: path name prefix for output files, or False for no file output :param debug: true for verbose debug output, or a specific logging level :param stop_at_error: stop if an error is raised :param resume: continue an existing run :param force: overwrite existing output if it exists :param no_mpi: run without MPI :param test: only test initialization rather than actually running :param override: option dictionary to merge into the input one, overriding settings (but with lower precedence than the explicit keyword arguments) :return: (updated_info, sampler) tuple of options dictionary and Sampler instance, or (updated_info, results) if using "post" post-processing """ # This function reproduces the model-->output-->sampler pipeline one would follow # when instantiating by hand, but alters the order to performs checks and dump info # as early as possible, e.g. to check if resuming possible or `force` needed. if no_mpi or test: mpi.set_mpi_disabled() with mpi.ProcessState("run"): info: InputDict = load_info_overrides(info_or_yaml_or_file, debug, stop_at_error, packages_path, override) if test: info["test"] = True # If any of resume|force given as cmd args, ignore those in the input file if resume or force: if resume and force: raise ValueError("'rename' and 'force' are exclusive options") info["resume"] = bool(resume) info["force"] = bool(force) if info.get("post"): if isinstance(output, str) or output is False: info["post"]["output"] = output or None return post(info) if isinstance(output, str) or output is False: info["output"] = output or None logger_setup(info.get("debug"), info.get("debug_file")) logger_run = get_logger(run.__name__) # MARKED FOR DEPRECATION IN v3.0 # BEHAVIOUR TO BE REPLACED BY ERROR: check_deprecated_modules_path(info) # END OF DEPRECATION BLOCK # 1. Prepare output driver, if requested by defining an output_prefix # GetDist needs to know the original sampler, so don't overwrite if minimizer try: which_sampler = list(info["sampler"])[0] except (KeyError, TypeError): raise LoggedError( logger_run, "You need to specify a sampler using the 'sampler' key " "as e.g. `sampler: {mcmc: None}.`") infix = "minimize" if which_sampler == "minimize" else None with get_output(prefix=info.get("output"), resume=info.get("resume"), force=info.get("force"), infix=infix) as out: # 2. Update the input info with the defaults for each component updated_info = update_info(info) if is_debug(logger_run): # Dump only if not doing output # (otherwise, the user can check the .updated file) if not out and mpi.is_main_process(): logger_run.info( "Input info updated with defaults (dumped to YAML):\n%s", yaml_dump(sort_cosmetic(updated_info))) # 3. If output requested, check compatibility if existing one, and dump. # 3.1 First: model only out.check_and_dump_info(info, updated_info, cache_old=True, ignore_blocks=["sampler"]) # 3.2 Then sampler -- 1st get the last sampler mentioned in the updated.yaml # TODO: ideally, using Minimizer would *append* to the sampler block. # Some code already in place, but not possible at the moment. try: last_sampler = list(updated_info["sampler"])[-1] last_sampler_info = { last_sampler: updated_info["sampler"][last_sampler] } except (KeyError, TypeError): raise LoggedError(logger_run, "No sampler requested.") sampler_name, sampler_class = get_sampler_name_and_class( last_sampler_info) check_sampler_info((out.reload_updated_info(use_cache=True) or {}).get("sampler"), updated_info["sampler"], is_resuming=out.is_resuming()) # Dump again, now including sampler info out.check_and_dump_info(info, updated_info, check_compatible=False) # Check if resumable run sampler_class.check_force_resume( out, info=updated_info["sampler"][sampler_name]) # 4. Initialize the posterior and the sampler with Model(updated_info["params"], updated_info["likelihood"], updated_info.get("prior"), updated_info.get("theory"), packages_path=info.get("packages_path"), timing=updated_info.get("timing"), allow_renames=False, stop_at_error=info.get("stop_at_error", False)) as model: # Re-dump the updated info, now containing parameter routes and version updated_info = recursive_update(updated_info, model.info()) out.check_and_dump_info(None, updated_info, check_compatible=False) sampler = sampler_class( updated_info["sampler"][sampler_name], model, out, name=sampler_name, packages_path=info.get("packages_path")) # Re-dump updated info, now also containing updates from the sampler updated_info["sampler"][sampler_name] = \ recursive_update(updated_info["sampler"][sampler_name], sampler.info()) out.check_and_dump_info(None, updated_info, check_compatible=False) mpi.sync_processes() if info.get("test", False): logger_run.info( "Test initialization successful! " "You can probably run now without `--%s`.", "test") return InfoSamplerTuple(updated_info, sampler) # Run the sampler sampler.run() return InfoSamplerTuple(updated_info, sampler)