def before_run(self, run_context): return session_run_hook.SessionRunArgs({ 'global_step': contrib_framework.get_global_step(), 'current_loss': run_context.session.graph.get_operation_by_name( LOSS_NAME).outputs[0] })
def before_run(self, run_context): if not self._session_wrapper: self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession( run_context.session, ui_type=self._ui_type, dump_root=self._dump_root, thread_name_filter=self._thread_name_filter) # Actually register tensor filters registered prior to the construction # of the underlying LocalCLIDebugWrapperSession object. for filter_name in self._pending_tensor_filters: self._session_wrapper.add_tensor_filter( filter_name, self._pending_tensor_filters[filter_name]) # Increment run call counter. self._session_wrapper.increment_run_call_count() # Adapt run_context to an instance of OnRunStartRequest for invoking # superclass on_run_start(). on_run_start_request = framework.OnRunStartRequest( run_context.original_args.fetches, run_context.original_args.feed_dict, None, None, self._session_wrapper.run_call_count) on_run_start_response = self._session_wrapper.on_run_start( on_run_start_request) self._performed_action = on_run_start_response.action run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=config_pb2.RunOptions()) if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_debug( run_args.options, on_run_start_response.debug_urls, debug_ops=on_run_start_response.debug_ops, node_name_regex_whitelist=( on_run_start_response.node_name_regex_whitelist), op_type_regex_whitelist=( on_run_start_response.op_type_regex_whitelist), tensor_dtype_regex_whitelist=( on_run_start_response.tensor_dtype_regex_whitelist), tolerate_debug_op_creation_failures=( on_run_start_response.tolerate_debug_op_creation_failures)) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN: # pylint: disable=protected-access self._session_wrapper._decorate_run_options_for_profile( run_args.options) # pylint: enable=protected-access elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: # The _finalized property must be set to False so that the NodeStepper # can insert ops for retrieving TensorHandles. # pylint: disable=protected-access run_context.session.graph._finalized = False # pylint: enable=protected-access return run_args
def before_run(self, run_context): loss = (self.loss_op if self.loss_op is not None else run_context. session.graph.get_operation_by_name(LOSS_NAME).outputs[0]) return session_run_hook.SessionRunArgs({ 'global_step': training_util.get_global_step(), 'current_loss': loss })
def before_run(self, _run_context): self._should_trigger = self._timer.should_trigger_for_step( self._iter_count) if self._should_trigger: fetches = { "predicted_words": self.predicted_words, "target_words": self.labels_dict["target_tokens"], "target_len": self.labels_dict["target_len"] } return session_run_hook.SessionRunArgs(fetches) return None
def before_run(self, run_context): if (self._step % self.stopping_step == 0) and \ (not self._step == self._prev_step) and (self._step > self.start_step): print("\n[ Early Stopping Check ]") # Get graph from run_context session graph = run_context.session.graph # Retrieve loss tensor from graph loss_tensor = graph.get_tensor_by_name(self.loss_name) # Populate feed dictionary with placeholders and values fd = {} for key, value in self.feed_dict.items(): placeholder = graph.get_tensor_by_name(key) fd[placeholder] = value return session_run_hook.SessionRunArgs({'step': self._global_step_tensor, 'loss': loss_tensor}, feed_dict=fd) else: return session_run_hook.SessionRunArgs({'step': self._global_step_tensor})
def before_run(self, run_context): return session_run_hook.SessionRunArgs({ 'global_step': contrib_framework.get_global_step(), 'current_loss': run_context.session.graph.get_operation_by_name( 'rf_training_loss').outputs[0], 'confusion_matrix_print': run_context.session.graph.get_operation_by_name( 'confusion_matrix_print').outputs[0], 'regression_ornot': run_context.session.graph.get_operation_by_name( 'regression_ornot').outputs[0], })
def before_run(self, run_context): if not self._wrapper_initialized: local_cli_wrapper.LocalCLIDebugWrapperSession.__init__( self, run_context.session, ui_type=self._ui_type, dump_root=self._dump_root) # Actually register tensor filters registered prior to the construction # of the underlying LocalCLIDebugWrapperSession object. for filter_name in self._pending_tensor_filters: local_cli_wrapper.LocalCLIDebugWrapperSession.add_tensor_filter( self, filter_name, self._pending_tensor_filters[filter_name]) self._wrapper_initialized = True # Increment run call counter. self._run_call_count += 1 # Adapt run_context to an instance of OnRunStartRequest for invoking # superclass on_run_start(). on_run_start_request = framework.OnRunStartRequest( run_context.original_args.fetches, run_context.original_args.feed_dict, None, None, self._run_call_count) on_run_start_response = self.on_run_start(on_run_start_request) self._performed_action = on_run_start_response.action run_args = session_run_hook.SessionRunArgs( None, feed_dict=None, options=config_pb2.RunOptions()) if self._performed_action == framework.OnRunStartAction.DEBUG_RUN: self._decorate_options_for_debug(run_args.options, run_context.session.graph) elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER: # The _finalized property must be set to False so that the NodeStepper # can insert ops for retrieving TensorHandles. # pylint: disable=protected-access run_context.session.graph._finalized = False # pylint: enable=protected-access with stepper.NodeStepper( run_context.session, run_context.original_args.fetches, run_context.original_args.feed_dict) as node_stepper: self.invoke_node_stepper(node_stepper, restore_variable_values_on_exit=True) return run_args
def before_run(self, run_context): if self._last_step is None: self._last_step = run_context.session.run(self._global_step_tensor) + 1 request = {self._global_step_tensor: self._global_step_tensor} monitor_fetches = [] for m in self._monitors: monitor_requests = m.step_begin(self._last_step) if monitor_requests: if not isinstance(monitor_requests, list): raise ValueError("Monitor.step_begin should return a list.") monitor_fetches.extend(monitor_requests) if monitor_fetches: request["monitors"] = dict( zip(monitor_fetches, [_as_graph_element(f) for f in monitor_fetches])) return session_run_hook.SessionRunArgs(request)
def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" # Run the appropriate cache_init and prep ops sess = run_context.session if not self._is_initialized: logging.info("SweepHook running cache init ops.") for init_op in self._cache_init_ops: sess.run(init_op) if self._is_sweep_done or not self._is_initialized: logging.info("SweepHook running sweep prep ops.") row_sweep = sess.run(self._is_row_sweep_var) prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) self._is_initialized = True # Request running the switch_ops and the global_step_incr_op logging.info("Partial fit starting.") return session_run_hook.SessionRunArgs(fetches=self._fetches)
def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" sess = run_context.session is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: sess.run(self._switch_op) is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: logging.info("SweepHook running prep ops for the {} sweep.".format( "row" if is_row_sweep else "col")) prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) self._is_initialized = True logging.info("Next fit step starting.") return session_run_hook.SessionRunArgs(fetches=[ self._row_train_op if is_row_sweep else self._col_train_op ])
def before_run(self, run_context): """Called right before a session is run. Args: run_context: A session_run_hook.SessionRunContext. Encapsulates information on the run. Returns: A session_run_hook.SessionRunArgs object. """ if not self._grpc_debug_wrapper_session: self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession( run_context.session, self._grpc_debug_server_addresses, watch_fn=self._watch_fn, thread_name_filter=self._thread_name_filter, log_usage=self._log_usage) fetches = run_context.original_args.fetches feed_dict = run_context.original_args.feed_dict watch_options = self._watch_fn(fetches, feed_dict) run_options = config_pb2.RunOptions() debug_utils.watch_graph( run_options, run_context.session.graph, debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls( fetches, feed_dict), debug_ops=watch_options.debug_ops, node_name_regex_whitelist=watch_options.node_name_regex_whitelist, op_type_regex_whitelist=watch_options.op_type_regex_whitelist, tensor_dtype_regex_whitelist=watch_options. tensor_dtype_regex_whitelist, tolerate_debug_op_creation_failures=( watch_options.tolerate_debug_op_creation_failures)) return session_run_hook.SessionRunArgs(None, feed_dict=None, options=run_options)
def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" # Runs the appropriate init ops and prep ops. sess = run_context.session is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: logging.info("SweepHook running cache init op.") sess.run(self._init_op) if is_sweep_done: sess.run(self._switch_op) if is_sweep_done or not self._is_initialized: logging.info("SweepHook running sweep prep ops.") row_sweep = sess.run(self._is_row_sweep_var) prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) self._is_initialized = True # Requests running `self._update_op` jointly with the training op. logging.info("Next fit step starting.") return session_run_hook.SessionRunArgs(fetches=[self._update_op])
def before_run(self, run_context): reset_disk_byte_usage = False if not self._session_wrapper: self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession( run_context.session, self._session_root, watch_fn=self._watch_fn, thread_name_filter=self._thread_name_filter, log_usage=self._log_usage) reset_disk_byte_usage = True self._session_wrapper.increment_run_call_count() # pylint: disable=protected-access debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config( run_context.original_args.fetches, run_context.original_args.feed_dict) # pylint: enable=protected-access run_options = config_pb2.RunOptions() debug_utils.watch_graph( run_options, run_context.session.graph, debug_urls=debug_urls, debug_ops=watch_options.debug_ops, node_name_regex_allowlist=watch_options.node_name_regex_allowlist, op_type_regex_allowlist=watch_options.op_type_regex_allowlist, tensor_dtype_regex_allowlist=watch_options. tensor_dtype_regex_allowlist, tolerate_debug_op_creation_failures=( watch_options.tolerate_debug_op_creation_failures), reset_disk_byte_usage=reset_disk_byte_usage) run_args = session_run_hook.SessionRunArgs(None, feed_dict=None, options=run_options) return run_args
def before_run(self, run_context): """Return the update_weights op so that it is executed during this run.""" return session_run_hook.SessionRunArgs(self._update_op)
def before_run(self, run_context): return session_run_hook.SessionRunArgs({'summary': self._summary_op})
def before_run(self, run_context): return session_run_hook.SessionRunArgs(self._num_trees_tensor)
def before_run(self, run_context): # pylint: disable=unused-argument if can_run_hook(run_context): return session_run_hook.SessionRunArgs(self._current_tensors) else: return session_run_hook.SessionRunArgs({'global_episode': self._global_episode_tensor})
def before_run(self, run_context): # pylint: disable=unused-argument return session_run_hook.SessionRunArgs(self._global_episode_tensor)
def before_run(self, run_context): # pylint: disable=unused-argument return session_run_hook.SessionRunArgs( fetches=None, feed_dict=self.feed_fn())
def before_run(self, run_context): return session_run_hook.SessionRunArgs( [self._global_step_tensor, self._tokens_processed_add])
def before_run(self, run_context): return session_run_hook.SessionRunArgs({ 'eval_steps': evaluation._get_or_create_eval_step() })
def before_run(self, run_context): del run_context return session_run_hook.SessionRunArgs(self._global_step_tensor)
def before_run(self, run_context): return session_run_hook.SessionRunArgs([ self._fed_avg_optimizer._global_step, self._fed_avg_optimizer._curr_iter ])
def before_run(self, run_context): del run_context return session_run_hook.SessionRunArgs(self._stop_var)
def before_run(self, run_context): del run_context # unused return session_run_hook.SessionRunArgs(self._loss_tensor)
def before_run(self, run_context): return session_run_hook.SessionRunArgs( [self._num_finalized_trees_tensor, self._num_attempted_layers_tensor])
def before_run(self, run_context): return session_run_hook.SessionRunArgs(self._completed_sweeps_var)
def before_run(self, run_context): return session_run_hook.SessionRunArgs( {'evals_completed': self._evals_completed})
def before_run(self, run_context): del run_context return session_run_hook.SessionRunArgs({ 'global_step': self._global_step_tensor, 'stop_var': self._stop_var })
def before_run(self, _): return session_run_hook.SessionRunArgs(fetches=b)