def testOverwritingBehavior(self): g = ops.Graph() with g.device(self._overwritingDeviceFunction): g.create_op("an_op", [], [dtypes.float32]) with g.device("/job:ps"): # Will be overwritten. g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. g.create_op("an_op", [], [dtypes.float32]) with g.device(None): # Disables overwriting device function with g.device("/job:ps"): g.create_op("an_op", [], [dtypes.float32]) with g.device(None): # Disables overwriting device function with g.device(pydev.merge_device("/job:ps")): g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "an_op" op: "an_op" device: "/job:overwrite" } node { name: "an_op_1" op: "an_op" device: "/job:overwrite" } node { name: "an_op_2" op: "an_op" device: "/job:overwrite" } node { name: "an_op_3" op: "an_op" device: "/job:ps" } node { name: "an_op_4" op: "an_op" device: "/job:ps" } """, gd)
def testNestingWithMergeDeviceFunction(self): g = ops.Graph() with g.device(pydev.merge_device("/device:GPU:0")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:worker")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/device:CPU:0")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:ps")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device(None)): g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals( """ node { name: "an_op" op: "an_op" device: "/device:GPU:0" } node { name: "an_op_1" op: "an_op" device: "/job:worker/device:GPU:0" } node { name: "an_op_2" op: "an_op" device: "/job:worker/device:CPU:0" } node { name: "an_op_3" op: "an_op" device: "/job:ps/device:CPU:0" } node { name: "an_op_4" op: "an_op" device: "/job:ps/device:CPU:0" } """, gd)
def testNestingWithMergeDeviceFunction(self): g = ops.Graph() with g.device(pydev.merge_device("/device:GPU:0")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:worker")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/device:CPU:0")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:ps")): g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device(None)): g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion(""" node { name: "an_op" op: "an_op" device: "/device:GPU:0" } node { name: "an_op_1" op: "an_op" device: "/job:worker/device:GPU:0" } node { name: "an_op_2" op: "an_op" device: "/job:worker/device:CPU:0" } node { name: "an_op_3" op: "an_op" device: "/job:ps/device:CPU:0" } node { name: "an_op_4" op: "an_op" device: "/job:ps/device:CPU:0" } """, gd)
def testOverwritingBehavior(self): g = ops.Graph() with g.device(self._overwritingDeviceFunction): g.create_op("an_op", [], [dtypes.float32]) with g.device("/job:ps"): # Will be overwritten. g.create_op("an_op", [], [dtypes.float32]) with g.device( pydev.merge_device("/job:ps")): # Will be overwritten. g.create_op("an_op", [], [dtypes.float32]) with g.device(None): # Disables overwriting device function with g.device("/job:ps"): g.create_op("an_op", [], [dtypes.float32]) with g.device(None): # Disables overwriting device function with g.device(pydev.merge_device("/job:ps")): g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEqualsVersion( """ node { name: "an_op" op: "an_op" device: "/job:overwrite" } node { name: "an_op_1" op: "an_op" device: "/job:overwrite" } node { name: "an_op_2" op: "an_op" device: "/job:overwrite" } node { name: "an_op_3" op: "an_op" device: "/job:ps" } node { name: "an_op_4" op: "an_op" device: "/job:ps" } """, gd)
def testMerge(self): d = device.DeviceSpec.from_string("/job:foo/replica:0") self.assertEquals("/job:foo/replica:0", d.to_string()) d.merge_from(device.DeviceSpec.from_string("/task:1/device:GPU:2")) self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) d = device.DeviceSpec() d.merge_from(device.DeviceSpec.from_string("/task:1/cpu:0")) self.assertEquals("/task:1/device:CPU:0", d.to_string()) d.merge_from(device.DeviceSpec.from_string("/job:boo/device:GPU:0")) self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string()) d.merge_from(device.DeviceSpec.from_string("/job:muu/cpu:2")) self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string()) d.merge_from(device.DeviceSpec.from_string( "/job:muu/device:MyFunnyDevice:2")) self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) if not context.executing_eagerly(): with ops.device(device.merge_device("/device:GPU:0")): var1 = variables.Variable(1.0) self.assertEquals("/device:GPU:0", var1.device) with ops.device(device.merge_device("/job:worker")): var2 = variables.Variable(1.0) self.assertEquals("/job:worker/device:GPU:0", var2.device) with ops.device(device.merge_device("/device:CPU:0")): var3 = variables.Variable(1.0) self.assertEquals("/job:worker/device:CPU:0", var3.device) with ops.device(device.merge_device("/job:ps")): var4 = variables.Variable(1.0) self.assertEquals("/job:ps/device:CPU:0", var4.device)
def testMerge(self, DeviceSpec): # pylint: disable=invalid-name d = DeviceSpec.from_string("/job:muu/task:1/device:MyFunnyDevice:2") self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string()) if not context.executing_eagerly(): with ops.device(device.merge_device("/device:GPU:0")): var1 = variables.Variable(1.0) self.assertEqual("/device:GPU:0", var1.device) with ops.device(device.merge_device("/job:worker")): var2 = variables.Variable(1.0) self.assertEqual("/job:worker/device:GPU:0", var2.device) with ops.device(device.merge_device("/device:CPU:0")): var3 = variables.Variable(1.0) self.assertEqual("/job:worker/device:CPU:0", var3.device) with ops.device(device.merge_device("/job:ps")): var4 = variables.Variable(1.0) self.assertEqual("/job:ps/device:CPU:0", var4.device)
def testWithDevice(self): with ops.Graph().as_default() as g: # No device. a = constant_op.constant(3.0, name="a") with ops.device("/cpu:0"): b = constant_op.constant(4.0, name="b") with ops.device("/job:worker"): c = constant_op.constant(5.0, name="c") gdef = g.as_graph_def() with ops.Graph().as_default(): a2, b2, c2 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual(a.device, a2.device) self.assertEqual(b.device, b2.device) self.assertEqual(c.device, c2.device) with ops.Graph().as_default(): with ops.device(device.merge_device("/task:0")): a3, b3, c3 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/task:0", a3.device) self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized. self.assertEqual(c.device + "/task:0", c3.device) with ops.Graph().as_default(): with ops.device(device.merge_device("/job:ps")): a4, b4, c4 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/job:ps", a4.device) self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized. self.assertEqual(c.device, c4.device) # worker overrides ps. with ops.Graph().as_default(): with ops.device(device.merge_device("/gpu:0")): a5, b5, c5 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/device:GPU:0", a5.device) self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu. self.assertEqual(c.device + "/device:GPU:0", c5.device)
def testWithDevice(self): with tf.Graph().as_default() as g: # No device. a = tf.constant(3.0, name='a') with tf.device('/cpu:0'): b = tf.constant(4.0, name='b') with tf.device('/job:worker'): c = tf.constant(5.0, name='c') gdef = g.as_graph_def() with tf.Graph().as_default(): a2, b2, c2 = tf.import_graph_def(gdef, return_elements=['a', 'b', 'c']) self.assertEqual(a.device, a2.device) self.assertEqual(b.device, b2.device) self.assertEqual(c.device, c2.device) with tf.Graph().as_default(): with tf.device(device.merge_device('/task:0')): a3, b3, c3 = tf.import_graph_def( gdef, return_elements=['a', 'b', 'c']) self.assertEqual('/task:0', a3.device) self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized. self.assertEqual(c.device + '/task:0', c3.device) with tf.Graph().as_default(): with tf.device(device.merge_device('/job:ps')): a4, b4, c4 = tf.import_graph_def( gdef, return_elements=['a', 'b', 'c']) self.assertEqual('/job:ps', a4.device) self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized. self.assertEqual(c.device, c4.device) # worker overrides ps. with tf.Graph().as_default(): with tf.device(device.merge_device('/gpu:0')): a5, b5, c5 = tf.import_graph_def( gdef, return_elements=['a', 'b', 'c']) self.assertEqual('/device:GPU:0', a5.device) self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu. self.assertEqual(c.device + '/device:GPU:0', c5.device)
def testWithDevice(self): if ops._USE_C_API: return # TODO(skyewm): make this work with C API with ops.Graph().as_default() as g: # No device. a = constant_op.constant(3.0, name="a") with ops.device("/cpu:0"): b = constant_op.constant(4.0, name="b") with ops.device("/job:worker"): c = constant_op.constant(5.0, name="c") gdef = g.as_graph_def() with ops.Graph().as_default(): a2, b2, c2 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual(a.device, a2.device) self.assertEqual(b.device, b2.device) self.assertEqual(c.device, c2.device) with ops.Graph().as_default(): with ops.device(device.merge_device("/task:0")): a3, b3, c3 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/task:0", a3.device) self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized. self.assertEqual(c.device + "/task:0", c3.device) with ops.Graph().as_default(): with ops.device(device.merge_device("/job:ps")): a4, b4, c4 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/job:ps", a4.device) self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized. self.assertEqual(c.device, c4.device) # worker overrides ps. with ops.Graph().as_default(): with ops.device(device.merge_device("/device:GPU:0")): a5, b5, c5 = importer.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/device:GPU:0", a5.device) self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu. self.assertEqual(c.device + "/device:GPU:0", c5.device)
def testWithDevice(self): with tf.Graph().as_default() as g: # No device. a = tf.constant(3.0, name='a') with tf.device('/cpu:0'): b = tf.constant(4.0, name='b') with tf.device('/job:worker'): c = tf.constant(5.0, name='c') gdef = g.as_graph_def() with tf.Graph().as_default(): a2, b2, c2 = tf.import_graph_def( gdef, return_elements=['a', 'b', 'c']) self.assertEqual(a.device, a2.device) self.assertEqual(b.device, b2.device) self.assertEqual(c.device, c2.device) with tf.Graph().as_default(): with tf.device(device.merge_device('/task:0')): a3, b3, c3 = tf.import_graph_def( gdef, return_elements=['a', 'b', 'c']) self.assertEqual('/task:0', a3.device) self.assertEqual('/task:0/device:CPU:0', b3.device) # canonicalized. self.assertEqual(c.device + '/task:0', c3.device) with tf.Graph().as_default(): with tf.device(device.merge_device('/job:ps')): a4, b4, c4 = tf.import_graph_def( gdef, return_elements=['a', 'b', 'c']) self.assertEqual('/job:ps', a4.device) self.assertEqual('/job:ps/device:CPU:0', b4.device) # canonicalized. self.assertEqual(c.device, c4.device) # worker overrides ps. with tf.Graph().as_default(): with tf.device(device.merge_device('/gpu:0')): a5, b5, c5 = tf.import_graph_def( gdef, return_elements=['a', 'b', 'c']) self.assertEqual('/device:GPU:0', a5.device) self.assertEqual('/device:CPU:0', b5.device) # cpu overrides gpu. self.assertEqual(c.device + '/device:GPU:0', c5.device)
def testWithDevice(self): with tf.Graph().as_default() as g: # No device. a = tf.constant(3.0, name="a") with tf.device("/cpu:0"): b = tf.constant(4.0, name="b") with tf.device("/job:worker"): c = tf.constant(5.0, name="c") gdef = g.as_graph_def() with tf.Graph().as_default(): a2, b2, c2 = tf.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual(a.device, a2.device) self.assertEqual(b.device, b2.device) self.assertEqual(c.device, c2.device) with tf.Graph().as_default(): with tf.device(device.merge_device("/task:0")): a3, b3, c3 = tf.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/task:0", a3.device) self.assertEqual("/task:0/device:CPU:0", b3.device) # canonicalized. self.assertEqual(c.device + "/task:0", c3.device) with tf.Graph().as_default(): with tf.device(device.merge_device("/job:ps")): a4, b4, c4 = tf.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/job:ps", a4.device) self.assertEqual("/job:ps/device:CPU:0", b4.device) # canonicalized. self.assertEqual(c.device, c4.device) # worker overrides ps. with tf.Graph().as_default(): with tf.device(device.merge_device("/gpu:0")): a5, b5, c5 = tf.import_graph_def( gdef, return_elements=["a", "b", "c"]) self.assertEqual("/device:GPU:0", a5.device) self.assertEqual("/device:CPU:0", b5.device) # cpu overrides gpu. self.assertEqual(c.device + "/device:GPU:0", c5.device)
def __init__(self, *args, **kwargs): super(OpsTestCase, self).__init__(*args, **kwargs) # add the device to the "_device_function_stack" of the default graph dev = device.merge_device(DEVICE_ID) tf.get_default_graph()._device_function_stack.append(dev) # create a tf session self.sess = tf.Session() self.ndigits = 7
def _make_execution_context() -> ExecutionContext: """Generates an ExecutionContext based on current contextual info.""" ctx = context.context() # Don't need to open an init_scope if the _cache_key call is in eager mode # already. executing_eagerly = ctx.executing_eagerly() parent_graph = None xla_context_id = 0 if not executing_eagerly: # We want to force function retracing for each different # XLAControlFlowContext, so add `xla_context_id` to the cache key. xla_context = _enclosing_xla_context() if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing( ): xla_context_id = id(xla_context) with ops.init_scope(): # The graph, or whether we're executing eagerly, should be a part of the # cache key so we don't improperly capture tensors such as variables. executing_eagerly = ctx.executing_eagerly() parent_graph = None if executing_eagerly else ops.get_default_graph( ) # pylint: disable=protected-access default_graph = ops.get_default_graph() # TODO(b/117617952): The current distribution strategy will affect graph # building (e.g. accessing different variables from different devices) and # so requires retracing for each device. strategy_stack = default_graph._distribution_strategy_stack uses_distribution_strategy = (strategy_stack and strategy_stack[-1].strategy.extended. _retrace_functions_for_each_device) if executing_eagerly: colocation_stack = () if uses_distribution_strategy: device_functions = (pydev.merge_device(ctx.device_name), ) else: device_functions = () else: colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) if (uses_distribution_strategy or func_graph_module.device_stack_has_callable( default_graph._device_function_stack)): # Putting the device in the cache key ensures that call-site device # annotations are respected. device_functions = tuple( default_graph._device_functions_outer_to_inner) else: device_functions = () in_cross_replica_context = False try: in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access except (AttributeError, IndexError): pass if save_context.in_save_context(): variable_policy = ( save_context.get_save_options().experimental_variable_policy) else: variable_policy = None return ExecutionContext(parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id)