Пример #1
0
    def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
        with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
            cli = stepper_cli.NodeStepperCLI(node_stepper)
            output = cli.cont([
                "opt/update_a/ApplyGradientDescent",
                "--invalidate_from_updated_variables"
            ])

            self.assertItemsEqual([self.a.name], _parse_updated(output.lines))

            # After cont() call on .../update_a/..., Variable a should have been
            # marked as dirty, whereas b should not have.
            output = cli.list_sorted_nodes([])
            node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                          stat_labels[node_names.index("a")])
            self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                             stat_labels[node_names.index("b")])

            output = cli.cont(
                ["opt/update_b/ApplyGradientDescent", "-r", "-i"])

            self.assertItemsEqual([self.b.name], _parse_updated(output.lines))

            # After cont() call on .../update_b/... with the -r flag, Variable b
            # should have been marked as dirty, whereas Variable a should not be
            # because it should have been restored.
            output = cli.list_sorted_nodes([])
            node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                          stat_labels[node_names.index("b")])
            self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                             stat_labels[node_names.index("a")])
Пример #2
0
  def testContToValidNodeShouldUpdateStatus(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    index_c = node_names.index("c")
    self.assertEqual("     ", stat_labels[index_c])
    self.assertEqual(0, node_pointer)

    output = cli.cont("c")
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self.assertGreaterEqual(len(node_names), 3)
    self.assertIn("c", node_names)
    index_c = node_names.index("c")
    self.assertEqual(index_c, node_pointer)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])

    output = cli.cont("d")
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    used_feed_types = _parsed_used_feeds(output.lines)
    self.assertEqual({"c:0": "handle"}, used_feed_types)

    self.assertGreaterEqual(len(node_names), 3)
    self.assertIn("d", node_names)
    index_d = node_names.index("d")
    self.assertEqual(index_d, node_pointer)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
Пример #3
0
    def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
        node_stepper = stepper.NodeStepper(self.sess, self.opt)

        sorted_nodes = node_stepper.sorted_nodes()
        closure_elements = node_stepper.closure_elements()

        # Find a node which is in the list of sorted nodes, but whose output tensor
        # is not in the transitive closure.
        no_output_node = None
        for node in sorted_nodes:
            if (node + ":0" not in closure_elements
                    and node + ":1" not in closure_elements):
                no_output_node = node
                break

        self.assertIsNotNone(no_output_node)

        cli = stepper_cli.NodeStepperCLI(node_stepper)
        output = cli.cont([no_output_node])
        node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
            output.lines)

        self.assertEqual(no_output_node, node_names[node_pointer])
        self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                         stat_labels[node_pointer])
Пример #4
0
    def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
        cli = stepper_cli.NodeStepperCLI(
            stepper.NodeStepper(self.sess, self.opt))
        output = cli.cont(["opt/update_a/ApplyGradientDescent"])

        # After cont() call on .../update_a/..., Variable a should have been marked
        # as dirty, whereas b should not have.
        output = cli.list_sorted_nodes([])
        node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
        self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                      stat_labels[node_names.index("a")])
        self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                         stat_labels[node_names.index("b")])

        output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])

        # After cont() call on .../update_b/... with the -r flag, Variable b should
        # have been marked as dirty, whereas Variable a should not be because it
        # should have been restored.
        output = cli.list_sorted_nodes([])
        node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
        self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                      stat_labels[node_names.index("b")])
        self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                         stat_labels[node_names.index("a")])
Пример #5
0
  def testInjectTensorValueByTensorNameShouldBeReflected(self):
    node_stepper = stepper.NodeStepper(self.sess, self.e)
    cli = stepper_cli.NodeStepperCLI(node_stepper)

    output = cli.cont(["d"])
    node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
    self.assertEqual("d", node_names[node_pointer])

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    index_d = node_names.index("d")
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
                     stat_labels[index_d])

    self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))

    output = cli.inject_value(["d:0", "20.0"])

    # Verify that the override is available.
    self.assertEqual(["d:0"], node_stepper.override_names())

    # Verify that the list of sorted nodes reflects the existence of the value
    # override (i.e., injection).
    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    index_d = node_names.index("d")
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                     stat_labels[index_d])
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
                  stat_labels[index_d])
Пример #6
0
  def testContToNodeOutsideTransitiveClosureShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.cont(["f"])
    self.assertEqual([
        "ERROR: f is not in the transitive closure of this stepper "
        "instance."
    ], output.lines)
Пример #7
0
  def testPrintTensorWithNonexistentTensorShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.print_tensor(["foobar"])
    self.assertEqual([
        "ERROR: foobar is not in the transitive closure of this stepper "
        "instance."
    ], output.lines)
Пример #8
0
  def testPrintTensorWithNoHandleShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.print_tensor("e")
    self.assertEqual([
        "This stepper instance does not have access to the value of tensor "
        "\"e:0\""
    ], output.lines)
Пример #9
0
  def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    cli.cont("d")
    output = cli.print_tensor(["d"])

    self.assertEqual("Tensor \"d:0\":", output.lines[0])
    self.assertEqual("-20.0", output.lines[-1])
Пример #10
0
  def testContToNonexistentNodeShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))

    output = cli.cont(["foobar"])
    self.assertEqual([
        "ERROR: foobar is not in the transitive closure of this stepper "
        "instance."
    ], output.lines)
Пример #11
0
 def testCallingInvokeNodeStepperOnDumpingWrapperRaisesException(self):
     sess = dumping_wrapper.DumpingDebugWrapperSession(
         self.sess, session_root=self.session_root, log_usage=False)
     node_stepper = stepper.NodeStepper(self.sess, self.inc_v)
     with self.assertRaisesRegexp(
             NotImplementedError,
             r"NonInteractiveDebugWrapperSession does not support node-stepper "
             r"mode\."):
         sess.invoke_node_stepper(node_stepper)
Пример #12
0
    def testPrintTensorShouldWorkWithTensorName(self):
        with stepper.NodeStepper(self.sess, self.e) as node_stepper:
            cli = stepper_cli.NodeStepperCLI(node_stepper)

            cli.cont("d")
            output = cli.print_tensor(["d:0"])

            self.assertEqual("Tensor \"d:0\":", output.lines[0])
            self.assertEqual("-20.0", output.lines[-1])
Пример #13
0
    def testInjectToNonexistentTensorShouldError(self):
        node_stepper = stepper.NodeStepper(self.sess, self.e)
        cli = stepper_cli.NodeStepperCLI(node_stepper)

        output = cli.inject_value(["foobar:0", "20.0"])
        self.assertEqual([
            "ERROR: foobar:0 is not in the transitive closure of this stepper "
            "instance."
        ], output.lines)
Пример #14
0
  def testContToUpdateNodeLeadsToDirtyVariableLabel(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
    output = cli.cont(["opt/update_b/ApplyGradientDescent"])

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                  stat_labels[node_names.index("b")])
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                     stat_labels[node_names.index("a")])
Пример #15
0
  def testListingSortedNodesPresentsTransitveClosure(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self._assert_nodes_topologically_sorted_with_target_e(node_names)
    self.assertEqual(len(node_names), len(stat_labels))
    for stat_label in stat_labels:
      self.assertEqual("     ", stat_label)
    self.assertEqual(0, node_pointer)
Пример #16
0
  def testPrintTensorShouldWorkSlicingString(self):
    ph_value = np.array([[1.0, 0.0], [0.0, 2.0]])
    cli = stepper_cli.NodeStepperCLI(
        stepper.NodeStepper(
            self.sess, self.f, feed_dict={self.ph: ph_value}))

    output = cli.print_tensor(["ph:0[:, 1]"])
    self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
    self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])

    output = cli.print_tensor(["ph[:, 1]"])
    self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
    self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
Пример #17
0
    def testContToUpdateNodeWithoutTrackingLeadsToNoDirtyVariableLabel(self):
        with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
            cli = stepper_cli.NodeStepperCLI(node_stepper)
            output = cli.cont(["opt/update_b/ApplyGradientDescent"])

            self.assertItemsEqual([self.b.name], _parse_updated(output.lines))

            output = cli.list_sorted_nodes([])
            node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                          stat_labels[node_names.index("b")])
            self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                             stat_labels[node_names.index("a")])
Пример #18
0
  def testSteppingMultipleStepsUpdatesStatus(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.list_sorted_nodes([])
    orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)

    output = cli.step(["-t", "3"])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self.assertEqual(orig_node_names[2], node_names[node_pointer])

    for i in xrange(node_pointer):
      self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])

    for i in xrange(node_pointer + 1, len(stat_labels)):
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
Пример #19
0
    def before_run(self, run_context):
        if not self._wrapper_initialized:
            local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(
                self, run_context.session, ui_type=self._ui_type)

            # Actually register tensor filters registered prior to the construction
            # of the underlying LocalCLIDebugWrapperSession object.
            for filter_name in self._pending_tensor_filters:
                local_cli_wrapper.LocalCLIDebugWrapperSession.add_tensor_filter(
                    self, filter_name,
                    self._pending_tensor_filters[filter_name])

            self._wrapper_initialized = True

        # Increment run call counter.
        self._run_call_count += 1

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._run_call_count)

        on_run_start_response = self.on_run_start(on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            self._decorate_options_for_debug(run_args.options,
                                             run_context.session.graph)
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            # The _finalized property must be set to False so that the NodeStepper
            # can insert ops for retrieving TensorHandles.
            # pylint: disable=protected-access
            run_context.session.graph._finalized = False
            # pylint: enable=protected-access

            with stepper.NodeStepper(
                    run_context.session, run_context.original_args.fetches,
                    run_context.original_args.feed_dict) as node_stepper:
                self.invoke_node_stepper(node_stepper,
                                         restore_variable_values_on_exit=True)

        return run_args
Пример #20
0
    def testContToValidNodeShouldUpdateStatus(self):
        with stepper.NodeStepper(self.sess, self.e) as node_stepper:
            cli = stepper_cli.NodeStepperCLI(node_stepper)

            output = cli.list_sorted_nodes([])
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            index_c = node_names.index("c")
            self.assertEqual("      ", stat_labels[index_c])
            self.assertEqual(0, node_pointer)

            output = cli.cont("c")
            self.assertIsNone(_parse_updated(output.lines))
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            self.assertGreaterEqual(len(node_names), 3)
            self.assertIn("c", node_names)
            index_c = node_names.index("c")
            self.assertEqual(index_c, node_pointer)
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                          stat_labels[index_c])

            output = cli.cont("d")
            self.assertIsNone(_parse_updated(output.lines))
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            used_feed_types = _parsed_used_feeds(output.lines)
            self.assertEqual(
                {
                    "c:0": stepper.NodeStepper.FEED_TYPE_HANDLE,
                    "a/read:0":
                    stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
                }, used_feed_types)

            self.assertGreaterEqual(len(node_names), 3)
            self.assertIn("d", node_names)
            index_d = node_names.index("d")
            self.assertEqual(index_d, node_pointer)
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                          stat_labels[index_d])
Пример #21
0
  def testListingSortedNodesLabelsPlaceholders(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self._assert_nodes_topologically_sorted_with_target_f(node_names)

    index_ph = node_names.index("ph")
    self.assertEqual(len(node_names), len(stat_labels))
    for i in xrange(len(stat_labels)):
      if index_ph == i:
        self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
                      stat_labels[i])
      else:
        self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
                         stat_labels[i])

    self.assertEqual(0, node_pointer)
Пример #22
0
    def testSteppingOneStepAtATimeShouldUpdateStatus(self):
        with stepper.NodeStepper(self.sess, self.e) as node_stepper:
            cli = stepper_cli.NodeStepperCLI(node_stepper)

            output = cli.list_sorted_nodes([])
            orig_node_names, _, node_pointer = _parse_sorted_nodes_list(
                output.lines)
            self.assertEqual(0, node_pointer)

            for i in xrange(len(orig_node_names)):
                output = cli.step([])
                node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                    output.lines)

                next_node_name = node_names[node_pointer]
                self.assertEqual(orig_node_names[i], next_node_name)

                self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                              stat_labels[node_pointer])

                # The order in which the nodes are listed should not change as the
                # stepping happens.
                output = cli.list_sorted_nodes([])
                node_names, _, node_pointer = _parse_sorted_nodes_list(
                    output.lines)
                self.assertEqual(orig_node_names, node_names)

                if i < len(orig_node_names) - 1:
                    self.assertEqual(i + 1, node_pointer)
                else:
                    # Stepped over the limit. Pointer should be at -1.
                    self.assertEqual(-1, node_pointer)

            # Attempt to step once more after the end has been reached should error
            # out.
            output = cli.step([])
            self.assertEqual([
                "ERROR: Cannot step any further because the end of the sorted "
                "transitive closure has been reached."
            ], output.lines)
Пример #23
0
    def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
        """Wrapper around Session.run() that inserts tensor watch options.

    Args:
      fetches: Same as the `fetches` arg to regular `Session.run()`.
      feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
      options: Same as the `options` arg to regular `Session.run()`.
      run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.

    Returns:
      Simply forwards the output of the wrapped `Session.run()` call.

    Raises:
      ValueError: On invalid `OnRunStartAction` value.
    """

        self._run_call_count += 1

        # Invoke on-run-start callback and obtain response.
        run_start_resp = self.on_run_start(
            OnRunStartRequest(fetches, feed_dict, options, run_metadata,
                              self._run_call_count))
        _check_type(run_start_resp, OnRunStartResponse)

        if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
            # Decorate RunOption to fill in debugger tensor watch specifications.
            decorated_run_options = options or config_pb2.RunOptions()
            run_metadata = run_metadata or config_pb2.RunMetadata()

            self._decorate_run_options(
                decorated_run_options,
                run_start_resp.debug_urls,
                debug_ops=run_start_resp.debug_ops,
                node_name_regex_whitelist=run_start_resp.
                node_name_regex_whitelist,
                op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist)

            # Invoke the run() method of the wrapped Session. Catch any TensorFlow
            # runtime errors.
            tf_error = None
            try:
                retvals = self._sess.run(fetches,
                                         feed_dict=feed_dict,
                                         options=decorated_run_options,
                                         run_metadata=run_metadata)
            except errors.OpError as op_error:
                tf_error = op_error
                retvals = op_error

            run_end_req = OnRunEndRequest(
                run_start_resp.action,
                run_metadata=run_metadata,
                client_graph_def=self._sess.graph.as_graph_def(),
                tf_error=tf_error)

        elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN
              or run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
            if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
                retvals = self.invoke_node_stepper(
                    stepper.NodeStepper(self._sess, fetches, feed_dict),
                    restore_variable_values_on_exit=True)

            # Invoke run() method of the wrapped session.
            retvals = self._sess.run(fetches,
                                     feed_dict=feed_dict,
                                     options=options,
                                     run_metadata=run_metadata)

            # Prepare arg for the on-run-end callback.
            run_end_req = OnRunEndRequest(run_start_resp.action)
        else:
            raise ValueError("Invalid OnRunStartAction value: %s" %
                             run_start_resp.action)

        # Invoke on-run-end callback and obtain response.
        run_end_resp = self.on_run_end(run_end_req)
        _check_type(run_end_resp, OnRunEndResponse)
        # Currently run_end_resp is only a placeholder. No action is taken on it.

        return retvals
Пример #24
0
    def testInjectTensorValueByNodeNameShouldBeReflected(self):
        node_stepper = stepper.NodeStepper(self.sess, self.e)
        cli = stepper_cli.NodeStepperCLI(node_stepper)

        cli.inject_value(["d", "20.0"])
        self.assertEqual(["d:0"], node_stepper.override_names())