async def send_to_worker_hook( fetches, worker_main, worker_args=None, worker_kwargs=None, queue_size=30, use_threading=True, block=False): """Sends the values of `tensors` after each run to a worker process. A mp.Queue is used for sending the values. In the beginning a new process is created as `worker_proc_main(*worker_args, q=q)`, where q is the mp.Queue. Then after each session run, it puts the values of `fetches` into the queue. If the queue is full, the fetched values are discarded. """ worker_args = worker_args or () worker_kwargs = worker_kwargs or {} if use_threading: q = queue.Queue(queue_size) worker = threading.Thread( target=worker_main, args=worker_args, kwargs={'q': q, **worker_kwargs}, daemon=True) else: # Spawn is used for more memory efficiency q = mp_spawn.Queue(queue_size) worker = mp_spawn.Process( target=util.safe_subprocess_main_with_flags, args=(init.FLAGS, worker_main, *worker_args), kwargs={'q': q, **worker_kwargs}) worker.start() async for values in tfasync.run_iter(fetches): try: q.put(values, block=block) except queue.Full: # logger.debug('Queue Full') pass # discard the fetched values, we don't want to block the session loop. q.put(None) worker.join()
async def logger_hook(msg, tensors=()): async for values in tfasync.run_iter(tensors): formattable_values = [ util.formattable_array(v) if isinstance(v, np.ndarray) else v for v in values ] logging.info(msg.format(*formattable_values))
async def collect_hook(fetches_to_collect=None): """Hook that collects the values of specified tensors in every iteration.""" if fetches_to_collect is None: return None results = [vals async for vals in tfasync.run_iter(fetches_to_collect)] if not isinstance(fetches_to_collect, (list, tuple)): return concatenate_atleast_1d(results) return [concatenate_atleast_1d(result_column) for result_column in zip(*results)]
async def eta_hook(n_total_steps, step_tensor=None, init_phase_seconds=60, summary_output_dir=None): summary_writer = (tf.summary.FileWriterCache.get(summary_output_dir) if summary_output_dir else None) call_times = [] # list of times when hook was called remaining_step_counts = [] # list of amount of work remaining at each call # (The initial phase can be slower and therefore misleading the ETA calculation # so we drop the first `init_phase_seconds` sec of measurements.) removed_init = False # have we removed this initial part already global_step_tensor = tf.train.get_or_create_global_step() if step_tensor is None: step_tensor = global_step_tensor async for i_step, i_global_step in tfasync.run_iter( [step_tensor, global_step_tensor]): call_times.append(time.time()) remaining_step_counts.append(n_total_steps - i_step) if not removed_init and call_times[ -1] > call_times[0] + init_phase_seconds: del call_times[:-1] del remaining_step_counts[:-1] removed_init = True # We need a few samples to have a good guess. if len(call_times) >= 3: eta_seconds = eta.eta(call_times, remaining_step_counts) eta_str = eta.format_timedelta(eta_seconds) progress = i_step / n_total_steps logging.info(f'{progress:.0%} complete; {eta_str} remaining.') if summary_writer: summary_writer.add_summary(tf.Summary(value=[ tf.Summary.Value(tag='ETA_hours', simple_value=eta_seconds / 3600) ]), global_step=i_global_step) if summary_writer: summary_writer.flush()
async def log_increment_per_sec(name, value_tensor, summary_output_dir=None): summary_writer = (tf.compat.v1.summary.FileWriterCache.get(summary_output_dir) if summary_output_dir else None) prev_value = None prev_time = None global_step = tf.compat.v1.train.get_global_step() if global_step is None: global_step = tf.convert_to_tensor(0) async for current_value, step in tfasync.run_iter([value_tensor, global_step]): current_time = time.time() if prev_value is not None: speed = (current_value - prev_value) / (current_time - prev_time) logging.info(f'{name}/sec: {speed:.1f}') if summary_writer: summary_writer.add_summary(tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=f'{name}_per_sec', simple_value=speed)]), global_step=step) summary_writer.flush() prev_value = current_value prev_time = current_time