def run(self, *args, **kwargs): pyro.set_rng_seed(self.rng_seed) torch.set_default_tensor_type(self.default_tensor_type) # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer" # at https://github.com/pytorch/pytorch/issues/10375 args = [ arg.clone().detach() if (torch.is_tensor(arg) and arg.is_cuda) else arg for arg in args ] kwargs = kwargs logger = logging.getLogger("pyro.infer.mcmc") logger_id = "CHAIN:{}".format(self.chain_id) log_queue = self.log_queue logger = initialize_logger(logger, logger_id, None, log_queue) logging_hook = _add_logging_hook(logger, None, self.hook) try: for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, logging_hook, None, *args, **kwargs): self.result_queue.put_nowait((self.chain_id, sample)) self.event.wait() self.event.clear() self.result_queue.put_nowait((self.chain_id, None)) except Exception as e: logger.exception(e) self.result_queue.put_nowait((self.chain_id, e))
def __init__(self, kernel, num_samples, warmup_steps, num_chains, mp_context, disable_progbar, initial_params=None, hook=None): self.kernel = kernel self.warmup_steps = warmup_steps self.num_chains = num_chains self.hook = hook self.workers = [] self.ctx = mp if mp_context: self.ctx = mp.get_context(mp_context) self.result_queue = self.ctx.Queue() self.log_queue = self.ctx.Queue() self.logger = initialize_logger(logging.getLogger("pyro.infer.mcmc"), "MAIN", log_queue=self.log_queue) self.num_samples = num_samples self.initial_params = initial_params self.log_thread = threading.Thread( target=logger_thread, args=(self.log_queue, self.warmup_steps, self.num_samples, self.num_chains, disable_progbar)) self.log_thread.daemon = True self.log_thread.start() self.events = [self.ctx.Event() for _ in range(num_chains)]
def __init__(self, kernel, num_samples, warmup_steps, num_chains, mp_context, disable_progbar): super(_ParallelSampler, self).__init__() self.kernel = kernel self.warmup_steps = warmup_steps self.num_chains = num_chains self.workers = [] self.ctx = mp if mp_context: if six.PY2: raise ValueError("multiprocessing.get_context() is " "not supported in Python 2.") self.ctx = mp.get_context(mp_context) self.result_queue = self.ctx.Queue() self.log_queue = self.ctx.Queue() self.logger = initialize_logger(logging.getLogger("pyro.infer.mcmc"), "MAIN", log_queue=self.log_queue) self.num_samples = num_samples self.log_thread = threading.Thread( target=logger_thread, args=(self.log_queue, self.warmup_steps, self.num_samples, self.num_chains, disable_progbar)) self.log_thread.daemon = True self.log_thread.start()
def run(self, *args, **kwargs): logger = logging.getLogger("pyro.infer.mcmc") progress_bar = ProgressBar(self.warmup_steps, self.num_samples, disable=self.disable_progbar) logger = initialize_logger(logger, "", progress_bar) hook_w_logging = _add_logging_hook(logger, progress_bar, self.hook) for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging, *args, **kwargs): yield sample, 0 # sample, chain_id (default=0) progress_bar.close()
def run(self, *args, **kwargs): pyro.set_rng_seed(self.rng_seed) torch.set_default_tensor_type(self.default_tensor_type) kwargs = kwargs logger = logging.getLogger("pyro.infer.mcmc") logger_id = "CHAIN:{}".format(self.chain_id) log_queue = self.log_queue logger = initialize_logger(logger, logger_id, None, log_queue) logging_hook = _add_logging_hook(logger, None, self.hook) try: for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, logging_hook, None, *args, **kwargs): self.result_queue.put_nowait((self.chain_id, sample)) self.event.wait() self.event.clear() self.result_queue.put_nowait((self.chain_id, None)) except Exception as e: logger.exception(e) self.result_queue.put_nowait((self.chain_id, e))
def run(self, *args, **kwargs): logger = logging.getLogger("pyro.infer.mcmc") for i in range(self.num_chains): if self.initial_params is not None: initial_params = { k: v[i] for k, v in self.initial_params.items() } self.kernel.initial_params = initial_params progress_bar = ProgressBar(self.warmup_steps, self.num_samples, disable=self.disable_progbar) logger = initialize_logger(logger, "", progress_bar) hook_w_logging = _add_logging_hook(logger, progress_bar, self.hook) for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging, i if self.num_chains > 1 else None, *args, **kwargs): yield sample, i # sample, chain_id self.kernel.cleanup() progress_bar.close()
def _traces(self, *args, **kwargs): logger_id = kwargs.pop("logger_id", "") log_queue = kwargs.pop("log_queue", None) self.logger = logging.getLogger("pyro.infer.mcmc") is_multiprocessing = log_queue is not None progress_bar = None if not is_multiprocessing: progress_bar = initialize_progbar(self.warmup_steps, self.num_samples, disable=self.disable_progbar) self.logger = initialize_logger(self.logger, logger_id, progress_bar, log_queue) self.kernel.setup(self.warmup_steps, *args, **kwargs) trace = self.kernel.initial_trace with optional(progress_bar, not is_multiprocessing): for trace in self._gen_samples(self.warmup_steps, trace): continue if progress_bar: progress_bar.set_description("Sample") for trace in self._gen_samples(self.num_samples, trace): yield (trace, 1.0) self.kernel.cleanup()