Ejemplo n.º 1
0
async def send_to_worker_hook(
        fetches, worker_main, worker_args=None, worker_kwargs=None, queue_size=30,
        use_threading=True, block=False):
    """Sends the values of `tensors` after each run to a worker process.

    A mp.Queue is used for sending the values.
    In the beginning a new process is created as `worker_proc_main(*worker_args, q=q)`,
    where q is the mp.Queue. Then after each session run, it puts the values
    of `fetches` into the queue.

    If the queue is full, the fetched values are discarded.
    """
    worker_args = worker_args or ()
    worker_kwargs = worker_kwargs or {}
    if use_threading:
        q = queue.Queue(queue_size)
        worker = threading.Thread(
            target=worker_main, args=worker_args, kwargs={'q': q, **worker_kwargs}, daemon=True)
    else:
        # Spawn is used for more memory efficiency
        q = mp_spawn.Queue(queue_size)
        worker = mp_spawn.Process(
            target=util.safe_subprocess_main_with_flags,
            args=(init.FLAGS, worker_main, *worker_args),
            kwargs={'q': q, **worker_kwargs})

    worker.start()
    async for values in tfasync.run_iter(fetches):
        try:
            q.put(values, block=block)
        except queue.Full:
            # logger.debug('Queue Full')
            pass  # discard the fetched values, we don't want to block the session loop.
    q.put(None)
    worker.join()
Ejemplo n.º 2
0
async def logger_hook(msg, tensors=()):
    async for values in tfasync.run_iter(tensors):
        formattable_values = [
            util.formattable_array(v) if isinstance(v, np.ndarray) else v
            for v in values
        ]
        logging.info(msg.format(*formattable_values))
Ejemplo n.º 3
0
async def collect_hook(fetches_to_collect=None):
    """Hook that collects the values of specified tensors in every iteration."""
    if fetches_to_collect is None:
        return None

    results = [vals async for vals in tfasync.run_iter(fetches_to_collect)]
    if not isinstance(fetches_to_collect, (list, tuple)):
        return concatenate_atleast_1d(results)

    return [concatenate_atleast_1d(result_column) for result_column in zip(*results)]
Ejemplo n.º 4
0
async def eta_hook(n_total_steps,
                   step_tensor=None,
                   init_phase_seconds=60,
                   summary_output_dir=None):
    summary_writer = (tf.summary.FileWriterCache.get(summary_output_dir)
                      if summary_output_dir else None)

    call_times = []  # list of times when hook was called
    remaining_step_counts = []  # list of amount of work remaining at each call

    # (The initial phase can be slower and therefore misleading the ETA calculation
    # so we drop the first `init_phase_seconds` sec of measurements.)
    removed_init = False  # have we removed this initial part already
    global_step_tensor = tf.train.get_or_create_global_step()
    if step_tensor is None:
        step_tensor = global_step_tensor

    async for i_step, i_global_step in tfasync.run_iter(
        [step_tensor, global_step_tensor]):
        call_times.append(time.time())
        remaining_step_counts.append(n_total_steps - i_step)

        if not removed_init and call_times[
                -1] > call_times[0] + init_phase_seconds:
            del call_times[:-1]
            del remaining_step_counts[:-1]
            removed_init = True

        # We need a few samples to have a good guess.
        if len(call_times) >= 3:
            eta_seconds = eta.eta(call_times, remaining_step_counts)
            eta_str = eta.format_timedelta(eta_seconds)
            progress = i_step / n_total_steps
            logging.info(f'{progress:.0%} complete; {eta_str} remaining.')
            if summary_writer:
                summary_writer.add_summary(tf.Summary(value=[
                    tf.Summary.Value(tag='ETA_hours',
                                     simple_value=eta_seconds / 3600)
                ]),
                                           global_step=i_global_step)

    if summary_writer:
        summary_writer.flush()
Ejemplo n.º 5
0
async def log_increment_per_sec(name, value_tensor, summary_output_dir=None):
    summary_writer = (tf.compat.v1.summary.FileWriterCache.get(summary_output_dir)
                      if summary_output_dir else None)
    prev_value = None
    prev_time = None
    global_step = tf.compat.v1.train.get_global_step()
    if global_step is None:
        global_step = tf.convert_to_tensor(0)

    async for current_value, step in tfasync.run_iter([value_tensor, global_step]):
        current_time = time.time()
        if prev_value is not None:
            speed = (current_value - prev_value) / (current_time - prev_time)
            logging.info(f'{name}/sec: {speed:.1f}')
            if summary_writer:
                summary_writer.add_summary(tf.compat.v1.Summary(
                    value=[tf.compat.v1.Summary.Value(tag=f'{name}_per_sec', simple_value=speed)]),
                    global_step=step)
                summary_writer.flush()

        prev_value = current_value
        prev_time = current_time