Beispiel #1
0
    def testto_string(self):
        d = device.Device()
        d.job = "foo"
        self.assertEquals("/job:foo", d.to_string())
        d.task = 3
        self.assertEquals("/job:foo/task:3", d.to_string())
        d.device_type = "CPU"
        d.device_index = 0
        self.assertEquals("/job:foo/task:3/device:CPU:0", d.to_string())
        d.task = None
        d.replica = 12
        self.assertEquals("/job:foo/replica:12/device:CPU:0", d.to_string())
        d.device_type = "GPU"
        d.device_index = 2
        self.assertEquals("/job:foo/replica:12/device:GPU:2", d.to_string())
        d.device_type = "CPU"
        d.device_index = 1
        self.assertEquals("/job:foo/replica:12/device:CPU:1", d.to_string())
        d.device_type = None
        d.device_index = None
        d.cpu = None
        self.assertEquals("/job:foo/replica:12", d.to_string())

        # Test wildcard
        d = device.Device(job="foo", replica=12, task=3, device_type="GPU")
        self.assertEquals("/job:foo/replica:12/task:3/device:GPU:*",
                          d.to_string())
    def testConstructor(self):
        d = device.Device(job="j",
                          replica=0,
                          task=1,
                          device_type="CPU",
                          device_index=2)
        self.assertEqual("j", d.job)
        self.assertEqual(0, d.replica)
        self.assertEqual(1, d.task)
        self.assertEqual("CPU", d.device_type)
        self.assertEqual(2, d.device_index)
        self.assertEqual("/job:j/replica:0/task:1/device:CPU:2", d.to_string())

        d = device.Device(device_type="GPU", device_index=0)
        self.assertEquals("/device:GPU:0", d.to_string())
Beispiel #3
0
    def device_function(self, op):
        """Chose a device for `op`.

    Args:
      op: an `Operation`.

    Returns:
      The device to use for the `Operation`.
    """
        if not self._merge_devices and op.device:
            return op.device
        current_device = pydev.from_string(op.device or "")
        spec = pydev.Device()
        if self._ps_tasks and self._ps_device:
            node_def = op if isinstance(op, graph_pb2.NodeDef) else op.node_def
            if node_def.op in self._ps_ops:
                device_string = "%s/task:%d" % (self._ps_device,
                                                self._next_task())
                if self._merge_devices:
                    spec = pydev.from_string(device_string)
                    spec.merge_from(current_device)
                    return spec.to_string()
                else:
                    return device_string
        if self._worker_device:
            if not self._merge_devices:
                return self._worker_device
            spec = pydev.from_string(self._worker_device)

        if not self._merge_devices:
            return ""

        spec.merge_from(current_device)
        return spec.to_string()
Beispiel #4
0
 def testDeviceObject(self):
     op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [])
     op._set_device("/job:goo/device:GPU:0")
     self.assertProtoEquals(
         "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ",
         op.node_def)
     op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], [])
     op._set_device(
         pydev.Device(job="muu", device_type="CPU", device_index=0))
     self.assertProtoEquals(
         "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
 def testParse(self):
     d = device.Device()
     d.parse_from_string("/job:foo/replica:0")
     self.assertEquals("/job:foo/replica:0", d.to_string())
     d.parse_from_string("/replica:1/task:0/cpu:0")
     self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
     d.parse_from_string("/replica:1/task:0/device:CPU:0")
     self.assertEquals("/replica:1/task:0/device:CPU:0", d.to_string())
     d.parse_from_string("/job:muu/gpu:2")
     self.assertEquals("/job:muu/device:GPU:2", d.to_string())
     with self.assertRaises(Exception) as e:
         d.parse_from_string("/job:muu/gpu:2/cpu:0")
     self.assertTrue("Cannot specify multiple device" in str(e.exception))
Beispiel #6
0
 def testDeviceFull(self):
     g = ops.Graph()
     with g.device(
             pydev.Device(job="worker",
                          replica=2,
                          task=0,
                          device_type="CPU",
                          device_index=3)):
         g.create_op("an_op", [], [dtypes.float32])
     gd = g.as_graph_def()
     self.assertProtoEquals(
         """
   node { name: "an_op" op: "an_op"
          device: "/job:worker/replica:2/task:0/device:CPU:3" }
 """, gd)
Beispiel #7
0
    def testMerge(self):
        d = device.from_string("/job:foo/replica:0")
        self.assertEquals("/job:foo/replica:0", d.to_string())
        d.merge_from(device.from_string("/task:1/gpu:2"))
        self.assertEquals("/job:foo/replica:0/task:1/device:GPU:2",
                          d.to_string())

        d = device.Device()
        d.merge_from(device.from_string("/task:1/cpu:0"))
        self.assertEquals("/task:1/device:CPU:0", d.to_string())
        d.merge_from(device.from_string("/job:boo/gpu:0"))
        self.assertEquals("/job:boo/task:1/device:GPU:0", d.to_string())
        d.merge_from(device.from_string("/job:muu/cpu:2"))
        self.assertEquals("/job:muu/task:1/device:CPU:2", d.to_string())
        d.merge_from(device.from_string("/job:muu/device:MyFunnyDevice:2"))
        self.assertEquals("/job:muu/task:1/device:MyFunnyDevice:2",
                          d.to_string())
Beispiel #8
0
 def testEmpty(self):
     d = device.Device()
     self.assertEquals("", d.ToString())
     d.parse_from_string("")
     self.assertEquals("", d.ToString())
Beispiel #9
0
 def testArgs(self):
     nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
     self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
                            nodedef)
     nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j"))
     self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)