def testto_string_legacy(self): """DeviceSpecV1 allows direct mutation.""" d = device_spec.DeviceSpecV1() d.job = "foo" self.assertEqual("/job:foo", d.to_string()) d.task = 3 self.assertEqual("/job:foo/task:3", d.to_string()) d.device_type = "CPU" d.device_index = 0 self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string()) d.task = None d.replica = 12 self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string()) d.device_type = "GPU" d.device_index = 2 self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string()) d.device_type = "CPU" d.device_index = 1 self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string()) d.device_type = None d.device_index = None self.assertEqual("/job:foo/replica:12", d.to_string()) # Test wildcard d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3, device_type="GPU") self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
def test_replace(self, device_spec_type): d = device_spec_type() d = d.replace(job="foo") self.assertEqual("/job:foo", d.to_string()) d = d.replace(task=3) self.assertEqual("/job:foo/task:3", d.to_string()) d = d.replace(device_type="CPU", device_index=0) self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string()) d = d.replace(task=None, replica=12) self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string()) d = d.replace(device_type="GPU", device_index=2) self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string()) d = d.replace(device_type="CPU", device_index=1) self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string()) d = d.replace(device_type=None, device_index=None) self.assertEqual("/job:foo/replica:12", d.to_string()) # Test wildcard d = device_spec.DeviceSpecV1(job="foo", replica=12, task=3, device_type="GPU") self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*", d.to_string())
def __call__(self, op): device_spec = tf_device.DeviceSpecV1(replica=self._replica, device_type=self._device_type, device_index=self._device_index) if self._num_tasks > 0: task_id = self._next_task_id self._next_task_id = (self._next_task_id + 1) % self._num_tasks device_spec.job = self._job_name device_spec.task = task_id return device_spec.to_string()
def test_parse_legacy(self): d = device_spec.DeviceSpecV1() d.parse_from_string("/job:foo/replica:0") self.assertEqual("/job:foo/replica:0", d.to_string()) d.parse_from_string("/replica:1/task:0/cpu:0") self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string()) d.parse_from_string("/replica:1/task:0/device:CPU:0") self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string()) d.parse_from_string("/job:muu/device:GPU:2") self.assertEqual("/job:muu/device:GPU:2", d.to_string()) with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"): d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
def test_merge_legacy(self): d = device_spec.DeviceSpecV1.from_string("/job:foo/replica:0") self.assertEqual("/job:foo/replica:0", d.to_string()) d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/device:GPU:2")) self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string()) d = device_spec.DeviceSpecV1() d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/cpu:0")) self.assertEqual("/task:1/device:CPU:0", d.to_string()) d.merge_from(device_spec.DeviceSpecV1.from_string("/job:boo/device:GPU:0")) self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string()) d.merge_from(device_spec.DeviceSpecV1.from_string("/job:muu/cpu:2")) self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string()) d.merge_from(device_spec.DeviceSpecV1.from_string( "/job:muu/device:MyFunnyDevice:2")) self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())