Ejemplo n.º 1
0
    def testto_string_legacy(self):
        """DeviceSpecV1 allows direct mutation."""
        d = device_spec.DeviceSpecV1()
        d.job = "foo"
        self.assertEqual("/job:foo", d.to_string())
        d.task = 3
        self.assertEqual("/job:foo/task:3", d.to_string())
        d.device_type = "CPU"
        d.device_index = 0
        self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())
        d.task = None
        d.replica = 12
        self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())
        d.device_type = "GPU"
        d.device_index = 2
        self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())
        d.device_type = "CPU"
        d.device_index = 1
        self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string())
        d.device_type = None
        d.device_index = None
        self.assertEqual("/job:foo/replica:12", d.to_string())

        # Test wildcard
        d = device_spec.DeviceSpecV1(job="foo",
                                     replica=12,
                                     task=3,
                                     device_type="GPU")
        self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*",
                         d.to_string())
Ejemplo n.º 2
0
    def test_replace(self, device_spec_type):
        d = device_spec_type()
        d = d.replace(job="foo")
        self.assertEqual("/job:foo", d.to_string())

        d = d.replace(task=3)
        self.assertEqual("/job:foo/task:3", d.to_string())

        d = d.replace(device_type="CPU", device_index=0)
        self.assertEqual("/job:foo/task:3/device:CPU:0", d.to_string())

        d = d.replace(task=None, replica=12)
        self.assertEqual("/job:foo/replica:12/device:CPU:0", d.to_string())

        d = d.replace(device_type="GPU", device_index=2)
        self.assertEqual("/job:foo/replica:12/device:GPU:2", d.to_string())

        d = d.replace(device_type="CPU", device_index=1)
        self.assertEqual("/job:foo/replica:12/device:CPU:1", d.to_string())

        d = d.replace(device_type=None, device_index=None)
        self.assertEqual("/job:foo/replica:12", d.to_string())

        # Test wildcard
        d = device_spec.DeviceSpecV1(job="foo",
                                     replica=12,
                                     task=3,
                                     device_type="GPU")
        self.assertEqual("/job:foo/replica:12/task:3/device:GPU:*",
                         d.to_string())
Ejemplo n.º 3
0
 def __call__(self, op):
     device_spec = tf_device.DeviceSpecV1(replica=self._replica,
                                          device_type=self._device_type,
                                          device_index=self._device_index)
     if self._num_tasks > 0:
         task_id = self._next_task_id
         self._next_task_id = (self._next_task_id + 1) % self._num_tasks
         device_spec.job = self._job_name
         device_spec.task = task_id
     return device_spec.to_string()
Ejemplo n.º 4
0
    def test_parse_legacy(self):
        d = device_spec.DeviceSpecV1()
        d.parse_from_string("/job:foo/replica:0")
        self.assertEqual("/job:foo/replica:0", d.to_string())

        d.parse_from_string("/replica:1/task:0/cpu:0")
        self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())

        d.parse_from_string("/replica:1/task:0/device:CPU:0")
        self.assertEqual("/replica:1/task:0/device:CPU:0", d.to_string())

        d.parse_from_string("/job:muu/device:GPU:2")
        self.assertEqual("/job:muu/device:GPU:2", d.to_string())

        with self.assertRaisesRegexp(ValueError, "Cannot specify multiple"):
            d.parse_from_string("/job:muu/device:GPU:2/cpu:0")
Ejemplo n.º 5
0
  def test_merge_legacy(self):
    d = device_spec.DeviceSpecV1.from_string("/job:foo/replica:0")
    self.assertEqual("/job:foo/replica:0", d.to_string())

    d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/device:GPU:2"))
    self.assertEqual("/job:foo/replica:0/task:1/device:GPU:2", d.to_string())

    d = device_spec.DeviceSpecV1()
    d.merge_from(device_spec.DeviceSpecV1.from_string("/task:1/cpu:0"))
    self.assertEqual("/task:1/device:CPU:0", d.to_string())

    d.merge_from(device_spec.DeviceSpecV1.from_string("/job:boo/device:GPU:0"))
    self.assertEqual("/job:boo/task:1/device:GPU:0", d.to_string())

    d.merge_from(device_spec.DeviceSpecV1.from_string("/job:muu/cpu:2"))
    self.assertEqual("/job:muu/task:1/device:CPU:2", d.to_string())
    d.merge_from(device_spec.DeviceSpecV1.from_string(
        "/job:muu/device:MyFunnyDevice:2"))
    self.assertEqual("/job:muu/task:1/device:MyFunnyDevice:2", d.to_string())