예제 #1
0
    def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
        cli = stepper_cli.NodeStepperCLI(
            stepper.NodeStepper(self.sess, self.opt))
        output = cli.cont(["opt/update_a/ApplyGradientDescent"])

        # After cont() call on .../update_a/..., Variable a should have been marked
        # as dirty, whereas b should not have.
        output = cli.list_sorted_nodes([])
        node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
        self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                      stat_labels[node_names.index("a")])
        self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                         stat_labels[node_names.index("b")])

        output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r"])

        # After cont() call on .../update_b/... with the -r flag, Variable b should
        # have been marked as dirty, whereas Variable a should not be because it
        # should have been restored.
        output = cli.list_sorted_nodes([])
        node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
        self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                      stat_labels[node_names.index("b")])
        self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                         stat_labels[node_names.index("a")])
예제 #2
0
  def testContToNodeWithoutOutputTensorInClosureShowsNoHandleCached(self):
    with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
      sorted_nodes = node_stepper.sorted_nodes()
      closure_elements = node_stepper.closure_elements()

      # Find a node which is in the list of sorted nodes, but whose output
      # Tensor is not in the transitive closure.
      no_output_node = None
      for node in sorted_nodes:
        if (node + ":0" not in closure_elements and
            node + ":1" not in closure_elements):
          no_output_node = node
          break

      self.assertIsNotNone(no_output_node)

      cli = stepper_cli.NodeStepperCLI(node_stepper)
      output = cli.cont([no_output_node])
      self.assertIsNone(_parse_updated(output.lines))
      node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
          output.lines)

      self.assertEqual(no_output_node, node_names[node_pointer])
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                       stat_labels[node_pointer])
예제 #3
0
  def testContWithRestoreVariablesOptionShouldRestoreVariableValue(self):
    with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)
      output = cli.cont(["opt/update_a/ApplyGradientDescent",
                         "--invalidate_from_updated_variables"])

      self.assertItemsEqual([self.a.name], _parse_updated(output.lines))

      # After cont() call on .../update_a/..., Variable a should have been
      # marked as dirty, whereas b should not have.
      output = cli.list_sorted_nodes([])
      node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
      self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                    stat_labels[node_names.index("a")])
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                       stat_labels[node_names.index("b")])

      output = cli.cont(["opt/update_b/ApplyGradientDescent", "-r", "-i"])

      self.assertItemsEqual([self.b.name], _parse_updated(output.lines))

      # After cont() call on .../update_b/... with the -r flag, Variable b
      # should have been marked as dirty, whereas Variable a should not be
      # because it should have been restored.
      output = cli.list_sorted_nodes([])
      node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
      self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                    stat_labels[node_names.index("b")])
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                       stat_labels[node_names.index("a")])
예제 #4
0
  def testContToValidNodeShouldUpdateStatus(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    index_c = node_names.index("c")
    self.assertEqual("     ", stat_labels[index_c])
    self.assertEqual(0, node_pointer)

    output = cli.cont("c")
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self.assertGreaterEqual(len(node_names), 3)
    self.assertIn("c", node_names)
    index_c = node_names.index("c")
    self.assertEqual(index_c, node_pointer)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_c])

    output = cli.cont("d")
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    used_feed_types = _parsed_used_feeds(output.lines)
    self.assertEqual({"c:0": "handle"}, used_feed_types)

    self.assertGreaterEqual(len(node_names), 3)
    self.assertIn("d", node_names)
    index_d = node_names.index("d")
    self.assertEqual(index_d, node_pointer)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
예제 #5
0
  def testInjectTensorValueByTensorNameShouldBeReflected(self):
    with stepper.NodeStepper(self.sess, self.e) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)

      output = cli.cont(["d"])
      node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
      self.assertEqual("d", node_names[node_pointer])

      output = cli.list_sorted_nodes([])
      node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
          output.lines)

      index_d = node_names.index("d")
      self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
                       stat_labels[index_d])

      self.assertAllClose(-20.0, node_stepper.get_tensor_value("d:0"))

      output = cli.inject_value(["d:0", "20.0"])

      # Verify that the override is available.
      self.assertEqual(["d:0"], node_stepper.override_names())

      # Verify that the list of sorted nodes reflects the existence of the value
      # override (i.e., injection).
      output = cli.list_sorted_nodes([])
      node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
          output.lines)

      index_d = node_names.index("d")
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                       stat_labels[index_d])
      self.assertIn(stepper_cli.NodeStepperCLI.STATE_OVERRIDDEN,
                    stat_labels[index_d])
예제 #6
0
  def testContToNodeOutsideTransitiveClosureShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.cont(["f"])
    self.assertEqual([
        "ERROR: f is not in the transitive closure of this stepper "
        "instance."
    ], output.lines)
예제 #7
0
  def testContToNonexistentNodeShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))

    output = cli.cont(["foobar"])
    self.assertEqual([
        "ERROR: foobar is not in the transitive closure of this stepper "
        "instance."
    ], output.lines)
예제 #8
0
  def testPrintTensorWithNonexistentTensorShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.print_tensor(["foobar"])
    self.assertEqual([
        "ERROR: foobar is not in the transitive closure of this stepper "
        "instance."
    ], output.lines)
예제 #9
0
  def testPrintTensorWithNoHandleShouldError(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.print_tensor("e")
    self.assertEqual([
        "This stepper instance does not have access to the value of tensor "
        "\"e:0\""
    ], output.lines)
예제 #10
0
  def testPrintTensorShouldWorkWithNodeNameWithOutputTensor(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    cli.cont("d")
    output = cli.print_tensor(["d"])

    self.assertEqual("Tensor \"d:0\":", output.lines[0])
    self.assertEqual("-20.0", output.lines[-1])
예제 #11
0
  def testPrintTensorShouldWorkWithTensorName(self):
    with stepper.NodeStepper(self.sess, self.e) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)

      cli.cont("d")
      output = cli.print_tensor(["d:0"])

      self.assertEqual("Tensor \"d:0\":", output.lines[0])
      self.assertEqual("-20.0", output.lines[-1])
예제 #12
0
  def testInjectToNonexistentTensorShouldError(self):
    with stepper.NodeStepper(self.sess, self.e) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)

      output = cli.inject_value(["foobar:0", "20.0"])
      self.assertEqual([
          "ERROR: foobar:0 is not in the transitive closure of this stepper "
          "instance."
      ], output.lines)
예제 #13
0
  def testContToUpdateNodeLeadsToDirtyVariableLabel(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.opt))
    output = cli.cont(["opt/update_b/ApplyGradientDescent"])

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
    self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                  stat_labels[node_names.index("b")])
    self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                     stat_labels[node_names.index("a")])
예제 #14
0
  def testListingSortedNodesPresentsTransitveClosure(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self._assert_nodes_topologically_sorted_with_target_e(node_names)
    self.assertEqual(len(node_names), len(stat_labels))
    for stat_label in stat_labels:
      self.assertEqual("     ", stat_label)
    self.assertEqual(0, node_pointer)
예제 #15
0
  def testPrintTensorShouldWorkSlicingString(self):
    ph_value = np.array([[1.0, 0.0], [0.0, 2.0]])
    with stepper.NodeStepper(
        self.sess, self.f, feed_dict={self.ph: ph_value}) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)

      output = cli.print_tensor(["ph:0[:, 1]"])
      self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
      self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])

      output = cli.print_tensor(["ph[:, 1]"])
      self.assertEqual("Tensor \"ph:0[:, 1]\":", output.lines[0])
      self.assertEqual(repr(ph_value[:, 1]), output.lines[-1])
예제 #16
0
  def testContToUpdateNodeWithoutTrackingLeadsToNoDirtyVariableLabel(self):
    with stepper.NodeStepper(self.sess, self.opt) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)
      output = cli.cont(["opt/update_b/ApplyGradientDescent"])

      self.assertItemsEqual([self.b.name], _parse_updated(output.lines))

      output = cli.list_sorted_nodes([])
      node_names, stat_labels, _ = _parse_sorted_nodes_list(output.lines)
      self.assertIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                    stat_labels[node_names.index("b")])
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_DIRTY_VARIABLE,
                       stat_labels[node_names.index("a")])
예제 #17
0
  def testSteppingMultipleStepsUpdatesStatus(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.e))

    output = cli.list_sorted_nodes([])
    orig_node_names, _, _ = _parse_sorted_nodes_list(output.lines)

    output = cli.step(["-t", "3"])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self.assertEqual(orig_node_names[2], node_names[node_pointer])

    for i in xrange(node_pointer):
      self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])

    for i in xrange(node_pointer + 1, len(stat_labels)):
      self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[i])
예제 #18
0
    def testContToValidNodeShouldUpdateStatus(self):
        with stepper.NodeStepper(self.sess, self.e) as node_stepper:
            cli = stepper_cli.NodeStepperCLI(node_stepper)

            output = cli.list_sorted_nodes([])
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            index_c = node_names.index("c")
            self.assertEqual("      ", stat_labels[index_c])
            self.assertEqual(0, node_pointer)

            output = cli.cont("c")
            self.assertIsNone(_parse_updated(output.lines))
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            self.assertGreaterEqual(len(node_names), 3)
            self.assertIn("c", node_names)
            index_c = node_names.index("c")
            self.assertEqual(index_c, node_pointer)
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                          stat_labels[index_c])

            output = cli.cont("d")
            self.assertIsNone(_parse_updated(output.lines))
            node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
                output.lines)

            used_feed_types = _parsed_used_feeds(output.lines)
            self.assertEqual(
                {
                    "c:0": stepper.NodeStepper.FEED_TYPE_HANDLE,
                    "a/read:0":
                    stepper.NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
                }, used_feed_types)

            self.assertGreaterEqual(len(node_names), 3)
            self.assertIn("d", node_names)
            index_d = node_names.index("d")
            self.assertEqual(index_d, node_pointer)
            self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                          stat_labels[index_d])
예제 #19
0
  def testListingSortedNodesLabelsPlaceholders(self):
    cli = stepper_cli.NodeStepperCLI(stepper.NodeStepper(self.sess, self.f))

    output = cli.list_sorted_nodes([])
    node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
        output.lines)

    self._assert_nodes_topologically_sorted_with_target_f(node_names)

    index_ph = node_names.index("ph")
    self.assertEqual(len(node_names), len(stat_labels))
    for i in xrange(len(stat_labels)):
      if index_ph == i:
        self.assertIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
                      stat_labels[i])
      else:
        self.assertNotIn(stepper_cli.NodeStepperCLI.STATE_IS_PLACEHOLDER,
                         stat_labels[i])

    self.assertEqual(0, node_pointer)
예제 #20
0
  def testSteppingOneStepAtATimeShouldUpdateStatus(self):
    if test_util.is_gpu_available():
      self.skipTest("b/123446705 this causes a segfault on GPU")

    with stepper.NodeStepper(self.sess, self.e) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)

      output = cli.list_sorted_nodes([])
      orig_node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
      self.assertEqual(0, node_pointer)

      for i in xrange(len(orig_node_names)):
        output = cli.step([])
        node_names, stat_labels, node_pointer = _parse_sorted_nodes_list(
            output.lines)

        next_node_name = node_names[node_pointer]
        self.assertEqual(orig_node_names[i], next_node_name)

        self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT,
                      stat_labels[node_pointer])

        # The order in which the nodes are listed should not change as the
        # stepping happens.
        output = cli.list_sorted_nodes([])
        node_names, _, node_pointer = _parse_sorted_nodes_list(output.lines)
        self.assertEqual(orig_node_names, node_names)

        if i < len(orig_node_names) - 1:
          self.assertEqual(i + 1, node_pointer)
        else:
          # Stepped over the limit. Pointer should be at -1.
          self.assertEqual(-1, node_pointer)

      # Attempt to step once more after the end has been reached should error
      # out.
      output = cli.step([])
      self.assertEqual([
          "ERROR: Cannot step any further because the end of the sorted "
          "transitive closure has been reached."
      ], output.lines)
예제 #21
0
  def invoke_node_stepper(self,
                          node_stepper,
                          restore_variable_values_on_exit=True):
    """Overrides method in base class to implement interactive node stepper.

    Args:
      node_stepper: (stepper.NodeStepper) The underlying NodeStepper API object.
      restore_variable_values_on_exit: (bool) Whether any variables whose values
        have been altered during this node-stepper invocation should be restored
        to their old values when this invocation ends.

    Returns:
      The same return values as the `Session.run()` call on the same fetches as
        the NodeStepper.
    """

    stepper = stepper_cli.NodeStepperCLI(node_stepper)

    # On exiting the node-stepper CLI, the finalize method of the node_stepper
    # object will be called, ensuring that the state of the graph will be the
    # same as if the stepping did not happen.
    # TODO(cais): Perhaps some users will want the effect of the interactive
    # stepping and value injection to persist. When that happens, make the call
    # to finalize optional.
    stepper_ui = curses_ui.CursesUI(
        on_ui_exit=(node_stepper.restore_variable_values
                    if restore_variable_values_on_exit else None))

    stepper_ui.register_command_handler(
        "list_sorted_nodes",
        stepper.list_sorted_nodes,
        stepper.arg_parsers["list_sorted_nodes"].format_help(),
        prefix_aliases=["lt", "lsn"])
    stepper_ui.register_command_handler(
        "cont",
        stepper.cont,
        stepper.arg_parsers["cont"].format_help(),
        prefix_aliases=["ct", "c"])
    stepper_ui.register_command_handler(
        "step",
        stepper.step,
        stepper.arg_parsers["step"].format_help(),
        prefix_aliases=["st", "s"])
    stepper_ui.register_command_handler(
        "print_tensor",
        stepper.print_tensor,
        stepper.arg_parsers["print_tensor"].format_help(),
        prefix_aliases=["pt"])
    stepper_ui.register_command_handler(
        "inject_value",
        stepper.inject_value,
        stepper.arg_parsers["inject_value"].format_help(),
        prefix_aliases=["inject", "override_value", "override"])

    # Register tab completion candidates.
    stepper_ui.register_tab_comp_context([
        "cont", "ct", "c", "pt", "inject_value", "inject", "override_value",
        "override"
    ], [str(elem) for elem in node_stepper.sorted_nodes()])
    # TODO(cais): Tie up register_tab_comp_context to a single alias to shorten
    # calls like this.

    return stepper_ui.run_ui(
        init_command="lt",
        title="Node Stepper: " + self._run_description,
        title_color="blue_on_white")
예제 #22
0
  def testInjectTensorValueByNodeNameShouldBeReflected(self):
    with stepper.NodeStepper(self.sess, self.e) as node_stepper:
      cli = stepper_cli.NodeStepperCLI(node_stepper)

      cli.inject_value(["d", "20.0"])
      self.assertEqual(["d:0"], node_stepper.override_names())