Exemplo n.º 1
0
Arquivo: mcmc.py Projeto: zyxue/pyro
def logger_thread(log_queue,
                  warmup_steps,
                  num_samples,
                  num_chains,
                  disable_progbar=False):
    """
    Logging thread that asynchronously consumes logging events from `log_queue`,
    and handles them appropriately.
    """
    progress_bars = [
        initialize_progbar(warmup_steps,
                           num_samples,
                           pos=i,
                           disable=disable_progbar) for i in range(num_chains)
    ]
    logger = logging.getLogger(__name__)
    logger.propagate = False
    logger.addHandler(TqdmHandler())
    num_samples = [0] * num_chains
    try:
        while True:
            try:
                record = log_queue.get(timeout=1)
            except queue.Empty:
                continue
            if record is None:
                break
            metadata, msg = record.getMessage().split("]", 1)
            _, msg_type, logger_id = metadata[1:].split()
            if msg_type == DIAGNOSTIC_MSG:
                pbar_pos = int(logger_id.split(":")[-1])
                num_samples[pbar_pos] += 1
                if num_samples[pbar_pos] == warmup_steps:
                    progress_bars[pbar_pos].set_description(
                        "Sample [{}]".format(pbar_pos + 1))
                diagnostics = json.loads(msg, object_pairs_hook=OrderedDict)
                progress_bars[pbar_pos].set_postfix(diagnostics)
                progress_bars[pbar_pos].update()
            else:
                logger.handle(record)
    finally:
        for pbar in progress_bars:
            pbar.close()
            # Required to not overwrite multiple progress bars on exit.
            if not pbar._ipython_env:
                sys.stderr.write("\n")
Exemplo n.º 2
0
def logger_thread(log_queue,
                  warmup_steps,
                  num_samples,
                  num_chains,
                  disable_progbar=False):
    """
    Logging thread that asynchronously consumes logging events from `log_queue`,
    and handles them appropriately.
    """
    progress_bars = ProgressBar(warmup_steps,
                                num_samples,
                                disable=disable_progbar,
                                num_bars=num_chains)
    logger = logging.getLogger(__name__)
    logger.propagate = False
    logger.addHandler(TqdmHandler())
    num_samples = [0] * num_chains
    try:
        while True:
            try:
                record = log_queue.get(timeout=1)
            except queue.Empty:
                continue
            if record is None:
                break
            metadata, msg = record.getMessage().split("]", 1)
            _, msg_type, logger_id = metadata[1:].split()
            if msg_type == DIAGNOSTIC_MSG:
                pbar_pos = int(logger_id.split(":")[-1])
                num_samples[pbar_pos] += 1
                if num_samples[pbar_pos] == warmup_steps:
                    progress_bars.set_description(
                        "Sample [{}]".format(pbar_pos + 1), pos=pbar_pos)
                diagnostics = json.loads(msg, object_pairs_hook=OrderedDict)
                progress_bars.set_postfix(diagnostics,
                                          pos=pbar_pos,
                                          refresh=False)
                progress_bars.update(pos=pbar_pos)
            else:
                logger.handle(record)
    finally:
        progress_bars.close()