Example #1
0
    def testUsingNodesNotUsingIntermediateTensors(self):
        stepper = NodeStepper(self.sess, self.e)

        # There should be no handles before any cont() calls.
        self.assertEqual([], stepper.handle_names())
        self.assertSetEqual(set(), stepper.handle_node_names())

        # Before the cont() call, the stepper should not have access to the value
        # of c:0.
        with self.assertRaisesRegexp(
                ValueError,
                "This stepper instance does not have access to the value of tensor "
                "\"c:0\""):
            stepper.get_tensor_value("c:0")

        # Using the node/tensor itself, instead of the name str, should work on
        # cont().
        result = stepper.cont(self.c)
        self.assertAllClose(6.0, result)
        self.assertEqual({}, stepper.last_feed_types())

        self.assertEqual(["c:0"], stepper.handle_names())
        self.assertEqual({"c"}, stepper.handle_node_names())

        # After the cont() call, the stepper should have access to the value of c:0
        # via a tensor handle.
        self.assertAllClose(6.0, stepper.get_tensor_value("c:0"))

        result = stepper.cont(self.e)
        self.assertAllClose(24.0, result)
        self.assertEqual({"c:0": NodeStepper.FEED_TYPE_HANDLE},
                         stepper.last_feed_types())
Example #2
0
  def testUsingNodesNotUsingIntermediateTensors(self):
    stepper = NodeStepper(self.sess, self.e)

    # There should be no handles before any cont() calls.
    self.assertEqual([], stepper.handle_names())

    # Before the cont() call, the stepper should not have access to the value
    # of c:0.
    with self.assertRaisesRegexp(
        ValueError,
        "This stepper instance does not have access to the value of tensor "
        "\"c:0\""):
      stepper.get_tensor_value("c:0")

    # Using the node/tensor itself, instead of the name str, should work on
    # cont().
    result = stepper.cont(self.c)
    self.assertAllClose(6.0, result)
    self.assertEqual({}, stepper.last_feed_types())

    self.assertEqual(["c:0"], stepper.handle_names())

    # After the cont() call, the stepper should have access to the value of c:0
    # via a tensor handle.
    self.assertAllClose(6.0, stepper.get_tensor_value("c:0"))

    result = stepper.cont(self.e)
    self.assertAllClose(24.0, result)
    self.assertEqual({
        "c:0": NodeStepper.FEED_TYPE_HANDLE
    }, stepper.last_feed_types())
Example #3
0
    def testGetTensorValueWorksOnPlaceholder(self):
        stepper = NodeStepper(self.sess,
                              self.y,
                              feed_dict={
                                  self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
                                  self.ph1: [[-1.0], [0.5]]
                              })

        self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
                            stepper.get_tensor_value("ph0"))
        self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
                            stepper.get_tensor_value("ph0:0"))
        with self.assertRaisesRegexp(
                KeyError,
                r"The name 'ph0:1' refers to a Tensor which does not exist"):
            stepper.get_tensor_value("ph0:1")
  def testGetTensorValueWorksOnPlaceholder(self):
    stepper = NodeStepper(
        self.sess,
        self.y,
        feed_dict={
            self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
            self.ph1: [[-1.0], [0.5]]
        })

    self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
                        stepper.get_tensor_value("ph0"))
    self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
                        stepper.get_tensor_value("ph0:0"))
    with self.assertRaisesRegexp(
        KeyError, r"The name 'ph0:1' refers to a Tensor which does not exist"):
      stepper.get_tensor_value("ph0:1")
  def testOverrideValue(self):
    stepper = NodeStepper(self.sess, self.e)

    result = stepper.cont(self.c)
    self.assertAllClose(6.0, result)
    self.assertEqual({}, stepper.last_feed_types())

    # There should be no overrides before any cont() calls.
    self.assertEqual([], stepper.override_names())

    # Calling cont() on c again should lead to use of the handle.
    result = stepper.cont(self.c)
    self.assertAllClose(6.0, result)
    self.assertEqual({
        "c:0": NodeStepper.FEED_TYPE_HANDLE
    }, stepper.last_feed_types())

    # Override c:0.
    stepper.override_tensor("c:0", 7.0)

    # After the overriding, calling get_tensor_value() on c:0 should yield the
    # overriding value.
    self.assertEqual(7.0, stepper.get_tensor_value("c:0"))

    # Now c:0 should have only an override value, but no cached handle, because
    # the handle should have been invalidated.
    self.assertEqual([], stepper.handle_names())
    self.assertSetEqual(set(), stepper.handle_node_names())
    self.assertEqual(["c:0"], stepper.override_names())

    # Run a downstream tensor after the value override.
    result = stepper.cont(self.e)
    self.assertAllClose(28.0, result)  # Should reflect the overriding value.

    # Should use override, instead of the handle.
    self.assertEqual({
        "c:0": NodeStepper.FEED_TYPE_OVERRIDE
    }, stepper.last_feed_types())
Example #6
0
  def testOverrideValue(self):
    stepper = NodeStepper(self.sess, self.e)

    result = stepper.cont(self.c)
    self.assertAllClose(6.0, result)
    self.assertEqual({}, stepper.last_feed_types())

    # There should be no overrides before any cont() calls.
    self.assertEqual([], stepper.override_names())

    # Calling cont() on c again should lead to use of the handle.
    result = stepper.cont(self.c)
    self.assertAllClose(6.0, result)
    self.assertEqual({
        "c:0": NodeStepper.FEED_TYPE_HANDLE
    }, stepper.last_feed_types())

    # Override c:0.
    stepper.override_tensor("c:0", 7.0)

    # After the overriding, calling get_tensor_value() on c:0 should yield the
    # overriding value.
    self.assertEqual(7.0, stepper.get_tensor_value("c:0"))

    # Now c:0 should have only an override value, but no cached handle, because
    # the handle should have been invalidated.
    self.assertEqual([], stepper.handle_names())
    self.assertSetEqual(set(), stepper.handle_node_names())
    self.assertEqual(["c:0"], stepper.override_names())

    # Run a downstream tensor after the value override.
    result = stepper.cont(self.e)
    self.assertAllClose(28.0, result)  # Should reflect the overriding value.

    # Should use override, instead of the handle.
    self.assertEqual({
        "c:0": NodeStepper.FEED_TYPE_OVERRIDE
    }, stepper.last_feed_types())