Example #1
0
 def testOverwritingBehavior(self):
   g = ops.Graph()
   with g.device(self._overwritingDeviceFunction):
     g.create_op("an_op", [], [dtypes.float32])
     with g.device("/job:ps"):  # Will be overwritten.
       g.create_op("an_op", [], [dtypes.float32])
     with g.device(pydev.merge_device("/job:ps")):  # Will be overwritten.
       g.create_op("an_op", [], [dtypes.float32])
     with g.device(None):  # Disables overwriting device function
       with g.device("/job:ps"):
         g.create_op("an_op", [], [dtypes.float32])
     with g.device(None):  # Disables overwriting device function
       with g.device(pydev.merge_device("/job:ps")):
         g.create_op("an_op", [], [dtypes.float32])
   gd = g.as_graph_def()
   self.assertProtoEqualsVersion("""
     node { name: "an_op" op: "an_op"
            device: "/job:overwrite" }
     node { name: "an_op_1" op: "an_op"
            device: "/job:overwrite" }
     node { name: "an_op_2" op: "an_op"
            device: "/job:overwrite" }
     node { name: "an_op_3" op: "an_op"
            device: "/job:ps" }
     node { name: "an_op_4" op: "an_op"
            device: "/job:ps" }
   """, gd)
Example #2
0
    def testNestingWithMergeDeviceFunction(self):
        g = ops.Graph()

        with g.device(pydev.merge_device("/device:GPU:0")):
            g.create_op("an_op", [], [dtypes.float32])
            with g.device(pydev.merge_device("/job:worker")):
                g.create_op("an_op", [], [dtypes.float32])
                with g.device(pydev.merge_device("/device:CPU:0")):
                    g.create_op("an_op", [], [dtypes.float32])
                    with g.device(pydev.merge_device("/job:ps")):
                        g.create_op("an_op", [], [dtypes.float32])
                        with g.device(pydev.merge_device(None)):
                            g.create_op("an_op", [], [dtypes.float32])

        gd = g.as_graph_def()
        self.assertProtoEquals(
            """
      node { name: "an_op" op: "an_op"
             device: "/device:GPU:0" }
      node { name: "an_op_1" op: "an_op"
             device: "/job:worker/device:GPU:0" }
      node { name: "an_op_2" op: "an_op"
             device: "/job:worker/device:CPU:0" }
      node { name: "an_op_3" op: "an_op"
             device: "/job:ps/device:CPU:0" }
      node { name: "an_op_4" op: "an_op"
             device: "/job:ps/device:CPU:0" }
    """, gd)
Example #3
0
  def testNestingWithMergeDeviceFunction(self):
    g = ops.Graph()

    with g.device(pydev.merge_device("/device:GPU:0")):
      g.create_op("an_op", [], [dtypes.float32])
      with g.device(pydev.merge_device("/job:worker")):
        g.create_op("an_op", [], [dtypes.float32])
        with g.device(pydev.merge_device("/device:CPU:0")):
          g.create_op("an_op", [], [dtypes.float32])
          with g.device(pydev.merge_device("/job:ps")):
            g.create_op("an_op", [], [dtypes.float32])
            with g.device(pydev.merge_device(None)):
              g.create_op("an_op", [], [dtypes.float32])

    gd = g.as_graph_def()
    self.assertProtoEqualsVersion("""
      node { name: "an_op" op: "an_op"
             device: "/device:GPU:0" }
      node { name: "an_op_1" op: "an_op"
             device: "/job:worker/device:GPU:0" }
      node { name: "an_op_2" op: "an_op"
             device: "/job:worker/device:CPU:0" }
      node { name: "an_op_3" op: "an_op"
             device: "/job:ps/device:CPU:0" }
      node { name: "an_op_4" op: "an_op"
             device: "/job:ps/device:CPU:0" }
    """, gd)
Example #4
0
 def testOverwritingBehavior(self):
     g = ops.Graph()
     with g.device(self._overwritingDeviceFunction):
         g.create_op("an_op", [], [dtypes.float32])
         with g.device("/job:ps"):  # Will be overwritten.
             g.create_op("an_op", [], [dtypes.float32])
         with g.device(
                 pydev.merge_device("/job:ps")):  # Will be overwritten.
             g.create_op("an_op", [], [dtypes.float32])
         with g.device(None):  # Disables overwriting device function
             with g.device("/job:ps"):
                 g.create_op("an_op", [], [dtypes.float32])
         with g.device(None):  # Disables overwriting device function
             with g.device(pydev.merge_device("/job:ps")):
                 g.create_op("an_op", [], [dtypes.float32])
     gd = g.as_graph_def()
     self.assertProtoEqualsVersion(
         """
   node { name: "an_op" op: "an_op"
          device: "/job:overwrite" }
   node { name: "an_op_1" op: "an_op"
          device: "/job:overwrite" }
   node { name: "an_op_2" op: "an_op"
          device: "/job:overwrite" }
   node { name: "an_op_3" op: "an_op"
          device: "/job:ps" }
   node { name: "an_op_4" op: "an_op"
          device: "/job:ps" }
 """, gd)
Example #5
0
  def testMerge(self):
    d = device.DeviceSpec.from_string("/job:foo/replica:0")
    self.assertEquals("/job:foo/replica:0", d.to_string())
    d.merge_from(device.DeviceSpec.from_string("/task:1/device:GPU:2"))
    self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())

    d = device.DeviceSpec()
    d.merge_from(device.DeviceSpec.from_string("/task:1/cpu:0"))
    self.assertEquals("/task:1/device:CPU:0", d.to_string())
    d.merge_from(device.DeviceSpec.from_string("/job:boo/device:GPU:0"))
    self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
    d.merge_from(device.DeviceSpec.from_string("/job:muu/cpu:2"))
    self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
    d.merge_from(device.DeviceSpec.from_string(
        "/job:muu/device:MyFunnyDevice:2"))
    self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())

    if not context.executing_eagerly():
      with ops.device(device.merge_device("/device:GPU:0")):
        var1 = variables.Variable(1.0)
        self.assertEquals("/device:GPU:0", var1.device)
        with ops.device(device.merge_device("/job:worker")):
          var2 = variables.Variable(1.0)
          self.assertEquals("/job:worker/device:GPU:0", var2.device)
          with ops.device(device.merge_device("/device:CPU:0")):
            var3 = variables.Variable(1.0)
            self.assertEquals("/job:worker/device:CPU:0", var3.device)
            with ops.device(device.merge_device("/job:ps")):
              var4 = variables.Variable(1.0)
              self.assertEquals("/job:ps/device:CPU:0", var4.device)
Example #6
0
  def testMerge(self, DeviceSpec):  # pylint: disable=invalid-name
    d = DeviceSpec.from_string("/job:muu/task:1/device:MyFunnyDevice:2")
    self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())

    if not context.executing_eagerly():
      with ops.device(device.merge_device("/device:GPU:0")):
        var1 = variables.Variable(1.0)
        self.assertEqual("/device:GPU:0", var1.device)
        with ops.device(device.merge_device("/job:worker")):
          var2 = variables.Variable(1.0)
          self.assertEqual("/job:worker/device:GPU:0", var2.device)
          with ops.device(device.merge_device("/device:CPU:0")):
            var3 = variables.Variable(1.0)
            self.assertEqual("/job:worker/device:CPU:0", var3.device)
            with ops.device(device.merge_device("/job:ps")):
              var4 = variables.Variable(1.0)
              self.assertEqual("/job:ps/device:CPU:0", var4.device)
Example #7
0
  def testMerge(self, DeviceSpec):  # pylint: disable=invalid-name
    d = DeviceSpec.from_string("/job:muu/task:1/device:MyFunnyDevice:2")
    self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())

    if not context.executing_eagerly():
      with ops.device(device.merge_device("/device:GPU:0")):
        var1 = variables.Variable(1.0)
        self.assertEqual("/device:GPU:0", var1.device)
        with ops.device(device.merge_device("/job:worker")):
          var2 = variables.Variable(1.0)
          self.assertEqual("/job:worker/device:GPU:0", var2.device)
          with ops.device(device.merge_device("/device:CPU:0")):
            var3 = variables.Variable(1.0)
            self.assertEqual("/job:worker/device:CPU:0", var3.device)
            with ops.device(device.merge_device("/job:ps")):
              var4 = variables.Variable(1.0)
              self.assertEqual("/job:ps/device:CPU:0", var4.device)
Example #8
0
    def testWithDevice(self):
        with ops.Graph().as_default() as g:
            # No device.
            a = constant_op.constant(3.0, name="a")

            with ops.device("/cpu:0"):
                b = constant_op.constant(4.0, name="b")
            with ops.device("/job:worker"):
                c = constant_op.constant(5.0, name="c")

        gdef = g.as_graph_def()

        with ops.Graph().as_default():
            a2, b2, c2 = importer.import_graph_def(
                gdef, return_elements=["a", "b", "c"])
            self.assertEqual(a.device, a2.device)
            self.assertEqual(b.device, b2.device)
            self.assertEqual(c.device, c2.device)

        with ops.Graph().as_default():
            with ops.device(device.merge_device("/task:0")):
                a3, b3, c3 = importer.import_graph_def(
                    gdef, return_elements=["a", "b", "c"])
                self.assertEqual("/task:0", a3.device)
                self.assertEqual("/task:0/device:CPU:0",
                                 b3.device)  # canonicalized.
                self.assertEqual(c.device + "/task:0", c3.device)

        with ops.Graph().as_default():
            with ops.device(device.merge_device("/job:ps")):
                a4, b4, c4 = importer.import_graph_def(
                    gdef, return_elements=["a", "b", "c"])
                self.assertEqual("/job:ps", a4.device)
                self.assertEqual("/job:ps/device:CPU:0",
                                 b4.device)  # canonicalized.
                self.assertEqual(c.device, c4.device)  # worker overrides ps.

        with ops.Graph().as_default():
            with ops.device(device.merge_device("/gpu:0")):
                a5, b5, c5 = importer.import_graph_def(
                    gdef, return_elements=["a", "b", "c"])
                self.assertEqual("/device:GPU:0", a5.device)
                self.assertEqual("/device:CPU:0",
                                 b5.device)  # cpu overrides gpu.
                self.assertEqual(c.device + "/device:GPU:0", c5.device)
Example #9
0
    def testWithDevice(self):
        with tf.Graph().as_default() as g:
            # No device.
            a = tf.constant(3.0, name='a')

            with tf.device('/cpu:0'):
                b = tf.constant(4.0, name='b')
            with tf.device('/job:worker'):
                c = tf.constant(5.0, name='c')

        gdef = g.as_graph_def()

        with tf.Graph().as_default():
            a2, b2, c2 = tf.import_graph_def(gdef,
                                             return_elements=['a', 'b', 'c'])
            self.assertEqual(a.device, a2.device)
            self.assertEqual(b.device, b2.device)
            self.assertEqual(c.device, c2.device)

        with tf.Graph().as_default():
            with tf.device(device.merge_device('/task:0')):
                a3, b3, c3 = tf.import_graph_def(
                    gdef, return_elements=['a', 'b', 'c'])
                self.assertEqual('/task:0', a3.device)
                self.assertEqual('/task:0/device:CPU:0',
                                 b3.device)  # canonicalized.
                self.assertEqual(c.device + '/task:0', c3.device)

        with tf.Graph().as_default():
            with tf.device(device.merge_device('/job:ps')):
                a4, b4, c4 = tf.import_graph_def(
                    gdef, return_elements=['a', 'b', 'c'])
                self.assertEqual('/job:ps', a4.device)
                self.assertEqual('/job:ps/device:CPU:0',
                                 b4.device)  # canonicalized.
                self.assertEqual(c.device, c4.device)  # worker overrides ps.

        with tf.Graph().as_default():
            with tf.device(device.merge_device('/gpu:0')):
                a5, b5, c5 = tf.import_graph_def(
                    gdef, return_elements=['a', 'b', 'c'])
                self.assertEqual('/device:GPU:0', a5.device)
                self.assertEqual('/device:CPU:0',
                                 b5.device)  # cpu overrides gpu.
                self.assertEqual(c.device + '/device:GPU:0', c5.device)
Example #10
0
  def testWithDevice(self):
    if ops._USE_C_API: return  # TODO(skyewm): make this work with C API

    with ops.Graph().as_default() as g:
      # No device.
      a = constant_op.constant(3.0, name="a")

      with ops.device("/cpu:0"):
        b = constant_op.constant(4.0, name="b")
      with ops.device("/job:worker"):
        c = constant_op.constant(5.0, name="c")

    gdef = g.as_graph_def()

    with ops.Graph().as_default():
      a2, b2, c2 = importer.import_graph_def(
          gdef, return_elements=["a", "b", "c"])
      self.assertEqual(a.device, a2.device)
      self.assertEqual(b.device, b2.device)
      self.assertEqual(c.device, c2.device)

    with ops.Graph().as_default():
      with ops.device(device.merge_device("/task:0")):
        a3, b3, c3 = importer.import_graph_def(
            gdef, return_elements=["a", "b", "c"])
        self.assertEqual("/task:0", a3.device)
        self.assertEqual("/task:0/device:CPU:0", b3.device)  # canonicalized.
        self.assertEqual(c.device + "/task:0", c3.device)

    with ops.Graph().as_default():
      with ops.device(device.merge_device("/job:ps")):
        a4, b4, c4 = importer.import_graph_def(
            gdef, return_elements=["a", "b", "c"])
        self.assertEqual("/job:ps", a4.device)
        self.assertEqual("/job:ps/device:CPU:0", b4.device)  # canonicalized.
        self.assertEqual(c.device, c4.device)  # worker overrides ps.

    with ops.Graph().as_default():
      with ops.device(device.merge_device("/device:GPU:0")):
        a5, b5, c5 = importer.import_graph_def(
            gdef, return_elements=["a", "b", "c"])
        self.assertEqual("/device:GPU:0", a5.device)
        self.assertEqual("/device:CPU:0", b5.device)  # cpu overrides gpu.
        self.assertEqual(c.device + "/device:GPU:0", c5.device)
  def testWithDevice(self):
    with tf.Graph().as_default() as g:
      # No device.
      a = tf.constant(3.0, name='a')

      with tf.device('/cpu:0'):
        b = tf.constant(4.0, name='b')
      with tf.device('/job:worker'):
        c = tf.constant(5.0, name='c')

    gdef = g.as_graph_def()

    with tf.Graph().as_default():
      a2, b2, c2 = tf.import_graph_def(
          gdef, return_elements=['a', 'b', 'c'])
      self.assertEqual(a.device, a2.device)
      self.assertEqual(b.device, b2.device)
      self.assertEqual(c.device, c2.device)

    with tf.Graph().as_default():
      with tf.device(device.merge_device('/task:0')):
        a3, b3, c3 = tf.import_graph_def(
            gdef, return_elements=['a', 'b', 'c'])
        self.assertEqual('/task:0', a3.device)
        self.assertEqual('/task:0/device:CPU:0', b3.device)  # canonicalized.
        self.assertEqual(c.device + '/task:0', c3.device)

    with tf.Graph().as_default():
      with tf.device(device.merge_device('/job:ps')):
        a4, b4, c4 = tf.import_graph_def(
            gdef, return_elements=['a', 'b', 'c'])
        self.assertEqual('/job:ps', a4.device)
        self.assertEqual('/job:ps/device:CPU:0', b4.device)  # canonicalized.
        self.assertEqual(c.device, c4.device)  # worker overrides ps.

    with tf.Graph().as_default():
      with tf.device(device.merge_device('/gpu:0')):
        a5, b5, c5 = tf.import_graph_def(
            gdef, return_elements=['a', 'b', 'c'])
        self.assertEqual('/device:GPU:0', a5.device)
        self.assertEqual('/device:CPU:0', b5.device)  # cpu overrides gpu.
        self.assertEqual(c.device + '/device:GPU:0', c5.device)
Example #12
0
  def testWithDevice(self):
    with tf.Graph().as_default() as g:
      # No device.
      a = tf.constant(3.0, name="a")

      with tf.device("/cpu:0"):
        b = tf.constant(4.0, name="b")
      with tf.device("/job:worker"):
        c = tf.constant(5.0, name="c")

    gdef = g.as_graph_def()

    with tf.Graph().as_default():
      a2, b2, c2 = tf.import_graph_def(
          gdef, return_elements=["a", "b", "c"])
      self.assertEqual(a.device, a2.device)
      self.assertEqual(b.device, b2.device)
      self.assertEqual(c.device, c2.device)

    with tf.Graph().as_default():
      with tf.device(device.merge_device("/task:0")):
        a3, b3, c3 = tf.import_graph_def(
            gdef, return_elements=["a", "b", "c"])
        self.assertEqual("/task:0", a3.device)
        self.assertEqual("/task:0/device:CPU:0", b3.device)  # canonicalized.
        self.assertEqual(c.device + "/task:0", c3.device)

    with tf.Graph().as_default():
      with tf.device(device.merge_device("/job:ps")):
        a4, b4, c4 = tf.import_graph_def(
            gdef, return_elements=["a", "b", "c"])
        self.assertEqual("/job:ps", a4.device)
        self.assertEqual("/job:ps/device:CPU:0", b4.device)  # canonicalized.
        self.assertEqual(c.device, c4.device)  # worker overrides ps.

    with tf.Graph().as_default():
      with tf.device(device.merge_device("/gpu:0")):
        a5, b5, c5 = tf.import_graph_def(
            gdef, return_elements=["a", "b", "c"])
        self.assertEqual("/device:GPU:0", a5.device)
        self.assertEqual("/device:CPU:0", b5.device)  # cpu overrides gpu.
        self.assertEqual(c.device + "/device:GPU:0", c5.device)
Example #13
0
    def __init__(self, *args, **kwargs):
        super(OpsTestCase, self).__init__(*args, **kwargs)

        # add the device to the "_device_function_stack" of the default graph
        dev = device.merge_device(DEVICE_ID)
        tf.get_default_graph()._device_function_stack.append(dev)

        # create a tf session
        self.sess = tf.Session()

        self.ndigits = 7
Example #14
0
def _make_execution_context() -> ExecutionContext:
    """Generates an ExecutionContext based on current contextual info."""
    ctx = context.context()

    # Don't need to open an init_scope if the _cache_key call is in eager mode
    # already.
    executing_eagerly = ctx.executing_eagerly()
    parent_graph = None
    xla_context_id = 0
    if not executing_eagerly:
        # We want to force function retracing for each different
        # XLAControlFlowContext, so add `xla_context_id` to the cache key.
        xla_context = _enclosing_xla_context()
        if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing(
        ):
            xla_context_id = id(xla_context)

        with ops.init_scope():
            # The graph, or whether we're executing eagerly, should be a part of the
            # cache key so we don't improperly capture tensors such as variables.
            executing_eagerly = ctx.executing_eagerly()
            parent_graph = None if executing_eagerly else ops.get_default_graph(
            )

    # pylint: disable=protected-access
    default_graph = ops.get_default_graph()
    # TODO(b/117617952): The current distribution strategy will affect graph
    # building (e.g. accessing different variables from different devices) and
    # so requires retracing for each device.
    strategy_stack = default_graph._distribution_strategy_stack
    uses_distribution_strategy = (strategy_stack
                                  and strategy_stack[-1].strategy.extended.
                                  _retrace_functions_for_each_device)
    if executing_eagerly:
        colocation_stack = ()
        if uses_distribution_strategy:
            device_functions = (pydev.merge_device(ctx.device_name), )
        else:
            device_functions = ()
    else:
        colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
        if (uses_distribution_strategy
                or func_graph_module.device_stack_has_callable(
                    default_graph._device_function_stack)):
            # Putting the device in the cache key ensures that call-site device
            # annotations are respected.
            device_functions = tuple(
                default_graph._device_functions_outer_to_inner)
        else:
            device_functions = ()

    in_cross_replica_context = False
    try:
        in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
    except (AttributeError, IndexError):
        pass

    if save_context.in_save_context():
        variable_policy = (
            save_context.get_save_options().experimental_variable_policy)
    else:
        variable_policy = None

    return ExecutionContext(parent_graph, device_functions, colocation_stack,
                            in_cross_replica_context, variable_policy,
                            xla_context_id)