def testto_string(self): d = device.Device() d.job = "foo" self.assertEquals("/job:foo", d.to_string()) d.task = 3 self.assertEquals("/job:foo/task:3", d.to_string()) d.device_type = "CPU" d.device_index = 0 self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string()) d.task = None d.replica = 12 self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string()) d.device_type = "GPU" d.device_index = 2 self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string()) d.device_type = "CPU" d.device_index = 1 self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string()) d.device_type = None d.device_index = None d.cpu = None self.assertEquals("/job:foo/replica:12", d.to_string()) # Test wildcard d = device.Device(job="foo", replica=12, task=3, device_type="GPU") self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
def testConstructor(self): d = device.Device(job="j", replica=0, task=1, device_type="CPU", device_index=2) self.assertEqual("j", d.job) self.assertEqual(0, d.replica) self.assertEqual(1, d.task) self.assertEqual("CPU", d.device_type) self.assertEqual(2, d.device_index) self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string()) d = device.Device(device_type="GPU", device_index=0) self.assertEquals("/device:GPU:0", d.to_string())
def device_function(self, op): """Chose a device for `op`. Args: op: an `Operation`. Returns: The device to use for the `Operation`. """ if not self._merge_devices and op.device: return op.device current_device = pydev.from_string(op.device or "") spec = pydev.Device() if self._ps_tasks and self._ps_device: node_def = op if isinstance(op, graph_pb2.NodeDef) else op.node_def if node_def.op in self._ps_ops: device_string = "%s/task:%d" % (self._ps_device, self._next_task()) if self._merge_devices: spec = pydev.from_string(device_string) spec.merge_from(current_device) return spec.to_string() else: return device_string if self._worker_device: if not self._merge_devices: return self._worker_device spec = pydev.from_string(self._worker_device) if not self._merge_devices: return "" spec.merge_from(current_device) return spec.to_string()
def testDeviceObject(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], []) op._set_device("/job:goo/device:GPU:0") self.assertProtoEquals( "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], []) op._set_device( pydev.Device(job="muu", device_type="CPU", device_index=0)) self.assertProtoEquals( "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
def testParse(self): d = device.Device() d.parse_from_string("/job:foo/replica:0") self.assertEquals("/job:foo/replica:0", d.to_string()) d.parse_from_string("/replica:1/task:0/cpu:0") self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string()) d.parse_from_string("/replica:1/task:0/device:CPU:0") self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string()) d.parse_from_string("/job:muu/gpu:2") self.assertEquals("/job:muu/device:GPU:2", d.to_string()) with self.assertRaises(Exception) as e: d.parse_from_string("/job:muu/gpu:2/cpu:0") self.assertTrue("Cannot specify multiple device" in str(e.exception))
def testDeviceFull(self): g = ops.Graph() with g.device( pydev.Device(job="worker", replica=2, task=0, device_type="CPU", device_index=3)): g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals( """ node { name: "an_op" op: "an_op" device: "/job:worker/replica:2/task:0/device:CPU:3" } """, gd)
def testMerge(self): d = device.from_string("/job:foo/replica:0") self.assertEquals("/job:foo/replica:0", d.to_string()) d.merge_from(device.from_string("/task:1/gpu:2")) self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) d = device.Device() d.merge_from(device.from_string("/task:1/cpu:0")) self.assertEquals("/task:1/device:CPU:0", d.to_string()) d.merge_from(device.from_string("/job:boo/gpu:0")) self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string()) d.merge_from(device.from_string("/job:muu/cpu:2")) self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string()) d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2")) self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())
def testEmpty(self): d = device.Device() self.assertEquals("", d.ToString()) d.parse_from_string("") self.assertEquals("", d.ToString())
def testArgs(self): nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", nodedef) nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j")) self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)