Esempio n. 1
0
 def before_run(self, run_context):
     return session_run_hook.SessionRunArgs({
         'global_step':
         contrib_framework.get_global_step(),
         'current_loss':
         run_context.session.graph.get_operation_by_name(
             LOSS_NAME).outputs[0]
     })
Esempio n. 2
0
    def before_run(self, run_context):
        if not self._session_wrapper:
            self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession(
                run_context.session,
                ui_type=self._ui_type,
                dump_root=self._dump_root,
                thread_name_filter=self._thread_name_filter)

            # Actually register tensor filters registered prior to the construction
            # of the underlying LocalCLIDebugWrapperSession object.
            for filter_name in self._pending_tensor_filters:
                self._session_wrapper.add_tensor_filter(
                    filter_name, self._pending_tensor_filters[filter_name])

        # Increment run call counter.
        self._session_wrapper.increment_run_call_count()

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._session_wrapper.run_call_count)

        on_run_start_response = self._session_wrapper.on_run_start(
            on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_debug(
                run_args.options,
                on_run_start_response.debug_urls,
                debug_ops=on_run_start_response.debug_ops,
                node_name_regex_whitelist=(
                    on_run_start_response.node_name_regex_whitelist),
                op_type_regex_whitelist=(
                    on_run_start_response.op_type_regex_whitelist),
                tensor_dtype_regex_whitelist=(
                    on_run_start_response.tensor_dtype_regex_whitelist),
                tolerate_debug_op_creation_failures=(
                    on_run_start_response.tolerate_debug_op_creation_failures))
            # pylint: enable=protected-access
        elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_profile(
                run_args.options)
            # pylint: enable=protected-access
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            # The _finalized property must be set to False so that the NodeStepper
            # can insert ops for retrieving TensorHandles.
            # pylint: disable=protected-access
            run_context.session.graph._finalized = False
            # pylint: enable=protected-access

        return run_args
Esempio n. 3
0
 def before_run(self, run_context):
     loss = (self.loss_op if self.loss_op is not None else run_context.
             session.graph.get_operation_by_name(LOSS_NAME).outputs[0])
     return session_run_hook.SessionRunArgs({
         'global_step':
         training_util.get_global_step(),
         'current_loss':
         loss
     })
Esempio n. 4
0
 def before_run(self, _run_context):
     self._should_trigger = self._timer.should_trigger_for_step(
         self._iter_count)
     if self._should_trigger:
         fetches = {
             "predicted_words": self.predicted_words,
             "target_words": self.labels_dict["target_tokens"],
             "target_len": self.labels_dict["target_len"]
         }
         return session_run_hook.SessionRunArgs(fetches)
     return None
Esempio n. 5
0
    def before_run(self, run_context):
        if (self._step % self.stopping_step == 0) and \
           (not self._step == self._prev_step) and (self._step > self.start_step):

            print("\n[ Early Stopping Check ]")
            
            # Get graph from run_context session
            graph = run_context.session.graph

            # Retrieve loss tensor from graph
            loss_tensor = graph.get_tensor_by_name(self.loss_name)

            # Populate feed dictionary with placeholders and values
            fd = {}
            for key, value in self.feed_dict.items():
                placeholder = graph.get_tensor_by_name(key)
                fd[placeholder] = value

            return session_run_hook.SessionRunArgs({'step': self._global_step_tensor,
                                                    'loss': loss_tensor}, feed_dict=fd)
        else:
            return session_run_hook.SessionRunArgs({'step': self._global_step_tensor})
Esempio n. 6
0
 def before_run(self, run_context):
     return session_run_hook.SessionRunArgs({
         'global_step':
         contrib_framework.get_global_step(),
         'current_loss':
         run_context.session.graph.get_operation_by_name(
             'rf_training_loss').outputs[0],
         'confusion_matrix_print':
         run_context.session.graph.get_operation_by_name(
             'confusion_matrix_print').outputs[0],
         'regression_ornot':
         run_context.session.graph.get_operation_by_name(
             'regression_ornot').outputs[0],
     })
Esempio n. 7
0
    def before_run(self, run_context):
        if not self._wrapper_initialized:
            local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
                self,
                run_context.session,
                ui_type=self._ui_type,
                dump_root=self._dump_root)

            # Actually register tensor filters registered prior to the construction
            # of the underlying LocalCLIDebugWrapperSession object.
            for filter_name in self._pending_tensor_filters:
                local_cli_wrapper.LocalCLIDebugWrapperSession.add_tensor_filter(
                    self, filter_name,
                    self._pending_tensor_filters[filter_name])

            self._wrapper_initialized = True

        # Increment run call counter.
        self._run_call_count += 1

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._run_call_count)

        on_run_start_response = self.on_run_start(on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            self._decorate_options_for_debug(run_args.options,
                                             run_context.session.graph)
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            # The _finalized property must be set to False so that the NodeStepper
            # can insert ops for retrieving TensorHandles.
            # pylint: disable=protected-access
            run_context.session.graph._finalized = False
            # pylint: enable=protected-access

            with stepper.NodeStepper(
                    run_context.session, run_context.original_args.fetches,
                    run_context.original_args.feed_dict) as node_stepper:
                self.invoke_node_stepper(node_stepper,
                                         restore_variable_values_on_exit=True)

        return run_args
  def before_run(self, run_context):
    if self._last_step is None:
      self._last_step = run_context.session.run(self._global_step_tensor) + 1

    request = {self._global_step_tensor: self._global_step_tensor}
    monitor_fetches = []
    for m in self._monitors:
      monitor_requests = m.step_begin(self._last_step)
      if monitor_requests:
        if not isinstance(monitor_requests, list):
          raise ValueError("Monitor.step_begin should return a list.")
        monitor_fetches.extend(monitor_requests)
    if monitor_fetches:
      request["monitors"] = dict(
          zip(monitor_fetches, [_as_graph_element(f) for f in monitor_fetches]))

    return session_run_hook.SessionRunArgs(request)
Esempio n. 9
0
  def before_run(self, run_context):
    """Runs the appropriate prep ops, and requests running update ops."""
    # Run the appropriate cache_init and prep ops
    sess = run_context.session
    if not self._is_initialized:
      logging.info("SweepHook running cache init ops.")
      for init_op in self._cache_init_ops:
        sess.run(init_op)

    if self._is_sweep_done or not self._is_initialized:
      logging.info("SweepHook running sweep prep ops.")
      row_sweep = sess.run(self._is_row_sweep_var)
      prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops
      for prep_op in prep_ops:
        sess.run(prep_op)

    self._is_initialized = True

    # Request running the switch_ops and the global_step_incr_op
    logging.info("Partial fit starting.")
    return session_run_hook.SessionRunArgs(fetches=self._fetches)
Esempio n. 10
0
 def before_run(self, run_context):
     """Runs the appropriate prep ops, and requests running update ops."""
     sess = run_context.session
     is_sweep_done = sess.run(self._is_sweep_done_var)
     if not self._is_initialized:
         logging.info("SweepHook running init op.")
         sess.run(self._init_op)
     if is_sweep_done:
         sess.run(self._switch_op)
     is_row_sweep = sess.run(self._is_row_sweep_var)
     if is_sweep_done or not self._is_initialized:
         logging.info("SweepHook running prep ops for the {} sweep.".format(
             "row" if is_row_sweep else "col"))
         prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops
         for prep_op in prep_ops:
             sess.run(prep_op)
     self._is_initialized = True
     logging.info("Next fit step starting.")
     return session_run_hook.SessionRunArgs(fetches=[
         self._row_train_op if is_row_sweep else self._col_train_op
     ])
Esempio n. 11
0
    def before_run(self, run_context):
        """Called right before a session is run.

    Args:
      run_context: A session_run_hook.SessionRunContext. Encapsulates
        information on the run.

    Returns:
      A session_run_hook.SessionRunArgs object.
    """

        if not self._grpc_debug_wrapper_session:
            self._grpc_debug_wrapper_session = grpc_wrapper.GrpcDebugWrapperSession(
                run_context.session,
                self._grpc_debug_server_addresses,
                watch_fn=self._watch_fn,
                thread_name_filter=self._thread_name_filter,
                log_usage=self._log_usage)

        fetches = run_context.original_args.fetches
        feed_dict = run_context.original_args.feed_dict
        watch_options = self._watch_fn(fetches, feed_dict)
        run_options = config_pb2.RunOptions()
        debug_utils.watch_graph(
            run_options,
            run_context.session.graph,
            debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls(
                fetches, feed_dict),
            debug_ops=watch_options.debug_ops,
            node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
            op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
            tensor_dtype_regex_whitelist=watch_options.
            tensor_dtype_regex_whitelist,
            tolerate_debug_op_creation_failures=(
                watch_options.tolerate_debug_op_creation_failures))

        return session_run_hook.SessionRunArgs(None,
                                               feed_dict=None,
                                               options=run_options)
Esempio n. 12
0
    def before_run(self, run_context):
        """Runs the appropriate prep ops, and requests running update ops."""
        # Runs the appropriate init ops and prep ops.
        sess = run_context.session
        is_sweep_done = sess.run(self._is_sweep_done_var)
        if not self._is_initialized:
            logging.info("SweepHook running cache init op.")
            sess.run(self._init_op)
        if is_sweep_done:
            sess.run(self._switch_op)
        if is_sweep_done or not self._is_initialized:
            logging.info("SweepHook running sweep prep ops.")
            row_sweep = sess.run(self._is_row_sweep_var)
            prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops
            for prep_op in prep_ops:
                sess.run(prep_op)

        self._is_initialized = True

        # Requests running `self._update_op` jointly with the training op.
        logging.info("Next fit step starting.")
        return session_run_hook.SessionRunArgs(fetches=[self._update_op])
Esempio n. 13
0
    def before_run(self, run_context):
        reset_disk_byte_usage = False
        if not self._session_wrapper:
            self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
                run_context.session,
                self._session_root,
                watch_fn=self._watch_fn,
                thread_name_filter=self._thread_name_filter,
                log_usage=self._log_usage)
            reset_disk_byte_usage = True

        self._session_wrapper.increment_run_call_count()

        # pylint: disable=protected-access
        debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict)
        # pylint: enable=protected-access
        run_options = config_pb2.RunOptions()
        debug_utils.watch_graph(
            run_options,
            run_context.session.graph,
            debug_urls=debug_urls,
            debug_ops=watch_options.debug_ops,
            node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
            op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
            tensor_dtype_regex_allowlist=watch_options.
            tensor_dtype_regex_allowlist,
            tolerate_debug_op_creation_failures=(
                watch_options.tolerate_debug_op_creation_failures),
            reset_disk_byte_usage=reset_disk_byte_usage)

        run_args = session_run_hook.SessionRunArgs(None,
                                                   feed_dict=None,
                                                   options=run_options)
        return run_args
Esempio n. 14
0
 def before_run(self, run_context):
   """Return the update_weights op so that it is executed during this run."""
   return session_run_hook.SessionRunArgs(self._update_op)
Esempio n. 15
0
 def before_run(self, run_context):
     return session_run_hook.SessionRunArgs({'summary': self._summary_op})
Esempio n. 16
0
 def before_run(self, run_context):
     return session_run_hook.SessionRunArgs(self._num_trees_tensor)
Esempio n. 17
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     if can_run_hook(run_context):
         return session_run_hook.SessionRunArgs(self._current_tensors)
     else:
         return session_run_hook.SessionRunArgs({'global_episode': self._global_episode_tensor})
Esempio n. 18
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
     return session_run_hook.SessionRunArgs(self._global_episode_tensor)
Esempio n. 19
0
 def before_run(self, run_context):  # pylint: disable=unused-argument
   return session_run_hook.SessionRunArgs(
       fetches=None, feed_dict=self.feed_fn())
Esempio n. 20
0
 def before_run(self, run_context):
   return session_run_hook.SessionRunArgs(
       [self._global_step_tensor, self._tokens_processed_add])
Esempio n. 21
0
 def before_run(self, run_context):
   return session_run_hook.SessionRunArgs({
       'eval_steps': evaluation._get_or_create_eval_step()
   })
Esempio n. 22
0
 def before_run(self, run_context):
   del run_context
   return session_run_hook.SessionRunArgs(self._global_step_tensor)
Esempio n. 23
0
 def before_run(self, run_context):
     return session_run_hook.SessionRunArgs([
         self._fed_avg_optimizer._global_step,
         self._fed_avg_optimizer._curr_iter
     ])
Esempio n. 24
0
 def before_run(self, run_context):
   del run_context
   return session_run_hook.SessionRunArgs(self._stop_var)
Esempio n. 25
0
 def before_run(self, run_context):
     del run_context  # unused
     return session_run_hook.SessionRunArgs(self._loss_tensor)
Esempio n. 26
0
 def before_run(self, run_context):
   return session_run_hook.SessionRunArgs(
       [self._num_finalized_trees_tensor, self._num_attempted_layers_tensor])
 def before_run(self, run_context):
     return session_run_hook.SessionRunArgs(self._completed_sweeps_var)
Esempio n. 28
0
 def before_run(self, run_context):
     return session_run_hook.SessionRunArgs(
         {'evals_completed': self._evals_completed})
Esempio n. 29
0
 def before_run(self, run_context):
     del run_context
     return session_run_hook.SessionRunArgs({
         'global_step': self._global_step_tensor,
         'stop_var': self._stop_var
     })
 def before_run(self, _):
   return session_run_hook.SessionRunArgs(fetches=b)